Skip to content
gentic.news — AI News Intelligence Platform
Connecting to the Living Graph…
Training & Inference

FlashAttention: definition + examples

FlashAttention is a memory-efficient, exact attention algorithm that reorders the standard attention computation to avoid reading and writing the full N×N attention matrix to high-bandwidth memory (HBM). Instead, it tiles the query, key, and value matrices, loads them into fast on-chip SRAM, and performs the softmax and weighted sum in a fused kernel. This IO-aware approach drastically reduces the number of HBM accesses — the primary bottleneck for long-sequence transformer training.

How it works:

Standard attention computes S = QK^T (an N×N matrix), applies softmax row-wise, then multiplies by V. Both the N×N matrix and the softmax normalization statistics (max, sum) are stored in HBM. For long sequences (e.g., 8k+ tokens), this incurs quadratic memory and significant IO overhead. FlashAttention avoids materializing the full matrix by dividing Q, K, V into blocks that fit into SRAM (typically 108 KB on an A100). It computes partial softmax values incrementally, storing only the running maximum and sum for each row. The final output is reconstructed by combining these partial results. The algorithm is exact — it produces bitwise identical results to standard attention (up to floating-point associativity differences).

Why it matters:

FlashAttention enables training and inference with context lengths that were previously impractical. For example, training a 1.4B-parameter model with 8k context using PyTorch attention would require >40 GB of HBM for the attention matrix alone; FlashAttention reduces this to <4 GB. It achieves 2–4× end-to-end speedup on GPUs (A100, H100) and up to 7× for very long sequences (e.g., 64k tokens). It also reduces wall-clock time for training large models by 15–30%.

When it's used vs. alternatives:

FlashAttention is the default attention implementation in virtually all modern transformer libraries (Hugging Face, Megatron-LM, Nemo). It is used for both training and inference. Alternatives include:

  • Standard PyTorch attention: slower, memory-prohibitive for long sequences.
  • Sparse attention (e.g., Longformer, BigBird): approximate, but can handle even longer contexts (e.g., 4096+ tokens) with linear memory. FlashAttention is preferred when exact attention is needed.
  • Linear attention (e.g., Performer, Linformer): approximate, less accurate on retrieval tasks.
  • Flashattention-2 / FlashAttention-3: improved versions with better parallelism and support for FP8, achieving up to 2× further speedup on H100.

Common pitfalls:

  • FlashAttention requires that the head dimension be a multiple of 64 or 128 (depending on GPU). Non-standard head dims (e.g., 96) cause fallback to slower kernels.
  • It does not support all attention masks natively (e.g., arbitrary causal masks beyond standard causal).
  • For very short sequences (<512 tokens), the overhead of tiling can make it slower than standard attention.
  • It is not yet optimized for all hardware (e.g., AMD MI250X support is experimental).

Current state of the art (2026):

FlashAttention-3 (Dao et al., 2025) introduced asynchronous execution and warp-specialized kernels for Hopper GPUs, achieving up to 1.8× speedup over FlashAttention-2 on H100 with FP16. FlashAttention-4 is in development, targeting Blackwell GPUs and supporting FP4/INT4 quantization. The algorithm has been extended to multi-query and grouped-query attention (FlashDecoding for inference). It is now a standard component in all major training frameworks (PyTorch 2.x native, JAX, TensorFlow).

Examples

  • Llama 3.1 405B uses FlashAttention-2 for training with 128k context length.
  • GPT-4 (reported) uses a variant of FlashAttention for handling 32k token sequences.
  • MosaicML's MPT-7B was trained with FlashAttention, enabling 65k token context.
  • FlashAttention-2 achieves 2.7× speedup over PyTorch attention on A100 for 4k sequence length.
  • Google's PaLM 2 (540B) uses FlashAttention-style tiling for long-document understanding.

Related terms

Multi-Head AttentionGrouped-Query AttentionIO-Aware AlgorithmsTritonKernel Fusion

Latest news mentioning FlashAttention

FAQ

What is FlashAttention?

FlashAttention is an IO-aware exact attention algorithm that computes attention without materializing the full N×N attention matrix to HBM, reducing memory reads/writes and achieving 2–4× speedup over standard PyTorch attention.

How does FlashAttention work?

FlashAttention is a memory-efficient, exact attention algorithm that reorders the standard attention computation to avoid reading and writing the full N×N attention matrix to high-bandwidth memory (HBM). Instead, it tiles the query, key, and value matrices, loads them into fast on-chip SRAM, and performs the softmax and weighted sum in a fused kernel. This IO-aware approach drastically reduces the number…

Where is FlashAttention used in 2026?

Llama 3.1 405B uses FlashAttention-2 for training with 128k context length. GPT-4 (reported) uses a variant of FlashAttention for handling 32k token sequences. MosaicML's MPT-7B was trained with FlashAttention, enabling 65k token context.