Skip to content
AImpact
IT EN

Article · Technical guide

Multi-Token Prediction — Meta Insegna ai LLM a Prevedere Più Token in Parallelo

Original source: Gloeckle et al. · Meta AI · "Better & Faster Large Language Models via Multi-token Prediction" · arXiv:2404.19737 — summary and rework in own words.

ShareLinkedInX

Cos'è: Multi-Token Prediction (MTP) è un cambio nell'obiettivo di training dei LLM proposto da Fabian Gloeckle, Badr Youbi Idrissi, Baptiste Rozière, David Lopez-Paz e Gabriel Synnaeve di Meta AI ad aprile 2024. Invece di predire solo il prossimo token t+1, il modello viene allenato con N teste di output in parallelo che predicono t+1, t+2, ..., t+N nello stesso forward pass. Il costo di training in parametri è minimo (N teste leggere sopra il backbone condiviso) e gli effetti sono duplici: miglioramenti su task complessi come coding (HumanEval +12% per Llama-style 7B addestrato su 1T token) e capacità di speculative decoding nativa che accelera l'inferenza fino a 3x senza modelli draft separati.

Il problema dell'obiettivo "next-token": granularità troppo fine

L'obiettivo di training dei language model autoregressivi è invariabilmente la next-token prediction: massimizzare la probabilità del token t+1 condizionata su tutti i token precedenti. È un obiettivo matematicamente naturale (cattura la cross-entropy sulla distribuzione del corpus) e pragmaticamente comodo (un solo termine di loss, gradienti puliti). Ma ha un effetto collaterale poco discusso: ottimizza esclusivamente per la coerenza locale a brevissimo termine. Il modello impara a fare scelte ottime per il prossimo token assumendo che tutti gli altri arriveranno dopo, "uno alla volta".

Gloeckle e colleghi articolano l'osservazione: per task che richiedono pianificazione (codice, ragionamento matematico, struttura argomentativa) la decisione del token t+1 dipende criticamente da scelte che verranno fatte ai token t+5, t+10, t+50. L'obiettivo next-token non costringe il modello a "vedere avanti": il gradiente arriva solo da t+1, e qualsiasi struttura più lunga deve emergere indirettamente attraverso il pattern statistico dei dati. Per testo narrativo questo basta. Per codice — dove l'indentazione, le parentesi e la firma di una funzione devono coordinarsi su scala lunga — è una limitazione tangibile.

L'architettura: un backbone, N teste di predizione

Il design è minimale per ragioni pratiche. Il transformer principale (backbone) resta identico: layer di self-attention e MLP nella struttura standard. Sopra il backbone, al posto della singola "lm_head" lineare che mappa l'embedding finale al vocabolario, si pongono N teste di predizione indipendenti, ciascuna composta da un transformer layer aggiuntivo seguito dalla proiezione al vocabolario. La testa i-esima è addestrata a predire il token in posizione t+i. Durante il training la loss totale è la somma delle N cross-entropy, una per ogni testa.

L'efficienza in parametri viene dal fatto che il backbone — la parte costosa — è condiviso. Per Llama-7B il backbone ha circa 6.7B parametri, e le 4 teste aggiuntive (per N=4) aggiungono insieme circa 300M parametri (4%). Il costo di FLOPs di training cresce di circa il 4-8% (le teste fanno il loro forward e backward), ma il numero di token effettivi del dataset visti dal backbone è lo stesso. Per modelli molto grandi (70B+) la frazione di overhead è ancora più piccola perché il backbone domina.

Risultati: scala matters, coding migliora di più

I risultati sperimentali del paper sono articolati per dimensione di modello e tipo di task. Su modelli piccoli (300M-1B parametri) l'effetto di MTP è neutro o leggermente negativo su benchmark linguistici generali. Sopra i 7B la curva si inverte: MTP comincia a superare next-token. Su benchmark di coding il guadagno è il più marcato e robusto: HumanEval migliora del 12% in pass@1 e del 17% in pass@10 per un Llama-style 13B addestrato su 200B token. MBPP migliora di circa il 9%. Su task linguistici puri (HellaSwag, ARC, MMLU) i miglioramenti sono più piccoli ma significativi a partire dai 7B.

Il paper spiega l'asimmetria con un'analisi qualitativa: il codice ha la proprietà di avere "token critici" — punti decisionali in cui scegliere una parola chiave o un nome di variabile vincola fortemente le centinaia di token successivi. Predire più token avanti costringe il modello a ottimizzare esplicitamente quelle scelte coordinate. Per testo naturale i token critici esistono ma sono più rari e il guadagno marginale è minore.

Speculative decoding nativo: la velocità senza modelli draft

Il sottoprodotto inaspettato di MTP è il più impattante in produzione. Lo speculative decoding classico (Leviathan et al., Google, 2023) accelera l'inferenza usando un modello piccolo "draft" che genera N token candidati velocemente, e poi un modello grande "target" verifica i candidati in un singolo forward pass batched. È efficace ma operativamente complicato: serve un draft model separato, addestrato compatibilmente, mantenuto e deployato accanto al target.

Un modello MTP ha già un "draft model" integrato: le sue N teste predicono i token successivi nel forward pass del modello grande, gratis. La procedura di inferenza diventa: usa le teste 2, 3, ..., N come predizioni speculative; verifica con un forward pass batched contro l'output della testa 1 (canonical); accetta il prefisso più lungo coerente; ripeti. Gloeckle et al. misurano speedup di 2.5-3x su coding e fino a 2x su testo naturale, su modelli da 7B-13B, senza modello draft separato e senza compromessi di qualità rispetto a decoding greedy standard.

Meta non ha rilasciato pubblicamente modelli MTP addestrati alla scala di Llama 3, ma il paper è chiaro nel collocare la tecnica come parte degli "internal models" della casa, e DeepSeek-V3 (dicembre 2024) ha integrato MTP nell'architettura ufficiale citando esplicitamente questo paper. La tecnica è quindi entrata nell'arsenale standard dei laboratori che vogliono inferenza più veloce senza il debito operativo dello speculative decoding tradizionale.


Link alla fonte originale

arxiv.org/abs/2404.19737 →

Paper Meta AI del 30 aprile 2024 di Gloeckle, Idrissi, Rozière, Lopez-Paz, Synnaeve. Letture complementari: Leviathan et al. "Fast Inference from Transformers via Speculative Decoding" (arXiv:2211.17192, 2023) per lo speculative decoding originale; DeepSeek-V3 technical report (dicembre 2024) per l'adozione MTP in un modello rilasciato pubblicamente.