Speculative Decoding: 2-3x LLM inference speedup without changing output
In one sentence Chen et al. (Google Brain) publish Speculative Decoding: a small model proposes tokens, the large model verifies them in parallel. Same output, 2-3x faster with no quality change.
Generating text with a large language model is slow because each token is produced one at a time: the model must finish computing the previous token before starting the next. It cannot be easily parallelized, because each step depends on the one before it.
The idea behind Speculative Decoding is elegant: first use a small, fast model to propose a sequence of 4-8 tokens at once. Then the large model verifies the entire sequence in parallel — something it can do because all the proposed tokens are already known. Correct tokens are accepted, wrong ones are discarded and recomputed.
The final output is identical to that of the pure large model. But since most tokens proposed by the small model are correct, 2-3x calls to the large model are saved. One of the smartest ideas in LLM inference in recent years.
Companies
Google Brain
Tools
PyTorch, JAX
Tags
Sources