MLsys Note (2) - Flash Attention algo & code walk through

MLsys Note (2) - Flash Attention algo & code walk through

Kai-Jie Lin Lv3

Flash Attention

相關背景應該不用多做介紹,一言以蔽之:

  1. FA利用online softmax技巧壓縮softmax時帶來的顯存壓力
  2. 系統上做資料搬運上的kernel優化(kernel fusion),達到高效的inference速度以及少量的顯存用量。

先看算法,再看系統

Flash Attention in Math ref

Naive Attention:

Attention(Q,K,V)=softmax(QKdk)V \mathrm{Attention}(Q,K,V)= \mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V

Softmax:

softmax({x1,...,xN})={exijNexj}i=1N \mathrm{softmax}(\{x_1,...,x_N\}) = \{\frac{e^{x_i}}{\sum_{j}^{N}{e^{x_j}}}\}_{i=1}^{N}

naive softmax不是associative的(summation),也就是說在cuda kernel的實作上,沒辦法被tile優化。

Online Softmax

Safe softmax: 避免exponential function造成過大的值,softmax有個很好的性質 ->

exijNexj=eximjNexjm \frac{e^{x_i}}{\sum_{j}^{N}{e^{x_j}}} = \frac{e^{x_i}-m}{\sum_{j}^{N}{e^{x_j}-m}}

where m=maxj=1Nxjm = \mathrm{max}_{j=1}^N{x_j}, 保證 (xjm)<=0    exjm[0,1].(x_j - m) <= 0 \implies e^{x_j-m} \in [0, 1].
使得fp16在數值計算上更加準確。

有了這個性質,我們可以整理出一個3-pass的算法,讓我們不需要一次load n個component到SRAM裡,而是iterative:

  1. 先計算 mi:maxj=1i{xj}{m_i}:\mathrm{max}_{j=1}^{i}\{x_j\} for ii in 1...N1...N
  2. didi1+eximid_i \leftarrow d_{i-1} + e^{x_i - m_i} for ii in 1...N1...N
  3. aieximidNa_i \leftarrow \frac{e^{x_i - m_i}}{d_N} for ii in 1...N1...N

雖然成功節省顯存,我們卻需要iterate N sequence三次,這樣在系統上是I/O inefficient的,如果把前兩個operation fuse在一起,可以減少data搬運次數,但我們發現第二個計算是depends on mNm_N,不能把他fuse在同一個loop裡,於是就有了online softmaxdid_i的遞迴性質:

di=j=1iexjmi=(j=1i1exjmi)+eximi=(j=1i1exjmi1)emi1mi+eximi=di1emi1mi+eximi d_i' = \sum_{j=1}^i{e^{x_j-m_i}} = (\sum_{j=1}^{i-1}{e^{x_j-m_i}}) + e^{x_i-m_i} \\ = (\sum_{j=1}^{i-1}{e^{x_j-m_{i-1}}})e^{m_{i-1}-m_{i}} + e^{x_i-m_i} = d_{i-1}'e^{m_{i-1}-m_{i}} + e^{x_i-m_i}

2-pass算法:

  1. mi=max(mi1,xi){m_i} = \mathrm{max}(m_{i-1}, x_i), di=di1emi1mi+eximid_i' = d_{i-1}'e^{m_{i-1}-m_{i}} + e^{x_i-m_i} for ii in 1...N1...N
  2. aieximidNa_i \leftarrow \frac{e^{x_i - m_i}}{d_N} for ii in 1...N1...N

有沒有辦法更近一步變成1-pass?答案是不行,在數學裡的話。

但是我們要計算的是self-attention,我們不只要求出attention score,還需要乘VV,如果把VV考慮進來呢?

為了計算出oio_i,能不能套用之前的遞迴技巧呢?

oi=oi1di1emi1midi+eximidiV[i,:] o_i' = o_{i-1}'\frac{d_{i-1}'e^{m_{i-1}-m_i}}{d_i'}+\frac{e^{x_i-m_i}}{d_i'}V[i,:]


這就是flash attention數學算法,接下來再運用tiling優化kernel就是完整的flash attention了。

Implementation walk thorugh

tiling: 在計算大矩陣時,分成小塊放入block,讓同一塊資料被重複使用,以減少資料搬運的時間。


用CUDA進行實現:(解釋概念不保證完全正確)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
template<int HEAD_DIM, int BM, int BN, bool CAUSAL>
__global__ void flash_fwd(
const half* __restrict__ Q, // [B, H, N, D]
const half* __restrict__ K, // [B, H, N, D]
const half* __restrict__ V, // [B, H, N, D]
half* __restrict__ O, // [B, H, N, D]
float scale,
int B, int H, int N
) {
// block mapping: one block computes one (b,h, q_tile)
int q_tile = blockIdx.x; // 0 .. ceil(N/BM)-1
int bh = blockIdx.y; // 0 .. B*H-1
int b = bh / H;
int h = bh % H;

int q_row0 = q_tile * BM; // first query row index of this tile

// Shared memory tiles (size picked to fit smem)
extern __shared__ half smem[];
half* sQ = smem; // [BM, D]
half* sK = sQ + BM * HEAD_DIM; // [BN, D]
half* sV = sK + BN * HEAD_DIM; // [BN, D]

// Per-row online softmax states
// m: running max; l: running denom; acc: running numerator (vector D)
// Usually stored in registers / fragments per thread/warp.
float m[/*rows per thread*/];
float l[/*rows per thread*/];
float acc[/*rows per thread*/][HEAD_DIM_PER_THREAD]; // fragment accumulators

init_states(m, l, acc); // m=-inf, l=0 (or 1), acc=0

// 1) Load Q tile once (reused across all K/V tiles)
// Cooperative load into shared memory
load_Q_tile_to_smem<HEAD_DIM, BM>(Q, sQ, b, h, q_row0, N);
__syncthreads();

// 2) Loop over K/V tiles
int k_end = N;
for (int k0 = 0; k0 < k_end; k0 += BN) {

// For causal: skip future blocks entirely (block-level masking)
if constexpr (CAUSAL) {
if (k0 > q_row0 + BM - 1) break; // keys start beyond last query row
}

// Load K/V tile to shared memory
load_KV_tile_to_smem<HEAD_DIM, BN>(K, V, sK, sV, b, h, k0, N);
__syncthreads();

// 3) Compute logits tile: S = Q_tile * K_tile^T => [BM, BN]
// Usually via tensor cores: MMA (Q in shared, K in shared)
// Each warp computes a sub-tile of S.
float S_sub[ROWS_PER_WARP][COLS_PER_WARP];
mma_qk<HEAD_DIM, BM, BN>(sQ, sK, S_sub); // conceptual

// 4) Apply scale + (optional) causal mask within diagonal tile
// - Off-diagonal tiles in causal are fully valid, no per-element mask.
// - Diagonal tile needs triangular mask.
if constexpr (CAUSAL) {
apply_causal_mask_if_diagonal(q_row0, k0, S_sub);
}
apply_scale(scale, S_sub);

// 5) Online softmax update:
// m_new = max(m_old, rowmax(S_sub))
// alpha = exp(m_old - m_new)
// p = exp(S_sub - m_new)
// l_new = l_old * alpha + rowsum(p)
// acc = acc * alpha + p @ V_tile
float rowmax = row_max(S_sub);
float m_new = max(m_row, rowmax);

float alpha = expf(m_row - m_new);

// rescale old accumulators to new max
l_row *= alpha;
scale_accumulator(acc_row, alpha);

// compute p and l_ij
float p_sub[...]; // keep in regs/fragments
exp_subtract_max(S_sub, m_new, p_sub);
float l_ij = row_sum(p_sub);
l_row += l_ij;

// 6) Accumulate output: acc += p @ V_tile
// Again use tensor cores / vectorized dot:
mma_pv<HEAD_DIM, BM, BN>(p_sub, sV, acc_row); // conceptual

// update m
m_row = m_new;

__syncthreads();
}

// 7) Final normalize: O = acc / l
normalize_and_store<HEAD_DIM, BM>(acc, l, O, b, h, q_row0, N);
}
Comments
On this page
MLsys Note (2) - Flash Attention algo & code walk through