FlashAttention-2: rewrite with 2x speedup, MQA/GQA support, and head-dim 256
In one sentence Tri Dao rewrites FlashAttention with 2x speedup over FA1: better parallelism across seq-len, head-dim support up to 256, query parallelism for MHA, MQA, and GQA. De facto training standard.
One year after publishing FlashAttention, Tri Dao returns with a complete rewrite. The problem with the original version: it did not fully exploit how modern GPUs parallelize work. Specifically, parallelism was limited to batch size and number of heads, leaving GPUs underutilized for long sequences.
FlashAttention-2 reorganizes the computation to parallelize across sequence length as well, reducing synchronizations between threads. It adds support for head dimensions up to 256 (needed for models like Llama-2 70B) and for GQA and MQA attention variants, which have become standard for reducing KV cache memory.
The result is roughly twice the speed of FlashAttention-1, with native support for all architectural variants of modern models. Immediately adopted by PyTorch 2.0, Hugging Face, and practically every training framework.
Companies
Princeton University, Tri Dao Research
Tools
FlashAttention-2, CUDA, PyTorch, Triton
Tags
Sources