Skip to content
AImpact
IT EN
Landmark AI Infrastructure · 1 min read

FlashAttention: IO-aware attention that revolutionizes transformer training

In one sentence Tri Dao (Stanford) publishes FlashAttention: an IO-aware implementation that avoids materializing the attention matrix in HBM, achieving 2-4x speedup and 10x less GPU memory.

Verified Official source
ShareLinkedInX
Reading level

Training large language models requires computing attention between every pair of tokens in a text — an operation that grows quadratically with sequence length. Until 2022, this operation wrote and read massive matrices to chip memory, slow and expensive.

Tri Dao, a PhD student at Stanford, realizes the bottleneck is not the number of math operations but the traffic between HBM memory (slow) and SRAM (fast). FlashAttention reorganizes the computation into blocks that stay entirely in SRAM, avoiding materializing the full attention matrix.

The result is 2-4x faster and uses 10-20x less memory, without changing the mathematical output. From that point on, FlashAttention becomes the building block of almost every AI training framework, from PyTorch to JAX.

Companies

Stanford University

Tools

FlashAttention, CUDA, PyTorch

Tags

FlashAttentionAttentionTransformerCUDAHBMIO-awareTri DaoStanford

Sources