Project 2: Implement Flash Attention (Memory-Efficient Attention)
Project 2: Implement Flash Attention (Memory-Efficient Attention)
Project Metadata
| Attribute | Value |
|---|---|
| Difficulty | Level 5: Master |
| Time Estimate | 3-4 weeks |
| Prerequisites | Project 1, understanding of GPU architecture, PyTorch proficiency |
| Main Programming Language | Python (with CUDA optional) |
| Alternative Languages | CUDA C++, Triton |
| Primary Tool | PyTorch, CUDA (optional), Triton |
| Main Book | โProgramming Massively Parallel Processorsโ by Hwu, Kirk, & Hajj |
| Knowledge Area | GPU Programming / Algorithmic Optimization |
Learning Objectives
By completing this project, you will be able to:
- Explain why attention is memory-bound, not compute-bound, on modern GPUs
- Implement the tiling algorithm that enables O(N) memory usage instead of O(Nยฒ)
- Implement online softmax for incremental computation without full materialization
- Profile and benchmark attention implementations to verify memory and speed improvements
- Understand GPU memory hierarchy (HBM vs SRAM) and its impact on algorithm design
- Explain kernel fusion and why it dramatically improves performance
Theoretical Foundation
The Core Question Youโre Answering
โWhy is the attention mechanism in transformers fundamentally limited by memory rather than computation, and how can we redesign the algorithm to optimize for the actual hardware memory hierarchy?โ
This project forces you to confront:
- Why does O(Nยฒ) memory kill long sequences when we have teraflops of compute?
- Whatโs the difference between compute-bound and memory-bound operations?
- How does GPU memory hierarchy (registers โ SRAM โ HBM โ CPU RAM) determine performance?
- Why does fusing operations provide massive speedups?
- How can an algorithm with MORE FLOPs (Flash Attention) be faster than standard attention?
Most people think GPUs are all about parallel compute. After this project, youโll understand that memory bandwidth is the real bottleneck in modern AI workloads.
Concepts You Must Understand First
1. GPU Memory Hierarchy
Modern GPUs have a dramatic hierarchy of memory:
| Memory Type | Capacity | Bandwidth | Latency |
|---|---|---|---|
| Registers | ~256 KB | Fastest | ~1 cycle |
| L1/SRAM | ~20 MB | ~19 TB/s | ~10 cycles |
| L2 Cache | ~40 MB | ~4 TB/s | ~30 cycles |
| HBM (VRAM) | 40-80 GB | ~1.6 TB/s | ~200 cycles |
| CPU RAM | 64+ GB | ~50 GB/s | ~1000+ cycles |
Key insight: SRAM is 12x faster than HBM, but 2000x smaller. Algorithm design must exploit this.
Book Reference: โProgramming Massively Parallel Processorsโ by Hwu, Kirk, & Hajj - Ch. 5
2. Memory-Bound vs Compute-Bound
Memory-bound: Kernel is limited by data transfer speed (most DL operations)
- Waiting for data to arrive from HBM
- GPU cores idle while memory transfers happen
Compute-bound: Kernel is limited by arithmetic throughput (rare: only large GEMM)
- All data in fast memory, GPU cores fully utilized
Arithmetic Intensity = FLOPs / bytes transferred
For standard attention:
- Load Q, K: 2Nd bytes
- Compute QK^T: 2Nยฒd FLOPs
- Load/store attention matrix: 2Nยฒ bytes
Arithmetic intensity = 2Nยฒd / (2Nd + 2Nยฒ) โ d for large N
With d=64, we compute only 64 FLOPs per byte of memory access. This is very low - attention is severely memory-bound.
Book Reference: โComputer Architecture: A Quantitative Approachโ by Hennessy & Patterson - Ch. 4
3. Tiling (Blocking) Algorithms
Tiling decomposes large computations into smaller โtilesโ that fit in fast memory.
For attention:
- Divide Q into blocks of size B
- Divide K, V into blocks of size B
- For each Q block, iterate through all K, V blocks
- Compute partial attention, accumulate results
- Only final output written to HBM
Memory reduction:
- Standard: O(Nยฒ) for full attention matrix
- Tiled: O(Bยฒ) per tile, never materialized in HBM
- Total: O(Nd) for inputs/outputs only
Book Reference: โProgramming Massively Parallel Processorsโ Ch. 4
4. Online (Incremental) Softmax
Problem: Softmax requires max and sum over the entire row. How can we compute it when we only see K in blocks?
Solution: Maintain running statistics and rescale as new blocks arrive.
Algorithm:
1. Initialize: m = -โ, l = 0, O = 0
2. For each K, V block:
a. Compute scores = Q_block @ K_block^T
b. Find block_max = max(scores)
c. Update global max: m_new = max(m, block_max)
d. Rescale previous: O = O ร exp(m - m_new)
e. Add new contribution: O += exp(scores - m_new) @ V_block
f. Update running sum: l = l ร exp(m - m_new) + sum(exp(scores - m_new))
g. Update m = m_new
3. Final: O = O / l
This is mathematically identical to standard softmax but works incrementally.
Reference: โOnline normalizer calculation for softmaxโ - Milakov & Gimelshein, 2018
5. IO Complexity Analysis
Standard Attention HBM accesses:
- Load Q, K: 2Nd bytes
- Write QK^T: Nยฒ bytes
- Load QK^T for softmax: Nยฒ bytes
- Write softmax result: Nยฒ bytes
- Load softmax + V: Nยฒ + Nd bytes
- Write output: Nd bytes
- Total: O(Nยฒ + Nd) = O(Nยฒ)
Flash Attention HBM accesses:
- Load Q, K, V once each: 3Nd bytes
- Write output once: Nd bytes
- Total: O(Nd)
Reduction: From O(Nยฒ) to O(Nd) - linear in sequence length!
6. Kernel Fusion
Standard approach: Separate CUDA kernels
matmul_kernel(Q, K, scores) // Load Q, K from HBM; write scores to HBM
softmax_kernel(scores, weights) // Load scores from HBM; write weights to HBM
matmul_kernel(weights, V, output) // Load weights, V from HBM; write output to HBM
Each kernel reads from HBM and writes back - maximum memory traffic.
Fused kernel: All operations in one kernel
flash_attention_kernel(Q, K, V, output)
// Load Q, K, V once; compute in SRAM; write output once
Intermediate results (scores, softmax weights) stay in registers/SRAM.
Project Specification
What Youโll Build
An implementation of Flash Attention that:
- Uses tiling to process sequences in blocks
- Implements online softmax for incremental computation
- Achieves O(N) memory instead of O(Nยฒ)
- Produces outputs mathematically identical to standard attention
- Includes comprehensive benchmarks showing memory and speed improvements
Input/Output Specification
def flash_attention(Q, K, V, block_size=256):
"""
Memory-efficient attention using tiling and online softmax.
Args:
Q: Queries, shape (batch, heads, seq_len, d_k)
K: Keys, same shape
V: Values, same shape
block_size: Size of tiles for processing
Returns:
output: Attention output, same shape as Q
(m, l): Softmax statistics for backward pass
"""
pass
Expected Benchmarks
=== Memory Comparison (32 heads, d=64) ===
Sequence Length | Standard Attention | Flash Attention | Reduction
----------------|--------------------| ----------------|----------
2,048 | 3.1 GB | 0.5 GB | 6.2x
4,096 | 12.4 GB | 1.0 GB | 12.4x
8,192 | OOM (>40GB) | 2.0 GB | N/A
16,384 | OOM | 4.1 GB | N/A
32,768 | OOM | 8.3 GB | N/A
65,536 | OOM | 16.8 GB | N/A
=== Maximum Sequence Length (A100 40GB) ===
Standard Attention: 4,096 tokens
Flash Attention: 65,536 tokens (16x improvement)
Solution Architecture
Algorithm Flow
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Flash Attention โ
โ โ
โ Input: Q, K, V each of shape (batch, heads, N, d) โ
โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Outer Loop: For each K, V block (j = 0 to num_blocks) โ โ
โ โ โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ
โ โ โ 1. Load K_block, V_block from HBM to SRAM โ โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ
โ โ โ โ โ
โ โ โผ โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ
โ โ โ 2. Compute scores = Q @ K_block^T โ โ โ
โ โ โ (Block size B ร B, stays in SRAM) โ โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ
โ โ โ โ โ
โ โ โผ โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ
โ โ โ 3. Online Softmax Update โ โ โ
โ โ โ - Compute block_max โ โ โ
โ โ โ - Update global max m โ โ โ
โ โ โ - Rescale previous O by exp(m_old - m) โ โ โ
โ โ โ - Add new: O += softmax(scores) @ V_blockโ โ โ
โ โ โ - Update running sum l โ โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ โ
โ โ โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ
โ Final: O = O / l (normalize) โ
โ Output: O shape (batch, heads, N, d) โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
File Structure
flash_attention/
โโโ flash_attention.py # Core implementation
โ โโโ standard_attention() # Baseline for comparison
โ โโโ flash_attention_forward() # Memory-efficient forward pass
โ โโโ online_softmax_update() # Incremental softmax helper
โโโ benchmarks.py # Performance measurement
โ โโโ benchmark_memory() # Profile GPU memory usage
โ โโโ benchmark_speed() # Measure latency
โ โโโ benchmark_scaling() # Test across sequence lengths
โโโ verification.py # Correctness tests
โ โโโ verify_numerical_equivalence()
โ โโโ stress_test()
โโโ visualizations.py # Generate charts
โ โโโ plot_memory_comparison()
โ โโโ plot_speedup_chart()
โโโ demo.py # Interactive demonstration
โโโ outputs/
โโโ memory_comparison.png
โโโ speedup_chart.png
โโโ benchmark_results.json
Implementation Guide
Phase 1: Profile Standard Attention (Days 1-2)
Goal: Understand the memory bottleneck
import torch
import torch.nn.functional as F
def standard_attention(Q, K, V):
"""Standard attention - O(Nยฒ) memory."""
d_k = Q.size(-1)
# This creates an NxN matrix - memory killer!
scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5)
# Another O(Nยฒ) operation
attn_weights = F.softmax(scores, dim=-1)
output = attn_weights @ V
return output
# Profile memory
def profile_memory(func, *args):
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
result = func(*args)
torch.cuda.synchronize()
peak_memory = torch.cuda.max_memory_allocated() / 1e9
return result, peak_memory
# Test
Q = torch.randn(1, 32, 4096, 64, device='cuda', dtype=torch.float16)
K = torch.randn(1, 32, 4096, 64, device='cuda', dtype=torch.float16)
V = torch.randn(1, 32, 4096, 64, device='cuda', dtype=torch.float16)
_, memory = profile_memory(standard_attention, Q, K, V)
print(f"Standard attention memory: {memory:.2f} GB")
Phase 2: Implement Tiled Attention (Days 3-5)
Goal: Process in blocks, but not yet memory-efficient
def tiled_attention(Q, K, V, block_size=256):
"""Tiled attention - shows the algorithm structure."""
batch, heads, seq_len, d_k = Q.shape
num_blocks = (seq_len + block_size - 1) // block_size
output = torch.zeros_like(Q)
# Outer loop over Q blocks
for i in range(num_blocks):
q_start = i * block_size
q_end = min((i + 1) * block_size, seq_len)
Q_block = Q[:, :, q_start:q_end, :]
block_output = torch.zeros_like(Q_block)
# Inner loop over K, V blocks
for j in range(num_blocks):
k_start = j * block_size
k_end = min((j + 1) * block_size, seq_len)
K_block = K[:, :, k_start:k_end, :]
V_block = V[:, :, k_start:k_end, :]
# Compute attention for this tile
scores = Q_block @ K_block.transpose(-2, -1) / (d_k ** 0.5)
weights = F.softmax(scores, dim=-1)
block_output += weights @ V_block
output[:, :, q_start:q_end, :] = block_output
return output
Note: This version is NOT correct (softmax is computed per-block incorrectly). Phase 3 fixes this.
Phase 3: Implement Online Softmax (Days 6-10)
Goal: Correct incremental softmax computation
def flash_attention_forward(Q, K, V, block_size=256):
"""
Flash Attention with online softmax.
Memory: O(Nd) instead of O(Nยฒ)
"""
batch, heads, seq_len, d_k = Q.shape
device = Q.device
dtype = Q.dtype
num_blocks = (seq_len + block_size - 1) // block_size
# Output accumulator
O = torch.zeros_like(Q)
# Softmax statistics: running max (m) and sum (l)
m = torch.full((batch, heads, seq_len, 1), float('-inf'),
device=device, dtype=dtype)
l = torch.zeros((batch, heads, seq_len, 1), device=device, dtype=dtype)
for j in range(num_blocks): # Iterate over K, V blocks
k_start = j * block_size
k_end = min((j + 1) * block_size, seq_len)
K_block = K[:, :, k_start:k_end, :]
V_block = V[:, :, k_start:k_end, :]
# Compute scores for all Q against this K block
# Shape: (batch, heads, seq_len, block_size)
scores = Q @ K_block.transpose(-2, -1) / (d_k ** 0.5)
# Online softmax update
# Current block max
block_max = scores.max(dim=-1, keepdim=True)[0]
# Update global max
m_new = torch.maximum(m, block_max)
# Rescaling factors
exp_m_diff = torch.exp(m - m_new) # Rescale old
exp_block = torch.exp(scores - m_new) # Current block
# Update running sum
l = l * exp_m_diff + exp_block.sum(dim=-1, keepdim=True)
# Update output: rescale old + add new contribution
O = O * exp_m_diff + exp_block @ V_block
# Update max tracker
m = m_new
# Final normalization
O = O / l
return O, (m, l)
Phase 4: Verification and Benchmarking (Days 11-14)
Goal: Prove correctness and measure improvements
def verify_correctness(seq_len=2048):
"""Verify Flash Attention matches standard attention."""
Q = torch.randn(1, 8, seq_len, 64, device='cuda', dtype=torch.float16)
K = torch.randn(1, 8, seq_len, 64, device='cuda', dtype=torch.float16)
V = torch.randn(1, 8, seq_len, 64, device='cuda', dtype=torch.float16)
out_standard = standard_attention(Q, K, V)
out_flash, _ = flash_attention_forward(Q, K, V)
diff = (out_standard - out_flash).abs()
print(f"Max difference: {diff.max().item():.2e}")
print(f"Mean difference: {diff.mean().item():.2e}")
# Should be < 1e-3 for float16, < 1e-5 for float32
assert diff.max() < 1e-3, "Outputs don't match!"
print("Verification PASSED")
def benchmark_memory_scaling():
"""Compare memory usage across sequence lengths."""
results = {}
for seq_len in [1024, 2048, 4096, 8192]:
Q = torch.randn(1, 32, seq_len, 64, device='cuda', dtype=torch.float16)
K = torch.randn(1, 32, seq_len, 64, device='cuda', dtype=torch.float16)
V = torch.randn(1, 32, seq_len, 64, device='cuda', dtype=torch.float16)
try:
_, std_mem = profile_memory(standard_attention, Q, K, V)
except RuntimeError: # OOM
std_mem = float('inf')
_, flash_mem = profile_memory(flash_attention_forward, Q, K, V)
results[seq_len] = {
'standard': std_mem,
'flash': flash_mem,
'reduction': std_mem / flash_mem if std_mem != float('inf') else 'N/A'
}
print(f"seq_len={seq_len}: Standard={std_mem:.2f}GB, Flash={flash_mem:.2f}GB")
return results
Testing Strategy
Correctness Tests
def test_numerical_equivalence():
"""Flash outputs should match standard within FP tolerance."""
for seq_len in [256, 512, 1024, 2048]:
Q = torch.randn(2, 8, seq_len, 64, device='cuda', dtype=torch.float32)
K = torch.randn(2, 8, seq_len, 64, device='cuda', dtype=torch.float32)
V = torch.randn(2, 8, seq_len, 64, device='cuda', dtype=torch.float32)
out_std = standard_attention(Q, K, V)
out_flash, _ = flash_attention_forward(Q, K, V)
assert torch.allclose(out_std, out_flash, atol=1e-5, rtol=1e-4), \
f"Mismatch at seq_len={seq_len}"
def test_online_softmax_correctness():
"""Verify online softmax produces correct probabilities."""
# Create known scores
scores_full = torch.randn(1, 1, 4, 8) # 4 queries, 8 keys
# Standard softmax
weights_standard = F.softmax(scores_full, dim=-1)
# Simulate online computation (2 blocks of 4)
m = torch.full((1, 1, 4, 1), float('-inf'))
l = torch.zeros((1, 1, 4, 1))
for block_start in [0, 4]:
block = scores_full[:, :, :, block_start:block_start+4]
block_max = block.max(dim=-1, keepdim=True)[0]
m_new = torch.maximum(m, block_max)
exp_diff = torch.exp(m - m_new)
exp_block = torch.exp(block - m_new)
l = l * exp_diff + exp_block.sum(dim=-1, keepdim=True)
m = m_new
# Verify sum matches
assert torch.allclose(l.squeeze(), weights_standard.sum(dim=-1))
def test_memory_reduction():
"""Verify Flash uses less memory than standard."""
Q = torch.randn(1, 32, 4096, 64, device='cuda')
K = torch.randn(1, 32, 4096, 64, device='cuda')
V = torch.randn(1, 32, 4096, 64, device='cuda')
_, std_mem = profile_memory(standard_attention, Q, K, V)
_, flash_mem = profile_memory(flash_attention_forward, Q, K, V)
assert flash_mem < std_mem / 2, "Flash should use at least 2x less memory"
Common Pitfalls
1. Incorrect Online Softmax Rescaling
Problem: Forgetting to rescale previous output when max changes
# Wrong: Not rescaling
O = O + exp_block @ V_block
# Right: Rescale by change in max
O = O * exp_m_diff + exp_block @ V_block
2. Using Normalized Instead of Unnormalized Values
Problem: Storing normalized softmax weights instead of raw exponentials
# Wrong: Can't properly combine across blocks
weights = F.softmax(scores, dim=-1)
O = O + weights @ V_block # Weights from different blocks aren't compatible
# Right: Use unnormalized exponentials, normalize at end
exp_scores = torch.exp(scores - m_new)
O = O + exp_scores @ V_block
# ... later: O = O / l
3. Block Size Too Large
Problem: Block doesnโt fit in SRAM, spills to L2/HBM
# Problematic for large d_k
block_size = 1024 # With d_k=128: 1024 * 128 * 2 = 256KB per tensor
# Safe for most GPUs
block_size = 256 # With d_k=128: 256 * 128 * 2 = 64KB per tensor
4. Forgetting Causal Masking
Problem: Standard causal mask approach doesnโt work with tiling
# Need block-aware masking
def get_block_causal_mask(q_start, q_end, k_start, k_end):
# Only mask within the relevant block region
q_pos = torch.arange(q_start, q_end)
k_pos = torch.arange(k_start, k_end)
mask = q_pos[:, None] < k_pos[None, :]
return mask.float() * -1e9
Extensions and Challenges
Challenge 1: Implement Causal Flash Attention
Add causal masking support:
def flash_attention_causal(Q, K, V, block_size=256):
"""Flash Attention with causal masking."""
# For each Q block at position i:
# - Only process K, V blocks at positions <= i
# - Apply causal mask within the diagonal block
pass
Challenge 2: Implement Backward Pass
Flash Attention saves only O(N) statistics for backward:
def flash_attention_backward(grad_O, Q, K, V, O, m, l):
"""
Backward pass recomputes attention on-the-fly.
Key insight: We saved m and l (O(N) memory), not the
attention matrix (O(Nยฒ) memory). During backward, we
recompute attention scores per-block.
"""
pass
Challenge 3: Triton Implementation
Write a Triton kernel for real speedup:
import triton
import triton.language as tl
@triton.jit
def flash_attention_kernel(
Q_ptr, K_ptr, V_ptr, O_ptr,
seq_len, d_k,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
# Triton handles SRAM allocation automatically
# Focus on the algorithm, not memory management
pass
Challenge 4: Multi-Query Attention (MQA)
Implement Flash Attention for shared K, V across heads:
def flash_mqa(Q, K, V, block_size=256):
"""
Multi-Query Attention: multiple Q heads share one K, V head.
Q: (batch, num_heads, seq, d_k)
K, V: (batch, 1, seq, d_k) # Shared across heads
This further reduces memory by (num_heads)x for K, V.
"""
pass
Interview Questions
Conceptual Questions
-
โWhy is Flash Attention faster even though it does more FLOPs?โ
Answer: โModern GPUs are memory-bound, not compute-bound. Standard attention spends most time moving data between HBM (slow) and compute cores. Flash Attention reduces HBM accesses by 9x through tiling and kernel fusion. Even with 2x more FLOPs for online softmax bookkeeping, the 9x reduction in memory traffic dominates.โ
-
โExplain the memory complexity difference.โ
Answer: โStandard attention materializes the NรN attention matrix: O(Nยฒ) memory. Flash Attention uses tiling to compute attention in blocks of size B. Each tile produces a partial result thatโs accumulated. Total: O(Nd) for Q, K, V, output + O(N) for softmax statistics = O(Nd).โ
-
โHow does online softmax maintain numerical equivalence?โ
Answer: โOnline softmax tracks running max m and sum l. When a new block arrives with larger values, we rescale previous outputs by exp(m_old - m_new). This is mathematically equivalent to computing max/sum over the full row. The final O/l gives the correct softmax-weighted sum.โ
-
โWhat determines optimal block size?โ
Answer: โBlock size trades off: larger blocks = more data reuse but more SRAM needed. Optimal: fit 3 blocks (Q, K, V tiles) in SRAM. For A100 (20MB SRAM), d_k=64, fp16: B=256 uses 3ร256ร64ร2=96KB per tile - fits easily. Too large causes spills to L2, degrading performance.โ
-
โHow does Flash Attention enable 100K+ context windows?โ
Answer: โBy reducing memory from O(Nยฒ) to O(N). For N=100K, d=64: standard needs 100Kยฒ ร 2 bytes โ 20GB just for attention matrix. Flash needs 100K ร 64 ร 2 bytes โ 12MB. Combined with sliding window attention and compression, modern models handle 1M+ tokens.โ
Resources
Primary Reading
| Topic | Book/Paper | Section |
|---|---|---|
| Flash Attention algorithm | โFlashAttention: Fast and Memory-Efficient Exact Attentionโ - Dao et al. | Full paper |
| GPU memory hierarchy | โProgramming Massively Parallel Processorsโ - Hwu et al. | Ch. 5 |
| Tiling algorithms | โProgramming Massively Parallel Processorsโ - Hwu et al. | Ch. 4 |
| Flash Attention 2 | โFlashAttention-2: Faster Attention with Better Parallelismโ - Dao | Full paper |
| Online softmax | โOnline normalizer calculation for softmaxโ - Milakov & Gimelshein | Full paper |
Code References
- Official Flash Attention: https://github.com/Dao-AILab/flash-attention
- Triton Tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
- PyTorch Memory Profiler: https://pytorch.org/docs/stable/cuda.html#memory-management
Interactive Resources
- NVIDIA A100 Whitepaper: Architecture details for memory hierarchy
- Roofline Model Tutorial: Understanding compute vs memory bounds
Self-Assessment Checklist
Before moving to Project 3, verify you can:
- Explain why attention is memory-bound, not compute-bound
- Describe the GPU memory hierarchy (HBM vs SRAM vs registers)
- Implement the online softmax algorithm from scratch
- Calculate IO complexity for standard vs Flash Attention
- Write tiled attention with correct softmax handling
- Benchmark memory usage and verify reductions
- Prove numerical equivalence between implementations
- Explain kernel fusion and its benefits
- Answer all interview questions confidently
Previous Project: P01 - Build Attention from Scratch
Next Project: P03 - Build a Full Transformer