Prologue
The Toy Model
The Machine
Machine diagram Machine diagram
Instruction Set
Instruction set table Instruction set
| Instruction | Operation | Cycles |
|---|---|---|
LOAD src, Rd | Rd = DRAM[src] | 1 |
STORE Rs, dst | DRAM[dst] = Rs | 1 |
MAC_Z R0, R1, Rd | Rd = R0 × R1 | 1 |
MAC R0, R1, Ra, Rd | Rd = R0 × R1 + Ra | 1 |
Problem 1: Compute c₀₀ — write the instruction trace Problem 1: Compute c₀₀
Write the instruction trace for :
| Step | Instruction | R0 | R1 | R2 |
|---|---|---|---|---|
| 1 | LOAD a[0][0], R0 | a₀₀ | ||
| 2 | ||||
| 3 | ||||
| 4 | ||||
| 5 | ||||
| 6 | ||||
| 7 | STORE R2, c[0][0] | c₀₀ |
Full trace for c₀₀ Full trace for c₀₀
| Step | Instruction | R0 | R1 | R2 |
|---|---|---|---|---|
| 1 | LOAD a[0][0], R0 | — | — | |
| 2 | LOAD b[0][0], R1 | — | ||
| 3 | MAC_Z R0, R1, R2 | |||
| 4 | LOAD a[0][1], R0 | |||
| 5 | LOAD b[1][0], R1 | |||
| 6 | MAC R0, R1, R2, R2 | |||
| 7 | STORE R2, c[0][0] |
Cost per output element: 4 LOADs + 2 MACs + 1 STORE = 7 cycles
The first MAC uses MAC_Z (multiply only, no accumulate) because there is no prior partial sum. The second uses MAC to accumulate into R2. The partial sum lives in R2 across both MACs — this is why we need three registers: two for operands (overwritten each iteration) and one for the running accumulator.
Problem 2: Full matmul — Locality! Problem 2: Full matmul
Now compute all four elements of . Naturally cycles … or less?
Hint: which elements share operands? Hint
and both use row 0 of . After finishing , R0 holds — if you compute next, you can reuse it.
Discussion: minimum registers and modern GPUs Minimum registers and modern GPUs
This is the smallest machine that can compute a matrix multiplication. The MAC unit needs two operand registers (R0, R1), and any dot product longer than one term needs a third register to hold the partial sum — because loading the next pair of operands overwrites R0 and R1, and there is no way to accumulate without a register (MAC takes a register, not a RAM address, as its addend). Three registers is the minimum.
In modern GPU architectures, the compute unit (e.g., a Tensor Core) sometimes has its own dedicated memory that it manages internally — a small, private buffer separate from the shared register file. In that case, the accumulator lives inside the compute unit itself, and the register file only needs to supply operands. This is a design choice in NVIDIA’s Blackwell architecture: by giving the Tensor Core its own accumulator storage, you reduce pressure on the register file — effectively shrinking the “register file” in our toy model back toward 2 registers, with the third hidden inside the MAC unit.
Enlarging the Register File
Problem 3: Outer product matmul — write the trace with 6 registers Problem 3: Outer product matmul
We double the register file from 3 to 6 registers: how many cycles does the outer product approach take?
Hint: how to count LOADs in outer product Hint
Each outer product loads one column of and one row of , then performs 4 MACs (one per output element). For , use MAC_Z (no prior accumulator). For , use MAC to accumulate into the same registers. Count the LOADs carefully — can you reuse any operands within one outer product?
Full trace: outer product matmul Full trace: outer product matmul
Outer product :
| Step | Instruction | R0 | R1 | R2 | R3 | R4 | R5 |
|---|---|---|---|---|---|---|---|
| 1 | LOAD a[0][0], R0 | — | — | — | — | — | |
| 2 | LOAD b[0][0], R1 | — | — | — | — | ||
| 3 | MAC_Z R0, R1, R2 | — | — | — | |||
| 4 | LOAD b[0][1], R1 | ” | — | — | — | ||
| 5 | MAC_Z R0, R1, R3 | ” | — | — | |||
| 6 | LOAD a[1][0], R0 | " | " | — | — | ||
| 7 | MAC_Z R0, R1, R5 | " | " | — | |||
| 8 | LOAD b[0][0], R1 | " | " | — | “ | ||
| 9 | MAC_Z R0, R1, R4 | " | " | " |
Outer product :
| Step | Instruction | R0 | R1 | R2 | R3 | R4 | R5 |
|---|---|---|---|---|---|---|---|
| 10 | LOAD a[0][1], R0 | " | " | " | " | ||
| 11 | LOAD b[1][0], R1 | " | " | " | " | ||
| 12 | MAC R0, R1, R2, R2 | " | " | " | |||
| 13 | LOAD b[1][1], R1 | " | " | " | " | ||
| 14 | MAC R0, R1, R3, R3 | " | " | " | |||
| 15 | LOAD a[1][1], R0 | " | " | " | " | ||
| 16 | MAC R0, R1, R5, R5 | " | " | " | |||
| 17 | LOAD b[1][0], R1 | " | " | " | " | ||
| 18 | MAC R0, R1, R4, R4 | " | " | " |
Store:
| Step | Instruction |
|---|---|
| 19 | STORE R2, c[0][0] |
| 20 | STORE R3, c[0][1] |
| 21 | STORE R4, c[1][0] |
| 22 | STORE R5, c[1][1] |
Cost: 10 LOADs + 8 MACs + 4 STOREs = 22 cycles with 6 registers
| Inner product | Outer product | |
|---|---|---|
| Registers | 3 | 6 |
| Cycles (full matmul) | 25 | 22 |
| LOADs | 13 | 10 |
More registers → each loaded value gets reused across multiple output elements before being replaced → fewer total LOADs.
Quiz: inner product with 6 registers Quiz
Adding three more registers also accelerates the inner product. With 6 registers, what is the best inner product scheme? How many cycles does it take, and where do the savings come from?
Softmax
The Machine
Machine diagram Machine diagram
Instruction Set
Instruction set table Instruction set
| Instruction | Operation |
|---|---|
LOAD src, Rd | Rd = DRAM[src] |
STORE Rs, dst | DRAM[dst] = Rs |
MAX Ra, Rb, Rd | Rd = max(Ra, Rb) |
SUB Ra, Rb, Rd | Rd = Ra - Rb |
EXP Ra, Rd | Rd = exp(Ra) |
ADD Ra, Rb, Rd | Rd = Ra + Rb |
MUL Ra, Rb, Rd | Rd = Ra × Rb |
DIV Ra, Rb, Rd | Rd = Ra / Rb |
MOV Rs, Rd | Rd = Rs |
Safe Softmax
Problem 4: Safe softmax on the machine Problem 4: Safe softmax
With 3 registers, compute the safe softmax. How many DRAM reads and writes total?
Hint: Pass 1 — Get max Hint: Pass 1
Find
| Instruction | R0 | r1 | r2 |
|---|---|---|---|
LOAD x[0], r1 | — | — | |
LOAD x[1], R0 | — | ||
MAX R0, r1, r1 | — | ||
LOAD x[2], R0 | — | ||
MAX R0, r1, r1 | — | ||
| … | … | … | — |
LOAD x[n-1], R0 | — | ||
MAX R0, r1, r1 | — |
Cost: reads, 0 writes. r1 = .
Hint: Pass 2 — Get sum of exponents Hint: Pass 2
Compute using the known max. The exponent is computed and immediately consumed — never written to DRAM.
| Instruction | R0 | r1 | r2 |
|---|---|---|---|
| (r1 = m from pass 1) | — | — | |
LOAD x[0], R0 | — | ||
SUB R0, r1, R0 | — | ||
EXP R0, R0 | — | ||
MOV R0, r2 | |||
LOAD x[1], R0 | |||
SUB R0, r1, R0 | ” | ||
EXP R0, R0 | ” | ||
ADD r2, R0, r2 | |||
| … | … | … |
Cost: reads, 0 writes. r1 = , r2 = .
Hint: Pass 3 — Normalize Hint: Pass 3
Compute . We recompute (since it was never stored) and divide by .
| Instruction | R0 | r1 | r2 |
|---|---|---|---|
| (r1 = m, r2 = S from pass 2) | — | ||
LOAD x[0], R0 | |||
SUB R0, r1, R0 | |||
EXP R0, R0 | |||
DIV R0, r2, R0 | |||
STORE R0, out[0] | |||
LOAD x[1], R0 | |||
SUB R0, r1, R0 | |||
EXP R0, R0 | |||
DIV R0, r2, R0 | |||
STORE R0, out[1] | |||
| … | … |
Cost: reads, writes. The exponent is recomputed from and — we trade extra compute for avoiding DRAM storage of the intermediate array.
Hint: Total cost Hint: Total cost
| Pass | DRAM reads | DRAM writes |
|---|---|---|
| 1. Get max | 0 | |
| 2. Get sum of exponents | 0 | |
| 3. Normalize | ||
| Total |
No intermediate array in DRAM — the exponents are computed on the fly in both pass 2 and pass 3.
One More Register
Two Paths to the Same Sum
Two paths to the same sum Two paths to the same sum
We want where . Let be the running max after seeing elements.
Path 1: Global max first (two passes). Compute in one pass. Then define — the partial sum using the known global max. The recurrence is:
Path 2: Running max (single pass). Define — the partial sum using the running max . Update both together:
The rescaling factor corrects all previously accumulated terms for the new max. Since , this factor is always — no overflow.
At : , so both definitions of agree: .
This is the online softmax algorithm (Milakov & Gimelshein, 2018).
Numerical safety Numerical safety
Writing the rescaling as is safe because the exponent is . Writing the equivalent is not — alone can overflow before the division happens. Same math, different numerical behavior. The choice of equivalent representation matters.
Nothing new under sun? Nothing new under sun?
The trick you just saw — maintaining a running state and correcting it when a reference point changes — sits at the intersection of statistics (the mean), numerical analysis (overflow avoidance), and algorithms (online computation). No single course teaches it, because each considers it someone else’s job, or too simple to bother with.
Architecture courses assume you know the math. Numerical analysis courses assume you know the systems. Algorithm courses assume the problem is abstract and the numbers are exact. The student is left holding pieces from three different boxes, with nobody having shown them that the pieces fit together.
Online Softmax with Four Registers
Problem 5: Write the instruction trace for online softmax with 4 registers Problem 5: Online softmax trace
Add one register to the machine (4 total: R0, r1, r2, r3). Using the online softmax recurrence from Path 2, write the instruction trace for the steady-state step — processing one new element and updating the running state .
How many DRAM reads and writes does the full softmax take with 4 registers?
Hint: full instruction trace Hint: instruction trace
Save old into r3 before updating r1 with the new max — the rescaling needs both.
Initialization (first element):
| Instruction | R0 | r1 | r2 | r3 |
|---|---|---|---|---|
LOAD x[0], R0 | — | — | — | |
MOV R0, r1 | — | — | ||
SUB R0, r1, R0 | — | — | ||
EXP R0, R0 | — | — | ||
MOV R0, r2 | — |
After init: r1 = , r2 = .
Steady state (each subsequent element ):
| Step | Instruction | R0 | r1 | r2 | r3 |
|---|---|---|---|---|---|
| 1 | LOAD x[i], R0 | xᵢ | m | s | — |
| 2 | MOV r1, r3 | xᵢ | m | s | m |
| 3 | MAX r1, R0, r1 | xᵢ | m' | s | m |
| 4 | SUB r3, r1, r3 | xᵢ | m' | s | m − m' |
| 5 | SUB R0, r1, R0 | xᵢ − m' | m' | s | m − m' |
| 6 | EXP r3, r3 | xᵢ − m' | m' | s | e^(m−m') |
| 7 | EXP R0, R0 | e^(xᵢ−m') | m' | s | e^(m−m') |
| 8 | MUL r2, r3, r2 | e^(xᵢ−m') | m' | s·e^(m−m') | e^(m−m') |
| 9 | ADD r2, R0, r2 | e^(xᵢ−m') | m' | s' | e^(m−m') |
Trace explanation and cost analysis Trace explanation and cost
After each iteration: r1 = , r2 = . After processing all values: r1 = , r2 = .
Step 2 saves the old into r3 before step 3 overwrites r1 with . This is necessary because the rescaling in step 4 needs both the old and new max. After that, r3 is reused for the rescaling factor, and R0 is reused for the new term — neither nor old is needed after step 5.
But we’re not done. This pass only computes and — the denominator. To produce the actual softmax values, we still need a second pass — same as pass 3 in the 3-register case.
Total cost (4 registers, 2 passes):
| Pass | Operation | DRAM reads | DRAM writes |
|---|---|---|---|
| 1. Online softmax | compute and | 0 | |
| 2. Normalize | compute for all | ||
| Total |
Cost Comparison
3 registers vs 4 registers 3 registers vs 4 registers
| Property | 3 registers (3 passes) | 4 registers (2 passes) |
|---|---|---|
| Passes | 3 (max + sum + normalize) | 2 (online + normalize) |
| DRAM reads | ||
| DRAM writes | ||
| Compute per element | simple (SUB, EXP, ADD) | more (MAX, 2×SUB, 2×EXP, MUL, ADD) |
One extra register saves one pass and DRAM reads by merging max-finding into the sum-accumulation. The cost is more compute per step, but for any nontrivial , the DRAM savings dominate.
Intuition and Exploration
Connections to other online algorithms Connections to other online algorithms
Online softmax replaces a global dependency (the max ) with a running state that self-corrects. The same structural pattern appears in other settings:
- Cauchy sequences: convergence defined by terms getting close to each other (local), not to a known limit (global)
- Abel summation: individual terms rewritten as partial sums + differences — the accumulated quantity becomes primary
- Kahan summation: a compensation variable tracks rounding error, correcting the running sum at each step
- Welford’s online variance: a running state corrects for a shifting mean, just as online softmax corrects for a shifting max
See the mathematical foundations appendix for full definitions, derivations, and references.
Problem 6: Other scenarios for the running trick Problem 6: Other running tricks
Online softmax maintains a running state and corrects it when a reference point changes. Can you think of other scenarios — from any domain — where the same pattern applies?
Quiz: Online variance (Welford's algorithm) Quiz
The variance of values requires the global mean , a similar global dependency as softmax’s need for the global max . Using the same principles (maintain a running state, apply corrections when the reference point changes), design an online algorithm that computes the variance in a single pass. What running state do you maintain? What is the correction step when a new element arrives?
Solution Solution
Definitions:
Derivation. Compute :
Using and :
So:
Result:
From Softmax to Attention
The Connection
Attention as softmax + matrix multiply The connection
The scaled dot-product attention is:
The scaling is a precomputed constant — we fold it into ahead of time and write:
Let where is the sequence length and is the head dimension. The naive approach computes this in three steps:
- Compute — an matrix of scores
- Apply row-wise softmax: — an attention matrix
- Multiply: — the output
The written expression suggests we must first compute the softmax, materialize the full attention matrix , and then multiply by . But do we actually need to?
Tracing the Dependencies
What does a single output row depend on? Tracing the dependencies
What does a single output row actually depend on? Decomposing by rows:
where , with and .
Each row’s softmax is independent — row depends only on and all of . This connects directly to our online softmax: the softmax denominator for row is exactly the sum we computed in the previous section.
Expanding the output:
This has the same structure as online softmax, except each term in the numerator carries a vector instead of a scalar. The denominator is exactly the from before. The numerator is a weighted sum of value vectors, with the same exponential weights.
Applying Online Softmax
Problem 7: Derive the update rules for attention Problem 7: Online attention update
The output is a ratio of a vector numerator and a scalar denominator. Both are weighted sums with the same exponential weights we saw in online softmax. Using the same running-max-and-rescale trick from Path 2, derive the update rules for the numerator and denominator when a new element (with value vector ) arrives. What is the running state?
Hint: the rescaling factor is scalar Hint
The rescaling factor is a scalar — it distributes over the vector numerator. The denominator update is identical to online softmax.
Full update rules and running state Full update rules
After processing all keys:
The running state at step is:
| State | Shape | Description |
|---|---|---|
| scalar | running max | |
| scalar | ||
| vector () |
This is storage per row — compared to for a full row of the attention matrix, or for the full matrix.
From Two Passes to One
Fusing softmax with V eliminates a pass From two passes to one
Recall that softmax alone required two passes even with online softmax: one pass to compute and , and a second pass to produce the actual softmax values for each . The second pass existed because we needed to output individual values — each one requires reading again.
But in attention, we don’t need the individual softmax values. We only need their weighted sum with . The running numerator already accumulates this weighted sum during the first pass. After processing all elements, the output is a single division:
No second pass. Each softmax value is produced, multiplied by , accumulated into the numerator, and discarded — it never needs to exist as an individual value.
| Passes | Why | |
|---|---|---|
| Softmax alone (4 registers) | 2 | must output each individually → second pass to read again |
| Attention (softmax fused with ) | 1 | only need the weighted sum → final division at the end, no second pass |
By connecting softmax to the multiplication by , the normalize pass disappears entirely.
The Attention Matrix Is Never Materialized
P never exists in memory P never exists in memory
Each attention score is computed, immediately consumed into the running numerator and denominator, and discarded. The attention matrix never needs to exist in memory.
Intermediate Elimination
The general pattern: fuse production and consumption Intermediate elimination
The pattern here is general: a large intermediate (, size ) is produced only to be immediately consumed by the next operation (multiplication by ). Because each element of is used exactly once in a structured reduction (weighted sum over columns of ), we can fuse production and consumption — compute each , multiply by , accumulate, and discard.
This same principle appears in many domains:
- Kernel fusion (GPU computing): avoid writing intermediates to HBM between operations
- Deforestation (functional programming): eliminate intermediate data structures when producer and consumer can be fused
- Loop fusion (compilers): merge loops that produce and consume the same array
- Our register file model: the attention score goes into a register, gets consumed into the running state, and the register is immediately reused — it never touches DRAM
The condition: if an intermediate is only ever used as part of a contraction/reduction, it does not need to exist as a full object. The softmax-then-multiply pattern in attention satisfies this exactly.
Element is Tile
The Tile Abstraction
Every element can be a tile — zero structural adjustment The tile abstraction
Everything so far is described in terms of individual scalar elements . But nothing in the formulation requires this — every element can be a tile, with zero structural adjustment.
Split a row of length into tiles of size . Tile covers elements . Within each tile, we compute a local state:
- — local tile max
- — local tile denominator
- — local tile numerator
To merge two tiles:
This is exactly the same merge operation as the element-wise online softmax. The rescaling factor is still a scalar, it still distributes over vectors, and the merge is still associative. The formulation is scale-free — the merge only operates on the state and does not care about what is inside each tile.
This is why the toy model was worth building. The scalar version was not a simplified analogy — it is the algorithm, at a different granularity. Going from scalar to tiled is a change of mindset, not a change of structure.
Recall from our toy model that we already introduced this idea for matrix multiplication: the same instruction set works at both the scalar level (MAC on individual floats) and the tiled level (Tensor Core on sub-matrices). The same principle applies here.
At the tiled level, our toy model maps directly to a GPU:
| Toy model | GPU |
|---|---|
| Register file | SRAM (shared memory) |
| DRAM | HBM |
| LOAD / STORE | data movement between HBM and SRAM |
| Compute unit (MAC / ALU) | Tensor Core |
The register file in our model is the fast, small memory close to compute — that is SRAM on a GPU. The DRAM in our model is the large, slow memory — that is HBM. Every cost tradeoff we analyzed (fewer passes, fewer DRAM reads, more registers) translates directly: fewer HBM accesses, more SRAM usage.
Tiling Direction
Tiling across the key/value dimension Tiling direction
We are tiling across a row of the attention matrix — over the key/value dimension. Each tile processes a block of keys and the corresponding values , while keeping one query fixed. This directly determines how and are blocked in memory.
Inner Product vs. Outer Product in Attention
Two loop orderings for attention Inner product vs. outer product
In our toy model, we saw two ways to compute a matrix multiplication: inner product (compute one output element fully, then move to the next) and outer product (load one pair of operands, update all output elements at once). The same choice appears in attention.
If we view the attention computation as conceptually a matrix operation, we can choose which dimension to iterate over in the outer loop:
Inner product style: fix a query, stream through all KV
For each query , stream through all key-value pairs. The running state for one row — (scalar), denominator (scalar), numerator (vector of size ) — stays in fast memory. After processing all KV pairs, is complete. Move to the next query.
- Running state in fast memory: per query
- K and V are reloaded for every query
Outer product style: fix a KV block, update all queries
Load a block of K and V into fast memory. Then stream through queries one at a time, updating each query’s running state with this KV block. After processing all queries, load the next KV block and repeat.
In the element-is-tile view, each “query” is one element (or one tile). Its running state — , denominator, numerator — is , same as the inner product style. The difference is not how much state a single query needs, but how many queries’ outputs must be stored to HBM during the process.
If we can keep only one query’s output in fast memory, we load/store each query’s output once per KV element — stores per query. Each additional output slot we keep in fast memory saves stores for that entry, because it can stay resident across all KV elements. At the extreme, if all outputs fit, each is stored once — and we’ve recovered the inner product schedule.
- Running state per query: — same as inner product
- K and V are loaded once
- Output stores per query: (once per KV element) — unless kept in fast memory across elements
- All query updates within one KV element are independent → parallelizable
Cost Comparison
HBM traffic: inner product vs outer product Cost comparison
| Inner product | Outer product | |
|---|---|---|
| Outer loop | over queries | over KV |
| Inner loop | over KV | over queries |
| Q loads | (once per query) | (every query reloaded per KV) |
| KV loads | (all KV reloaded per query) | (once per KV) |
| O loads | (initialize once per query) | (reloaded per KV) |
| O stores | (once per query) | (stored per KV) |
The cost is not about fast memory capacity per query — it’s the same in both cases. The cost is about which data gets reloaded from HBM: KV (inner product) or output (outer product). Each additional output slot in fast memory saves stores, moving along the spectrum from outer product toward inner product.
This is the flash attention loop ordering question:
- Flash Attention v2 uses the inner product style (outer loop over queries, inner loop over KV). Each query’s output stays in SRAM across all KV blocks — stored once. KV is reloaded for every query block.
- Flash Attention v1 uses the outer product style (outer loop over KV, inner loop over queries). A KV block stays in SRAM. Query outputs are loaded/stored from HBM for each KV block — stores per query. More HBM traffic for outputs, but the inner loop over query blocks is independent and can be parallelized across GPU thread blocks.
The same tradeoff we saw in the matmul — 3 registers with inner product vs. 6 registers with outer product — plays out here at the scale of SRAM and HBM. But at the tiled level, it’s a continuous spectrum: the more output you keep in SRAM, the fewer HBM round-trips for outputs, at the cost of less SRAM available for KV blocks.
Quizzes
Arithmetic Intensity: Inner Product vs. Outer Product
Arithmetic Intensity: Flash Attention
Backward Pass
This section has not been validated and polished. The content below is a draft.
Backward pass (draft) Backward pass (draft)
Why Study the Backward Pass?
In PyTorch, requires_grad=True tells the autograd engine to preserve the intermediate states of the forward pass — so they can be reused when computing gradients. During inference we skip this; during training we keep everything.
But flash attention never materializes the attention matrix . If the intermediate doesn’t exist, what do we differentiate through?
The answer is activation checkpointing: rerun the forward pass to regenerate the intermediates on demand. And here the structure of flash attention pays off twice. The first time through, everything is new — we use the online softmax trick to avoid materializing . The second time through, things are cheaper: we already have the row-wise maxima and denominators stored from the first pass. We don’t recompute them. The computational cost of the recomputation drops significantly.
So is that enough? Run activation checkpointing, recompute block by block during the backward pass, and we’re done?
No. The backward pass itself involves intermediate matrices — and — that are just as large as . If we materialize those, we’ve solved nothing. We need to avoid materializing any matrix during the backward pass too. That is why we need to study the details of what happens in the backward pass — not just that gradients flow through softmax, but how to compute them without ever building the full attention-sized matrices.
But before we dive in — let’s pause and make sure we understand what exactly allowed us to avoid materializing the attention matrix in the forward pass.
Quiz: What is the key ingredient that allows us to avoid materializing the attention matrix in the forward pass? It’s not online softmax — that handles the denominator. What handles the matrix itself?
The answer is contraction: only ever appears inside the sum . Each element is produced, multiplied by , accumulated into the output, and discarded. An intermediate that feeds directly into a linear contraction never needs to be fully materialized.
Now consider the backward pass. The intermediates and play the same role that played in the forward pass — they are large intermediates we want to eliminate. In the forward pass, we could merge all the intermediate steps because flowed directly into a linear contraction with . Can we do the same here?
The chain is: . If we can contract directly with the next step (, ), then does not need to be fully materialized — we need to discover the same contraction pattern again. But the step from to passes through the softmax Jacobian — a nonlinear derivative. This is the critical step that makes the backward pass harder than the forward pass. We will see a clever trick (the identity) that handles this step without ever materializing the full matrices.
The Language of Tensors
A matrix is a tensor — one upper index (row), one lower index (column). The matrix product can be written component-wise:
where is the row, is the column, and is summed over. We can also write this in Einstein notation, where a repeated index — one upper, one lower — implies summation:
No symbol needed. The repeated (lower in , upper in ) tells you to sum over it. This is the same convention used by torch.einsum.
Now consider the derivative of a matrix with respect to another matrix. If depends on , we need to ask: how does each entry change when we perturb each entry ? That requires four indices — two for the output () and two for the input ():
This is a tensor with entries — a Jacobian, but organized as a 4-dimensional object rather than a flattened matrix.
The simplest case: the derivative of a matrix with respect to itself. Each component is an independent variable, so its derivative with respect to is 1 when and , and 0 otherwise. In Kronecker delta notation:
This is the fundamental identity. The two deltas enforce (same row) and (same column).
Derivative of a matrix product. Given , we differentiate with respect to :
does not depend on , so we can pull it out:
Apply the fundamental identity:
The contracts with (setting ):
This says: row of only depends on row of .
Chain rule. For a composition :
The Kronecker delta collapses the sum over , leaving only — we only need derivatives of with respect to row of .
Contracting a delta with a denominator index. When a Kronecker delta contracts with an index that appears in the denominator of a partial derivative, the variance flips. For example, in the expression:
The index is lower in , but taking flips its variance — becomes upper in . The has as lower. Upper meets lower → contraction, setting :
In Euclidean space the component values don’t change when we raise or lower indices, so this is purely bookkeeping — but it tells us which indices get summed.
Matrix-to-Matrix Derivatives
The derivative of a matrix with respect to a matrix is a 4-index object:
This has entries — a tensor. For the attention forward pass , the full Jacobians at each step are 4D tensors that are expensive to compute and store.
The Collapse: (2,2) to (1,1)
In practice, backpropagation never computes the full 4D Jacobians. Instead, given a scalar loss and the upstream gradient (same shape as , a matrix — a tensor), we compute the vector-Jacobian product (VJP):
The upstream gradient contracts with the Jacobian, and the result collapses back to a tensor — a matrix. The 4D Jacobian is never built. Every derivative in the chain can be represented as a matrix, not a 4D tensor.
This is the same intermediate elimination pattern from the forward pass. There, the attention matrix was the large intermediate that was never materialized. Here, the Jacobian is the large intermediate — and it too is never materialized, because it is immediately contracted with the upstream gradient.
The Backward Pass Step by Step
We use for the attention matrix (called in our earlier sections) to match the flash attention paper notation. All -prefixed matrices are gradients of the scalar loss :
| Symbol | Shape | Description |
|---|---|---|
| inputs | ||
| pre-softmax scores | ||
| attention weights | ||
| output | ||
| — upstream gradient (given) | ||
| what we want | ||
| — intermediate | ||
| — intermediate | ||
| scalars | row-wise dot product of and |
Forward was:
Backward reverses this:
Step 1: Through
We have . We want .
Deriving :
First, the Jacobian. Using our derivative-of-a-product result (differentiating with respect to the second factor this time):
Now contract with via the VJP — recall that , which by the chain rule is:
Substituting the Jacobian:
The contracts with the in the denominator (variance flip: is upper in , lower in ), setting :
The repeated is a contraction — this is a matrix product:
In matrix form: .
Deriving :
The Jacobian with respect to the first factor:
Contract with via the VJP:
Substituting the Jacobian:
The contracts with in the denominator (variance flip), setting :
The repeated is a contraction:
In matrix form: .
Extracting row 0.
The only part that depends on row 0 is — the matrix is shared across all rows. So row 0 of is simply row 0 of times .
Step 2: Through
Say .
where — transpose in Einstein notation swaps the indices.
Step 3: Through
Step 4: Through the QKV projections
The inputs , , are all projections of the same input :
Since all three depend on , the gradient flows back through all three paths and sums:
Expanding , , :
By associativity, the projection weights can be folded into the right operand first. The parenthesized products , , and are each — cheap to precompute and not .
But is only the activation gradient. During training, we also need the weight gradients for the parameter update. Since , , :
Each weight gradient is — small. But computing them requires , , explicitly. This conflicts with the associativity trick above: if we fold directly into and discard it, we can’t also use it for .
The resolution: each row of , , can be consumed by two accumulators simultaneously. For example, when row of is produced, it contributes a row to (via ) and a rank-1 update to (via ), then is discarded. The intermediates are still never materialized as full matrices — they’re just consumed twice instead of once before being discarded.
In practice, the flash attention kernel stops at , , . The projection weight gradients and are handled by the framework’s standard linear layer backward pass — there is no intermediate involved in those steps, so no special treatment is needed.
Quiz: Why do we need to compute , , at all? What are they used for, and why can’t we skip them?
Quiz: How much compute is required to compute the QKV projection weight updates (, , ), and how much compute is required to compute ? Express in terms of and .
Execution Schedule
Recall from the forward pass: having an expression is not the same as having an execution schedule. In attention, the expression admits multiple execution orders — and the choice determines whether we materialize or not. The same question applies here.
We want to avoid materializing both and (both ). Since softmax is row-wise, it’s natural to work row by row: compute a row of , then immediately use it.
Say we’ve computed row 0 of . Now consider the two consumers:
: row 0 of times all of gives row 0 of . This is the inner product style — one row of the left matrix updates one row of the output completely. No problem.
: transposing turns row 0 into column 0. A column of the left matrix times a row of doesn’t give a single row of — it gives a rank-1 update to the entire matrix. This is the outer product style: column 0 of (which is row 0 of ) times row 0 of updates all of .
So the two gradients require different computation patterns:
| Gradient | Style | What happens per row of |
|---|---|---|
| inner product | row of all of → completes one row of | |
| outer product | row of (as column) one row of → rank-1 update to all of |
For , we keep the full matrix in memory and accumulate into it as each row of is produced. We don’t load all of — just one row at a time, paired with the corresponding row of .
And as before, element is tile: each element of can be a scalar (one entry) or a tile (a block of entries in the row). The same execution schedule works at both granularities.
Row-wise vs. column-wise. But wait — we chose to iterate row-wise because softmax is row-wise. What if we iterate column-wise instead? Consider column of . We also need , which we haven’t accounted for yet — with row-wise iteration, row of becomes column of , giving another outer product.
Count the inner vs. outer products for both strategies:
Row-wise (iterate over rows of , rows of ):
| Gradient | Style | Why |
|---|---|---|
| inner product | row of → completes one row of | |
| outer product | row of = column of , rank-1 update to all of | |
| outer product | row of = column of , rank-1 update to all of |
Score: 1 inner, 2 outer.
Column-wise (iterate over columns of , columns of ):
| Gradient | Style | Why |
|---|---|---|
| outer product | column of row of → rank-1 update to all of | |
| inner product | column of = row of , times → completes row of | |
| inner product | column of = row of , times → completes row of |
Score: 2 inner, 1 outer.
Column-wise is better balanced. And for the weight gradients, it’s even cleaner: with and produced row-by-row (inner product style), each row immediately gives a rank-1 update to and . Only requires the outer product accumulation.
But softmax is row-wise — can we actually compute column-wise? Yes. Column of is (all queries dotted with key ). Then using the precomputed per-row statistics and . Column of is . Then with precomputed . Everything works — the row-wise structure of softmax is captured in the stored scalars, and the iteration itself can proceed column by column.
Can we push further? The gradients , , are all — not , so less urgent. But for long sequences ( in the millions), even matrices are large. Do we need to materialize them, or can they be consumed immediately too?
Recall that the final target is . Each of , , feeds into a linear contraction with a projection weight matrix — exactly the pattern that allows intermediate elimination.
Term 1: . Inner product style, fully streamable. is precomputed. Row of times produces row of the first contribution to . Produce, consume, discard — is never materialized.
Term 2: . is precomputed. For each row of (which becomes column of ): outer product of column of with row of , accumulated into . Neither nor is materialized.
Term 3: . is precomputed. Now needs columns of — but the outer product decomposition avoids this. Row of (which we already have from the row-wise softmax recomputation) becomes column of . Outer product of column of with row of , accumulated into . We never need a full column of — just one row at a time.
All three terms can be accumulated into a single matrix as we stream row-by-row through and . The intermediates , , are eliminated by the same contraction principle that eliminated in the forward pass.
The Trick: Why It Matters
At the row level, is just a dot product — trivial to compute from the recomputed row of and the row of . So why does flash attention go out of its way to precompute it as ?
The answer is arithmetic intensity — specifically, how many times and are loaded from HBM.
Without the trick, the backward pass for each block requires two loads of :
- Load block into SRAM to recompute (via ). Load block to compute (via ). Compute from the row of and . But needs the full row — so we must finish all blocks in this row before proceeding.
- Load block into SRAM again to compute .
is loaded from HBM twice. The dependency on prevents fusing the recomputation of with the consumption of , because requires which requires the full row.
With the trick, is precomputed from quantities already in hand — no needed. Now everything fuses into a single load per block:
- Load block, block into SRAM
- Recompute block ( is in SRAM)
- Compute block ( is in SRAM)
- Compute block (using precomputed — no waiting)
- Compute contribution ( is still in SRAM)
- Accumulate contribution
- Discard block
and are loaded once instead of twice. The trick removes the last dependency that forced a second load, enabling the recomputation and gradient computation to be fused into a single pass over the data. This doubles the arithmetic intensity of the backward pass.