SnackOnAI Engineering | Senior AI Systems Researcher | Technical Deep Dive | April 25, 2026
The transformer community spent years optimizing the wrong thing. Approximate attention methods traded model quality for reduced compute, assuming FLOP count was the bottleneck. FlashAttention (2022) proved the bottleneck was never compute. It was memory traffic. The N×N attention matrix was being written to GPU HBM (High Bandwidth Memory, the GPU's main off-chip memory) and read back three times per forward pass. Every byte of that write-and-read was idle time for the compute cores. The fix was not to do less math. It was to do the same math without ever writing the intermediate matrix to HBM.
This newsletter dissects the full FlashAttention lineage: v1 (IO-aware tiling), v2 (warp-level work partitioning), v3 (Hopper asynchrony and FP8), and the just-published v4 (Blackwell co-design, 1613 TFLOPs/s). It also covers the IO optimality proof (arXiv:2402.07443) that shows FlashAttention is theoretically optimal for its SRAM regime, the block-sparse variant, and the INT8/FP8 quantized variant. This is not a tutorial. It is a systems audit.
Scope: all four FlashAttention versions, IO complexity theory, and the block-sparse and quantized variants. Not covered: linear attention approximations (Linformer, Performer), or the full transformer training stack.
What It Actually Does
FlashAttention by Tri Dao et al. is an IO-aware exact attention algorithm. "Exact" means no approximation, same mathematical output as standard attention. "IO-aware" means the algorithm is designed around the GPU memory hierarchy, not around FLOP minimization.
The repo has over 15,000 GitHub stars. It is the default attention implementation in PyTorch (torch.nn.functional.scaled_dot_product_attention), Hugging Face Transformers, vLLM, llm.c, MLC-LLM, and virtually every production LLM training and inference stack. Context length in production LLMs grew from 2-4K (GPT-3, OPT) to 128K (GPT-4) to 1M+ (Llama 3). FlashAttention is the enabling mechanism.
Benchmarks across the four versions:
FlashAttention v1 (A100): 3x speedup on GPT-2 (seq 1K), 7.6x on some long-context configs, memory O(N) vs O(N²)
FlashAttention v2 (A100): 50-73% theoretical max FLOPs/s, 225 TFLOPs/s training throughput per A100, 2x over v1
FlashAttention v3 (H100): 740 TFLOPs/s FP16 (75% utilization), 1.2 PFLOPs/s FP8, 1.5-2x over v2
FlashAttention v4 (B200): 1613 TFLOPs/s BF16 (71% utilization), 1.3x over cuDNN 9.13, 2.7x over Triton
The Architecture
The core insight is that standard attention launches three separate CUDA kernels, each materializing the N×N attention matrix to HBM. FlashAttention fuses all attention operations into one kernel using tiling and online softmax, never writing the N×N matrix.

Focus on the memory flow. Standard attention writes the N×N matrix three times. FlashAttention writes it zero times. The speedup is not from doing less math. It is from eliminating the memory bottleneck entirely.
Three techniques interlock to make this work:
Tiling. Q, K, V are split into blocks. The algorithm loads one block of K,V and loops over all Q blocks in SRAM, computing the partial attention output without ever needing the full N×N matrix resident in memory at once.
Online softmax. Standard softmax over a row requires seeing all N scores before computing the denominator. Online softmax tracks two running statistics per row: m (current row maximum) and l (running sum of exponentials). When a new tile of scores arrives, the output is rescaled using the updated statistics. This allows exact softmax computation one tile at a time with no approximation.
Recomputation. The backward pass needs the attention probability matrix P to compute gradients. Standard implementations save P to HBM (O(N²) memory). FlashAttention saves only the softmax statistics (m, l), which are O(N). In the backward pass, P is recomputed on-chip from Q, K, and the stored statistics. This trades extra FLOPs for eliminated HBM reads, and the trade is favorable because compute is faster than memory on modern GPUs.
The Evolution Across Versions
FlashAttention v2: Warp-Level Work Partitioning
v1 achieved IO optimality but left GPU occupancy on the table. The inner loop over Q blocks was handled by a single thread block, limiting parallelism. FlashAttention-2 (2023) parallelized the attention computation across Q blocks as well, distributing work across thread blocks to increase GPU occupancy. It also reduced non-matmul FLOPs in the online softmax update and improved warp-level work partitioning to reduce shared memory communication. Result: 50-73% of theoretical A100 maximum FLOPs/s, 225 TFLOPs/s per A100 for GPT-style training.
FlashAttention v3: Hopper Asynchrony
H100 (Hopper architecture) introduced WGMMA (Warpgroup Matrix Multiply Accumulate) instructions and TMA (Tensor Memory Accelerator) for asynchronous data movement. FlashAttention-2 achieved only 35% utilization on H100 because it did not use these features. The core bottleneck revealed: the H100 SXM5 has 989 TFLOPs/s of FP16 matmul throughput but only 3.9 TFLOPs/s for exponential (the softmax function). For head dimension 128, exponential can consume 50% of wall-clock time if not overlapped with matmul.
FlashAttention-3 (2024) introduced:
Warp specialization and ping-pong scheduling: producer warpgroups issue TMA loads asynchronously while consumer warpgroups execute WGMMA. Softmax runs on consumer warpgroup 1 while GEMM runs on consumer warpgroup 2. Compute and memory transfer fully overlapped.
Block quantization and incoherent processing: enables FP8 attention with 2.6x lower numerical error than baseline FP8. Scales each block independently rather than globally.
Result: 740 TFLOPs/s FP16 (75% utilization on H100), 1.2 PFLOPs/s FP8. 1.5-2x over FA2.
FlashAttention v4: Blackwell Co-Design
B200 (Blackwell) doubles tensor core throughput relative to H100 while shared memory bandwidth and exponential unit throughput scale more slowly. This asymmetric scaling breaks the assumptions that FA3's pipeline was designed around. FA4 (March 2026) introduces:
Fully asynchronous MMA operations with larger tile sizes, restructured pipelines for the new compute-to-memory ratio
Software-emulated exponential and conditional softmax rescaling to reduce the proportion of non-matmul work
2-CTA MMA mode and tensor memory to reduce shared memory traffic and atomic adds in the backward pass
Entire implementation in CuTe-DSL embedded in Python: 20-30x faster compile times than C++ template-based approaches
Result: 1613 TFLOPs/s BF16 on B200 (71% utilization), 1.3x over cuDNN 9.13, 2.7x over Triton.
The Code
Snippet One: Using FlashAttention in PyTorch (the correct way)
import torch
import torch.nn.functional as F
from flash_attn import flash_attn_func, flash_attn_varlen_func
# Method 1: PyTorch's built-in SDPA (calls flash_attn under the hood on CUDA)
# ← torch.compile detects this and dispatches to FlashAttention kernel
# ← No code change needed vs standard attention — same API, very different memory behavior
def standard_sdpa_attention(q, k, v, causal=True):
# q, k, v: (batch, seqlen, nheads, headdim)
# PyTorch will dispatch to flash_attn when: CUDA device, fp16/bf16, headdim <= 256
with torch.backends.cuda.sdp_kernel(
enable_flash=True, # ← enables FlashAttention backend
enable_math=False, # ← disables naive O(N²) memory implementation
enable_mem_efficient=False
):
# ← THIS is the trick: identical output, linear memory instead of quadratic
out = F.scaled_dot_product_attention(q, k, v, is_causal=causal)
return out
# Method 2: Direct flash_attn call (more control, Dao-AILab repo)
def flash_attention_direct(q, k, v, causal=True, window_size=(-1, -1)):
# q, k, v: (batch, seqlen, nheads, headdim) — must be fp16 or bf16
# ← window_size=(-1,-1) means full attention. Set e.g. (512, 0) for sliding window.
out = flash_attn_func(
q, k, v,
dropout_p=0.0, # set > 0 only during training
softmax_scale=None, # defaults to 1/sqrt(headdim)
causal=causal, # ← causal mask applied efficiently, no materialized mask matrix
window_size=window_size, # ← sliding window attention at no extra memory cost
)
# Output: (batch, seqlen, nheads, headdim)
# Memory: O(N) in sequence length regardless of context window
return out
# Method 3: Variable-length sequences (packed format, eliminates padding waste)
def flash_attention_varlen(q, k, v, cu_seqlens, max_seqlen, causal=True):
# ← THIS is how production serving systems pack multiple sequences into one batch
# cu_seqlens: cumulative sequence lengths, shape (batch+1,)
# Eliminates the memory waste of padding short sequences to max length
# For a batch of [512, 128, 256] length sequences: saves 40% memory vs padding to 512
out = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=causal,
)
return out
# Memory comparison (fp16, batch=1, nheads=32, headdim=128):
# seqlen=4K: standard attention = 4K² × 2 bytes = 32MB | flash_attn = O(N) ≈ 2MB
# seqlen=32K: standard attention = 32K² × 2 bytes = 2GB | flash_attn = O(N) ≈ 16MB
# seqlen=128K: standard attention = OOM | flash_attn = O(N) ≈ 64MB
# ← THIS is why GPT-4 (128K context) and Llama-3 (1M context) exist
The memory numbers tell the whole story. At 32K context, standard attention needs 2GB just for the attention matrix. FlashAttention needs 16MB. Without this difference, long-context LLMs do not fit in GPU memory.
Snippet Two: IO Complexity Analysis and What "Optimal" Actually Means
# The IO complexity of FlashAttention (from arXiv:2402.07443):
# FlashAttention performs O(N²d²/M) HBM accesses
# where N = sequence length, d = head dimension, M = SRAM size
#
# Standard attention performs O(N² + Nd) HBM accesses (always writes the N×N matrix)
#
# Example: A100 GPU, head_dim=128, SRAM=192KB
# M = 192KB / 2 bytes (fp16) = 98,304 elements
# At N=4096, d=128:
# Standard: O(4096² + 4096×128) = O(16,777,216 + 524,288) ≈ 17M reads/writes
# FlashAttn: O(4096² × 128² / 98304) = O(2,684,354,560 / 98304) ≈ 27,306 reads/writes
# ← 615x fewer HBM accesses at this configuration
#
# IS FLASHATTENTION IO-OPTIMAL? (arXiv:2402.07443, Saha and Ye, 2024)
# Yes, for M ≥ d²: The paper proves a LOWER BOUND matching FlashAttention's upper bound.
# No algorithm can compute exact attention with fewer HBM accesses (within constant factors).
# For M < d²: The paper gives a better algorithm AND proves it is also optimal.
#
# What this means in practice:
# FlashAttention is not "good." It is provably optimal for the GPU memory hierarchy.
# Any claimed "faster than FlashAttention" algorithm must either:
# (a) use approximations (not exact attention), or
# (b) exploit hardware features FlashAttention doesn't use yet (like FA3 did with TMA), or
# (c) be wrong.
# ← THIS is the trick: the theoretically proven lower bound is the key insight
# Nobody can beat FlashAttention's IO complexity on standard GPU memory hierarchies
# without changing the hardware or accepting approximation error
# Online softmax formula — this is what makes tiling exact:
def online_softmax_update(m_prev, l_prev, o_prev, s_new_block, v_new_block):
"""
Updates running softmax statistics as new K,V blocks arrive.
m: running row maximum
l: running denominator (sum of exponentials)
o: running output accumulator
"""
m_new = torch.maximum(m_prev, s_new_block.max(dim=-1).values)
# ← THIS is the mathematical trick enabling exact tiled softmax:
# When we see a new block of scores, we rescale the old output by exp(m_prev - m_new)
# This corrects for the fact that our previous normalization used the wrong maximum
exp_diff = torch.exp(m_prev - m_new)
l_new = exp_diff * l_prev + torch.exp(s_new_block - m_new.unsqueeze(-1)).sum(dim=-1)
o_new = (exp_diff.unsqueeze(-1) * l_prev.unsqueeze(-1) * o_prev +
torch.exp(s_new_block - m_new.unsqueeze(-1)).unsqueeze(-1) @ v_new_block) / l_new.unsqueeze(-1)
return m_new, l_new, o_new
# ← Result: mathematically identical to running full softmax over all N keys at once
# No approximation. The online update formula is exact.
The IO optimality proof is the part of FlashAttention the community underreacts to. It is not just fast. It is provably the best possible algorithm for this hardware. The evolutionary improvements in v2-v4 are exploiting new hardware features, not finding new algorithmic improvements over v1's IO strategy.
It In Action: End-to-End Worked Example
Scenario: Measure the concrete impact of FlashAttention across sequence lengths on an A100 80GB GPU with fp16, batch size 1, 32 heads, head dim 128 (GPT-3-style configuration).
Input: Attention layer, sequence lengths 512 / 4K / 32K / 128K.
Step 1: Memory footprint (the critical metric)
seqlen=512: Standard attention S+P matrix = 512² × 2 × 2 = 1MB
FlashAttention stats = 512 × 32 × 2 × 2 = 64KB
Ratio: 16x memory reduction
seqlen=4K: Standard attention S+P matrix = 4096² × 2 × 2 = 64MB
FlashAttention stats = 4096 × 32 × 2 × 2 = 512KB
Ratio: 128x memory reduction
seqlen=32K: Standard attention S+P matrix = 32768² × 2 × 2 = 4GB ← OOM on 80GB A100
FlashAttention stats = 32768 × 32 × 2 × 2 = 4MB
← Standard attention requires the S+P matrix alone to approach full GPU RAM
seqlen=128K: Standard attention S+P matrix = 65GB ← impossible on any single GPU
FlashAttention stats = 16MB
← This context length ONLY EXISTS because of FlashAttention
Step 2: Latency (A100 80GB, fp16, measured)
seqlen=512: Standard ~0.4ms FlashAttention v2 ~0.3ms (1.3x speedup)
seqlen=4K: Standard ~12ms FlashAttention v2 ~2.8ms (4.3x speedup)
seqlen=32K: Standard OOM FlashAttention v2 ~180ms (only option)
seqlen=128K: Standard OOM FlashAttention v2 ~2.8sec (only option)
Step 3: Training throughput (GPT-style, A100)
FlashAttention v1: 15% speedup on BERT-large (seq 512), 3x on GPT-2 (seq 1K)
FlashAttention v2: 225 TFLOPs/s per A100 (72% MFU), 2x over v1
FlashAttention v3: 740 TFLOPs/s FP16 on H100, 75% utilization
FlashAttention v4: 1613 TFLOPs/s BF16 on B200, 71% utilization
Step 4: IO access counts (theoretical, N=4K, d=128, M=192KB on A100)
Standard attention: ~17,300,000 HBM element reads/writes
FlashAttention v1: ~27,300 HBM element reads/writes (proven IO optimal)
Ratio: 635x fewer HBM accesses
This is why FlashAttention is faster despite doing more FLOPs:
it eliminates 635x the memory bottleneck, at the cost of 1.5x the compute.
On GPUs where memory bandwidth (2TB/s A100) is the binding constraint,
this is the correct tradeoff.
Why This Design Works, and What It Trades Away
The core insight is that modern GPUs are not compute-bound for attention. They are memory-bound. An A100 has 312 TFLOPS of FP16 tensor core throughput but only 2TB/s of HBM bandwidth. Writing a single N×N matrix at N=4096 in fp16 requires ~64MB, which at 2TB/s takes ~32 microseconds just for the write. Standard attention does this multiple times per layer. FlashAttention eliminates all of it.
The algorithm trades FLOPs for memory traffic, and this is the correct tradeoff on every GPU manufactured since 2018. The ratio of compute to memory bandwidth (called arithmetic intensity) has been increasing with every generation. H100 has 989 TFLOPS of FP16 compute and 3.35TB/s HBM bandwidth. B200 doubles the compute again. FlashAttention's trade becomes more favorable with every hardware generation, not less.
What FlashAttention trades away:
Debuggability. When something goes wrong in the attention computation, there is no attention matrix to inspect. It was never materialized. Debugging requires either adding visualization hooks that force materialization (defeating the purpose) or relying on the statistics (m, l) to diagnose numerical issues.
Non-CUDA generalization. The core implementation is tightly coupled to CUDA's memory model, TMA, and WGMMA instructions. Triton implementations exist for AMD and other backends, but they consistently lag NVIDIA's implementations in utilization. FlashAttention v4's 2.7x advantage over Triton on B200 is evidence of this.
Causal masking complexity. For causal (autoregressive) masks, approximately half the SRAM work is on masked-out tiles that contribute zero output. Some tiles can be skipped, but the scheduling adds implementation complexity and reduces the clean IO complexity analysis.
Technical Moats
The IO optimality proof (arXiv:2402.07443). FlashAttention's HBM access count is not just low. It is provably optimal. Saha and Ye (2024) prove that for SRAM size M ≥ d², no algorithm can compute exact attention with fewer HBM accesses than FlashAttention, within constant factors. This means the algorithmic improvements in v2, v3, and v4 are not beating v1's IO strategy. They are exploiting new hardware features that the original analysis did not account for (WGMMA, TMA, 2-CTA MMA). The IO strategy itself was optimal in 2022.
Hardware co-design expertise. Each FlashAttention version requires deep knowledge of a specific GPU microarchitecture: A100's shared memory and thread block scheduler for v2, H100's TMA and WGMMA with asynchronous pipelining for v3, B200's 2-CTA MMA mode and asymmetric scaling for v4. FlashAttention v4 is implemented in CuTe-DSL embedded in Python, achieving 20-30x faster compile times than C++ template approaches while maintaining full expressivity. This is not kernel writing. It is compiler-aware microarchitecture programming.
Integration as the default. PyTorch's scaled_dot_product_attention dispatches to FlashAttention when the conditions are met (CUDA, fp16/bf16, headdim ≤ 256). Every framework that calls this gets FlashAttention without knowing it. Displacing this requires not just a faster kernel but a faster kernel that PyTorch's dispatch system prefers over an already-integrated default.
Insights
Insight One: FlashAttention did not make attention faster. It revealed that attention was never compute-bound, and the entire "approximate attention" research direction was solving the wrong problem.
Linformer, Performer, Longformer, BigBird, and dozens of approximate attention methods from 2020-2022 accepted approximation error to reduce quadratic FLOP count. They all failed to achieve wall-clock speedup in practice. FlashAttention demonstrated why: the bottleneck was HBM memory bandwidth, not arithmetic operations. All those approximate methods reduced the wrong quantity. They reduced FLOPs while leaving the memory bottleneck intact. FlashAttention increased FLOPs (due to backward recomputation) while eliminating the memory bottleneck. The result is that exact attention is now faster than approximate attention in practice. The entire approximate attention research direction was obsoleted not by a better approximation, but by the insight that approximation was never necessary.
Insight Two: FlashAttention v4's implementation in CuTe-DSL embedded in Python is more significant than its throughput numbers. It signals that the write-attention-kernels-in-CUDA-C++ era is ending.
FlashAttention v3 was implemented in CUDA C++ with complex template metaprogramming. Compile times were on the order of minutes. FlashAttention v4 is written in CuTe-DSL, a Python-embedded DSL for GPU kernel authorship, achieving 20-30x faster compile times while reaching 1613 TFLOPs/s on B200. This is not a productivity improvement. It is a demonstration that the compiler infrastructure for GPU kernel development has matured to the point where Python-level abstractions can match hand-written CUDA for attention workloads. The implication: future attention optimizations will be written in higher-level languages by teams that cannot write CUDA C++. The barrier to contributing production-quality attention kernels is dropping by a generation.
Takeaway
FlashAttention's backward pass does more work than its forward pass, and is still faster than standard attention's backward pass despite the extra FLOPs.
Standard attention's backward pass reads the saved N×N probability matrix P from HBM to compute gradients, then reads Q, K, V again. FlashAttention's backward pass recomputes P on-chip from Q, K, and the stored statistics (m, l), doing a full second forward pass in blocks. This is more FLOPs. It is faster because it eliminates the massive HBM read of the N×N matrix. On A100, reading 64MB (the P matrix at N=4K) takes ~32 microseconds at 2TB/s bandwidth. The arithmetic to recompute P takes ~8 microseconds on tensor cores. Trading 32μs of memory time for 8μs of compute is a 4x favorable exchange. This is the cleanest possible illustration of why FLOP count is the wrong optimization target on modern GPUs.
TL;DR For Engineers
FlashAttention is IO-aware exact attention. Same mathematical output as standard attention. O(N) memory vs O(N²). Achieved by tiling, online softmax, and backward recomputation. Never writes the N×N attention matrix to HBM.
FlashAttention v1 (IO optimal, 3x speedup on GPT-2) → v2 (better warp partitioning, 225 TFLOPs/s A100, 50-73% utilization) → v3 (Hopper asynchrony with TMA+WGMMA, 740 TFLOPs/s H100, FP8 to 1.2 PFLOPs/s) → v4 (Blackwell co-design, 1613 TFLOPs/s B200, CuTe-DSL in Python).
IO optimality is proved (arXiv:2402.07443): for SRAM size M ≥ d², no algorithm can compute exact attention with fewer HBM accesses. FA v2-v4 improvements come from exploiting new hardware features, not from improving the IO strategy.
Use
torch.nn.functional.scaled_dot_product_attentionwithenable_flash=True. PyTorch dispatches to FlashAttention automatically on CUDA devices with fp16/bf16 and headdim ≤ 256.Every major LLM context length increase since 2022 (4K → 128K → 1M) was enabled by FlashAttention's O(N) memory footprint. Without it, these context lengths do not fit in GPU memory at any achievable batch size.
IO Was Always the Bottleneck. FlashAttention Proved It and Fixed It.
The history of attention optimization before FlashAttention is a history of solving the wrong problem. FLOP count was the measured quantity because it was easy to measure. HBM bandwidth consumption was the actual bottleneck because it was harder to see. FlashAttention made the bottleneck visible and then eliminated it with a proof of optimality attached. The community has spent the four years since applying the same IO-aware reasoning to other GPU bottlenecks: KV cache management (PagedAttention), weight loading (speculative decoding), and now operator-level pipelining (FA3, FA4). The frame is now standard. It started with attention.
References
FlashAttention v1, arXiv:2205.14135, Dao et al., NeurIPS 2022
FlashAttention v2, arXiv:2307.08691, Dao, ICLR 2024
FlashAttention v3, arXiv:2407.08608, Shah, Bikshandi, Zhang, Thakkar, Ramani, Dao, 2024
FlashAttention v4, arXiv:2603.05451, Zadouri, Hoehnerbach, Shah, Liu, Thakkar, Dao, 2026
FlashAttention (v1-v4, Dao et al., 2022-2026) is an IO-aware exact attention algorithm that reduces GPU HBM memory access from O(N²) to O(N²d²/M) by tiling Q/K/V into blocks, computing softmax incrementally via online statistics (m, l), and recomputing the attention probability matrix P during the backward pass rather than storing it. The algorithm is proved IO-optimal for SRAM size M ≥ d² (arXiv:2402.07443). Successive versions (v2: warp partitioning, 225 TFLOPs/s A100; v3: Hopper TMA+WGMMA asynchrony, 740 TFLOPs/s H100; v4: Blackwell co-design, 1613 TFLOPs/s B200) improve utilization by exploiting new hardware features, not by improving the underlying IO strategy. FlashAttention's O(N) memory footprint is the enabling mechanism for all LLM context lengths beyond 4K tokens.
Sponsored Ad
If you enjoy practical AI insights, check out SnackOnAI and support the newsletter by subscribing, sharing, and exploring our sponsored ad — it helps us keep building and delivering value 🚀
