Article · Technical guide
Multi-Token Prediction — Meta Insegna ai LLM a Prevedere Più Token in Parallelo
Original source: Gloeckle et al. · Meta AI · "Better & Faster Large Language Models via Multi-token Prediction" · arXiv:2404.19737 — summary and rework in own words.
Cos'è: Multi-Token Prediction (MTP) è un cambio nell'obiettivo di training dei LLM proposto da Fabian Gloeckle, Badr Youbi Idrissi, Baptiste Rozière, David Lopez-Paz e Gabriel Synnaeve di Meta AI ad aprile 2024. Invece di predire solo il prossimo token t+1, il modello viene allenato con N teste di output in parallelo che predicono t+1, t+2, ..., t+N nello stesso forward pass. Il costo di training in parametri è minimo (N teste leggere sopra il backbone condiviso) e gli effetti sono duplici: miglioramenti su task complessi come coding (HumanEval +12% per Llama-style 7B addestrato su 1T token) e capacità di speculative decoding nativa che accelera l'inferenza fino a 3x senza modelli draft separati.
Il problema dell'obiettivo "next-token": granularità troppo fine
L'obiettivo di training dei language model autoregressivi è invariabilmente la next-token prediction: massimizzare la probabilità del token t+1 condizionata su tutti i token precedenti. È un obiettivo matematicamente naturale (cattura la cross-entropy sulla distribuzione del corpus) e pragmaticamente comodo (un solo termine di loss, gradienti puliti). Ma ha un effetto collaterale poco discusso: ottimizza esclusivamente per la coerenza locale a brevissimo termine. Il modello impara a fare scelte ottime per il prossimo token assumendo che tutti gli altri arriveranno dopo, "uno alla volta".
Gloeckle e colleghi articolano l'osservazione: per task che richiedono pianificazione (codice, ragionamento matematico, struttura argomentativa) la decisione del token t+1 dipende criticamente da scelte che verranno fatte ai token t+5, t+10, t+50. L'obiettivo next-token non costringe il modello a "vedere avanti": il gradiente arriva solo da t+1, e qualsiasi struttura più lunga deve emergere indirettamente attraverso il pattern statistico dei dati. Per testo narrativo questo basta. Per codice — dove l'indentazione, le parentesi e la firma di una funzione devono coordinarsi su scala lunga — è una limitazione tangibile.
L'architettura: un backbone, N teste di predizione
Il design è minimale per ragioni pratiche. Il transformer principale (backbone) resta identico: layer di self-attention e MLP nella struttura standard. Sopra il backbone, al posto della singola "lm_head" lineare che mappa l'embedding finale al vocabolario, si pongono N teste di predizione indipendenti, ciascuna composta da un transformer layer aggiuntivo seguito dalla proiezione al vocabolario. La testa i-esima è addestrata a predire il token in posizione t+i. Durante il training la loss totale è la somma delle N cross-entropy, una per ogni testa.
L'efficienza in parametri viene dal fatto che il backbone — la parte costosa — è condiviso. Per Llama-7B il backbone ha circa 6.7B parametri, e le 4 teste aggiuntive (per N=4) aggiungono insieme circa 300M parametri (4%). Il costo di FLOPs di training cresce di circa il 4-8% (le teste fanno il loro forward e backward), ma il numero di token effettivi del dataset visti dal backbone è lo stesso. Per modelli molto grandi (70B+) la frazione di overhead è ancora più piccola perché il backbone domina.
Risultati: scala matters, coding migliora di più
I risultati sperimentali del paper sono articolati per dimensione di modello e tipo di task. Su modelli piccoli (300M-1B parametri) l'effetto di MTP è neutro o leggermente negativo su benchmark linguistici generali. Sopra i 7B la curva si inverte: MTP comincia a superare next-token. Su benchmark di coding il guadagno è il più marcato e robusto: HumanEval migliora del 12% in pass@1 e del 17% in pass@10 per un Llama-style 13B addestrato su 200B token. MBPP migliora di circa il 9%. Su task linguistici puri (HellaSwag, ARC, MMLU) i miglioramenti sono più piccoli ma significativi a partire dai 7B.
Il paper spiega l'asimmetria con un'analisi qualitativa: il codice ha la proprietà di avere "token critici" — punti decisionali in cui scegliere una parola chiave o un nome di variabile vincola fortemente le centinaia di token successivi. Predire più token avanti costringe il modello a ottimizzare esplicitamente quelle scelte coordinate. Per testo naturale i token critici esistono ma sono più rari e il guadagno marginale è minore.
Speculative decoding nativo: la velocità senza modelli draft
Il sottoprodotto inaspettato di MTP è il più impattante in produzione. Lo speculative decoding classico (Leviathan et al., Google, 2023) accelera l'inferenza usando un modello piccolo "draft" che genera N token candidati velocemente, e poi un modello grande "target" verifica i candidati in un singolo forward pass batched. È efficace ma operativamente complicato: serve un draft model separato, addestrato compatibilmente, mantenuto e deployato accanto al target.
Un modello MTP ha già un "draft model" integrato: le sue N teste predicono i token successivi nel forward pass del modello grande, gratis. La procedura di inferenza diventa: usa le teste 2, 3, ..., N come predizioni speculative; verifica con un forward pass batched contro l'output della testa 1 (canonical); accetta il prefisso più lungo coerente; ripeti. Gloeckle et al. misurano speedup di 2.5-3x su coding e fino a 2x su testo naturale, su modelli da 7B-13B, senza modello draft separato e senza compromessi di qualità rispetto a decoding greedy standard.
Meta non ha rilasciato pubblicamente modelli MTP addestrati alla scala di Llama 3, ma il paper è chiaro nel collocare la tecnica come parte degli "internal models" della casa, e DeepSeek-V3 (dicembre 2024) ha integrato MTP nell'architettura ufficiale citando esplicitamente questo paper. La tecnica è quindi entrata nell'arsenale standard dei laboratori che vogliono inferenza più veloce senza il debito operativo dello speculative decoding tradizionale.
Link alla fonte originale
Paper Meta AI del 30 aprile 2024 di Gloeckle, Idrissi, Rozière, Lopez-Paz, Synnaeve. Letture complementari: Leviathan et al. "Fast Inference from Transformers via Speculative Decoding" (arXiv:2211.17192, 2023) per lo speculative decoding originale; DeepSeek-V3 technical report (dicembre 2024) per l'adozione MTP in un modello rilasciato pubblicamente.