FlashAttention-2: riscrittura con 2x speedup, MQA/GQA e head-dim 256
In una frase Tri Dao riscrive FlashAttention con 2x speedup su FA1: migliore parallelismo su seq-len, supporto head-dim fino a 256, query parallelism per MHA, MQA e GQA. Standard de facto per il training.
Un anno dopo aver pubblicato FlashAttention, Tri Dao torna con una riscrittura completa. Il problema della versione originale: non sfruttava al massimo come le GPU moderne parallelizzano il lavoro. In particolare, il parallelismo era limitato alla dimensione del batch e al numero di heads, lasciando sottoutilizzate le GPU per sequenze lunghe.
FlashAttention-2 riorganizza il calcolo per parallelizzare anche sulla lunghezza della sequenza, riducendo le sincronizzazioni tra thread. Aggiunge il supporto per head dimensions fino a 256 (necessario per modelli come Llama-2 70B) e per le varianti GQA e MQA di attenzione, diventate standard per ridurre la memoria KV cache.
Il risultato è circa il doppio della velocità rispetto a FlashAttention-1, con supporto nativo per tutte le varianti architetturali dei modelli moderni. Adottato immediatamente da PyTorch 2.0, Hugging Face, e praticamente ogni framework di training.
Aziende
Princeton University, Tri Dao Research
Tool
FlashAttention-2, CUDA, PyTorch, Triton
Tag
Fonti