Articolo · Guida tecnica
Grouped-Query Attention — Come Ridurre il KV Cache Senza Perdere Qualità
Fonte originale: Ainslie et al. · Google · "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" · arXiv:2305.13245 — sintesi e rielaborazione in parole proprie.
Cos'è: Grouped-Query Attention (GQA) è una variante di attenzione introdotta da Joshua Ainslie e colleghi di Google nel maggio 2023. Si posiziona come compromesso intermedio tra Multi-Head Attention (MHA, ogni testa ha proprie matrici K e V) e Multi-Query Attention (MQA, tutte le teste condividono una sola K e V). In GQA le N teste di query vengono suddivise in G gruppi: ogni gruppo condivide una singola coppia K/V. Il risultato è una riduzione del KV cache di un fattore N/G — tipicamente 4-8x — con perdita di qualità trascurabile rispetto a MHA piena. È oggi lo standard de facto dell'attenzione nei modelli aperti moderni.
Il problema: MHA è preciso ma costoso, MQA è economico ma fragile
Nel multi-head attention originale (Vaswani et al., "Attention is All You Need", 2017) ogni testa di attenzione ha matrici Q, K e V indipendenti. Per un modello con h teste e dimensione embedding d, le matrici di proiezione K e V hanno ciascuna dimensione d×(d/h)×h = d×d. Durante l'inferenza autoregressiva il KV cache memorizza per ogni token e per ogni layer le h coppie complete di K e V. Su Llama 2 70B (80 layer, 64 teste, head_dim 128) un contesto da 4K token richiede circa 10.7 GB di KV cache per singola sequenza utente — il fattore dominante nel budget di memoria di inferenza.
Noam Shazeer aveva proposto già nel 2019 (paper "Fast Transformer Decoding: One Write-Head is All You Need") l'idea estrema di Multi-Query Attention: tutte le teste condividono una sola K e una sola V. La memoria KV cache crolla di un fattore h — su Llama 70B sarebbero circa 170 MB invece di 10.7 GB. Però MQA ha un costo in qualità: i benchmark di Shazeer mostravano regressioni di 1-2 punti su task linguistici, e instabilità durante il training a scala. Per modelli decoder grandi addestrati da zero MQA era percepito come compromesso troppo aggressivo.
Ainslie e colleghi hanno osservato che la maggior parte della perdita di MQA viene dal collasso totale della diversità tra teste: con una sola K/V condivisa, le teste perdono la capacità di "guardare in direzioni diverse" del contesto. GQA propone un compromesso parametrizzato: invece di 1 sola K/V (MQA) o N K/V indipendenti (MHA), si usano G gruppi con G compreso tra 1 e N.
La formulazione matematica: gruppi di teste che condividono K e V
Sia h il numero di teste di query, d_h la dimensione per testa, G il numero di gruppi (divisore di h). In GQA si mantengono h teste di query indipendenti Q_1, ..., Q_h, ma solo G coppie K e V indipendenti: K_1, ..., K_G e V_1, ..., V_G. Ogni gruppo g contiene h/G teste di query consecutive, e tutte queste teste usano lo stesso K_g e V_g per il calcolo dell'attenzione. Quando G = h ricadiamo in MHA classico; quando G = 1 ricadiamo in MQA. Il caso interessante è G compreso tra questi estremi: ad esempio G = 8 con h = 64 (Llama 3 70B).
Il calcolo dell'attenzione per la testa i appartenente al gruppo g è quindi attention(Q_i, K_g, V_g) = softmax(Q_i · K_g transpose / sqrt(d_h)) · V_g. La riduzione di memoria è esatta: il KV cache si dimensiona come 2 × num_layer × G × d_h × seq_len × precision_bytes invece di 2 × num_layer × h × d_h × seq_len × precision_bytes. Per Llama 3 70B (h = 64, G = 8) il KV cache passa da 10.7 GB a 1.34 GB per 4K token — fattore 8x. È la stessa riduzione che permette di servire 8x più utenti concorrenti sulla stessa GPU.
Uptraining: convertire un modello MHA esistente in GQA
Il contributo metodologico più pratico del paper è dimostrare che non serve riaddestrare un modello da zero per ottenere GQA. Ainslie et al. propongono una procedura di uptraining: partire da un checkpoint MHA pre-addestrato, comprimere le h coppie K/V in G coppie facendo la media dei pesi all'interno di ogni gruppo, poi fare fine-tuning leggero (tipicamente il 5% dei FLOP del pre-training originale). Con questa procedura T5-XXL convertito a GQA-8 raggiunge qualità entro 0.3 punti del modello MHA originale su una batteria di benchmark (CNN/DailyMail, WMT, SuperGLUE).
Questo è il punto che ha fatto adottare GQA dall'intera industria nel giro di dodici mesi. Convertire un Llama 2 70B MHA in una versione GQA-8 costa una frazione minima del budget originale, ma libera 8x di memoria KV — abbastanza per servire context window 8x più lunghi o 8x più utenti in parallelo. Per i provider cloud è una delle ottimizzazioni con miglior rapporto costo/beneficio mai pubblicate.
Adozione: lo standard del 2023-2025
Il paper esce a maggio 2023. A luglio 2023 Meta pubblica Llama 2 70B usando GQA con 8 KV head (la versione 7B e 13B usano ancora MHA standard). A settembre 2023 Mistral 7B esce con GQA-8, dimostrando che il design funziona benissimo anche a scala più piccola. Nel 2024 GQA diventa universale: Llama 3 (8B, 70B, 405B) la usa in tutte le varianti, Mixtral 8x7B, Falcon-180B, Qwen2, DeepSeek, Gemma 2, Phi-3 medium e large. Le rare eccezioni sono modelli che spingono ancora più in là (DeepSeek-V2 introduce Multi-Latent Attention con compressione ulteriore, GPT-4 secondo le indiscrezioni usa varianti proprietarie).
Il tradeoff misurato è notevolmente favorevole. Sui benchmark del paper la perplexity di T5-XXL con GQA-8 è 4.05 contro 4.03 di MHA pieno (differenza nel terzo decimale) e 4.27 di MQA. Su MMLU, GSM8K e HumanEval i modelli Llama 3 GQA-8 non mostrano regressioni rispetto a baseline MHA equivalenti addestrate con lo stesso budget. La velocità di inferenza migliora linearmente con la riduzione di KV cache: throughput su sequenze lunghe (32K+ token) può raddoppiare o triplicare rispetto a MHA pieno sullo stesso hardware.
Link alla fonte originale
Paper Google del maggio 2023 di Ainslie, Lee-Thorp, de Jong, Zemlyanskiy, Lebron, Sanghai. Letture complementari: Shazeer "Fast Transformer Decoding" (arXiv:1911.02150, 2019) per il paper originale MQA; DeepSeek-V2 (arXiv:2405.04434) per Multi-Latent Attention come evoluzione successiva. Implementazione di riferimento nel codice Llama 3 di Meta.