FlashAttention: attenzione IO-aware che rivoluziona il training dei transformer
In una frase Tri Dao (Stanford) pubblica FlashAttention: implementazione IO-aware che evita di materializzare la matrice di attenzione in HBM, con 2-4x speedup e 10x meno memoria GPU.
Addestrare modelli linguistici grandi richiede calcolare l'attenzione tra tutti i token di un testo: un'operazione che cresce quadraticamente con la lunghezza della sequenza. Fino al 2022, questa operazione scriveva e leggeva enormi matrici nella memoria del chip — lenta e costosa.
Tri Dao, dottorando a Stanford, capisce che il collo di bottiglia non è il numero di operazioni matematiche ma il traffico tra memoria HBM (lenta) e SRAM (veloce). FlashAttention riorganizza il calcolo in blocchi che rimangono interamente nella SRAM, evitando di materializzare la matrice di attenzione completa.
Il risultato è 2-4x più veloce e usa 10-20x meno memoria, senza cambiare l'output matematico. Da quel momento FlashAttention diventa il building block di quasi ogni framework di training AI, da PyTorch a JAX.
Aziende
Stanford University
Tool
FlashAttention, CUDA, PyTorch
Tag
Fonti