Salta al contenuto
AImpact
IT EN
Alto Infrastruttura AI · 1 min lettura

FlashAttention-2: riscrittura con 2x speedup, MQA/GQA e head-dim 256

In una frase Tri Dao riscrive FlashAttention con 2x speedup su FA1: migliore parallelismo su seq-len, supporto head-dim fino a 256, query parallelism per MHA, MQA e GQA. Standard de facto per il training.

Verificato Fonte ufficiale
CondividiLinkedInX
Livello di lettura

Un anno dopo aver pubblicato FlashAttention, Tri Dao torna con una riscrittura completa. Il problema della versione originale: non sfruttava al massimo come le GPU moderne parallelizzano il lavoro. In particolare, il parallelismo era limitato alla dimensione del batch e al numero di heads, lasciando sottoutilizzate le GPU per sequenze lunghe.

FlashAttention-2 riorganizza il calcolo per parallelizzare anche sulla lunghezza della sequenza, riducendo le sincronizzazioni tra thread. Aggiunge il supporto per head dimensions fino a 256 (necessario per modelli come Llama-2 70B) e per le varianti GQA e MQA di attenzione, diventate standard per ridurre la memoria KV cache.

Il risultato è circa il doppio della velocità rispetto a FlashAttention-1, con supporto nativo per tutte le varianti architetturali dei modelli moderni. Adottato immediatamente da PyTorch 2.0, Hugging Face, e praticamente ogni framework di training.

Aziende

Princeton University, Tri Dao Research

Tool

FlashAttention-2, CUDA, PyTorch, Triton

Tag

FlashAttention-2AttentionTransformerCUDAMHAMQAGQATri Dao

Fonti