• ↑↓ pour naviguer
  • pour ouvrir
  • pour sélectionner
  • ⌘ ⌥ ↵ pour ouvrir dans un panneau
  • esc pour rejeter
⌘ '
raccourcis clavier

See also: LLMs, embedding, visualisation from Brendan Bycroft

A multi-layer perceptron (MLP) architecture built on top of a multi-head attention mechanism (Vaswani et al., 2023) to signal high entropy tokens to be amplified and less important tokens to be diminished.

ELI5: Mom often creates a food list consists of nn of items to buy. Your job is to guess what the last item on this list would be.

Most implementations are autoregressive. Most major SOTA are decoder-only, as encoder-decoder models has lack behind due to their expensive encoding phase.

state-space models which address transformers’ Efficient Transformers: A SurveyarXiv (Tay et al., 2022) in attention layers within information-dense data

internals

memory limitations.

see also: AI and Memory WallarXiv (Gholami et al., 2024)

Arithmetic intensity can be determined with the following:

Arithmetic Intensity=# FLOPs# MOPs\text{Arithmetic Intensity} = \frac{\text{\# FLOPs}}{\text{\# MOPs}}

inference.

Either compute-bound (batch inference, saturated usage) or memory-bound (latency)

Prefill/Decode

Lien vers l'original

Speculative decoding

Idea: “draft-and-verify” using smaller models to generate a head tokens (quick explanation from karpathy)

Intuitively:

  • we generate a small set of lookahead tokens, albeit 2-5 tokens with smaller speculators
  • uses the larger model to “verify” the input sequences + draft tokens (then replace tokens that aren’t valid from rejection sampler)

In a sense, we are verify these in parallel instead of autoregressive decoding.

A few techniques such as ngrams, EAGLE are supported in vLLM

EAGLE

Extrapolation Algorithm for Greater Language-model Efficiency

Motivation:

  • speculative sampling relies on the draft models having similar distributions as the target models.
    • use smaller models. i.e: Llama 3.2 3B as draft for Llama 3.3 70B.
    • high overhead for stepping through the whole models would outweighs the benefits

Difference between EAGLE-1 and EAGLE-3

  • EAGLE-1’s limitation at its feature prediction constraints, via LM head architecture,
  • EAGLE-3 addresses this by use direct token prediction and rely on multi-layer feature fusion called “training-time test”, similar to MLP Speculator

distribution skew

EAGLE does not involve any fine-tuning of the target model, therefore preservation of outputs distributions by EAGLE is theoretically guaranteed for both greedy and non-greedy sampling. This is not the case with Lookahead and Medusa.

EAGLE-1

Observations:

autoregressive on feature-level 1 is simpler than token-level, given that there are more regularity.

uncertainty in sampling process hinders the performance of predicting the next feature.

feature-level are high-dimensional and continuous, meaning sampling “am” or “always” will results in different feature sequences.

EAGLE address this by inputs the token sequence from one time step ahead including the sampling outcomes into the draft models.

  • predicting falwaysf_{\text{always}} based on fIf_{\text{I}} and talwayst_\text{always}
  • predicting famf_{\text{am}} based on fIf_{\text{I}} and tamt_\text{am}

notation.

  • “Features” refers to second-to-top-layer feature of LLM, or the hidden states before LM head
  • Token by tt, embedding by ee, features by ff, distributions by pp
  • Sequences are referred as Ti:jT_{i:j} for (ti,ti+1,,tj)(t_i, t_{i+1},\ldots, t_j) 2

architecture

  • [feature_seq, token_seq] # [bs, seq_len, hidden_dim], [bs, seq_len]
  • token_seq -> token_emb # [bs, seq_len] -> [bs, seq_len, hidden_dim]
  • fused_seq = feature_seq * token_emb # [bs, seq_len, 2xhidden_dim] 3
  • autoregressive_head:
    • FC layer reduce # [bs, seq_len, hidden_dim]
    • decoder layer features
  • using tree attention to generate a draft tree of depth mm and more than mm tokens for mm forward pass. 4

training

  • Smooth L1 loss:

    Lreg=Smooth L1(fi+1draft(T2:i+1,F1:i)) L_\text{reg} = \text{Smooth L1}(f_{i+1} \text{draft}(T_{2:i+1}, F_{1:i}))
  • classification loss to optimize given objectives:

    pi+2=Softmax(LM_Head(fi+1))p^i+2=Softmax(LM_Head(f^i+1))Lcls=CrossEntropy(pi+2,p^i+2)\begin{aligned} p_{i+2} &= \text{Softmax}(\text{LM\_Head}(f_{i+1})) \\ \hat{p}_{i+2} &= \text{Softmax}(\text{LM\_Head}(\hat{f}_{i+1})) \\ L_{\text{cls}} &= \text{CrossEntropy}(p_{i+2}, \hat{p}_{i+2}) \end{aligned}
  • Autoregressive head with loss L=Lreg+wclsLclsL = L_{\text{reg}} + w_{\text{cls}} L_{\text{cls}}

    • set wcls=0.1w_{\text{cls}}=0.1 given that classification loss is in order magnitude bigger than regression loss
  • Dataset: ShareGPT, 68k dialogue

  • Hyperparameter:

    • LR: 3e53e^{-5}
    • AdamW with beta (β1,β2)=(0.9,0.95)(\beta_1, \beta_2)=(0.9,0.95)
    • gradient clipping: 0.50.5

EAGLE-2

tl/dr: Improvement on EAGLE-1 via context-aware dynamic draft tree into this drafting modeling.

EAGLE-3

HASS

Learning Harmonized Representations for Speculative SamplingarXiv (Zhang et al., 2025)

HArmonizedSS/HASS

Falcon

Falcon: Faster and Parallel Inference of Large Language Models through Enhanced Semi-Autoregressive Drafting and Custom-Designed Decoding TreearXiv (Gao et al., 2025)

MLP Speculator

via combined tokens/embedding speculators

Accelerating Production LLMs with Combined Token/Embedding SpeculatorsarXiv (Wertheimer et al., 2024)

DistillSpec

DistillSpec: Improving Speculative Decoding via Knowledge DistillationarXiv (Zhou et al., 2024)

Medusa

https://sites.google.com/view/medusa-llm

FasterDecoding/Medusa

ngrams

apoorvumang/prompt-lookup-decoding

also known as Prompt Lookup Decoding (PLD), HF’s assisted generations

idea: to use string matching from prompt to generate candidate tokens, instead of using a draft-based models.

def find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=10):
  input_length = input_ids.size(1)
 
  for ngram_size in range(max_ngram_size, 0, -1):
    # Extract the last n tokens as our search ngram
    ngram = input_ids[0, -ngram_size:].tolist()
 
    # Create sliding windows of size ngram_size
    windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
 
    # Convert ngram to a tensor for comparison
    ngram_tensor = torch.tensor(ngram, device=input_ids.device).unsqueeze(0)
 
    # Find where the windows match the ngram
    matches = (windows == ngram_tensor).all(dim=2)
 
    # Get the indices of matches
    match_indices = matches.nonzero(as_tuple=True)[1]
 
    # Iterate through match indices to find a valid continuation
    for idx in match_indices:
      start_idx = idx + ngram_size
      end_idx = start_idx + num_pred_tokens
      # Ensure we don't go beyond the length of input_ids and avoid self-match
      if end_idx <= input_length and start_idx < input_length - ngram_size:
        return input_ids[0, start_idx:end_idx]
 
  # If no match is found, return an empty tensor
  return torch.tensor([], dtype=torch.long, device=input_ids.device)

lookahead decoding

see also: LMSYS blog,

SPiRE

MagicDec


optimization

Optimizing Speculative Decoding for Serving Large Language Models Using GoodputarXiv (Liu et al., 2024) optimizes via goodput.

Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language ModelsarXiv (Mamou et al., 2024) focuses on dynamic speculative length.

speculative sampling

aliases: SpS, speculative decoding.

Based on:

tl/dr

  • Latency is improved at the cost of increasing ops, meaning 2x-3x with T5, with a scale of γ=5\gamma=5 (also referred in practice as num_speculative_tokens)
  • This is not useful when computation resources are limited.
  • Walltime improvement by 1αγ+1(1α)(γc+1)\frac{1-\alpha^{\gamma +1}}{(1-\alpha)(\gamma c + 1)} where α\alpha is the approximation E(β)E(\beta) (or natural measure of the acceptance rate β\beta)
  • Note that this is different from rejection sampling 6, given that non-iterative rejection sampling would yield lower acceptance rate
  • Lenience factor ll to perform speed versus quality trade-off 7 when draft-models distributions is different from target-models’. Note that we can’t use temperature=0 (i.e argmax sampling).
    • Instead we allow some lenience before standardizing the distribution (accept token xx sampled from MqM_q in case of p(x)lmax˙pp(x) \le l \dot \max{p})
    • In this case, then similar empirical increases to α\alpha to those of temperature=1

goal and algorithm

Let MpM_p be the target model for task XX, and p(xtx<t)p(x_t \mid x_{<t}) the distribution we get from model for a prefix x<tx_{<t}

Let MqM_q be the draft/approximation models at the same task, and q(xtx<t)q(x_t \mid x_{<t}) the distribution we get from model for a prefix x<tx_{<t}

Objective: to use MqM_q to generate γZ+\gamma \in \mathbb{Z}^{+} completions, and use MpM_p to verify γ\gamma tokens in parallel

  • Keep when q(x)p(x)q(x) \le p(x)
  • Reject when q(x)p(x)q(x) \ge p(x) for sample with P=1p(x)q(x)P=1-\frac{p(x)}{q(x)} and sample xx again from p(x)=norm(max(0,p(x)q(x)))p^{'}(x) = \textit{norm}(\textit{max}(0, p(x) - q(x))) 8
"\\begin{algorithm}\n\\caption{SpeculativeDecodingStep}\n\\begin{algorithmic}\n\n\\INPUT{$M_p,\\;M_q,\\;\\textit{prefix}$}\n\n\\State $\\triangleright$ Sample $\\gamma$ guesses $x_1,\\dots,x_\\gamma$ from $M_q$\n\\FOR{$i \\gets 1$ \\TO $\\gamma$}\n \\STATE $q_i(x) \\gets M_q\\!\\bigl(\\textit{prefix} + [x_1,\\dots,x_{i-1}]\\bigr)$\n \\STATE $x_i \\sim q_i(x)$\n\\ENDFOR\n\n\\State $\\triangleright$ Run $M_p$ in parallel\n\\STATE $p_1(x),\\dots,p_{\\gamma+1}(x) \\gets\n M_p(\\textit{prefix}),\\dots,\n M_p\\!\\bigl(\\textit{prefix} + [x_1,\\dots,x_\\gamma]\\bigr)$\n\n\\State $\\triangleright$ Determine the number of accepted guesses $n$\n\\STATE $r_1,\\dots,r_\\gamma \\sim U(0,1)$\n\\STATE $n \\gets \\min\\!\\bigl(\\{\\,i-1 \\mid\n 1\\le i\\le\\gamma,\\;\n r_i > \\frac{p_i(x)}{q_i(x)}\\,\\}\\cup\\{\\gamma\\}\\bigr)$\n\n\\State $\\triangleright$ Adjust $M_p$’s distribution if needed\n\\STATE $p'(x) \\gets p_{n+1}(x)$\n\\IF{$n < \\gamma$}\n \\STATE $p'(x) \\gets \\mathrm{norm}\\!\\bigl(\\max\\!\\bigl(0,\\;\n p_{n+1}(x)-q_{n+1}(x)\\bigr)\\bigr)$\n\\ENDIF\n\n\\State $\\triangleright$ Emit one token from $M_p$ and $n$ from $M_q$\n\\STATE $t \\sim p'(x)$\n\\RETURN $\\textit{prefix} + [x_1,\\dots,x_n,t]$\n\n\\end{algorithmic}\n\\end{algorithm}"

Algorithm 4 SpeculativeDecodingStep

Input: Mp,  Mq,  prefixM_p,\;M_q,\;\textit{prefix}

\triangleright Sample γ\gamma guesses x1,,xγx_1,\dots,x_\gamma from MqM_q

for i1i \gets 1 to γ\gamma do

qi(x)Mq ⁣(prefix+[x1,,xi1])q_i(x) \gets M_q\!\bigl(\textit{prefix} + [x_1,\dots,x_{i-1}]\bigr)

xiqi(x)x_i \sim q_i(x)

end for

\triangleright Run MpM_p in parallel

p1(x),,pγ+1(x)Mp(prefix),,Mp ⁣(prefix+[x1,,xγ])p_1(x),\dots,p_{\gamma+1}(x) \gets M_p(\textit{prefix}),\dots, M_p\!\bigl(\textit{prefix} + [x_1,\dots,x_\gamma]\bigr)

\triangleright Determine the number of accepted guesses nn

r1,,rγU(0,1)r_1,\dots,r_\gamma \sim U(0,1)

nmin ⁣({i11iγ,  ri>pi(x)qi(x)}{γ})n \gets \min\!\bigl(\{\,i-1 \mid 1\le i\le\gamma,\; r_i > \frac{p_i(x)}{q_i(x)}\,\}\cup\{\gamma\}\bigr)

\triangleright Adjust MpM_p’s distribution if needed

p(x)pn+1(x)p'(x) \gets p_{n+1}(x)

if n<γn < \gamma then

p(x)norm ⁣(max ⁣(0,  pn+1(x)qn+1(x)))p'(x) \gets \mathrm{norm}\!\bigl(\max\!\bigl(0,\; p_{n+1}(x)-q_{n+1}(x)\bigr)\bigr)

end if

\triangleright Emit one token from MpM_p and nn from MqM_q

tp(x)t \sim p'(x)

return prefix+[x1,,xn,t]\textit{prefix} + [x_1,\dots,x_n,t]

acceptance probability

alias: acceptance rate

definition 3.1

acceptance rate βx<t\beta_{x<t} given a prefix x<tx_{<t} is the probability of accepting xtq(xtx<t)x_t \sim q(x_t\mid x_{<t}) via speculative sampling.

E(β)E(\beta) is the natural measure of how well MqM_q approximates MpM_p

α=E(β)\alpha = E(\beta) assuming β\beta are i.i.d, (1) is a capped geometrics variables, with success probability of 1α1 - \alpha and cap γ+1\gamma + 1:

E(# generated tokens)=1αγ+11α E(\text{\# generated tokens}) = \frac{1-\alpha^{\gamma +1}}{1-\alpha}

calculating α\alpha

definition 3.2

Let natural divergence DLKD_{LK} be:

DLK(p,q)=xp(x)M(x)=xq(x)M(x)D_{LK}(p,q) = \sum_{x} |p(x) - M(x)| = \sum_{x} \mid q(x) - M(x) \mid

where M(x)=p(x)+q(x)2M(x) = \frac{p(x) + q(x)}{2}

Lemma 3.3

DLK(p,q)=1xminp(x),q(x)D_{LK}(p,q) = 1 - \sum_{x} \min{p(x), q(x)} 9

Corollary 3.4

DLK(p,q)D_{LK}(p,q) is a symmetric divergence in [0,1][0,1], where

DLK(p,q)=0p=qD_{LK}(p,q)=0 \Longleftrightarrow p=q

DLK(p,q)=1p and q have disjoint supportD_{LK}(p,q)=1 \Longleftrightarrow \text{p and q have disjoint support}

Theorem 3.5

β=1DLK(p,q)\beta = 1 - D_{LK}(p,q) 10

Corollary 3.6

α=1E(DLK(p,q))=E(min(p,q))\alpha = 1 - E(D_{LK}(p,q)) = E(min(p,q))

walltime improvement

With i.i.d assumption speculative sampling reduces # of calls\text{\# of calls} to target models by 1αγ+11α\frac{1-\alpha^{\gamma +1}}{1-\alpha }, assuming running on compute resources that support increased concurrency (GPUs.)

For walltime 11 analysis, assuming we can run γ+1\gamma +1 concurrent evaluation of MpM_p:

cost-efficient

let cc be the ratio between time for single run of MqM_q and the time for single run MpM_p

cc is highly dependent on hardware measure. From the paper, c0c \approx 0 to avoid expectancy biases

Theorem 3.8

expected improvement factor in total walltime by 1αγ+1(1α)(γc+1)\frac{1-\alpha^{\gamma +1}}{(1-\alpha)(\gamma c + 1)} 12

Note that we assume there are long enough generations sequence here.

Corollary 3.9

α>c  γ we will get improvement by a factor of 1+α1+c\forall \alpha > c \space \exists \space \gamma \mid \text{ we will get improvement by a factor of } \frac{1+\alpha }{1+c}

If we get an improvement for γ\gamma, we’d also get improvement for any 0<γ<γ0 < \gamma^{*} < \gamma, hence we can use (3.8) for γ=1\gamma = 1, which yields 1+α1+c\frac{1+\alpha}{1+c}

arithmetic operations

arithmetics operations per token

let c^\hat{c} be the ratio of arithmetics operations per tokens of MqM_q to that of MpM_p

Note that the number of operations will then grow by γ+1\gamma +1, given that we will produce at most γ+1\gamma +1 tokens per run.

Theorem 3.11

The expected factor of increase in number of operations is (1α)(γc^+γ+1)1αγ+1\frac{(1-\alpha)(\gamma \hat{c} + \gamma + 1)}{1-\alpha^{\gamma +1}} 13

Lien vers l'original

KV

The core “retrieval” bags that contains all previous stored key-value pair or newly added items.

Prefill disaggregation is pretty interesting in a sense that we can separate prefill stage to a separate nodes (Qin et al., 2024)

KV-centric optimization
figure1: KV-centric optimization

Question

Why do we need to use KV Cache?

next-token prediction.

Sampling: we essentially look forward K-tokens, and then we sample from the distribution of the next token.

multi-token prediction.

(Gloeckle et al., 2024)

MTP implementation in DeepSeek, where they keep causal chain for prediction of each token at each depth
figure2: MTP implementation in DeepSeek, where they keep causal chain for prediction of each token at each depth

tl/dr: predict nn-tokens at once, via shared trunk and n dedicated attention heads 14

Note that during inference, we only employ one attention head

Byte-Latent Transformer

idea: learn from raw-bytes and skip tokenizer/detokenizer protocol.

Feynman-Kac

Let V\mathcal{V} be the vocab of given transformers model, and S=V\mathcal{S} = \mathcal{V}^{*} the set of multi-token strings. Assume V\mathcal{V} contains token EOS and write FS\mathcal{F} \subseteq \mathcal{S} for the set of EOS-terminated strings.

Feynman-Kac Transformer model

is a tuple (s0,{Mt}t1,{Gt}t1)(s_{0}, \{M_t\}_{t\ge 1}, \{G_t\}_{t\ge 1}) where:

  • s0Ss_{0} \in \mathcal{S} is an initial state, which will take as empty string ϵ\epsilon
  • Mt(stst1,fθ)M_t(s_t \mid s_{t-1}, f_\theta) is a Markov kernel from st1Fcs_{t-1} \in \mathcal{F}^c to stSs_t \in \mathcal{S}, parameterised by a transformer network fθ:FcRVf_\theta: \mathcal{F}^c \to \mathbb{R}^{\mid \mathcal{V} \mid} mapping non-EOS-terminated strings to vectors of logits
  • Gt(st1,st,fθ)G_t(s_{t-1}, s_t, f_\theta) is a potential function, mapping a pair (st1,st)Fc×S(s_{t-1}, s_t) \in \mathcal{F}^c \times \mathcal{S} to a real-valued non-negative score.

Goal: generate from distribution P\mathbb{P} that reweights Markov chain M\mathbb{M} by potential functions GtG_t. We define step-t filtering posteriors:

Pt(st)=EM[i=1tTGi(Si1,Si,fθ)[St=st]]EM[i=1tTGi(Si1,Si,fθ)]P_t(s_t) = \frac{\mathbb{E}_\mathbb{M} \left[ \prod_{i=1}^{t \wedge T} G_i(S_{i-1}, S_i, f_\theta) \cdot [S_t = s_t] \right]}{\mathbb{E}_\mathbb{M} \left[ \prod_{i=1}^{t \wedge T} G_i(S_{i-1}, S_i, f_\theta) \right]}

Given that TT is mostly finite we can then define overall posterior (Lew et al., 2023, p. see 2.2 for examples)

P(s)=limtPt(s)\mathbb{P}(s) = \lim_{t \to \infty} \mathbb{P}_t(s)
"\\begin{algorithm}\n\\caption{Sequential Monte Carlo Transformer Steering}\n\\begin{algorithmic}\n\\State \\textbf{Input:} $N$ (\\# particles), $K$ (factor), Feynman-Kac Transformer model $\\{s_0, \\{M_t\\}_{t \\geq 1}, \\{G_t\\}_{t \\geq 1}\\}$\n\\State \\textbf{Output:} Weighted particle approximation $\\{(x_i, w_i)\\}_{i=1,\\ldots,N}$ of the posterior $\\mathbb{P}$ \\\\\n\\State \\textbf{Output:} Unbiased estimate $\\hat{Z}$ of the partition function $Z = \\mathbb{E}_\\mathbb{M}[\\prod_{t=1}^T G_t(s_t, s_{t-1}, f_\\theta)]$ \\\\\n\\State Initialize $f_\\theta \\gets \\texttt{CachedTransformer}()$\n\\State Initialize $(x_i, w_i) \\gets (s_0, 1)$ for $i = 1, \\ldots, N$\n\\State Initialize $t \\gets 1$\n\\While{$x_i \\not\\in \\mathcal{F}$ for some $i \\in \\{1, \\ldots, N\\}$}\n \\State $K_i \\gets K (1 - \\mathbb{1}_{\\mathcal{F}}(x_i)) + \\mathbb{1}_{\\mathcal{F}}(x_i)$ for $i = 1, \\ldots, N$\n \\State $N' \\gets \\sum_{i=1}^N K_i$\n \\For{$i \\in \\{1, \\ldots, N\\}$}\n \\If{$x_i \\in \\mathcal{F}$}\n \\State Set $(x_{i,1}, w_{i,1}) \\gets (x_i, w_i \\cdot \\frac{N'}{N})$\n \\Else\n \\State Generate $x_{i,k} \\sim M_t(\\cdot \\mid x_i, f_\\theta)$ for $k = 1, \\ldots, K$\n \\State Set $w_{i,k} \\gets w_i \\cdot G_t(x_i, x_{i,k}, f_\\theta) \\cdot \\frac{N'}{K N}$ for $k = 1, \\ldots, K$\n \\EndIf\n \\EndFor\n \\State Set normalized weights $\\hat{w}_{i,k} \\gets \\frac{w_{(i,k)}}{\\sum_{j=1}^N \\sum_{l=1}^{K_j} w_{(j,l)}}$ for $i = 1, \\ldots, N$ and $k = 1, \\ldots, K_i$\n \\State Set $c^* \\gets \\inf\\{c \\in \\mathbb{R}_{> 0} \\mid \\sum_{i=1}^N \\sum_{k=1}^{K_i} (\\mathbb{1} \\wedge c \\hat{w}_{(i,k)}) > N\\}$\n \\State Set $(I_\\text{det}, I_\\text{stoch}, I_\\text{strat}) \\gets (\\{(i,k) \\mid c^{*} \\hat{w}_{i,k} \\geq 1\\}, \\{(i,k) \\mid c^{*} \\cdot \\hat{w}_{i,k} < 1\\}, \\{\\})$\n \\State Set $\\alpha \\gets \\frac{\\sum_{i \\in I_\\text{stoch}} \\hat{w}_i}{|I_\\text{det}|}$ and generate $U \\sim \\text{Uniform}([0, \\alpha])$\n \\For{$i \\in I_\\text{stoch}$}\n \\State Set $U \\gets U - \\hat{w}_i$\n \\If{$U < 0$}\n \\State Set $I_\\text{strat} \\gets I_\\text{strat} \\cup \\{i\\}$\n \\State Set $U \\gets U + \\alpha$\n \\EndIf\n \\EndFor\n \\State Set particles $\\{(x_i, w_i)\\}_{i=1,\\ldots,|I_\\text{det}|} \\gets \\{(x_j, w_j \\cdot \\frac{N}{N'}) \\mid j \\in I_\\text{det}\\}$\n \\State Set particles $\\{(x_i, w_i)\\}_{i=|I_\\text{det}|+1,\\ldots,N} \\gets \\{(x_j, \\frac{N}{c^* N'} \\sum_{l=1}^{N} \\sum_{k=1}^{K_l} w_{(j,k)}) \\mid j \\in I_\\text{strat}\\}$\n\\EndWhile\n\\State \\Return $\\left((x_i, w_i)_{i=1,\\ldots,N}, \\hat{Z} = \\frac{1}{N} \\sum_{i=1}^N w_i \\right)$\n\\end{algorithmic}\n\\end{algorithm}"

Algorithm 5 Sequential Monte Carlo Transformer Steering

Input: NN (# particles), KK (factor), Feynman-Kac Transformer model {s0,{Mt}t1,{Gt}t1}\{s_0, \{M_t\}_{t \geq 1}, \{G_t\}_{t \geq 1}\}

Output: Weighted particle approximation {(xi,wi)}i=1,,N\{(x_i, w_i)\}_{i=1,\ldots,N} of the posterior P\mathbb{P}

Output: Unbiased estimate Z^\hat{Z} of the partition function Z=EM[t=1TGt(st,st1,fθ)]Z = \mathbb{E}_\mathbb{M}[\prod_{t=1}^T G_t(s_t, s_{t-1}, f_\theta)]

Initialize fθCachedTransformer()f_\theta \gets \texttt{CachedTransformer}()

Initialize (xi,wi)(s0,1)(x_i, w_i) \gets (s_0, 1) for i=1,,Ni = 1, \ldots, N

Initialize t1t \gets 1

while xi∉Fx_i \not\in \mathcal{F} for some i{1,,N}i \in \{1, \ldots, N\} do

KiK(11F(xi))+1F(xi)K_i \gets K (1 - \mathbb{1}_{\mathcal{F}}(x_i)) + \mathbb{1}_{\mathcal{F}}(x_i) for i=1,,Ni = 1, \ldots, N

Ni=1NKiN' \gets \sum_{i=1}^N K_i

for i{1,,N}i \in \{1, \ldots, N\} do

if xiFx_i \in \mathcal{F} then

Set (xi,1,wi,1)(xi,wiNN)(x_{i,1}, w_{i,1}) \gets (x_i, w_i \cdot \frac{N'}{N})

else

Generate xi,kMt(xi,fθ)x_{i,k} \sim M_t(\cdot \mid x_i, f_\theta) for k=1,,Kk = 1, \ldots, K

Set wi,kwiGt(xi,xi,k,fθ)NKNw_{i,k} \gets w_i \cdot G_t(x_i, x_{i,k}, f_\theta) \cdot \frac{N'}{K N} for k=1,,Kk = 1, \ldots, K

end if

end for

Set normalized weights w^i,kw(i,k)j=1Nl=1Kjw(j,l)\hat{w}_{i,k} \gets \frac{w_{(i,k)}}{\sum_{j=1}^N \sum_{l=1}^{K_j} w_{(j,l)}} for i=1,,Ni = 1, \ldots, N and k=1,,Kik = 1, \ldots, K_i

Set cinf{cR>0i=1Nk=1Ki(1cw^(i,k))>N}c^* \gets \inf\{c \in \mathbb{R}_{> 0} \mid \sum_{i=1}^N \sum_{k=1}^{K_i} (\mathbb{1} \wedge c \hat{w}_{(i,k)}) > N\}

Set (Idet,Istoch,Istrat)({(i,k)cw^i,k1},{(i,k)cw^i,k<1},{})(I_\text{det}, I_\text{stoch}, I_\text{strat}) \gets (\{(i,k) \mid c^{*} \hat{w}_{i,k} \geq 1\}, \{(i,k) \mid c^{*} \cdot \hat{w}_{i,k} < 1\}, \{\})

Set αiIstochw^iIdet\alpha \gets \frac{\sum_{i \in I_\text{stoch}} \hat{w}_i}{|I_\text{det}|} and generate UUniform([0,α])U \sim \text{Uniform}([0, \alpha])

for iIstochi \in I_\text{stoch} do

Set UUw^iU \gets U - \hat{w}_i

if U<0U < 0 then

Set IstratIstrat{i}I_\text{strat} \gets I_\text{strat} \cup \{i\}

Set UU+αU \gets U + \alpha

end if

end for

Set particles {(xi,wi)}i=1,,Idet{(xj,wjNN)jIdet}\{(x_i, w_i)\}_{i=1,\ldots,|I_\text{det}|} \gets \{(x_j, w_j \cdot \frac{N}{N'}) \mid j \in I_\text{det}\}

Set particles {(xi,wi)}i=Idet+1,,N{(xj,NcNl=1Nk=1Klw(j,k))jIstrat}\{(x_i, w_i)\}_{i=|I_\text{det}|+1,\ldots,N} \gets \{(x_j, \frac{N}{c^* N'} \sum_{l=1}^{N} \sum_{k=1}^{K_l} w_{(j,k)}) \mid j \in I_\text{strat}\}

end while

return ((xi,wi)i=1,,N,Z^=1Ni=1Nwi)\left((x_i, w_i)_{i=1,\ldots,N}, \hat{Z} = \frac{1}{N} \sum_{i=1}^N w_i \right)

Remarque

  1. features here refer to the hidden states of the decoder layers second-to-top-layer of the LLM, before the LM head. Not to be confused with features

  2. Vanilla autoregressive at token-level is described by T1:jE1:jfjpj+1tj+1T_{1:j} \rightarrow E_{1:j} \rightarrow f_j \rightarrow p_{j+1} \rightarrow t_{j+1}:

    • input T1:jT_{1:j} is then transformed into embeddings E1:jE_{1:j}
    • then into features F1:jF_{1:j},
    • LM Head maps fjf_j to a distribution pj+1=LM_Head(fj)p_{j+1} = \text{LM\_Head}(f_j)
    • sampling next token tj+1t_{j+1}
  3. See vllm-project/vllm#20078

  4. Aligns with DistillSpec and Medusa

  5. Note that we refer to standard sampling to methods such as argmax, top-k, nucleus, temperatures, et al., albeit each have a different ways to process logits. We will consider these as standard sampling from an adjusted distribution

  6. Rejection sampling follows a iterative sampling procedure that might looks superficially similar to speculative sampling:

    1. Sample xq(x)x \sim q(x) and returns rU(0,1)r \sim U(0,1)
    2. If r<p(x)Mq(x)r < \frac{p(x)}{M q(x)} return xx
    3. then go to 1

    Where M=maxxp(x)q(x)M = \operatorname{max}_{x} \frac{p(x)}{q(x)}

    We could employ non-iterative version of rejection sampling instead of speculative sampling here (go through step 1 and 2, and otherwise sample an unmodified p(x)p(x) directly)

    • less efficient, given that the accept probability here is: Exq(x)p(x)Mq(x)=xp(x)minxq(x)p(x)xp(x)min(1,q(x)p(x))=xmin(p(x),q(x))E_{x\sim q(x)} \frac{p(x)}{M q(x)} = \sum_{x} p(x) \min_{x^{'}}{\frac{q(x^{'})}{p(x^{'})}} \le \sum_{x} p(x) \min{(1, \frac{q(x)}{p(x)})} = \sum_{x} \min{(p(x), q(x))}
  7. A lenience parameter l[0,1]l \in [0,1] to introduce further trade-off. This is useful when the distributions of draft models does not match the target model exactly.

    Specifically we have:

    α=Exq(x) ⁣[{1,if lq(x)p(x),p(x)lq(x),if lq(x)>p(x)]=Exq(x) ⁣p(x)max ⁣(p(x),lq(x))=xp(x)q(x)max ⁣(p(x),lq(x))=1lxmin ⁣(p(x),lq(x))=xmin ⁣(p(x)l,q(x)).\begin{aligned} \alpha &= \mathbb{E}_{x\sim q(x)} \!\left[ \begin{cases} 1, & \text{if } l\,q(x) \le p(x),\\[6pt] \displaystyle\frac{p(x)}{l\,q(x)}, & \text{if } l\,q(x) > p(x) \end{cases} \right] \\[10pt] &= \mathbb{E}_{x\sim q(x)}\! \frac{p(x)}{\max\!\bigl(p(x),\,l\,q(x)\bigr)} \\[8pt] &= \sum_{x} \frac{p(x)\,q(x)}{\max\!\bigl(p(x),\,l\,q(x)\bigr)} \\[8pt] &= \frac{1}{l}\sum_{x} \min\!\bigl(p(x),\,l\,q(x)\bigr) \\[8pt] &= \sum_{x} \min\!\Bigl(\tfrac{p(x)}{l},\,q(x)\Bigr). \end{aligned}
    MqM_ql=1l=1l=0.5l=0.5l=0.3l=0.3l=0.1l=0.1
    Unigram0.070.100.110.16
    Bigram0.190.230.250.32
    t5-small0.620.710.760.84
    t5-base0.680.800.830.9

    this relies on q is sampled from this given distributions, and ll increases α\alpha

    In the case of greedy decoding (temperature=0), the draft essentially outputs xq=arg maxq(x)x^{'}_q = \argmax{q(x)}, so scaling lq(x)l q(x) becomes a no-op, given that the argmax will be unchanged in this case.

  8. On Correctness of Speculative Sampling (SpS)

    We will show that p(x) and q(x)\forall p(x) \text{ and } q(x), tokens sampled via speculative sampling from p(x)p(x) and q(x)q(x) are distributed identically to those sampled from p(x)p(x) alone.

    Let β\beta be the acceptance probability

    Note that

    p(x)=norm ⁣(max(0,  p(x)q(x)))=p(x)min ⁣(q(x),p(x))x ⁣(p(x)min ⁣(q(x),p(x)))=p(x)min ⁣(q(x),p(x))1β,p'(x) = \operatorname{norm}\!\bigl(\max(0,\;p(x)-q(x))\bigr) = \frac{p(x)-\min\!\bigl(q(x),\,p(x)\bigr)} {\displaystyle \sum_{x'}\!\bigl(p(x')-\min\!\bigl(q(x'),\,p(x')\bigr)\bigr)} = \frac{p(x)-\min\!\bigl(q(x),\,p(x)\bigr)}{1-\beta},

    so the normalising constant for the adjusted distribution p(x)p'(x) is 1β1-\beta; the last equality follows immediately from Lemma 3.3 and Theorem 3.5.

    Now

    P(x=x)  =  P(guess accepted,x=x)  +  P(guess rejected,x=x).P(x = x') \;=\; P(\text{guess accepted},\,x = x') \;+\; P(\text{guess rejected},\,x = x').

    Where

    P(guess accepted,x=x)  =  q(x)min ⁣(1,p(x)q(x))  =  min ⁣(q(x),p(x)),P(\text{guess accepted},\,x = x') \;=\; q(x')\,\min\!\bigl(1,\tfrac{p(x')}{q(x')}\bigr) \;=\; \min\!\bigl(q(x'),\,p(x')\bigr),

    and

    P(guess rejected,x=x)  =  (1β)p(x)  =  p(x)min ⁣(q(x),p(x)).P(\text{guess rejected},\,x = x') \;=\; (1-\beta)\,p'(x') \;=\; p(x') - \min\!\bigl(q(x'),\,p(x')\bigr).

    Overall

    P(x=x)=min ⁣(p(x),q(x))  +  p(x)min ⁣(p(x),q(x))=p(x).\begin{aligned} P(x = x') &= \min\!\bigl(p(x'),\,q(x')\bigr) \;+\; p(x') - \min\!\bigl(p(x'),\,q(x')\bigr) \\ &= p(x'). \end{aligned}

    \boxed{}

  9. DLK(p,q)=xp(x)M(x)=xpq2=1xp+qpq2=1xminp(x),q(x)\begin{aligned} D_{LK}(p,q) &= \sum_{x} |p(x) - M(x)| = \sum_{x} \frac{|p-q|}{2} \\ &= 1- \sum_{x} \frac{p+q - |p-q|}{2} \\ &= 1 - \sum_{x} \min{p(x), q(x)} \end{aligned}

    \boxed{}

  10. β=Exq(x)[{1if q(x)p(x),p(x)q(x)if q(x)>p(x)]=xmin ⁣(p(x),q(x)).\begin{aligned} \beta &= \mathbb{E}_{x \sim q(x)} \Biggl[ \begin{cases} 1 & \text{if } q(x) \le p(x), \\[6pt] \displaystyle\frac{p(x)}{q(x)} & \text{if } q(x) > p(x) \end{cases} \Biggr] \\[8pt] &= \sum_{x} \min\!\bigl(p(x),\,q(x)\bigr). \end{aligned} \qquad\square
  11. also known as elapsed real timeWikipedia. This is different from CPU time, given that it measure the actual time taken from the start of the computer program, where as CPU time only measures time during which processor is actively working on a certain task or process

  12. Denote the cost of running single steps of MpM_p by TT.

    Each run will then costs Tcγ+T=T(cγ+1)T c \gamma + T = T(c \gamma +1) (running MqM_q γ\gamma times and running MpM_p once)

    Given (1) procduces 1αγ+11α\frac{1-\alpha^{\gamma +1}}{1-\alpha} tokens

    The cost to produces a token with speculative sampling would be (cγ+1)(1α)1αγ+1T\frac{(c \gamma +1)(1-\alpha )}{1-\alpha^{\gamma +1}} T

    \boxed{}

  13. Denote by T^\hat{T} the number of arithmetic operations done by standard decoding per tokens, therefore speculative sampling costs T^c^γ+T^(γ+1)\hat{T} \hat{c} \gamma + \hat{T}(\gamma +1) operations. Then divided by the expected tokens we got the desired results \boxed{}

  14. Gloeckle et al. (2024) employs n=4n=4. The order of the forward and backward in a n-token prediction model with n=4n=4 heads of the shared trunk works as follow:

    z = model.shared(x)
    d = z.detach()
    d.requires_grad = False
     
    for i in range(n):
      p = model.heads[i](d)
      loss(p, y[i]).backward()
    z.backward()