FlashAttention: IO-aware attention that revolutionizes transformer training
In one sentence Tri Dao (Stanford) publishes FlashAttention: an IO-aware implementation that avoids materializing the attention matrix in HBM, achieving 2-4x speedup and 10x less GPU memory.
Training large language models requires computing attention between every pair of tokens in a text — an operation that grows quadratically with sequence length. Until 2022, this operation wrote and read massive matrices to chip memory, slow and expensive.
Tri Dao, a PhD student at Stanford, realizes the bottleneck is not the number of math operations but the traffic between HBM memory (slow) and SRAM (fast). FlashAttention reorganizes the computation into blocks that stay entirely in SRAM, avoiding materializing the full attention matrix.
The result is 2-4x faster and uses 10-20x less memory, without changing the mathematical output. From that point on, FlashAttention becomes the building block of almost every AI training framework, from PyTorch to JAX.
Companies
Stanford University
Tools
FlashAttention, CUDA, PyTorch
Tags
Sources