Skip to content
AImpact
IT EN
High AI Infrastructure · 1 min read

FlashAttention-2: rewrite with 2x speedup, MQA/GQA support, and head-dim 256

In one sentence Tri Dao rewrites FlashAttention with 2x speedup over FA1: better parallelism across seq-len, head-dim support up to 256, query parallelism for MHA, MQA, and GQA. De facto training standard.

Verified Official source
ShareLinkedInX
Reading level

One year after publishing FlashAttention, Tri Dao returns with a complete rewrite. The problem with the original version: it did not fully exploit how modern GPUs parallelize work. Specifically, parallelism was limited to batch size and number of heads, leaving GPUs underutilized for long sequences.

FlashAttention-2 reorganizes the computation to parallelize across sequence length as well, reducing synchronizations between threads. It adds support for head dimensions up to 256 (needed for models like Llama-2 70B) and for GQA and MQA attention variants, which have become standard for reducing KV cache memory.

The result is roughly twice the speed of FlashAttention-1, with native support for all architectural variants of modern models. Immediately adopted by PyTorch 2.0, Hugging Face, and practically every training framework.

Companies

Princeton University, Tri Dao Research

Tools

FlashAttention-2, CUDA, PyTorch, Triton

Tags

FlashAttention-2AttentionTransformerCUDAMHAMQAGQATri Dao

Sources