← Back to calendar
Workshop

FlashAttention-4: CuTeDSL-Powered Attention for Hopper & Blackwell GPUs

Install, benchmark, and integrate FlashAttention-4 into your inference and training pipelines
35 min flash-attention cuda performance transformers inference

What's happening

The attention mechanism is the computational bottleneck of every transformer model. Standard self-attention scales quadratically in both time and memory with sequence length, which makes long-context inference and training prohibitively expensive on commodity hardware. FlashAttention, introduced by Tri Dao et al. in 2022, addressed this by restructuring the attention computation to be IO-aware — tiling the operation to exploit the GPU memory hierarchy (SRAM vs. HBM) rather than materializing the full N×N attention matrix.

FlashAttention-2 improved parallelism and work partitioning. FlashAttention-3 targeted Hopper GPUs (H100) specifically, adding FP8 forward pass support and leveraging the Tensor Memory Accelerator (TMA) asynchronous data movement features unique to the Hopper architecture.

Now, FlashAttention-4 arrives with a significant architectural shift: it is written entirely in CuTeDSL, NVIDIA's new domain-specific language for composing CUDA kernels at the CuTE (CUDA Templates) abstraction level. This is consequential for two reasons. First, CuTeDSL enables a single codebase to target both Hopper (H100) and Blackwell (B200) GPUs without maintaining separate kernel implementations. Second, the pip install flash-attn-4 installation path eliminates the notoriously painful compilation step that plagued earlier versions — a change that meaningfully lowers the barrier to adoption.

In this session, we will install the package, construct benchmarking scripts to measure latency and memory across varying sequence lengths, integrate FlashAttention-4 into a transformer training loop, and explore FP8 mixed-precision inference.

1

Understanding the Architecture: IO-Awareness and CuTeDSL

Before installing anything, it is worth understanding why FlashAttention produces identical results to standard attention while using dramatically less memory.

The Memory Hierarchy Problem

A GPU has two relevant tiers of memory: HBM (High Bandwidth Memory) — large but relatively slow — and SRAM (on-chip shared memory) — fast but extremely small. Standard attention computes Q·Kᵀ, writes the full N×N matrix to HBM, reads it back for softmax, writes the result to HBM again, then reads it for the final multiplication with V. Each of these read/write operations is a memory-bound bottleneck.

FlashAttention restructures this computation using a tiling strategy: it processes the attention computation in blocks that fit entirely in SRAM, never materializing the full attention matrix in HBM. The key mathematical insight is an online softmax algorithm that accumulates the softmax normalization statistics incrementally across tiles, producing numerically exact results.

What CuTeDSL Changes

Previous versions of FlashAttention were written in raw CUDA with hand-tuned assembly for specific GPU architectures. CuTeDSL — a Python-embedded DSL built on top of NVIDIA's CuTE (CUDA Templates) library — allows kernel authors to express the same tiling, data movement, and MMA (Matrix Multiply-Accumulate) operations at a higher abstraction level while still generating architecture-specific code. The practical consequence: a single CuTeDSL kernel can be compiled for both Hopper's TMA units and Blackwell's extended tensor cores.

Ask your agent
Ask your agent to explain the memory access pattern difference between standard attention and FlashAttention, including a concrete numerical example of HBM reads/writes for a given sequence length.
Think about it
  • What are the dimensions of the intermediate matrices in standard attention (the Q·Kᵀ product, the softmax output)?
  • If the sequence length is 4096 and the head dimension is 128, how many bytes does the full attention matrix occupy in BF16?
  • How does tiling eliminate the need to store this matrix, and what must the algorithm track instead?
  • What is the difference in total HBM I/O between the two approaches?
What the agent gives back

The agent should produce a clear comparison showing that for sequence length N=4096 and head dimension d=128, standard attention writes an N×N = 16M-element attention matrix to HBM (32 MB in BF16), reads it back for softmax, then writes and reads again — totaling roughly 128 MB of HBM traffic for that matrix alone. FlashAttention, by contrast, processes in SRAM tiles (e.g., 128×128 blocks), never writes the attention matrix to HBM, and tracks only two running statistics per row (the softmax maximum and denominator). Total HBM traffic drops to O(N·d) rather than O(N²). The agent should note that the results are mathematically identical — this is not an approximation.

Tip
FlashAttention is exact attention — it produces bit-identical results to the naive implementation. The optimization is purely in memory access patterns, not in mathematical approximation. This distinguishes it from methods like sparse attention or linear attention, which trade accuracy for speed.
How does online softmax work across tiles?

The standard softmax requires knowing the maximum value across the entire row before computing any exponentials (for numerical stability). The online softmax algorithm processes tiles sequentially, maintaining a running maximum m and a running sum of exponentials l. When a new tile introduces a larger maximum, all previously accumulated exponentials are rescaled by exp(m_old − m_new). This rescaling is exact — no precision is lost. The key insight from Milakov & Gimelshein (2018) is that this online correction can be fused with the matrix multiplication itself, eliminating extra memory passes. FlashAttention extends this to the backward pass as well, recomputing the attention matrix from Q, K, V blocks rather than storing it — trading modest extra computation for dramatic memory savings.

2

Installation and Environment Verification

FlashAttention-4 introduces a much simpler installation path compared to its predecessors. Earlier versions required compiling CUDA kernels from source (often taking 30+ minutes and consuming significant RAM), but FlashAttention-4's CuTeDSL foundation enables distribution as prebuilt wheels.

Prerequisites

The requirements are straightforward: a Hopper or Blackwell GPU (H100, H800, B200, or B200A), CUDA toolkit 12.3 or later (12.8 recommended for best performance), and PyTorch 2.2+. The package is Linux-only at present.

The Installation Itself

The installation is a single pip command: pip install flash-attn-4. This stands in contrast to flash-attn (versions 1–2), which required --no-build-isolation and often the MAX_JOBS environment variable to prevent OOM during compilation.

Once installed, the module is imported as flash_attn.cute — note the .cute submodule, which distinguishes the CuTeDSL implementation from the legacy CUDA path.

Ask your agent
Ask your agent to generate a concise environment-verification script that confirms GPU compatibility, CUDA version, PyTorch version, and successful import of flash_attn.cute — printing a clear pass/fail summary.
Think about it
  • What specific GPU architectures (compute capabilities) correspond to Hopper and Blackwell?
  • How do you query the CUDA toolkit version from within Python, versus the CUDA runtime version PyTorch was compiled against?
  • What should the script report if the GPU is an A100 (Ampere) — is that a soft warning or a hard failure?
  • How can you verify that the `flash_attn_func` callable is actually accessible from the imported module?
What the agent gives back

The agent should produce a short diagnostic script (under 15 lines) that checks torch.cuda.get_device_capability() for sm_90 (Hopper) or sm_100 (Blackwell), verifies the CUDA toolkit version is ≥12.3, confirms PyTorch ≥2.2, attempts to import flash_attn.cute.flash_attn_func, and prints a summary table of pass/fail results. For unsupported GPUs, it should print a clear message explaining which architectures are supported rather than failing silently.

Warning
FlashAttention-4 does not support Ampere GPUs (A100, A10G). If you are running on Ampere hardware, you must use FlashAttention-2 (pip install flash-attn --no-build-isolation). The CuTeDSL kernels rely on Hopper-specific hardware features (TMA, wgmma instructions) that have no Ampere equivalent.
Tip
CUDA 12.8 is strongly recommended even though 12.3 is the minimum. The performance gap between 12.3 and 12.8 on Hopper can be substantial due to compiler improvements in ptxas that better schedule TMA operations.
Why did earlier versions require --no-build-isolation?

FlashAttention 1 and 2 ship as a source distribution that compiles CUDA kernels at install time using PyTorch's JIT compilation infrastructure. The --no-build-isolation flag tells pip to use the current Python environment rather than creating a fresh virtual environment for the build — which is necessary because the build process needs access to the already-installed PyTorch and CUDA headers. Without this flag, pip creates an isolated build environment that lacks these dependencies, causing compilation to fail. FlashAttention-4's switch to CuTeDSL enables ahead-of-time compilation into prebuilt wheels, eliminating this pain point entirely.

At this point, you should have a clear mental model of *why* FlashAttention is fast (IO-aware tiling that avoids materializing the N×N attention matrix in HBM) and *what changed* in version 4 (CuTeDSL enabling a single codebase for Hopper and Blackwell, plus pip-installable wheels). You should also have a diagnostic script ready to verify your environment.

Quick Check

You are benchmarking FlashAttention-4 and notice that at sequence length 512 with head dimension 64, FlashAttention-4 is actually *slower* than a standard PyTorch `F.scaled_dot_product_attention` call. What is the most likely explanation?
✗ Not quite. This is unlikely. The algorithm is correct at all sequence lengths. The issue is performance-economic, not correctness-related.
✓ Correct! Correct. FlashAttention's advantage comes from avoiding HBM traffic for the attention matrix. When N is small enough that the matrix fits comfortably in SRAM (or the HBM traffic is negligible relative to kernel launch overhead), the tiling and online softmax bookkeeping add cost without corresponding savings. PyTorch's built-in SDPA may dispatch to a simpler cuBLAS-based path that has lower overhead at these sizes. FlashAttention's crossover point is typically around N=1024–2048, depending on head dimension and batch size.
✗ Not quite. While PyTorch's SDPA *can* dispatch to FlashAttention-2 as one of its backends, it can also dispatch to a memory-efficient attention implementation or a standard math path. Moreover, `flash-attn-4` is a separate package — PyTorch's built-in SDPA does not use it. The two are distinct code paths.
✗ Not quite. FP8 is an optional forward-pass feature, not the default. Using BF16 or FP16 does not involve any quantization. Moreover, FP8 is designed to be *faster* (not slower) due to higher tensor core throughput — any slowdown would come from quantization overhead, not rounding errors themselves.
3

Benchmarking: Latency and Memory Across Sequence Lengths

The central claim of FlashAttention is that it provides sub-quadratic memory scaling and significant speedups at long sequence lengths. Rather than accepting these claims at face value, we will construct a systematic benchmark.

What to Measure

Two metrics matter: wall-clock latency (measured via CUDA events, not Python time.time(), to avoid CPU-GPU synchronization artifacts) and peak GPU memory (measured via torch.cuda.max_memory_allocated()). We want to sweep across:

  • Sequence lengths: 512, 1024, 2048, 4096, 8192, 16384
  • Batch sizes: 1, 4, 16
  • Head dimension: 128 (the standard for most modern LLMs)
  • Number of heads: 32

The comparison should include three implementations: standard PyTorch F.scaled_dot_product_attention, FlashAttention-2 (if installed), and FlashAttention-4.

Warmup and Statistical Rigor

GPU benchmarking requires care. The first kernel launch incurs JIT compilation and initialization overhead. A proper benchmark discards several warmup iterations, then averages across many timed iterations, reporting both mean and standard deviation.

Ask your agent
Ask your agent to build a benchmarking harness that compares FlashAttention-4 against PyTorch's built-in SDPA across the parameter sweep described above, measuring latency (via CUDA events) and peak memory, and outputting results as a formatted table.
Think about it
  • How should CUDA events be used to time GPU operations accurately? Why is `torch.cuda.synchronize()` important before recording events?
  • What is the correct way to reset peak memory statistics between benchmark runs to get isolated measurements?
  • How many warmup iterations and timed iterations provide a reasonable tradeoff between accuracy and total benchmark runtime?
  • Should the benchmark use `torch.no_grad()` for forward-only measurement? What changes if we also want to benchmark the backward pass?
What the agent gives back

The agent should produce a benchmark script that: (1) creates random Q, K, V tensors for each configuration in BF16 on CUDA, (2) runs 10 warmup iterations followed by 50 timed iterations using CUDA event pairs for each implementation, (3) calls torch.cuda.reset_peak_memory_stats() before each configuration, (4) collects latency mean/std and peak memory, (5) prints a table with columns for sequence length, batch size, implementation, mean latency (ms), std latency, and peak memory (MB). The key structural insight is that CUDA events must bracket only the kernel execution, with a torch.cuda.synchronize() after the final event to ensure all operations complete before reading elapsed time.

Tip
When prompting your agent for benchmarking code, explicitly request CUDA event timing rather than wall-clock timing. Many agents default to time.time() or time.perf_counter(), which include CPU-GPU synchronization delays and produce misleading results for GPU-bound operations.
Expected performance characteristics

At sequence length 512, FlashAttention-4 and SDPA should perform similarly (within ~10%), as the attention matrix is small enough that HBM traffic is not the bottleneck. At 2048, FlashAttention-4 should show a 1.5–2× speedup. At 8192 and above, the speedup grows to 2–4× and the memory advantage becomes dramatic: standard attention requires storing an 8192×8192 matrix (128 MB in BF16 per head), while FlashAttention's memory usage scales linearly. At 16384, standard attention will likely OOM on a single GPU even at batch size 1 with 32 heads, while FlashAttention-4 handles it comfortably. These are the numbers that should motivate the integration work in the next steps.

4

Integration: Drop-In Replacement in a Transformer Training Loop

With performance characteristics established, the next task is integrating FlashAttention-4 into an actual model. The API is designed as a drop-in replacement: the function signature accepts Q, K, V tensors and returns the attention output, with the same semantics as standard attention.

The Interface

The core function is:

```python

from flash_attn.cute import flash_attn_func

out = flash_attn_func(q, k, v, causal=True)

```

The key detail is tensor layout. FlashAttention-4 expects Q, K, V in the shape (batch, seqlen, nheads, headdim) — note that nheads and seqlen are in a different order than some libraries expect (batch, nheads, seqlen, headdim). A transpose or einops rearrangement may be necessary.

Causal vs. Non-Causal

The causal=True flag applies a causal mask (lower-triangular), which is essential for autoregressive language models. For bidirectional models (e.g., BERT-style encoders), set causal=False. The causal variant is slightly faster because it skips computation for the upper-triangular portion of the attention matrix.

Gradient Computation

FlashAttention-4 supports both forward and backward passes in BF16/FP16. The backward pass uses the recomputation strategy: rather than storing the attention matrix from the forward pass (which would defeat the memory savings), it recomputes it from Q, K, V during the backward pass. This trades a modest amount of extra computation (~25% more FLOPs in the backward pass) for dramatic memory savings.

Ask your agent
Ask your agent to modify a standard multi-head attention module to use FlashAttention-4, handling the tensor layout conversion and supporting both causal and non-causal modes.
Think about it
  • What is the expected input shape for `flash_attn_func`, and how does it differ from PyTorch's standard `(batch, nheads, seqlen, headdim)` convention?
  • Should the conversion logic live inside the attention module, or should the model's architecture be refactored to use FlashAttention's native layout throughout?
  • How should the module handle the case where FlashAttention-4 is not installed — should it fall back gracefully to standard attention?
  • What happens to attention dropout during training — does `flash_attn_func` support a dropout parameter?
What the agent gives back

The agent should produce a FlashMultiHeadAttention module (a PyTorch nn.Module) that: (1) accepts Q, K, V in the standard (batch, nheads, seqlen, headdim) format, (2) transposes them to (batch, seqlen, nheads, headdim) for FlashAttention-4, (3) calls flash_attn_func with the causal flag, (4) transposes the output back to the caller's expected format, and (5) includes a fallback to F.scaled_dot_product_attention if flash_attn.cute is not importable. The module should be no more than 20 lines. The agent should note that FlashAttention supports a dropout_p parameter for training-time attention dropout.

Warning
The tensor layout difference between (batch, nheads, seqlen, headdim) and (batch, seqlen, nheads, headdim) is the single most common source of silent correctness bugs when integrating FlashAttention. The operation will run with the wrong layout — tensors are just blocks of memory — but produce incorrect attention outputs. Always verify with a small test case against the reference implementation.
Why does FlashAttention use a different tensor layout?

The (batch, seqlen, nheads, headdim) layout is chosen because it makes the data for each sequence position contiguous in memory. When tiling across the sequence dimension (which is what FlashAttention does), contiguous access patterns maximize memory bandwidth utilization. In the (batch, nheads, seqlen, headdim) layout, accessing a tile of consecutive sequence positions for a single head requires strided memory access, which wastes bandwidth. This is a concrete example of how the algorithm's tiling strategy dictates the optimal data layout — and why systems-level understanding of memory access patterns matters for high-performance ML.

You should now have a working FlashAttention-4 integration inside a multi-head attention module, complete with layout conversion and a fallback path. The benchmarking results from Step 3 should show clear latency and memory advantages at sequence lengths ≥2048.
5

FP8 Forward Pass for Mixed-Precision Inference

FlashAttention-3 introduced FP8 forward pass support on Hopper GPUs, and FlashAttention-4 extends this to Blackwell. FP8 (8-bit floating point) is significant for inference scenarios because it doubles the effective throughput of tensor cores compared to FP16/BF16: each tensor core cycle processes twice as many elements.

FP8 Formats: E4M3 vs. E5M2

Two FP8 formats exist. E4M3 has 4 exponent bits and 3 mantissa bits, providing higher precision but a narrower dynamic range. E5M2 has 5 exponent bits and 2 mantissa bits, offering wider range but lower precision. For attention, E4M3 is generally preferred because the softmax output is bounded between 0 and 1, requiring precision rather than range.

When to Use FP8 Attention

FP8 forward is appropriate for inference and the forward pass of training with FP8-aware training recipes (e.g., combined with loss scaling). It is not appropriate as a standalone change in a BF16 training pipeline — the reduced precision can destabilize training if the rest of the pipeline does not account for it.

The backward pass remains in BF16/FP16 even when the forward pass uses FP8, because gradient computation is more sensitive to precision.

Ask your agent
Ask your agent to create a comparison script that runs FlashAttention-4 forward passes in both BF16 and FP8 (E4M3), measuring throughput (TFLOPS) and maximum absolute error relative to a full-precision reference.
Think about it
  • How do you create FP8 E4M3 tensors in PyTorch? What `dtype` corresponds to this format?
  • What is the appropriate reference for measuring numerical error — FP32 standard attention, or BF16 FlashAttention?
  • How do you compute TFLOPS for attention? What is the FLOP count for a forward attention pass as a function of N, d, and batch size?
  • Should the Q, K, V tensors be cast to FP8 *before* calling `flash_attn_func`, or does the function handle internal casting?
What the agent gives back

The agent should produce a script that: (1) generates reference Q, K, V in FP32 and computes reference attention output, (2) runs FlashAttention-4 in BF16 and measures latency, (3) runs FlashAttention-4 in FP8 E4M3 (torch.float8_e4m3fn) and measures latency, (4) computes max absolute error of both against the FP32 reference, (5) reports throughput in TFLOPS using the formula 2 batch nheads seqlen² headdim / (latency_seconds * 1e12). Expected results: FP8 should show roughly 1.5–2× throughput improvement over BF16, with max absolute error in the range of 1e-2 to 1e-3 (compared to ~1e-3 to 1e-4 for BF16).

Warning
FP8 attention is a forward-only feature. Do not attempt to call .backward() on FP8 attention outputs — the backward pass requires BF16 or FP16 precision. For training, use FP8 forward within a mixed-precision recipe that handles gradient scaling appropriately.
Tip
The FLOP count for attention forward is approximately 4 d per head (two matrix multiplications of size N×d by d×N, each costing 2·N²·d FLOPs). This is the standard formula for computing TFLOPS from measured latency.
FP8 on Blackwell vs. Hopper

Blackwell GPUs (B200) offer roughly 2× the FP8 tensor core throughput of Hopper (H100) — approximately 4.5 PFLOPS vs. 2 PFLOPS for dense FP8 operations. This means the speedup from using FP8 attention on Blackwell is even more pronounced than on Hopper. However, the same precision considerations apply: the attention softmax distribution must remain representable in E4M3's dynamic range. For very long sequences where the attention distribution becomes extremely peaked (high-confidence attention patterns), the precision loss in FP8 can be more noticeable. Monitoring the max absolute error at your target sequence length is essential before deploying FP8 attention in production.

Your Turn

Build an end-to-end throughput benchmark that measures tokens-per-second for a full GPT-2–scale transformer forward pass (12 layers, 12 heads, head dimension 64, vocabulary size 50257) using FlashAttention-4 versus PyTorch SDPA, across sequence lengths 1024, 2048, and 4096.
Individual attention kernel benchmarks (Step 3) tell only part of the story. In a real transformer, attention competes for GPU resources with feedforward layers, layer norms, and embedding lookups. The end-to-end throughput measurement reveals whether FlashAttention-4's kernel-level improvements translate into meaningful model-level speedups — or whether other layers become the new bottleneck.
Think about it
  • What does 'tokens per second' mean in this context — is it `batch_size × seqlen / latency`, or something else?
  • Should the model use the standard PyTorch `nn.TransformerEncoder`, or a custom stack? What are the tradeoffs?
  • How do you isolate the attention implementation as the only variable while keeping everything else constant?
  • What batch size should you use — the largest that fits in memory, or a fixed value across all configurations?
See a sample prompt
One way you could prompt it
Build a PyTorch benchmarking script that measures end-to-end forward-pass throughput (tokens/second) for a GPT-2-sized transformer (12 layers, 12 attention heads, head dim 64, hidden dim 768, vocab 50257). Compare two configurations: one using PyTorch's built-in F.scaled_dot_product_attention, and one using flash_attn.cute.flash_attn_func with causal=True. Sweep sequence lengths [1024, 2048, 4096] with batch size 8. Use CUDA event timing with 5 warmup and 30 timed iterations. The model should be identical in both cases except for the attention function. Report a table with columns: seq_len, implementation, mean_latency_ms, tokens_per_second. Use BF16 autocast and torch.no_grad().

Quick Check

You are deploying FlashAttention-4 in a production inference service for a 7B-parameter language model. The service handles both short prompts (128–512 tokens) and long documents (8K–32K tokens). Which deployment strategy is most appropriate?
✗ Not quite. As discussed in the earlier decision point, FlashAttention-4 incurs overhead at short sequence lengths due to its tiling machinery. For 128-token prompts, standard attention or PyTorch's SDPA (which can dispatch to an optimized short-sequence path) may be faster.
✓ Correct! This is the pragmatic choice. By profiling both implementations on your specific hardware and model to find the crossover point, you get optimal performance across the full range of input lengths. The crossover typically falls between 512 and 2048 tokens depending on head dimension and batch size. PyTorch's own SDPA uses a similar dispatch strategy internally.
✗ Not quite. FP8 introduces measurable precision loss (max absolute error ~1e-2). For a production language model, this can degrade output quality — particularly for tasks requiring precise reasoning or numerical answers. FP8 attention should be validated against quality benchmarks for your specific use case before deployment, not applied as a blanket optimization.
✗ Not quite. If you have Hopper or Blackwell GPUs (which the question implies), FlashAttention-4 provides better performance than FlashAttention-2 due to architecture-specific optimizations. FlashAttention-2 is the right choice only if you need Ampere support.

Recap

In this session, we examined FlashAttention-4 from its theoretical foundations through practical deployment considerations. We began with the IO-awareness principle — the insight that restructuring attention to work within the GPU's SRAM, rather than repeatedly reading and writing the full attention matrix through HBM, yields dramatic performance gains without any mathematical approximation.

We then explored the architectural shift to CuTeDSL, which provides a unified kernel codebase targeting both Hopper and Blackwell GPUs while enabling simpler pip-based installation. We built benchmarking infrastructure using proper GPU timing methodology (CUDA events, warmup iterations, memory reset), integrated FlashAttention-4 into a transformer module with appropriate tensor layout handling, and examined the FP8 forward pass as an inference optimization with well-defined precision tradeoffs.

The overarching lesson extends beyond this specific library: understanding the memory hierarchy of your compute hardware is often more important than algorithmic cleverness. FlashAttention performs exactly the same mathematical operations as standard attention — it simply moves data more efficiently.

Where to go next

  • Explore FlashAttention-4's support for grouped-query attention (GQA) and multi-query attention (MQA), which are used in Llama-2, Mistral, and other modern architectures — the key is the flash_attn_func parameter for specifying K/V head counts.
  • Investigate variable-length sequence batching (the flash_attn_varlen_func variant), which avoids padding waste when processing batches of sequences with different lengths — essential for production inference servers.
  • Profile FlashAttention-4's backward pass performance and memory savings in a full training run, comparing total training throughput (samples/second) against PyTorch's native SDPA on your target model scale.

Sources