Skip to content
AImpact
IT EN

Paper · Foundational research

FlashAttention — L'Ottimizzazione che ha Reso Possibili i Context Window Lunghi

Original source: Tri Dao et al. · Stanford 2022 · arxiv.org/abs/2205.14135 — summary and rework in own words.

ShareLinkedInX

Cos'è: FlashAttention è un algoritmo di attenzione IO-aware pubblicato da Tri Dao, Dan Fu e colleghi di Stanford nel maggio 2022. Invece di modificare la matematica dell'attenzione, riprogetta il pattern di accesso alla memoria per adattarsi alla gerarchia hardware della GPU — SRAM on-chip vs HBM off-chip — riducendo le letture/scritture di memoria del 5–20×, accelerando il training di 3–8× e rendendo la complessità in memoria sublineare rispetto alla lunghezza della sequenza.

Il problema O(n²): non il calcolo, ma la memoria

L'attenzione standard dei transformer ha una complessità computazionale e di memoria O(n²) rispetto alla lunghezza della sequenza n. Per una sequenza di 512 token il problema è gestibile. Per 4.096 token diventa pesante. Per 32.768 token — che corrispondono a circa 25.000 parole, un racconto lungo — la matrice di attenzione occupa da sola oltre 8 GB di VRAM con precisione fp16.

Questo non è un problema teorico: fino al 2022 era la ragione pratica per cui i modelli linguistici avevano context window di 2.048 o al massimo 4.096 token. GPT-3 aveva 2.048 token. I paper di fine 2020 e 2021 erano pieni di proposte per "efficient attention" — Longformer, BigBird, Performer, Linformer — tutte basate su approssimazioni della matrice di attenzione: attenzione locale a finestra scorrevole, attenzione sparsa, proiezioni lineari.

Tri Dao ha identificato che il problema era mal posto. Non era un problema matematico da approssimare: era un problema di architettura hardware. L'attenzione standard non era lenta perché faceva troppi calcoli floating-point — i GPU hanno TFLOPS a sufficienza. Era lenta perché accedeva alla memoria nel modo sbagliato.

La gerarchia della memoria GPU: HBM vs SRAM

Una GPU moderna come l'A100 ha due tipi di memoria radicalmente diversi per velocità e capacità. La HBM (High Bandwidth Memory) è la VRAM principale: su un A100 SXM4 ne troviamo 40 o 80 GB, con una bandwidth di circa 2 TB/s. È dove risiedono i pesi del modello, le attivazioni, i gradienti. È molta e accessibile da tutti i core, ma in termini di latenza e bandwidth rispetto ai core di calcolo è relativamente lenta.

La SRAM on-chip è la memoria cache che vive fisicamente all'interno dei chip CUDA: su un A100 sono circa 20 MB totali, distribuiti tra i 108 Streaming Multiprocessor. È microscopicamente piccola rispetto alla HBM, ma è ordini di grandezza più veloce — circa 19 TB/s di bandwidth effettiva per i core che vi accedono direttamente.

L'attenzione standard lavora così: carica in HBM le matrici Q, K, V intere. Calcola la matrice di score QKᵀ e la scrive in HBM. Legge da HBM per applicare softmax. Riscrive il risultato in HBM. Legge di nuovo per moltiplicare per V. Ogni passaggio — e ce ne sono parecchi — comporta un round-trip completo tra i core di calcolo e la HBM lenta. Il vero collo di bottiglia non è il numero di operazioni floating-point: è il numero di byte che viaggiano avanti e indietro tra HBM e i core.

L'insight IO-aware: tiling e kernel fusion

L'idea centrale di FlashAttention è il tiling: invece di calcolare l'intera matrice di attenzione in un colpo, la si calcola a blocchi (tile) sufficientemente piccoli da stare interamente nella SRAM on-chip. Il procedimento aggiornato è questo:

  • Si divide Q in blocchi di righe, K e V in blocchi di colonne.
  • Per ogni coppia di blocchi si carica il blocco in SRAM, si esegue il calcolo parziale dell'attenzione, si accumula il risultato locale usando un'implementazione online del softmax (che non richiede di vedere tutti gli score prima di normalizzare).
  • Il risultato parziale viene aggiornato incrementalmente senza mai materializzare la matrice N×N completa in HBM.

Il trucco matematico che rende possibile il softmax online è noto come log-sum-exp trick: si può calcolare il softmax su un blocco e poi aggiornarlo correttamente quando si vedono nuovi valori, senza dover rileggere i blocchi precedenti. Tri Dao e collaboratori hanno dimostrato formalmente che l'algoritmo produce esattamente lo stesso output dell'attenzione standard — non è un'approssimazione.

Combinato con la kernel fusion — eseguire tutti i passaggi (score, softmax, moltiplicazione per V) in un unico kernel CUDA invece che in kernel separati — il risultato è un drastico taglio negli accessi HBM: da O(n²) read/write a O(n) read/write per la fase di forward pass.

Risultati: velocità e memoria

Le misurazioni del paper su A100 GPU sono precise e riproducibili. FlashAttention è 3–8× più veloce dell'attenzione standard PyTorch su sequenze da 1K a 16K token, con il guadagno che aumenta all'aumentare della lunghezza della sequenza (perché il problema HBM diventa proporzionalmente più dominante). La memoria necessaria per l'attenzione passa da O(n²) a O(n): una sequenza di 16K token che richiedeva 16 GB di VRAM solo per la matrice di attenzione ora ne richiede una frazione lineare.

Su GPT-2 da 1.3B parametri, il training con FlashAttention è 3× più veloce end-to-end rispetto alla baseline. Su BERT-large, 15% di speedup con sequenze da 512 token (dove il guadagno è minore perché la sequenza è corta e l'overhead HBM è proporzionalmente meno dominante). Il paper include anche risultati su Long Range Arena, benchmark per modelli su sequenze lunghe: FlashAttention è il primo metodo a completare Path-X (lunghezza 16.384) con un transformer standard, che prima non scalava abbastanza.

FlashAttention-2 e FlashAttention-3

A luglio 2023 Tri Dao — nel frattempo passato a lavorare ad Anthropic — pubblica FlashAttention-2. Le ottimizzazioni principali riguardano la distribuzione del lavoro tra i thread CUDA (warps): la versione originale aveva thread che aspettavano altri thread in modo non ottimale. FA2 riduce la sincronizzazione, parallelizza meglio su batch size e numero di teste, ed è circa 2× più veloce di FA1 — fino al 73% dell'utilizzo teorico massimo dell'A100.

Nel 2024 arriva FlashAttention-3, ottimizzato per l'architettura Hopper (H100): sfrutta le nuove istruzioni WGMMA (Warpgroup Matrix Multiply-Accumulate) e la possibilità di sovrapporre calcolo aritmetico e accessi a memoria in modo asincrono. Su H100 raggiunge fino a 740 TFLOPS, ovvero circa il 75% del picco teorico.

FlashAttention è oggi integrato nativamente in PyTorch (a partire da 2.0, come F.scaled_dot_product_attention), in HuggingFace Transformers, e nelle librerie di training di praticamente tutti i laboratori AI principali.

Impatto pratico: da 4K a 1M token di contesto

L'impatto di FlashAttention sul campo è difficile da sovrastimare. Prima del 2022, un context window di 4.096 token era considerato ambizioso. GPT-4 al lancio (marzo 2023) offriva 8.192 token, con una versione a 32.768 accessibile a un sottoinsieme di utenti. Claude 2 (luglio 2023) ha portato il context a 100.000 token — il primo modello commerciale a rompere la barriera dei 100K. Gemini 1.5 Pro (febbraio 2024) ha raggiunto 1 milione di token di context window in preview.

Nessuno di questi avanzamenti sarebbe stato praticamente realizzabile senza FlashAttention. Non è l'unico ingrediente — ci sono tecniche complementari come Group Query Attention (GQA) e sliding window attention — ma è la fondazione su cui tutto il resto è costruito. Quando Claude legge un intero documento di 150 pagine in un singolo prompt, o quando GPT-4 analizza una repository di codice completa, il meccanismo che rende questo possibile computazionalmente è discendente diretto del paper di Tri Dao del 2022.

FlashAttention è anche usato durante il training di quasi tutti i modelli moderni: GPT-4, Claude, LLaMA 2 e 3, Mistral, Falcon, Gemma. Non è un'ottimizzazione opzionale per chi vuole risparmio di memoria: è diventato il modo standard di implementare l'attenzione.


Link alla fonte originale

arxiv.org/abs/2205.14135 →

Paper originale EN, Tri Dao et al. Stanford, maggio 2022. Codice open-source su github.com/Dao-AILab/flash-attention. Integrato in PyTorch 2.0+.