Flash attention

Prologue

The Toy Model

The Machine

Machine diagram Machine diagram
RAMMatrix Aa₀₀a₀₁a₁₀a₁₁Matrix Bb₀₀b₀₁b₁₀b₁₁LOADSTORERegister FileR0R1R2ComputeMACa × b + c

Instruction Set

Instruction set table Instruction set
InstructionOperationCycles
LOAD src, RdRd = DRAM[src]1
STORE Rs, dstDRAM[dst] = Rs1
MAC_Z R0, R1, RdRd = R0 × R11
MAC R0, R1, Ra, RdRd = R0 × R1 + Ra1
Problem 1: Compute c₀₀ — write the instruction trace Problem 1: Compute c₀₀
A,BR2×2C=A×Bcij=kaikbkjA, B \in \mathbb{R}^{2 \times 2} \qquad C = A \times B \qquad c_{ij} = \sum_k a_{ik} \cdot b_{kj}

Write the instruction trace for c00=a00b00+a01b10c_{00} = a_{00} b_{00} + a_{01} b_{10}:

StepInstructionR0R1R2
1LOAD a[0][0], R0a₀₀
2
3
4
5
6
7STORE R2, c[0][0]c₀₀
Full trace for c₀₀ Full trace for c₀₀
StepInstructionR0R1R2
1LOAD a[0][0], R0a00a_{00}
2LOAD b[0][0], R1a00a_{00}b00b_{00}
3MAC_Z R0, R1, R2a00a_{00}b00b_{00}a00b00a_{00} b_{00}
4LOAD a[0][1], R0a01a_{01}b00b_{00}a00b00a_{00} b_{00}
5LOAD b[1][0], R1a01a_{01}b10b_{10}a00b00a_{00} b_{00}
6MAC R0, R1, R2, R2a01a_{01}b10b_{10}c00c_{00}
7STORE R2, c[0][0]a01a_{01}b10b_{10}c00c_{00}

Cost per output element: 4 LOADs + 2 MACs + 1 STORE = 7 cycles

The first MAC uses MAC_Z (multiply only, no accumulate) because there is no prior partial sum. The second uses MAC to accumulate into R2. The partial sum lives in R2 across both MACs — this is why we need three registers: two for operands (overwritten each iteration) and one for the running accumulator.

Problem 2: Full matmul — Locality! Problem 2: Full matmul

Now compute all four elements of CC. Naturally 4×7=284 \times 7 = 28 cycles … or less?

Hint: which elements share operands? Hint

c00c_{00} and c01c_{01} both use row 0 of AA. After finishing c00c_{00}, R0 holds a01a_{01} — if you compute c01c_{01} next, you can reuse it.

Discussion: minimum registers and modern GPUs Minimum registers and modern GPUs

This is the smallest machine that can compute a matrix multiplication. The MAC unit needs two operand registers (R0, R1), and any dot product longer than one term needs a third register to hold the partial sum — because loading the next pair of operands overwrites R0 and R1, and there is no way to accumulate without a register (MAC takes a register, not a RAM address, as its addend). Three registers is the minimum.

In modern GPU architectures, the compute unit (e.g., a Tensor Core) sometimes has its own dedicated memory that it manages internally — a small, private buffer separate from the shared register file. In that case, the accumulator lives inside the compute unit itself, and the register file only needs to supply operands. This is a design choice in NVIDIA’s Blackwell architecture: by giving the Tensor Core its own accumulator storage, you reduce pressure on the register file — effectively shrinking the “register file” in our toy model back toward 2 registers, with the third hidden inside the MAC unit.

Enlarging the Register File

Problem 3: Outer product matmul — write the trace with 6 registers Problem 3: Outer product matmul
C=[a00a10][b00b01]k=0+[a01a11][b10b11]k=1C = \underbrace{\begin{bmatrix} a_{00} \\ a_{10} \end{bmatrix} \begin{bmatrix} b_{00} & b_{01} \end{bmatrix}}_{k=0} + \underbrace{\begin{bmatrix} a_{01} \\ a_{11} \end{bmatrix} \begin{bmatrix} b_{10} & b_{11} \end{bmatrix}}_{k=1}

We double the register file from 3 to 6 registers: how many cycles does the outer product approach take?

Hint: how to count LOADs in outer product Hint

Each outer product kk loads one column of AA and one row of BB, then performs 4 MACs (one per output element). For k=0k=0, use MAC_Z (no prior accumulator). For k=1k=1, use MAC to accumulate into the same registers. Count the LOADs carefully — can you reuse any operands within one outer product?

Full trace: outer product matmul Full trace: outer product matmul

Outer product k=0k = 0:

StepInstructionR0R1R2R3R4R5
1LOAD a[0][0], R0a00a_{00}
2LOAD b[0][0], R1a00a_{00}b00b_{00}
3MAC_Z R0, R1, R2a00a_{00}b00b_{00}a00b00a_{00}b_{00}
4LOAD b[0][1], R1a00a_{00}b01b_{01}
5MAC_Z R0, R1, R3a00a_{00}b01b_{01}a00b01a_{00}b_{01}
6LOAD a[1][0], R0a10a_{10}b01b_{01}""
7MAC_Z R0, R1, R5a10a_{10}b01b_{01}""a10b01a_{10}b_{01}
8LOAD b[0][0], R1a10a_{10}b00b_{00}""
9MAC_Z R0, R1, R4a10a_{10}b00b_{00}""a10b00a_{10}b_{00}"

Outer product k=1k = 1:

StepInstructionR0R1R2R3R4R5
10LOAD a[0][1], R0a01a_{01}b00b_{00}""""
11LOAD b[1][0], R1a01a_{01}b10b_{10}""""
12MAC R0, R1, R2, R2a01a_{01}b10b_{10}c00c_{00}"""
13LOAD b[1][1], R1a01a_{01}b11b_{11}""""
14MAC R0, R1, R3, R3a01a_{01}b11b_{11}"c01c_{01}""
15LOAD a[1][1], R0a11a_{11}b11b_{11}""""
16MAC R0, R1, R5, R5a11a_{11}b11b_{11}"""c11c_{11}
17LOAD b[1][0], R1a11a_{11}b10b_{10}""""
18MAC R0, R1, R4, R4a11a_{11}b10b_{10}""c10c_{10}"

Store:

StepInstruction
19STORE R2, c[0][0]
20STORE R3, c[0][1]
21STORE R4, c[1][0]
22STORE R5, c[1][1]

Cost: 10 LOADs + 8 MACs + 4 STOREs = 22 cycles with 6 registers

Inner productOuter product
Registers36
Cycles (full matmul)2522
LOADs1310

More registers → each loaded value gets reused across multiple output elements before being replaced → fewer total LOADs.

Quiz: inner product with 6 registers Quiz

Adding three more registers also accelerates the inner product. With 6 registers, what is the best inner product scheme? How many cycles does it take, and where do the savings come from?

Softmax

The Machine

Machine diagram Machine diagram
DRAMx[0..n−1]x₀x₁x₂out[0..n−1]o₀o₁o₂LOADSTORERegister FileR0r1r2ComputeALUMAX SUBEXP ADDMUL DIV

Instruction Set

Instruction set table Instruction set
InstructionOperation
LOAD src, RdRd = DRAM[src]
STORE Rs, dstDRAM[dst] = Rs
MAX Ra, Rb, RdRd = max(Ra, Rb)
SUB Ra, Rb, RdRd = Ra - Rb
EXP Ra, RdRd = exp(Ra)
ADD Ra, Rb, RdRd = Ra + Rb
MUL Ra, Rb, RdRd = Ra × Rb
DIV Ra, Rb, RdRd = Ra / Rb
MOV Rs, RdRd = Rs

Safe Softmax

softmax(xi)=eximj=0n1exjmwherem=maxjxj\text{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_{j=0}^{n-1} e^{x_j - m}} \quad \text{where} \quad m = \max_j x_j
Problem 4: Safe softmax on the machine Problem 4: Safe softmax

With 3 registers, compute the safe softmax. How many DRAM reads and writes total?

Hint: Pass 1 — Get max Hint: Pass 1

Find m=max(x0,,xn1)m = \max(x_0, \ldots, x_{n-1})

InstructionR0r1r2
LOAD x[0], r1x0x_0
LOAD x[1], R0x1x_1x0x_0
MAX R0, r1, r1x1x_1max(x0,x1)\max(x_0, x_1)
LOAD x[2], R0x2x_2max(x0,x1)\max(x_0, x_1)
MAX R0, r1, r1x2x_2max(x0,x1,x2)\max(x_0, x_1, x_2)
LOAD x[n-1], R0xn1x_{n-1}max(x0..n2)\max(x_{0..n{-}2})
MAX R0, r1, r1xn1x_{n-1}mm

Cost: nn reads, 0 writes. r1 = mm.

Hint: Pass 2 — Get sum of exponents Hint: Pass 2

Compute S=eximS = \sum e^{x_i - m} using the known max. The exponent is computed and immediately consumed — never written to DRAM.

InstructionR0r1r2
(r1 = m from pass 1)mm
LOAD x[0], R0x0x_0mm
SUB R0, r1, R0x0mx_0 - mmm
EXP R0, R0ex0me^{x_0 - m}mm
MOV R0, r2ex0me^{x_0 - m}mmex0me^{x_0 - m}
LOAD x[1], R0x1x_1mmex0me^{x_0 - m}
SUB R0, r1, R0x1mx_1 - mmm
EXP R0, R0ex1me^{x_1 - m}mm
ADD r2, R0, r2ex1me^{x_1 - m}mmex0m+ex1me^{x_0-m} + e^{x_1-m}
mm

Cost: nn reads, 0 writes. r1 = mm, r2 = SS.

Hint: Pass 3 — Normalize Hint: Pass 3

Compute softmax(xi)=exim/S\text{softmax}(x_i) = e^{x_i - m} / S. We recompute exime^{x_i - m} (since it was never stored) and divide by SS.

InstructionR0r1r2
(r1 = m, r2 = S from pass 2)mmSS
LOAD x[0], R0x0x_0mmSS
SUB R0, r1, R0x0mx_0 - mmmSS
EXP R0, R0ex0me^{x_0-m}mmSS
DIV R0, r2, R0ex0m/Se^{x_0-m}/SmmSS
STORE R0, out[0]ex0m/Se^{x_0-m}/SmmSS
LOAD x[1], R0x1x_1mmSS
SUB R0, r1, R0x1mx_1 - mmmSS
EXP R0, R0ex1me^{x_1-m}mmSS
DIV R0, r2, R0ex1m/Se^{x_1-m}/SmmSS
STORE R0, out[1]ex1m/Se^{x_1-m}/SmmSS
mmSS

Cost: nn reads, nn writes. The exponent exime^{x_i - m} is recomputed from xix_i and mm — we trade extra compute for avoiding DRAM storage of the intermediate array.

Hint: Total cost Hint: Total cost
PassDRAM readsDRAM writes
1. Get maxnn0
2. Get sum of exponentsnn0
3. Normalizennnn
Total3n3nnn

No intermediate array e0,,en1e_0, \ldots, e_{n-1} in DRAM — the exponents are computed on the fly in both pass 2 and pass 3.

One More Register

Two Paths to the Same Sum

Two paths to the same sum Two paths to the same sum

We want S=i=0n1eximS = \sum_{i=0}^{n-1} e^{x_i - m} where m=maxjxjm = \max_j x_j. Let mk=max(x0,,xk1)m_k = \max(x_0, \ldots, x_{k-1}) be the running max after seeing kk elements.

Path 1: Global max first (two passes). Compute m=mnm = m_n in one pass. Then define sk=j=0k1exjms_k = \sum_{j=0}^{k-1} e^{x_j - m} — the partial sum using the known global max. The recurrence is:

sk=sk1+exk1ms_k = s_{k-1} + e^{x_{k-1} - m}

Path 2: Running max (single pass). Define sk=j=0k1exjmks_k = \sum_{j=0}^{k-1} e^{x_j - m_k} — the partial sum using the running max mkm_k. Update both together:

mk=max(mk1,  xk1)m_k = \max(m_{k-1},\; x_{k-1})sk=sk1emk1mk+exk1mks_k = s_{k-1} \cdot e^{m_{k-1} - m_k} + e^{x_{k-1} - m_k}

The rescaling factor emk1mke^{m_{k-1} - m_k} corrects all previously accumulated terms for the new max. Since mkmk1m_k \geq m_{k-1}, this factor is always 1\leq 1 — no overflow.

At k=nk = n: mn=mm_n = m, so both definitions of sks_k agree: sn=Ss_n = S.

This is the online softmax algorithm (Milakov & Gimelshein, 2018).

Numerical safety Numerical safety

Writing the rescaling as emk1mke^{m_{k-1} - m_k} is safe because the exponent is 0\leq 0. Writing the equivalent emk1emk\frac{e^{m_{k-1}}}{e^{m_k}} is notemk1e^{m_{k-1}} alone can overflow before the division happens. Same math, different numerical behavior. The choice of equivalent representation matters.

Nothing new under sun? Nothing new under sun?

The trick you just saw — maintaining a running state and correcting it when a reference point changes — sits at the intersection of statistics (the mean), numerical analysis (overflow avoidance), and algorithms (online computation). No single course teaches it, because each considers it someone else’s job, or too simple to bother with.

Architecture courses assume you know the math. Numerical analysis courses assume you know the systems. Algorithm courses assume the problem is abstract and the numbers are exact. The student is left holding pieces from three different boxes, with nobody having shown them that the pieces fit together.

Online Softmax with Four Registers

Problem 5: Write the instruction trace for online softmax with 4 registers Problem 5: Online softmax trace

Add one register to the machine (4 total: R0, r1, r2, r3). Using the online softmax recurrence from Path 2, write the instruction trace for the steady-state step — processing one new element xix_i and updating the running state (m,s)(m, s).

How many DRAM reads and writes does the full softmax take with 4 registers?

Hint: full instruction trace Hint: instruction trace

Save old mm into r3 before updating r1 with the new max — the rescaling needs both.

Initialization (first element):

InstructionR0r1r2r3
LOAD x[0], R0x0x_0
MOV R0, r1x0x_0x0x_0
SUB R0, r1, R000x0x_0
EXP R0, R011x0x_0
MOV R0, r211x0x_011

After init: r1 = m=x0m = x_0, r2 = s=1s = 1.

Steady state (each subsequent element xix_i):

StepInstructionR0r1r2r3
1LOAD x[i], R0xᵢms
2MOV r1, r3xᵢmsm
3MAX r1, R0, r1xᵢm'sm
4SUB r3, r1, r3xᵢm'sm − m'
5SUB R0, r1, R0xᵢ − m'm'sm − m'
6EXP r3, r3xᵢ − m'm'se^(m−m')
7EXP R0, R0e^(xᵢ−m')m'se^(m−m')
8MUL r2, r3, r2e^(xᵢ−m')m's·e^(m−m')e^(m−m')
9ADD r2, R0, r2e^(xᵢ−m')m's'e^(m−m')
Trace explanation and cost analysis Trace explanation and cost

After each iteration: r1 = mm', r2 = ss'. After processing all nn values: r1 = mm, r2 = SS.

Step 2 saves the old mm into r3 before step 3 overwrites r1 with mm'. This is necessary because the rescaling in step 4 needs both the old and new max. After that, r3 is reused for the rescaling factor, and R0 is reused for the new term — neither xix_i nor old mm is needed after step 5.

But we’re not done. This pass only computes mm and SS — the denominator. To produce the actual softmax values, we still need a second pass — same as pass 3 in the 3-register case.

Total cost (4 registers, 2 passes):

PassOperationDRAM readsDRAM writes
1. Online softmaxcompute mm and SSnn0
2. Normalizecompute exim/Se^{x_i - m} / S for all iinnnn
Total2n2nnn

Cost Comparison

3 registers vs 4 registers 3 registers vs 4 registers
Property3 registers (3 passes)4 registers (2 passes)
Passes3 (max + sum + normalize)2 (online + normalize)
DRAM reads3n3n2n2n
DRAM writesnnnn
Compute per elementsimple (SUB, EXP, ADD)more (MAX, 2×SUB, 2×EXP, MUL, ADD)

One extra register saves one pass and nn DRAM reads by merging max-finding into the sum-accumulation. The cost is more compute per step, but for any nontrivial nn, the DRAM savings dominate.

Intuition and Exploration

Connections to other online algorithms Connections to other online algorithms

Online softmax replaces a global dependency (the max mm) with a running state that self-corrects. The same structural pattern appears in other settings:

  • Cauchy sequences: convergence defined by terms getting close to each other (local), not to a known limit (global)
  • Abel summation: individual terms rewritten as partial sums + differences — the accumulated quantity becomes primary
  • Kahan summation: a compensation variable tracks rounding error, correcting the running sum at each step
  • Welford’s online variance: a running (n,xˉ,M2)(n, \bar{x}, M_2) state corrects for a shifting mean, just as online softmax corrects for a shifting max

See the mathematical foundations appendix for full definitions, derivations, and references.

Problem 6: Other scenarios for the running trick Problem 6: Other running tricks

Online softmax maintains a running state and corrects it when a reference point changes. Can you think of other scenarios — from any domain — where the same pattern applies?

Quiz: Online variance (Welford's algorithm) Quiz

The variance of nn values requires the global mean xˉ\bar{x} , a similar global dependency as softmax’s need for the global max mm. Using the same principles (maintain a running state, apply corrections when the reference point changes), design an online algorithm that computes the variance in a single pass. What running state do you maintain? What is the correction step when a new element arrives?

Solution Solution

Definitions:

δ=xxˉxˉxˉ+δn=xˉ+xxˉn\delta = x - \bar{x} \qquad \bar{x} \leftarrow \bar{x} + \frac{\delta}{n} = \bar{x} + \frac{x - \bar{x}}{n}S2=i=1N(xixˉ)2N1S^2 = \frac{\sum_{i=1}^{N}(x_i - \bar{x})^2}{N - 1}

Derivation. Compute (N1)SN2(N2)SN12(N-1)S_N^2 - (N-2)S_{N-1}^2:

=i=1N(xixˉN)2i=1N1(xixˉN1)2= \sum_{i=1}^{N}(x_i - \bar{x}_N)^2 - \sum_{i=1}^{N-1}(x_i - \bar{x}_{N-1})^2=(xNxˉN)2+i=1N1[(xixˉN)2(xixˉN1)2]= (x_N - \bar{x}_N)^2 + \sum_{i=1}^{N-1}\left[(x_i - \bar{x}_N)^2 - (x_i - \bar{x}_{N-1})^2\right]=(xNxˉN)2+i=1N1(2xixˉNxˉN1)(xˉN1xˉN)= (x_N - \bar{x}_N)^2 + \sum_{i=1}^{N-1}(2x_i - \bar{x}_N - \bar{x}_{N-1})(\bar{x}_{N-1} - \bar{x}_N)

Using i=1N1xi=NxˉNxN\sum_{i=1}^{N-1} x_i = N\bar{x}_N - x_N and (N1)xˉN1=NxˉNxN(N-1)\bar{x}_{N-1} = N\bar{x}_N - x_N:

i=1N1(2xixˉNxˉN1)=2(NxˉNxN)(N1)xˉN(NxˉNxN)=xˉNxN\sum_{i=1}^{N-1}(2x_i - \bar{x}_N - \bar{x}_{N-1}) = 2(N\bar{x}_N - x_N) - (N-1)\bar{x}_N - (N\bar{x}_N - x_N) = \bar{x}_N - x_N

So:

=(xNxˉN)2+(xˉNxN)(xˉN1xˉN)= (x_N - \bar{x}_N)^2 + (\bar{x}_N - x_N)(\bar{x}_{N-1} - \bar{x}_N)=(xNxˉN)[(xNxˉN)(xˉN1xˉN)]= (x_N - \bar{x}_N)\left[(x_N - \bar{x}_N) - (\bar{x}_{N-1} - \bar{x}_N)\right]=(xNxˉN)(xNxˉN1)= (x_N - \bar{x}_N)(x_N - \bar{x}_{N-1})

Result:

(N1)SN2=(N2)SN12+(xNxˉN)(xNxˉN1)(N-1)S_N^2 = (N-2)S_{N-1}^2 + (x_N - \bar{x}_N)(x_N - \bar{x}_{N-1})

From Softmax to Attention

The Connection

Attention as softmax + matrix multiply The connection

The scaled dot-product attention is:

Attention(Q,K,V)=softmax(QKTd)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V

The scaling 1d\frac{1}{\sqrt{d}} is a precomputed constant — we fold it into QQ ahead of time and write:

O=softmax(QKT)VO = \text{softmax}(QK^T) \, V

Let Q,K,VRn×dQ, K, V \in \mathbb{R}^{n \times d} where nn is the sequence length and dd is the head dimension. The naive approach computes this in three steps:

  1. Compute S=QKTS = QK^T — an n×nn \times n matrix of scores
  2. Apply row-wise softmax: P=softmax(S)P = \text{softmax}(S) — an n×nn \times n attention matrix
  3. Multiply: O=PVO = PV — the n×dn \times d output

The written expression softmax(QKT)V\text{softmax}(QK^T) \cdot V suggests we must first compute the softmax, materialize the full n×nn \times n attention matrix PP, and then multiply by VV. But do we actually need to?

Tracing the Dependencies

What does a single output row depend on? Tracing the dependencies

What does a single output row oi\mathbf{o}_i actually depend on? Decomposing by rows:

oi=piV=jpijvj\mathbf{o}_i = \mathbf{p}_i \, V = \sum_j p_{ij} \, \mathbf{v}_j

where pij=exjmikexkmip_{ij} = \frac{e^{x_j - m_i}}{\sum_k e^{x_k - m_i}}, with xj=qikjx_j = \mathbf{q}_i \cdot \mathbf{k}_j and mi=maxjxjm_i = \max_j x_j.

Each row’s softmax is independent — row ii depends only on qi\mathbf{q}_i and all of KK. This connects directly to our online softmax: the softmax denominator for row ii is exactly the sum SS we computed in the previous section.

Expanding the output:

oi=jexjmivjkexkmi\mathbf{o}_i = \frac{\sum_j e^{x_j - m_i} \, \mathbf{v}_j}{\sum_k e^{x_k - m_i}}

This has the same structure as online softmax, except each term in the numerator carries a vector vj\mathbf{v}_j instead of a scalar. The denominator is exactly the SS from before. The numerator is a weighted sum of value vectors, with the same exponential weights.

Applying Online Softmax

Problem 7: Derive the update rules for attention Problem 7: Online attention update

The output oi\mathbf{o}_i is a ratio of a vector numerator and a scalar denominator. Both are weighted sums with the same exponential weights we saw in online softmax. Using the same running-max-and-rescale trick from Path 2, derive the update rules for the numerator and denominator when a new element xkx_k (with value vector vk\mathbf{v}_k) arrives. What is the running state?

Hint: the rescaling factor is scalar Hint

The rescaling factor emkmk+1e^{m_k - m_{k+1}} is a scalar — it distributes over the vector numerator. The denominator update is identical to online softmax.

Full update rules and running state Full update rules
mk+1=max(mk,  xk)m_{k+1} = \max(m_k, \; x_k)numeratork+1=numeratorkemkmk+1+exkmk+1vk\text{numerator}_{k+1} = \text{numerator}_k \cdot e^{m_k - m_{k+1}} + e^{x_k - m_{k+1}} \, \mathbf{v}_kdenominatork+1=denominatorkemkmk+1+exkmk+1\text{denominator}_{k+1} = \text{denominator}_k \cdot e^{m_k - m_{k+1}} + e^{x_k - m_{k+1}}

After processing all nn keys:

oi=numeratorndenominatorn\mathbf{o}_i = \frac{\text{numerator}_n}{\text{denominator}_n}

The running state at step kk is:

StateShapeDescription
mkm_kscalarrunning max
denominatork\text{denominator}_kscalarj=0k1exjmk\sum_{j=0}^{k-1} e^{x_j - m_k}
numeratork\text{numerator}_kvector (dd)j=0k1exjmkvj\sum_{j=0}^{k-1} e^{x_j - m_k} \, \mathbf{v}_j

This is O(d)O(d) storage per row — compared to O(n)O(n) for a full row of the attention matrix, or O(n2)O(n^2) for the full matrix.

From Two Passes to One

Fusing softmax with V eliminates a pass From two passes to one

Recall that softmax alone required two passes even with online softmax: one pass to compute mm and SS, and a second pass to produce the actual softmax values exim/Se^{x_i - m}/S for each ii. The second pass existed because we needed to output nn individual values — each one requires reading xix_i again.

But in attention, we don’t need the individual softmax values. We only need their weighted sum with VV. The running numerator already accumulates this weighted sum during the first pass. After processing all nn elements, the output is a single division:

oi=numeratorndenominatorn\mathbf{o}_i = \frac{\text{numerator}_n}{\text{denominator}_n}

No second pass. Each softmax value pijp_{ij} is produced, multiplied by vj\mathbf{v}_j, accumulated into the numerator, and discarded — it never needs to exist as an individual value.

PassesWhy
Softmax alone (4 registers)2must output each exim/Se^{x_i - m}/S individually → second pass to read xix_i again
Attention (softmax fused with VV)1only need the weighted sum → final division at the end, no second pass

By connecting softmax to the multiplication by VV, the normalize pass disappears entirely.

The Attention Matrix Is Never Materialized

P never exists in memory P never exists in memory

Each attention score xj=qikjx_j = \mathbf{q}_i \cdot \mathbf{k}_j is computed, immediately consumed into the running numerator and denominator, and discarded. The n×nn \times n attention matrix PP never needs to exist in memory.

Intermediate Elimination

The general pattern: fuse production and consumption Intermediate elimination

The pattern here is general: a large intermediate (PP, size n×nn \times n) is produced only to be immediately consumed by the next operation (multiplication by VV). Because each element of PP is used exactly once in a structured reduction (weighted sum over columns of VV), we can fuse production and consumption — compute each pijp_{ij}, multiply by vj\mathbf{v}_j, accumulate, and discard.

This same principle appears in many domains:

  • Kernel fusion (GPU computing): avoid writing intermediates to HBM between operations
  • Deforestation (functional programming): eliminate intermediate data structures when producer and consumer can be fused
  • Loop fusion (compilers): merge loops that produce and consume the same array
  • Our register file model: the attention score goes into a register, gets consumed into the running state, and the register is immediately reused — it never touches DRAM

The condition: if an intermediate is only ever used as part of a contraction/reduction, it does not need to exist as a full object. The softmax-then-multiply pattern in attention satisfies this exactly.

Element is Tile

The Tile Abstraction

Every element can be a tile — zero structural adjustment The tile abstraction

Everything so far is described in terms of individual scalar elements xjx_j. But nothing in the formulation requires this — every element can be a tile, with zero structural adjustment.

Split a row of length nn into tiles of size BB. Tile tt covers elements x[tB:(t+1)B]x[tB : (t+1)B]. Within each tile, we compute a local state:

  • m(t)=max(x[tB:(t+1)B])m^{(t)} = \max(x[tB : (t+1)B]) — local tile max
  • d(t)=j=tB(t+1)B1exjm(t)d^{(t)} = \sum_{j=tB}^{(t+1)B-1} e^{x_j - m^{(t)}} — local tile denominator
  • num(t)=j=tB(t+1)B1exjm(t)vj\mathbf{num}^{(t)} = \sum_{j=tB}^{(t+1)B-1} e^{x_j - m^{(t)}} \, \mathbf{v}_j — local tile numerator

To merge two tiles:

m=max(m(t),  m(t+1))m' = \max(m^{(t)},\; m^{(t+1)})d=d(t)em(t)m+d(t+1)em(t+1)md' = d^{(t)} \cdot e^{m^{(t)} - m'} + d^{(t+1)} \cdot e^{m^{(t+1)} - m'}num=num(t)em(t)m+num(t+1)em(t+1)m\mathbf{num}' = \mathbf{num}^{(t)} \cdot e^{m^{(t)} - m'} + \mathbf{num}^{(t+1)} \cdot e^{m^{(t+1)} - m'}

This is exactly the same merge operation as the element-wise online softmax. The rescaling factor is still a scalar, it still distributes over vectors, and the merge is still associative. The formulation is scale-free — the merge only operates on the state (m,d,num)(m, d, \mathbf{num}) and does not care about what is inside each tile.

This is why the toy model was worth building. The scalar version was not a simplified analogy — it is the algorithm, at a different granularity. Going from scalar to tiled is a change of mindset, not a change of structure.

Recall from our toy model that we already introduced this idea for matrix multiplication: the same instruction set works at both the scalar level (MAC on individual floats) and the tiled level (Tensor Core on sub-matrices). The same principle applies here.

At the tiled level, our toy model maps directly to a GPU:

Toy modelGPU
Register fileSRAM (shared memory)
DRAMHBM
LOAD / STOREdata movement between HBM and SRAM
Compute unit (MAC / ALU)Tensor Core

The register file in our model is the fast, small memory close to compute — that is SRAM on a GPU. The DRAM in our model is the large, slow memory — that is HBM. Every cost tradeoff we analyzed (fewer passes, fewer DRAM reads, more registers) translates directly: fewer HBM accesses, more SRAM usage.

Tiling Direction

Tiling across the key/value dimension Tiling direction

We are tiling across a row of the attention matrix — over the key/value dimension. Each tile processes a block of keys K[tB:(t+1)B]K[tB:(t{+}1)B] and the corresponding values V[tB:(t+1)B]V[tB:(t{+}1)B], while keeping one query qi\mathbf{q}_i fixed. This directly determines how KK and VV are blocked in memory.

Inner Product vs. Outer Product in Attention

Two loop orderings for attention Inner product vs. outer product

In our toy model, we saw two ways to compute a matrix multiplication: inner product (compute one output element fully, then move to the next) and outer product (load one pair of operands, update all output elements at once). The same choice appears in attention.

If we view the attention computation O=softmax(QKT)VO = \text{softmax}(QK^T) \, V as conceptually a matrix operation, we can choose which dimension to iterate over in the outer loop:

Inner product style: fix a query, stream through all KV

For each query qi\mathbf{q}_i, stream through all key-value pairs. The running state for one row — mm (scalar), denominator (scalar), numerator (vector of size dd) — stays in fast memory. After processing all KV pairs, oi\mathbf{o}_i is complete. Move to the next query.

  • Running state in fast memory: O(d)O(d) per query
  • K and V are reloaded for every query

Outer product style: fix a KV block, update all queries

Load a block of K and V into fast memory. Then stream through queries one at a time, updating each query’s running state with this KV block. After processing all queries, load the next KV block and repeat.

In the element-is-tile view, each “query” is one element (or one tile). Its running state — mm, denominator, numerator — is O(d)O(d), same as the inner product style. The difference is not how much state a single query needs, but how many queries’ outputs must be stored to HBM during the process.

If we can keep only one query’s output in fast memory, we load/store each query’s output once per KV element — nn stores per query. Each additional output slot we keep in fast memory saves n1n - 1 stores for that entry, because it can stay resident across all KV elements. At the extreme, if all nn outputs fit, each is stored once — and we’ve recovered the inner product schedule.

  • Running state per query: O(d)O(d) — same as inner product
  • K and V are loaded once
  • Output stores per query: nn (once per KV element) — unless kept in fast memory across elements
  • All query updates within one KV element are independent → parallelizable

Cost Comparison

HBM traffic: inner product vs outer product Cost comparison
Inner productOuter product
Outer loopover queriesover KV
Inner loopover KVover queries
Q loadsnn (once per query)n2n^2 (every query reloaded per KV)
KV loadsn2n^2 (all KV reloaded per query)nn (once per KV)
O loadsnn (initialize once per query)n2n^2 (reloaded per KV)
O storesnn (once per query)n2n^2 (stored per KV)

The cost is not about fast memory capacity per query — it’s the same in both cases. The cost is about which data gets reloaded from HBM: KV (inner product) or output (outer product). Each additional output slot in fast memory saves n1n - 1 stores, moving along the spectrum from outer product toward inner product.

This is the flash attention loop ordering question:

  • Flash Attention v2 uses the inner product style (outer loop over queries, inner loop over KV). Each query’s output stays in SRAM across all KV blocks — stored once. KV is reloaded for every query block.
  • Flash Attention v1 uses the outer product style (outer loop over KV, inner loop over queries). A KV block stays in SRAM. Query outputs are loaded/stored from HBM for each KV block — TcT_c stores per query. More HBM traffic for outputs, but the inner loop over query blocks is independent and can be parallelized across GPU thread blocks.

The same tradeoff we saw in the 2×22 \times 2 matmul — 3 registers with inner product vs. 6 registers with outer product — plays out here at the scale of SRAM and HBM. But at the tiled level, it’s a continuous spectrum: the more output you keep in SRAM, the fewer HBM round-trips for outputs, at the cost of less SRAM available for KV blocks.

Quizzes

Arithmetic Intensity: Inner Product vs. Outer Product

Arithmetic Intensity: Flash Attention

Backward Pass

This section has not been validated and polished. The content below is a draft.

Backward pass (draft) Backward pass (draft)

Why Study the Backward Pass?

In PyTorch, requires_grad=True tells the autograd engine to preserve the intermediate states of the forward pass — so they can be reused when computing gradients. During inference we skip this; during training we keep everything.

But flash attention never materializes the n×nn \times n attention matrix PP. If the intermediate doesn’t exist, what do we differentiate through?

The answer is activation checkpointing: rerun the forward pass to regenerate the intermediates on demand. And here the structure of flash attention pays off twice. The first time through, everything is new — we use the online softmax trick to avoid materializing PP. The second time through, things are cheaper: we already have the row-wise maxima mim_i and denominators did_i stored from the first pass. We don’t recompute them. The computational cost of the recomputation drops significantly.

So is that enough? Run activation checkpointing, recompute PP block by block during the backward pass, and we’re done?

No. The backward pass itself involves n×nn \times n intermediate matrices — dPdP and dSdS — that are just as large as PP. If we materialize those, we’ve solved nothing. We need to avoid materializing any n×nn \times n matrix during the backward pass too. That is why we need to study the details of what happens in the backward pass — not just that gradients flow through softmax, but how to compute them without ever building the full attention-sized matrices.

But before we dive in — let’s pause and make sure we understand what exactly allowed us to avoid materializing the attention matrix in the forward pass.

Quiz: What is the key ingredient that allows us to avoid materializing the attention matrix in the forward pass? It’s not online softmax — that handles the denominator. What handles the matrix itself?

The answer is contraction: PP only ever appears inside the sum jPijvj\sum_j P_{ij} \mathbf{v}_j. Each element is produced, multiplied by vj\mathbf{v}_j, accumulated into the output, and discarded. An intermediate that feeds directly into a linear contraction never needs to be fully materialized.

Now consider the backward pass. The n×nn \times n intermediates dPdP and dSdS play the same role that PP played in the forward pass — they are large intermediates we want to eliminate. In the forward pass, we could merge all the intermediate steps because PP flowed directly into a linear contraction with VV. Can we do the same here?

The chain is: dPdSdQ,dKdP \to dS \to dQ, dK. If we can contract dSdS directly with the next step (dQ=dSKdQ = dS \cdot K, dK=dSTQdK = dS^T \cdot Q), then dSdS does not need to be fully materialized — we need to discover the same contraction pattern again. But the step from dPdP to dSdS passes through the softmax Jacobian — a nonlinear derivative. This is the critical step that makes the backward pass harder than the forward pass. We will see a clever trick (the DiD_i identity) that handles this step without ever materializing the full n×nn \times n matrices.

The Language of Tensors

A matrix is a (1,1)(1,1) tensor — one upper index (row), one lower index (column). The matrix product C=ABC = AB can be written component-wise:

cij=kaikbkjc_{ij} = \sum_k a_{ik} \, b_{kj}

where ii is the row, jj is the column, and kk is summed over. We can also write this in Einstein notation, where a repeated index — one upper, one lower — implies summation:

Cij=AikBkjC^i{}_j = A^i{}_k \, B^k{}_j

No \sum symbol needed. The repeated kk (lower in AA, upper in BB) tells you to sum over it. This is the same convention used by torch.einsum.

Now consider the derivative of a matrix with respect to another matrix. If CRm×nC \in \mathbb{R}^{m \times n} depends on ARp×qA \in \mathbb{R}^{p \times q}, we need to ask: how does each entry cijc_{ij} change when we perturb each entry akla_{kl}? That requires four indices — two for the output (i,ji, j) and two for the input (k,lk, l):

CijAkl\frac{\partial C^i{}_j}{\partial A^k{}_l}

This is a (2,2)(2,2) tensor with mnpqm \cdot n \cdot p \cdot q entries — a Jacobian, but organized as a 4-dimensional object rather than a flattened matrix.

The simplest case: the derivative of a matrix with respect to itself. Each component AisA^i{}_s is an independent variable, so its derivative with respect to AklA^k{}_l is 1 when i=ki = k and s=ls = l, and 0 otherwise. In Kronecker delta notation:

AisAkl=δkiδsl\frac{\partial A^i{}_s}{\partial A^k{}_l} = \delta^i_k \, \delta^l_s

This is the fundamental identity. The two deltas enforce i=ki = k (same row) and s=ls = l (same column).

Derivative of a matrix product. Given Cij=AisBsjC^i{}_j = A^i{}_s \, B^s{}_j, we differentiate with respect to AklA^k{}_l:

CijAkl=(AisBsj)Akl\frac{\partial C^i{}_j}{\partial A^k{}_l} = \frac{\partial (A^i{}_s \, B^s{}_j)}{\partial A^k{}_l}

BB does not depend on AA, so we can pull it out:

=AisAklBsj= \frac{\partial A^i{}_s}{\partial A^k{}_l} \, B^s{}_j

Apply the fundamental identity:

=δkiδslBsj= \delta^i_k \, \delta^l_s \, B^s{}_j

The δsl\delta^l_s contracts with BsjB^s{}_j (setting s=ls = l):

=δkiBlj= \delta^i_k \, B^l{}_j

This says: row ii of CC only depends on row ii of AA.

Chain rule. For a composition D=f(C)=f(AB)D = f(C) = f(AB):

DαβAkl=DαβCijCijAkl=DαβCkjBlj\frac{\partial D^{\alpha}{}_{\beta}}{\partial A^k{}_l} = \frac{\partial D^{\alpha}{}_{\beta}}{\partial C^i{}_j} \cdot \frac{\partial C^i{}_j}{\partial A^k{}_l} = \frac{\partial D^{\alpha}{}_{\beta}}{\partial C^k{}_j} \, B^l{}_j

The Kronecker delta collapses the sum over ii, leaving only i=ki = k — we only need derivatives of DD with respect to row kk of CC.

Contracting a delta with a denominator index. When a Kronecker delta contracts with an index that appears in the denominator of a partial derivative, the variance flips. For example, in the expression:

LOijδjl\frac{\partial L}{\partial O^i{}_j} \cdot \delta^l_j

The index jj is lower in OijO^i{}_j, but taking Oij\frac{\partial}{\partial O^i{}_j} flips its variance — jj becomes upper in LOij\frac{\partial L}{\partial O^i{}_j}. The δjl\delta^l_j has jj as lower. Upper meets lower → contraction, setting j=lj = l:

=LOil= \frac{\partial L}{\partial O^i{}_l}

In Euclidean space the component values don’t change when we raise or lower indices, so this is purely bookkeeping — but it tells us which indices get summed.

Matrix-to-Matrix Derivatives

The derivative of a matrix CRm×nC \in \mathbb{R}^{m \times n} with respect to a matrix ARp×qA \in \mathbb{R}^{p \times q} is a 4-index object:

CijAkl\frac{\partial C^i{}_j}{\partial A^k{}_l}

This has mnpqm \cdot n \cdot p \cdot q entries — a (2,2)(2,2) tensor. For the attention forward pass Q,KSPOQ, K \to S \to P \to O, the full Jacobians at each step are 4D tensors that are expensive to compute and store.

The Collapse: (2,2) to (1,1)

In practice, backpropagation never computes the full 4D Jacobians. Instead, given a scalar loss LL and the upstream gradient dO=LOdO = \frac{\partial L}{\partial O} (same shape as OO, a matrix — a (1,1)(1,1) tensor), we compute the vector-Jacobian product (VJP):

dAkl=LAkl=LCijCijAkldA^k{}_l = \frac{\partial L}{\partial A^k{}_l} = \frac{\partial L}{\partial C^i{}_j} \cdot \frac{\partial C^i{}_j}{\partial A^k{}_l}

The (1,1)(1,1) upstream gradient contracts with the (2,2)(2,2) Jacobian, and the result collapses back to a (1,1)(1,1) tensor — a matrix. The 4D Jacobian is never built. Every derivative in the chain can be represented as a matrix, not a 4D tensor.

This is the same intermediate elimination pattern from the forward pass. There, the n×nn \times n attention matrix PP was the large intermediate that was never materialized. Here, the (2,2)(2,2) Jacobian is the large intermediate — and it too is never materialized, because it is immediately contracted with the upstream gradient.

The Backward Pass Step by Step

We use PP for the attention matrix (called AA in our earlier sections) to match the flash attention paper notation. All dd-prefixed matrices are gradients of the scalar loss LL:

SymbolShapeDescription
Q,K,VQ, K, Vn×dn \times dinputs
S=QKTS = QK^Tn×nn \times npre-softmax scores
P=softmax(S)P = \text{softmax}(S)n×nn \times nattention weights
O=PVO = PVn×dn \times doutput
dOdOn×dn \times dLO\frac{\partial L}{\partial O} — upstream gradient (given)
dV,dQ,dKdV, dQ, dKn×dn \times dwhat we want
dPdPn×nn \times nLP\frac{\partial L}{\partial P} — intermediate
dSdSn×nn \times nLS\frac{\partial L}{\partial S} — intermediate
DiD_inn scalarsrow-wise dot product of dOdO and OO

Forward was: Q,KS=QKTP=softmax(S)O=PVQ, K \to S = QK^T \to P = \text{softmax}(S) \to O = PV

Backward reverses this:

Step 1: Through O=PVO = PV

We have Oij=PisVsjO^i{}_j = P^i{}_s \, V^s{}_j. We want dVkl=LVkldV^k{}_l = \frac{\partial L}{\partial V^k{}_l}.

Deriving dVdV:

First, the Jacobian. Using our derivative-of-a-product result (differentiating with respect to the second factor this time):

OijVkl=Pikδjl\frac{\partial O^i{}_j}{\partial V^k{}_l} = P^i{}_k \, \delta^l_j

Now contract with dOdO via the VJP — recall that dVkl=LVkldV^k{}_l = \frac{\partial L}{\partial V^k{}_l}, which by the chain rule is:

dVkl=LVkl=LOijOijVkldV^k{}_l = \frac{\partial L}{\partial V^k{}_l} = \frac{\partial L}{\partial O^i{}_j} \cdot \frac{\partial O^i{}_j}{\partial V^k{}_l}

Substituting the Jacobian:

=LOijPikδjl= \frac{\partial L}{\partial O^i{}_j} \cdot P^i{}_k \cdot \delta^l_j

The δjl\delta^l_j contracts with the jj in the denominator (variance flip: jj is upper in LOij\frac{\partial L}{\partial O^i{}_j}, lower in δjl\delta^l_j), setting j=lj = l:

=LOilPik= \frac{\partial L}{\partial O^i{}_l} \cdot P^i{}_k

The repeated ii is a contraction — this is a matrix product:

=(PT)kiLOil=(PTLO)kl= (P^T)^k{}_i \cdot \frac{\partial L}{\partial O^i{}_l} = \left(P^T \, \frac{\partial L}{\partial O}\right)^k{}_l

In matrix form: dV=PTdOdV = P^T \, dO.

Deriving dPdP:

The Jacobian with respect to the first factor:

OijPkl=δkiVlj\frac{\partial O^i{}_j}{\partial P^k{}_l} = \delta^i_k \, V^l{}_j

Contract with LO\frac{\partial L}{\partial O} via the VJP:

dPkl=LPkl=LOijOijPkldP^k{}_l = \frac{\partial L}{\partial P^k{}_l} = \frac{\partial L}{\partial O^i{}_j} \cdot \frac{\partial O^i{}_j}{\partial P^k{}_l}

Substituting the Jacobian:

=LOijδkiVlj= \frac{\partial L}{\partial O^i{}_j} \cdot \delta^i_k \cdot V^l{}_j

The δki\delta^i_k contracts with ii in the denominator (variance flip), setting i=ki = k:

=LOkjVlj= \frac{\partial L}{\partial O^k{}_j} \cdot V^l{}_j

The repeated jj is a contraction:

=(LO)kj(VT)jl=(LOVT)kl= \left(\frac{\partial L}{\partial O}\right)^k{}_j \cdot (V^T)^j{}_l = \left(\frac{\partial L}{\partial O} \, V^T\right)^k{}_l

In matrix form: dP=dOVTdP = dO \, V^T.

Extracting row 0. (dP)0l=(dO)0β(VT)βl=βdO0βVβlT(dP)^0{}_l = (dO)^0{}_\beta \, (V^T)^\beta{}_l = \sum_{\beta} dO_{0\beta} \, V^T_{\beta l}

The only part that depends on row 0 is dO0βdO_{0\beta} — the matrix VTV^T is shared across all rows. So row 0 of dPdP is simply row 0 of dOdO times VTV^T.

Step 2: Through P=softmax(S)P = \text{softmax}(S)

PiSj=Pi(δijPj)O=PV\frac{\partial P_i}{\partial S_j} = P_i(\delta_{ij} - P_j) \qquad O = PVdP=(dP)kl=(LP)kl=(dOVT)kldP = (dP)^k{}_l = \left(\frac{\partial L}{\partial P}\right)^k{}_l = (dO \, V^T)^k{}_l

Say k=0k = 0.

LP0l=(LP)0l=(dOVT)0l=(dO)0k(VT)kl=k(dO)0kVlk\frac{\partial L}{\partial P_{0l}} = \left(\frac{\partial L}{\partial P}\right)^0{}_l = (dO \, V^T)^0{}_l = (dO)^0{}_k \, (V^T)^k{}_l = \sum_k (dO)_{0k} \, V_{lk}

where (VT)kl=Vlk(V^T)^k{}_l = V_{lk} — transpose in Einstein notation swaps the indices.

(LS)0j=LP0lP0lS0j=lLP0lPlS0j\left(\frac{\partial L}{\partial S}\right)^0{}_j = \frac{\partial L}{\partial P^0{}_l} \frac{\partial P^0{}_l}{\partial S^0{}_j} = \sum_l \frac{\partial L}{\partial P_{0l}} \frac{\partial P_l}{\partial S_{0j}}=lLP0lP0l(δljP0j)=LP0jP0jlLP0lP0lP0j= \sum_l \frac{\partial L}{\partial P_{0l}} P_{0l} (\delta_{lj} - P_{0j}) = \frac{\partial L}{\partial P_{0j}} P_{0j} - \sum_l \frac{\partial L}{\partial P_{0l}} P_{0l} P_{0j}=dP0jP0jl(k(dO)0kVlk)P0lP0j= dP_{0j} \cdot P_{0j} - \sum_l \left(\sum_k (dO)_{0k} \, V_{lk}\right) P_{0l} \, P_{0j}=dP0jP0jP0jk(dO)0klP0lVlk= dP_{0j} \cdot P_{0j} - P_{0j} \sum_k (dO)_{0k} \sum_l P_{0l} \, V_{lk}=P0jdP0jP0jk(dO)0kO0k= P_{0j} \, dP_{0j} - P_{0j} \sum_k (dO)_{0k} \, O_{0k}=P0jdP0jP0jD0D0=k(dO)0kO0k= P_{0j} \, dP_{0j} - P_{0j} \, D_0 \qquad D_0 = \sum_k (dO)_{0k} \, O_{0k}(LS)ij=PijdPijPijDiDi=k(dO)ikOik\Rightarrow \left(\frac{\partial L}{\partial S}\right)^i{}_j = P_{ij} \, dP_{ij} - P_{ij} \, D_i \qquad D_i = \sum_k (dO)_{ik} \, O_{ik}

Step 3: Through S=QKTS = QK^T

dQ=dSKdK=dSTQdQ = dS \cdot K \qquad dK = dS^T \cdot Q

Step 4: Through the QKV projections

The inputs QQ, KK, VV are all projections of the same input XX:

Q=XWQK=XWKV=XWVQ = XW_Q \qquad K = XW_K \qquad V = XW_V

Since all three depend on XX, the gradient flows back through all three paths and sums:

dX=dQWQT+dKWKT+dVWVTdX = dQ \, W_Q^T + dK \, W_K^T + dV \, W_V^T

Expanding dQdQ, dKdK, dVdV:

dX=(dSK)WQT+(dSTQ)WKT+(PTdO)WVTdX = (dS \cdot K) \, W_Q^T + (dS^T \cdot Q) \, W_K^T + (P^T \, dO) \, W_V^T=dS(KWQT)+dST(QWKT)+PT(dOWVT)= dS \cdot (K \, W_Q^T) + dS^T \cdot (Q \, W_K^T) + P^T \cdot (dO \, W_V^T)

By associativity, the projection weights can be folded into the right operand first. The parenthesized products KWQTKW_Q^T, QWKTQW_K^T, and dOWVTdO \, W_V^T are each n×dn \times d — cheap to precompute and not n×nn \times n.

But dXdX is only the activation gradient. During training, we also need the weight gradients for the parameter update. Since Q=XWQQ = XW_Q, K=XWKK = XW_K, V=XWVV = XW_V:

dWQ=XTdQ=XT(dSK)dW_Q = X^T \, dQ = X^T (dS \cdot K)dWK=XTdK=XT(dSTQ)dW_K = X^T \, dK = X^T (dS^T \cdot Q)dWV=XTdV=XT(PTdO)dW_V = X^T \, dV = X^T (P^T \, dO)

Each weight gradient is d×dd \times d — small. But computing them requires dQdQ, dKdK, dVdV explicitly. This conflicts with the associativity trick above: if we fold dQdQ directly into dXdX and discard it, we can’t also use it for dWQdW_Q.

The resolution: each row of dQdQ, dKdK, dVdV can be consumed by two accumulators simultaneously. For example, when row ii of dQdQ is produced, it contributes a row to dXdX (via WQTW_Q^T) and a rank-1 update to dWQdW_Q (via xiT(dQ)ix_i^T \cdot (dQ)_i), then is discarded. The intermediates are still never materialized as full matrices — they’re just consumed twice instead of once before being discarded.

In practice, the flash attention kernel stops at dQdQ, dKdK, dVdV. The projection weight gradients and dXdX are handled by the framework’s standard linear layer backward pass — there is no n×nn \times n intermediate involved in those steps, so no special treatment is needed.

Quiz: Why do we need to compute , , at all? What are they used for, and why can’t we skip them?
Quiz: How much compute is required to compute the QKV projection weight updates (, , ), and how much compute is required to compute ? Express in terms of and .

Execution Schedule

Recall from the forward pass: having an expression is not the same as having an execution schedule. In attention, the expression O=softmax(QKT)VO = \text{softmax}(QK^T)V admits multiple execution orders — and the choice determines whether we materialize PP or not. The same question applies here.

We want to avoid materializing both dPdP and dSdS (both n×nn \times n). Since softmax is row-wise, it’s natural to work row by row: compute a row of dSdS, then immediately use it.

Say we’ve computed row 0 of dSdS. Now consider the two consumers:

dQ=dSKdQ = dS \cdot K: row 0 of dSdS times all of KK gives row 0 of dQdQ. This is the inner product style — one row of the left matrix updates one row of the output completely. No problem.

dK=dSTQdK = dS^T \cdot Q: transposing dSdS turns row 0 into column 0. A column of the left matrix times a row of QQ doesn’t give a single row of dKdK — it gives a rank-1 update to the entire dKdK matrix. This is the outer product style: column 0 of dSTdS^T (which is row 0 of dSdS) times row 0 of QQ updates all of dKdK.

So the two gradients require different computation patterns:

GradientStyleWhat happens per row of dSdS
dQ=dSKdQ = dS \cdot Kinner productrow of dSdS ×\times all of KK → completes one row of dQdQ
dK=dSTQdK = dS^T \cdot Qouter productrow of dSdS (as column) ×\times one row of QQ → rank-1 update to all of dKdK

For dKdK, we keep the full dKdK matrix in memory and accumulate into it as each row of dSdS is produced. We don’t load all of QQ — just one row at a time, paired with the corresponding row of dSdS.

And as before, element is tile: each element of dSdS can be a scalar (one entry) or a tile (a block of entries in the row). The same execution schedule works at both granularities.

Row-wise vs. column-wise. But wait — we chose to iterate row-wise because softmax is row-wise. What if we iterate column-wise instead? Consider column jj of dSdS. We also need dV=PTdOdV = P^T \, dO, which we haven’t accounted for yet — with row-wise iteration, row ii of PP becomes column ii of PTP^T, giving another outer product.

Count the inner vs. outer products for both strategies:

Row-wise (iterate over rows of dSdS, rows of PP):

GradientStyleWhy
dQ=dSKdQ = dS \cdot Kinner productrow of dSdS ×\times KK → completes one row of dQdQ
dK=dSTQdK = dS^T \cdot Qouter productrow of dSdS = column of dSTdS^T, rank-1 update to all of dKdK
dV=PTdOdV = P^T \, dOouter productrow of PP = column of PTP^T, rank-1 update to all of dVdV

Score: 1 inner, 2 outer.

Column-wise (iterate over columns of dSdS, columns of PP):

GradientStyleWhy
dQ=dSKdQ = dS \cdot Kouter productcolumn jj of dSdS ×\times row jj of KK → rank-1 update to all of dQdQ
dK=dSTQdK = dS^T \cdot Qinner productcolumn jj of dSdS = row jj of dSTdS^T, times QQ → completes row jj of dKdK
dV=PTdOdV = P^T \, dOinner productcolumn jj of PP = row jj of PTP^T, times dOdO → completes row jj of dVdV

Score: 2 inner, 1 outer.

Column-wise is better balanced. And for the weight gradients, it’s even cleaner: with dKdK and dVdV produced row-by-row (inner product style), each row immediately gives a rank-1 update to dWKdW_K and dWVdW_V. Only dWQdW_Q requires the outer product accumulation.

But softmax is row-wise — can we actually compute dSdS column-wise? Yes. Column jj of SS is QkjQ \, k_j (all queries dotted with key jj). Then Pij=eSijmi/iP_{ij} = e^{S_{ij} - m_i} / \ell_i using the precomputed per-row statistics mim_i and i\ell_i. Column jj of dPdP is dOvjdO \cdot v_j. Then dSij=Pij(dPijDi)dS_{ij} = P_{ij}(dP_{ij} - D_i) with precomputed DiD_i. Everything works — the row-wise structure of softmax is captured in the stored scalars, and the iteration itself can proceed column by column.

Can we push further? The gradients dQdQ, dKdK, dVdV are all n×dn \times d — not n×nn \times n, so less urgent. But for long sequences (nn in the millions), even n×dn \times d matrices are large. Do we need to materialize them, or can they be consumed immediately too?

Recall that the final target is dX=dQWQT+dKWKT+dVWVTdX = dQ \, W_Q^T + dK \, W_K^T + dV \, W_V^T. Each of dQdQ, dKdK, dVdV feeds into a linear contraction with a projection weight matrix — exactly the pattern that allows intermediate elimination.

Term 1: dS(KWQT)dS \cdot (K \, W_Q^T). Inner product style, fully streamable. KWQTK \, W_Q^T is precomputed. Row ii of dSdS times KWQTK \, W_Q^T produces row ii of the first contribution to dXdX. Produce, consume, discard — dQdQ is never materialized.

Term 2: dST(QWKT)dS^T \cdot (Q \, W_K^T). QWKTQ \, W_K^T is precomputed. For each row ii of dSdS (which becomes column ii of dSTdS^T): outer product of column ii of dSTdS^T with row ii of QWKTQ \, W_K^T, accumulated into dXdX. Neither dKdK nor dSTdS^T is materialized.

Term 3: PT(dOWVT)P^T \cdot (dO \, W_V^T). dOWVTdO \, W_V^T is precomputed. Now PTP^T needs columns of PP — but the outer product decomposition avoids this. Row ii of PP (which we already have from the row-wise softmax recomputation) becomes column ii of PTP^T. Outer product of column ii of PTP^T with row ii of dOWVTdO \, W_V^T, accumulated into dXdX. We never need a full column of PP — just one row at a time.

All three terms can be accumulated into a single dXdX matrix as we stream row-by-row through dSdS and PP. The intermediates dQdQ, dKdK, dVdV are eliminated by the same contraction principle that eliminated PP in the forward pass.

The DiD_i Trick: Why It Matters

At the row level, DiD_i is just a dot product — trivial to compute from the recomputed row of PP and the row of dPdP. So why does flash attention go out of its way to precompute it as Di=k(dO)ikOikD_i = \sum_k (dO)_{ik} \, O_{ik}?

The answer is arithmetic intensity — specifically, how many times KK and VV are loaded from HBM.

Without the trick, the backward pass for each block requires two loads of KK:

  1. Load KK block into SRAM to recompute PP (via S=QKTS = QK^T). Load VV block to compute dPdP (via dOVTdO \cdot V^T). Compute DiD_i from the row of PP and dPdP. But DiD_i needs the full row — so we must finish all blocks in this row before proceeding.
  2. Load KK block into SRAM again to compute dQ=dSKdQ = dS \cdot K.

KK is loaded from HBM twice. The dependency on DiD_i prevents fusing the recomputation of PP with the consumption of dSdS, because dSij=Pij(dPijDi)dS_{ij} = P_{ij}(dP_{ij} - D_i) requires DiD_i which requires the full row.

With the trick, Di=k(dO)ikOikD_i = \sum_k (dO)_{ik} \, O_{ik} is precomputed from quantities already in hand — no PP needed. Now everything fuses into a single load per block:

  1. Load KK block, VV block into SRAM
  2. Recompute PP block (KK is in SRAM)
  3. Compute dPdP block (VV is in SRAM)
  4. Compute dSdS block (using precomputed DD — no waiting)
  5. Compute dQdQ contribution (KK is still in SRAM)
  6. Accumulate dKdK contribution
  7. Discard block

KK and VV are loaded once instead of twice. The DD trick removes the last dependency that forced a second load, enabling the recomputation and gradient computation to be fused into a single pass over the data. This doubles the arithmetic intensity of the backward pass.