Project 2: Implement Flash Attention (Memory-Efficient Attention)

Project 2: Implement Flash Attention (Memory-Efficient Attention)

Project Metadata

Attribute Value
Difficulty Level 5: Master
Time Estimate 3-4 weeks
Prerequisites Project 1, understanding of GPU architecture, PyTorch proficiency
Main Programming Language Python (with CUDA optional)
Alternative Languages CUDA C++, Triton
Primary Tool PyTorch, CUDA (optional), Triton
Main Book โ€œProgramming Massively Parallel Processorsโ€ by Hwu, Kirk, & Hajj
Knowledge Area GPU Programming / Algorithmic Optimization

Learning Objectives

By completing this project, you will be able to:

  1. Explain why attention is memory-bound, not compute-bound, on modern GPUs
  2. Implement the tiling algorithm that enables O(N) memory usage instead of O(Nยฒ)
  3. Implement online softmax for incremental computation without full materialization
  4. Profile and benchmark attention implementations to verify memory and speed improvements
  5. Understand GPU memory hierarchy (HBM vs SRAM) and its impact on algorithm design
  6. Explain kernel fusion and why it dramatically improves performance

Theoretical Foundation

The Core Question Youโ€™re Answering

โ€œWhy is the attention mechanism in transformers fundamentally limited by memory rather than computation, and how can we redesign the algorithm to optimize for the actual hardware memory hierarchy?โ€

This project forces you to confront:

  • Why does O(Nยฒ) memory kill long sequences when we have teraflops of compute?
  • Whatโ€™s the difference between compute-bound and memory-bound operations?
  • How does GPU memory hierarchy (registers โ†’ SRAM โ†’ HBM โ†’ CPU RAM) determine performance?
  • Why does fusing operations provide massive speedups?
  • How can an algorithm with MORE FLOPs (Flash Attention) be faster than standard attention?

Most people think GPUs are all about parallel compute. After this project, youโ€™ll understand that memory bandwidth is the real bottleneck in modern AI workloads.

Concepts You Must Understand First

1. GPU Memory Hierarchy

Modern GPUs have a dramatic hierarchy of memory:

Memory Type Capacity Bandwidth Latency
Registers ~256 KB Fastest ~1 cycle
L1/SRAM ~20 MB ~19 TB/s ~10 cycles
L2 Cache ~40 MB ~4 TB/s ~30 cycles
HBM (VRAM) 40-80 GB ~1.6 TB/s ~200 cycles
CPU RAM 64+ GB ~50 GB/s ~1000+ cycles

Key insight: SRAM is 12x faster than HBM, but 2000x smaller. Algorithm design must exploit this.

Book Reference: โ€œProgramming Massively Parallel Processorsโ€ by Hwu, Kirk, & Hajj - Ch. 5

2. Memory-Bound vs Compute-Bound

Memory-bound: Kernel is limited by data transfer speed (most DL operations)

  • Waiting for data to arrive from HBM
  • GPU cores idle while memory transfers happen

Compute-bound: Kernel is limited by arithmetic throughput (rare: only large GEMM)

  • All data in fast memory, GPU cores fully utilized

Arithmetic Intensity = FLOPs / bytes transferred

For standard attention:

  • Load Q, K: 2Nd bytes
  • Compute QK^T: 2Nยฒd FLOPs
  • Load/store attention matrix: 2Nยฒ bytes

Arithmetic intensity = 2Nยฒd / (2Nd + 2Nยฒ) โ‰ˆ d for large N

With d=64, we compute only 64 FLOPs per byte of memory access. This is very low - attention is severely memory-bound.

Book Reference: โ€œComputer Architecture: A Quantitative Approachโ€ by Hennessy & Patterson - Ch. 4

3. Tiling (Blocking) Algorithms

Tiling decomposes large computations into smaller โ€œtilesโ€ that fit in fast memory.

For attention:

  1. Divide Q into blocks of size B
  2. Divide K, V into blocks of size B
  3. For each Q block, iterate through all K, V blocks
  4. Compute partial attention, accumulate results
  5. Only final output written to HBM

Memory reduction:

  • Standard: O(Nยฒ) for full attention matrix
  • Tiled: O(Bยฒ) per tile, never materialized in HBM
  • Total: O(Nd) for inputs/outputs only

Book Reference: โ€œProgramming Massively Parallel Processorsโ€ Ch. 4

4. Online (Incremental) Softmax

Problem: Softmax requires max and sum over the entire row. How can we compute it when we only see K in blocks?

Solution: Maintain running statistics and rescale as new blocks arrive.

Algorithm:
1. Initialize: m = -โˆž, l = 0, O = 0
2. For each K, V block:
   a. Compute scores = Q_block @ K_block^T
   b. Find block_max = max(scores)
   c. Update global max: m_new = max(m, block_max)
   d. Rescale previous: O = O ร— exp(m - m_new)
   e. Add new contribution: O += exp(scores - m_new) @ V_block
   f. Update running sum: l = l ร— exp(m - m_new) + sum(exp(scores - m_new))
   g. Update m = m_new
3. Final: O = O / l

This is mathematically identical to standard softmax but works incrementally.

Reference: โ€œOnline normalizer calculation for softmaxโ€ - Milakov & Gimelshein, 2018

5. IO Complexity Analysis

Standard Attention HBM accesses:

  • Load Q, K: 2Nd bytes
  • Write QK^T: Nยฒ bytes
  • Load QK^T for softmax: Nยฒ bytes
  • Write softmax result: Nยฒ bytes
  • Load softmax + V: Nยฒ + Nd bytes
  • Write output: Nd bytes
  • Total: O(Nยฒ + Nd) = O(Nยฒ)

Flash Attention HBM accesses:

  • Load Q, K, V once each: 3Nd bytes
  • Write output once: Nd bytes
  • Total: O(Nd)

Reduction: From O(Nยฒ) to O(Nd) - linear in sequence length!

6. Kernel Fusion

Standard approach: Separate CUDA kernels

matmul_kernel(Q, K, scores)  // Load Q, K from HBM; write scores to HBM
softmax_kernel(scores, weights)  // Load scores from HBM; write weights to HBM
matmul_kernel(weights, V, output)  // Load weights, V from HBM; write output to HBM

Each kernel reads from HBM and writes back - maximum memory traffic.

Fused kernel: All operations in one kernel

flash_attention_kernel(Q, K, V, output)
// Load Q, K, V once; compute in SRAM; write output once

Intermediate results (scores, softmax weights) stay in registers/SRAM.


Project Specification

What Youโ€™ll Build

An implementation of Flash Attention that:

  1. Uses tiling to process sequences in blocks
  2. Implements online softmax for incremental computation
  3. Achieves O(N) memory instead of O(Nยฒ)
  4. Produces outputs mathematically identical to standard attention
  5. Includes comprehensive benchmarks showing memory and speed improvements

Input/Output Specification

def flash_attention(Q, K, V, block_size=256):
    """
    Memory-efficient attention using tiling and online softmax.

    Args:
        Q: Queries, shape (batch, heads, seq_len, d_k)
        K: Keys, same shape
        V: Values, same shape
        block_size: Size of tiles for processing

    Returns:
        output: Attention output, same shape as Q
        (m, l): Softmax statistics for backward pass
    """
    pass

Expected Benchmarks

=== Memory Comparison (32 heads, d=64) ===

Sequence Length | Standard Attention | Flash Attention | Reduction
----------------|--------------------| ----------------|----------
2,048           | 3.1 GB             | 0.5 GB          | 6.2x
4,096           | 12.4 GB            | 1.0 GB          | 12.4x
8,192           | OOM (>40GB)        | 2.0 GB          | N/A
16,384          | OOM                | 4.1 GB          | N/A
32,768          | OOM                | 8.3 GB          | N/A
65,536          | OOM                | 16.8 GB         | N/A

=== Maximum Sequence Length (A100 40GB) ===
Standard Attention: 4,096 tokens
Flash Attention: 65,536 tokens (16x improvement)

Solution Architecture

Algorithm Flow

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                     Flash Attention                             โ”‚
โ”‚                                                                 โ”‚
โ”‚  Input: Q, K, V each of shape (batch, heads, N, d)             โ”‚
โ”‚                                                                 โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”   โ”‚
โ”‚  โ”‚  Outer Loop: For each K, V block (j = 0 to num_blocks)  โ”‚   โ”‚
โ”‚  โ”‚                                                          โ”‚   โ”‚
โ”‚  โ”‚    โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”     โ”‚   โ”‚
โ”‚  โ”‚    โ”‚  1. Load K_block, V_block from HBM to SRAM   โ”‚     โ”‚   โ”‚
โ”‚  โ”‚    โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜     โ”‚   โ”‚
โ”‚  โ”‚                         โ”‚                                โ”‚   โ”‚
โ”‚  โ”‚                         โ–ผ                                โ”‚   โ”‚
โ”‚  โ”‚    โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”     โ”‚   โ”‚
โ”‚  โ”‚    โ”‚  2. Compute scores = Q @ K_block^T           โ”‚     โ”‚   โ”‚
โ”‚  โ”‚    โ”‚     (Block size B ร— B, stays in SRAM)        โ”‚     โ”‚   โ”‚
โ”‚  โ”‚    โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜     โ”‚   โ”‚
โ”‚  โ”‚                         โ”‚                                โ”‚   โ”‚
โ”‚  โ”‚                         โ–ผ                                โ”‚   โ”‚
โ”‚  โ”‚    โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”     โ”‚   โ”‚
โ”‚  โ”‚    โ”‚  3. Online Softmax Update                    โ”‚     โ”‚   โ”‚
โ”‚  โ”‚    โ”‚     - Compute block_max                      โ”‚     โ”‚   โ”‚
โ”‚  โ”‚    โ”‚     - Update global max m                    โ”‚     โ”‚   โ”‚
โ”‚  โ”‚    โ”‚     - Rescale previous O by exp(m_old - m)   โ”‚     โ”‚   โ”‚
โ”‚  โ”‚    โ”‚     - Add new: O += softmax(scores) @ V_blockโ”‚     โ”‚   โ”‚
โ”‚  โ”‚    โ”‚     - Update running sum l                   โ”‚     โ”‚   โ”‚
โ”‚  โ”‚    โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜     โ”‚   โ”‚
โ”‚  โ”‚                                                          โ”‚   โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜   โ”‚
โ”‚                                                                 โ”‚
โ”‚  Final: O = O / l  (normalize)                                 โ”‚
โ”‚  Output: O shape (batch, heads, N, d)                          โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

File Structure

flash_attention/
โ”œโ”€โ”€ flash_attention.py      # Core implementation
โ”‚   โ”œโ”€โ”€ standard_attention()    # Baseline for comparison
โ”‚   โ”œโ”€โ”€ flash_attention_forward()  # Memory-efficient forward pass
โ”‚   โ””โ”€โ”€ online_softmax_update()    # Incremental softmax helper
โ”œโ”€โ”€ benchmarks.py           # Performance measurement
โ”‚   โ”œโ”€โ”€ benchmark_memory()      # Profile GPU memory usage
โ”‚   โ”œโ”€โ”€ benchmark_speed()       # Measure latency
โ”‚   โ””โ”€โ”€ benchmark_scaling()     # Test across sequence lengths
โ”œโ”€โ”€ verification.py         # Correctness tests
โ”‚   โ”œโ”€โ”€ verify_numerical_equivalence()
โ”‚   โ””โ”€โ”€ stress_test()
โ”œโ”€โ”€ visualizations.py       # Generate charts
โ”‚   โ”œโ”€โ”€ plot_memory_comparison()
โ”‚   โ””โ”€โ”€ plot_speedup_chart()
โ”œโ”€โ”€ demo.py                # Interactive demonstration
โ””โ”€โ”€ outputs/
    โ”œโ”€โ”€ memory_comparison.png
    โ”œโ”€โ”€ speedup_chart.png
    โ””โ”€โ”€ benchmark_results.json

Implementation Guide

Phase 1: Profile Standard Attention (Days 1-2)

Goal: Understand the memory bottleneck

import torch
import torch.nn.functional as F

def standard_attention(Q, K, V):
    """Standard attention - O(Nยฒ) memory."""
    d_k = Q.size(-1)

    # This creates an NxN matrix - memory killer!
    scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5)

    # Another O(Nยฒ) operation
    attn_weights = F.softmax(scores, dim=-1)

    output = attn_weights @ V
    return output

# Profile memory
def profile_memory(func, *args):
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()

    result = func(*args)

    torch.cuda.synchronize()
    peak_memory = torch.cuda.max_memory_allocated() / 1e9
    return result, peak_memory

# Test
Q = torch.randn(1, 32, 4096, 64, device='cuda', dtype=torch.float16)
K = torch.randn(1, 32, 4096, 64, device='cuda', dtype=torch.float16)
V = torch.randn(1, 32, 4096, 64, device='cuda', dtype=torch.float16)

_, memory = profile_memory(standard_attention, Q, K, V)
print(f"Standard attention memory: {memory:.2f} GB")

Phase 2: Implement Tiled Attention (Days 3-5)

Goal: Process in blocks, but not yet memory-efficient

def tiled_attention(Q, K, V, block_size=256):
    """Tiled attention - shows the algorithm structure."""
    batch, heads, seq_len, d_k = Q.shape
    num_blocks = (seq_len + block_size - 1) // block_size

    output = torch.zeros_like(Q)

    # Outer loop over Q blocks
    for i in range(num_blocks):
        q_start = i * block_size
        q_end = min((i + 1) * block_size, seq_len)
        Q_block = Q[:, :, q_start:q_end, :]

        block_output = torch.zeros_like(Q_block)

        # Inner loop over K, V blocks
        for j in range(num_blocks):
            k_start = j * block_size
            k_end = min((j + 1) * block_size, seq_len)
            K_block = K[:, :, k_start:k_end, :]
            V_block = V[:, :, k_start:k_end, :]

            # Compute attention for this tile
            scores = Q_block @ K_block.transpose(-2, -1) / (d_k ** 0.5)
            weights = F.softmax(scores, dim=-1)
            block_output += weights @ V_block

        output[:, :, q_start:q_end, :] = block_output

    return output

Note: This version is NOT correct (softmax is computed per-block incorrectly). Phase 3 fixes this.

Phase 3: Implement Online Softmax (Days 6-10)

Goal: Correct incremental softmax computation

def flash_attention_forward(Q, K, V, block_size=256):
    """
    Flash Attention with online softmax.
    Memory: O(Nd) instead of O(Nยฒ)
    """
    batch, heads, seq_len, d_k = Q.shape
    device = Q.device
    dtype = Q.dtype

    num_blocks = (seq_len + block_size - 1) // block_size

    # Output accumulator
    O = torch.zeros_like(Q)

    # Softmax statistics: running max (m) and sum (l)
    m = torch.full((batch, heads, seq_len, 1), float('-inf'),
                   device=device, dtype=dtype)
    l = torch.zeros((batch, heads, seq_len, 1), device=device, dtype=dtype)

    for j in range(num_blocks):  # Iterate over K, V blocks
        k_start = j * block_size
        k_end = min((j + 1) * block_size, seq_len)
        K_block = K[:, :, k_start:k_end, :]
        V_block = V[:, :, k_start:k_end, :]

        # Compute scores for all Q against this K block
        # Shape: (batch, heads, seq_len, block_size)
        scores = Q @ K_block.transpose(-2, -1) / (d_k ** 0.5)

        # Online softmax update
        # Current block max
        block_max = scores.max(dim=-1, keepdim=True)[0]

        # Update global max
        m_new = torch.maximum(m, block_max)

        # Rescaling factors
        exp_m_diff = torch.exp(m - m_new)  # Rescale old
        exp_block = torch.exp(scores - m_new)  # Current block

        # Update running sum
        l = l * exp_m_diff + exp_block.sum(dim=-1, keepdim=True)

        # Update output: rescale old + add new contribution
        O = O * exp_m_diff + exp_block @ V_block

        # Update max tracker
        m = m_new

    # Final normalization
    O = O / l

    return O, (m, l)

Phase 4: Verification and Benchmarking (Days 11-14)

Goal: Prove correctness and measure improvements

def verify_correctness(seq_len=2048):
    """Verify Flash Attention matches standard attention."""
    Q = torch.randn(1, 8, seq_len, 64, device='cuda', dtype=torch.float16)
    K = torch.randn(1, 8, seq_len, 64, device='cuda', dtype=torch.float16)
    V = torch.randn(1, 8, seq_len, 64, device='cuda', dtype=torch.float16)

    out_standard = standard_attention(Q, K, V)
    out_flash, _ = flash_attention_forward(Q, K, V)

    diff = (out_standard - out_flash).abs()
    print(f"Max difference: {diff.max().item():.2e}")
    print(f"Mean difference: {diff.mean().item():.2e}")

    # Should be < 1e-3 for float16, < 1e-5 for float32
    assert diff.max() < 1e-3, "Outputs don't match!"
    print("Verification PASSED")

def benchmark_memory_scaling():
    """Compare memory usage across sequence lengths."""
    results = {}

    for seq_len in [1024, 2048, 4096, 8192]:
        Q = torch.randn(1, 32, seq_len, 64, device='cuda', dtype=torch.float16)
        K = torch.randn(1, 32, seq_len, 64, device='cuda', dtype=torch.float16)
        V = torch.randn(1, 32, seq_len, 64, device='cuda', dtype=torch.float16)

        try:
            _, std_mem = profile_memory(standard_attention, Q, K, V)
        except RuntimeError:  # OOM
            std_mem = float('inf')

        _, flash_mem = profile_memory(flash_attention_forward, Q, K, V)

        results[seq_len] = {
            'standard': std_mem,
            'flash': flash_mem,
            'reduction': std_mem / flash_mem if std_mem != float('inf') else 'N/A'
        }

        print(f"seq_len={seq_len}: Standard={std_mem:.2f}GB, Flash={flash_mem:.2f}GB")

    return results

Testing Strategy

Correctness Tests

def test_numerical_equivalence():
    """Flash outputs should match standard within FP tolerance."""
    for seq_len in [256, 512, 1024, 2048]:
        Q = torch.randn(2, 8, seq_len, 64, device='cuda', dtype=torch.float32)
        K = torch.randn(2, 8, seq_len, 64, device='cuda', dtype=torch.float32)
        V = torch.randn(2, 8, seq_len, 64, device='cuda', dtype=torch.float32)

        out_std = standard_attention(Q, K, V)
        out_flash, _ = flash_attention_forward(Q, K, V)

        assert torch.allclose(out_std, out_flash, atol=1e-5, rtol=1e-4), \
            f"Mismatch at seq_len={seq_len}"

def test_online_softmax_correctness():
    """Verify online softmax produces correct probabilities."""
    # Create known scores
    scores_full = torch.randn(1, 1, 4, 8)  # 4 queries, 8 keys

    # Standard softmax
    weights_standard = F.softmax(scores_full, dim=-1)

    # Simulate online computation (2 blocks of 4)
    m = torch.full((1, 1, 4, 1), float('-inf'))
    l = torch.zeros((1, 1, 4, 1))

    for block_start in [0, 4]:
        block = scores_full[:, :, :, block_start:block_start+4]
        block_max = block.max(dim=-1, keepdim=True)[0]
        m_new = torch.maximum(m, block_max)

        exp_diff = torch.exp(m - m_new)
        exp_block = torch.exp(block - m_new)

        l = l * exp_diff + exp_block.sum(dim=-1, keepdim=True)
        m = m_new

    # Verify sum matches
    assert torch.allclose(l.squeeze(), weights_standard.sum(dim=-1))

def test_memory_reduction():
    """Verify Flash uses less memory than standard."""
    Q = torch.randn(1, 32, 4096, 64, device='cuda')
    K = torch.randn(1, 32, 4096, 64, device='cuda')
    V = torch.randn(1, 32, 4096, 64, device='cuda')

    _, std_mem = profile_memory(standard_attention, Q, K, V)
    _, flash_mem = profile_memory(flash_attention_forward, Q, K, V)

    assert flash_mem < std_mem / 2, "Flash should use at least 2x less memory"

Common Pitfalls

1. Incorrect Online Softmax Rescaling

Problem: Forgetting to rescale previous output when max changes

# Wrong: Not rescaling
O = O + exp_block @ V_block

# Right: Rescale by change in max
O = O * exp_m_diff + exp_block @ V_block

2. Using Normalized Instead of Unnormalized Values

Problem: Storing normalized softmax weights instead of raw exponentials

# Wrong: Can't properly combine across blocks
weights = F.softmax(scores, dim=-1)
O = O + weights @ V_block  # Weights from different blocks aren't compatible

# Right: Use unnormalized exponentials, normalize at end
exp_scores = torch.exp(scores - m_new)
O = O + exp_scores @ V_block
# ... later: O = O / l

3. Block Size Too Large

Problem: Block doesnโ€™t fit in SRAM, spills to L2/HBM

# Problematic for large d_k
block_size = 1024  # With d_k=128: 1024 * 128 * 2 = 256KB per tensor

# Safe for most GPUs
block_size = 256  # With d_k=128: 256 * 128 * 2 = 64KB per tensor

4. Forgetting Causal Masking

Problem: Standard causal mask approach doesnโ€™t work with tiling

# Need block-aware masking
def get_block_causal_mask(q_start, q_end, k_start, k_end):
    # Only mask within the relevant block region
    q_pos = torch.arange(q_start, q_end)
    k_pos = torch.arange(k_start, k_end)
    mask = q_pos[:, None] < k_pos[None, :]
    return mask.float() * -1e9

Extensions and Challenges

Challenge 1: Implement Causal Flash Attention

Add causal masking support:

def flash_attention_causal(Q, K, V, block_size=256):
    """Flash Attention with causal masking."""
    # For each Q block at position i:
    # - Only process K, V blocks at positions <= i
    # - Apply causal mask within the diagonal block
    pass

Challenge 2: Implement Backward Pass

Flash Attention saves only O(N) statistics for backward:

def flash_attention_backward(grad_O, Q, K, V, O, m, l):
    """
    Backward pass recomputes attention on-the-fly.

    Key insight: We saved m and l (O(N) memory), not the
    attention matrix (O(Nยฒ) memory). During backward, we
    recompute attention scores per-block.
    """
    pass

Challenge 3: Triton Implementation

Write a Triton kernel for real speedup:

import triton
import triton.language as tl

@triton.jit
def flash_attention_kernel(
    Q_ptr, K_ptr, V_ptr, O_ptr,
    seq_len, d_k,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    # Triton handles SRAM allocation automatically
    # Focus on the algorithm, not memory management
    pass

Challenge 4: Multi-Query Attention (MQA)

Implement Flash Attention for shared K, V across heads:

def flash_mqa(Q, K, V, block_size=256):
    """
    Multi-Query Attention: multiple Q heads share one K, V head.

    Q: (batch, num_heads, seq, d_k)
    K, V: (batch, 1, seq, d_k)  # Shared across heads

    This further reduces memory by (num_heads)x for K, V.
    """
    pass

Interview Questions

Conceptual Questions

  1. โ€œWhy is Flash Attention faster even though it does more FLOPs?โ€

    Answer: โ€œModern GPUs are memory-bound, not compute-bound. Standard attention spends most time moving data between HBM (slow) and compute cores. Flash Attention reduces HBM accesses by 9x through tiling and kernel fusion. Even with 2x more FLOPs for online softmax bookkeeping, the 9x reduction in memory traffic dominates.โ€

  2. โ€œExplain the memory complexity difference.โ€

    Answer: โ€œStandard attention materializes the Nร—N attention matrix: O(Nยฒ) memory. Flash Attention uses tiling to compute attention in blocks of size B. Each tile produces a partial result thatโ€™s accumulated. Total: O(Nd) for Q, K, V, output + O(N) for softmax statistics = O(Nd).โ€

  3. โ€œHow does online softmax maintain numerical equivalence?โ€

    Answer: โ€œOnline softmax tracks running max m and sum l. When a new block arrives with larger values, we rescale previous outputs by exp(m_old - m_new). This is mathematically equivalent to computing max/sum over the full row. The final O/l gives the correct softmax-weighted sum.โ€

  4. โ€œWhat determines optimal block size?โ€

    Answer: โ€œBlock size trades off: larger blocks = more data reuse but more SRAM needed. Optimal: fit 3 blocks (Q, K, V tiles) in SRAM. For A100 (20MB SRAM), d_k=64, fp16: B=256 uses 3ร—256ร—64ร—2=96KB per tile - fits easily. Too large causes spills to L2, degrading performance.โ€

  5. โ€œHow does Flash Attention enable 100K+ context windows?โ€

    Answer: โ€œBy reducing memory from O(Nยฒ) to O(N). For N=100K, d=64: standard needs 100Kยฒ ร— 2 bytes โ‰ˆ 20GB just for attention matrix. Flash needs 100K ร— 64 ร— 2 bytes โ‰ˆ 12MB. Combined with sliding window attention and compression, modern models handle 1M+ tokens.โ€


Resources

Primary Reading

Topic Book/Paper Section
Flash Attention algorithm โ€œFlashAttention: Fast and Memory-Efficient Exact Attentionโ€ - Dao et al. Full paper
GPU memory hierarchy โ€œProgramming Massively Parallel Processorsโ€ - Hwu et al. Ch. 5
Tiling algorithms โ€œProgramming Massively Parallel Processorsโ€ - Hwu et al. Ch. 4
Flash Attention 2 โ€œFlashAttention-2: Faster Attention with Better Parallelismโ€ - Dao Full paper
Online softmax โ€œOnline normalizer calculation for softmaxโ€ - Milakov & Gimelshein Full paper

Code References

  • Official Flash Attention: https://github.com/Dao-AILab/flash-attention
  • Triton Tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
  • PyTorch Memory Profiler: https://pytorch.org/docs/stable/cuda.html#memory-management

Interactive Resources

  • NVIDIA A100 Whitepaper: Architecture details for memory hierarchy
  • Roofline Model Tutorial: Understanding compute vs memory bounds

Self-Assessment Checklist

Before moving to Project 3, verify you can:

  • Explain why attention is memory-bound, not compute-bound
  • Describe the GPU memory hierarchy (HBM vs SRAM vs registers)
  • Implement the online softmax algorithm from scratch
  • Calculate IO complexity for standard vs Flash Attention
  • Write tiled attention with correct softmax handling
  • Benchmark memory usage and verify reductions
  • Prove numerical equivalence between implementations
  • Explain kernel fusion and its benefits
  • Answer all interview questions confidently

Previous Project: P01 - Build Attention from Scratch

Next Project: P03 - Build a Full Transformer