MLsys Note (1) - FlashInfer Cascade Attention & KV Cache Layout

MLsys Note (1) - FlashInfer Cascade Attention & KV Cache Layout

Kai-Jie Lin Lv3

FlashInfer

FlashInfer 是一個集合各種 LLM inference kernel 的 library。提供了 attention, GEMM 和 MOE 的 API。也支援多種 nvidia gpu 架構:turing, ampere, hopper, blackwell。提供多種底層的高效 kernel 實現。

FlashInfer 提出有趣的兩個優化:Cascade Attention 和多種KV-Cache Layout。在許多 inference 場景下,多條 request 是可以被共用的,若是可以共用,即大大減少 KV Cache 存儲,可以在計算上獲得許多收益。一樣,這篇會看算法以及實作層面,若有錯誤,不吝賜教。

Cascade Attention

Reference

Flash Attention 1 利用 online softmax 的技巧,大大減少計算 attention 的 in-memory O(n^2) -> O(n)。FA2&3 在 kernel 上進行更合理的優化(tiling)。而 FlashInfer 基於此之上針對 GQA(Group Query Attention) 提出了 attention state 和 merge operator on the attention states。

Attention State

Naive Attention:

Attention(Q,K,V)=softmax(QKdk)V \mathrm{Attention}(Q,K,V)= \mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V

Suppose sis_i is the pre-softmax attention score between query and the key at index ii:

si=qkiT s_i = qk_i^{T}

we can generalize the definition from single index to an index set:

s(I)=log(iIexp(si)) s(I) = \mathrm{log}(\sum_{i \in I}{\mathrm{exp}(s_i)})

value vector v\mathrm{v}:

v(I)=iIsoftmax(si)vi \mathrm{v}(I) = \sum_{i\in I}{\mathrm{softmax}(s_i)\mathrm{v}_i} v(1,2,...,n)\mathrm{v}({1,2,...,n}) 則是整個 sequence 的 self-attention output。

The attention state of the index set II can be defined as a tuple:

(s(I),v(I)) (s(I), \mathrm{v}(I))

Merge operator

定義兩個 state 的 merge \oplus:

[v(IJ)s(IJ)]=[v(I)s(I)][v(J)s(J)]=[v(I)exp(s(I))+v(J)exp(s(J))exp(s(I))+exp(s(J))logbigl(exp(s(I))+exp(s(J)))] \begin{bmatrix}\mathrm{v}(I \cup J)s(I \cup J)\end{bmatrix}=\begin{bmatrix}\mathrm{v}(I)s(I)\end{bmatrix}\oplus\begin{bmatrix}\mathrm{v}(J)s(J)\end{bmatrix}=\begin{bmatrix}\dfrac{\mathrm{v}(I)\exp(s(I)) + \mathrm{v}(J)\exp(s(J))}{\exp(s(I)) + \exp(s(J))}\log\\bigl(\exp(s(I)) + \exp(s(J))\bigr)\end{bmatrix}

對於任何長度的 attention state inputs:

[v(i=1nIi)s(i=1nIi)]=i=1n[v(Ii)s(Ii)]=[i=1nsoftmax(s(Ii)),v(Ii)logi=1nexp ⁣(s(Ii))]\begin{bmatrix}\mathrm{v}\left(\bigcup_{i=1}^n I_i\right)s\left(\bigcup_{i=1}^n I_i\right)\end{bmatrix}=\bigoplus_{i=1}^n\begin{bmatrix}\mathrm{v}(I_i)s(I_i)\end{bmatrix}=\begin{bmatrix}\displaystyle\sum_{i=1}^n\mathrm{softmax}(s(I_i)), \mathrm{v}(I_i)\displaystyle\log \sum_{i=1}^n \exp\!\bigl(s(I_i)\bigr)\end{bmatrix}

可以注意到這個 operator 是 communicative 和 associative 的,也就是說,對於任何長度的 attention sequence,可以以任何長度切割並合併計算,就是具有 devide & conquer 的性質。

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
/*!
* \brief The flashattention state.
* \tparam vec_size The size of the vector used in o.
*/
// vec_size在compile time可知幫助compiler快速unroll
// 這邊通常對應attention head_dim / 8
template <size_t vec_size>
struct state_t {
/* the weighted sum of v: exp(pre-softmax logit - m) * v / d */
vec_t<float, vec_size> o;
/* maximum value of pre-softmax logits */
float m;
/* sum of exp(pre-softmax logits - m) */
float d;

__device__ __forceinline__ void init() {
o.fill(0.f);
m = -math::inf;
d = 1.f;
}

__device__ __forceinline__ state_t() { init(); }

// log-sum-exp
__device__ __forceinline__ float get_lse() const { return m + math::ptx_log2(d); }

/*!
* \brief Merge the state with another state.
* \param other_m The maximum value of pre-softmax logits of the other state.
* \param other_d The sum of exp(pre-softmax logits - m) of the other state.
* \param other_o The weighted sum of v of the other state.
*/
__device__ __forceinline__ void merge(const vec_t<float, vec_size>& other_o, float other_m,
float other_d) {
float m_prev = m, d_prev = d;
m = max(m_prev, other_m);
d = d_prev * math::ptx_exp2(m_prev - m) + other_d * math::ptx_exp2(other_m - m);
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] = o[i] * math::ptx_exp2(m_prev - m) + other_o[i] * math::ptx_exp2(other_m - m);
}
}

/*!
* \brief Merge the state with another state.
* \param other The other state.
*/
__device__ __forceinline__ void merge(const state_t<vec_size>& other) {
merge(other.o, other.m, other.d);
}

__device__ __forceinline__ void normalize() {
// only normalize by d when not normalized on the fly
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] = __fdividef(o[i], d);
}
}
};

Applications

merge是commutative and associative的,所以我們可以把不同subset of KV offload到不同的device上之後再merge起來。這在許多long-context的設定上有許多好處。

Cascade Inference: Shared-Prefix Batch Decoding

Workflow:

  1. Use multi-query (prefill/append) attention kernel to compute the attention state between queries and KV-Cache of shared prefix.
  2. Use batch decode attention kernel to compute the attention state between queries and KV-Cache of unique suffixes.
  3. Use merge operator to combine two attention states to get the final attention output.

    上格是naive作法,對於多個request是獨立處理的,並把KV Cache存在L2 Cache/Global Memory,相對低效。若把前綴相同的部分share起來,shared KV Cache存在SMEM/Registers,在移動shared KV Cache上可以memory efficient。最高可以達到30x的效率收益。

KV Cache Layout

Flasshinfer在vLLM的Page Attention基礎上更進一步優化KV Cache。首先先了解什麽是Paged KV Cache:

Page Attention

核心思想就是把KV Cache用page的方式管理。在OS中,我們知道不同process之間基本上是互相獨立的,每個process覺得自己在使用獨立的空間,但他們使用的memory其實是分享同一個物理記憶體。同理,每一個LLM request都是獨立的,但是因為長短不一,如果真的每個request都維護一個記憶空間是很浪費GPU RAM的,為了防止記憶體碎片化,我們用一個page table,把邏輯上KV的儲存位置mapping到physical KV memory上面。

圖中的一個block就是一個page,在vLLM裡一個block可以裝16個token的KV data。
Flashinfer基於此之上,在資料結構層面上做出更急止的優化:

Layout: NHD/HND

KV Cache的最後三個維度,可以分成兩種layout:

1
2
NHD: (seq_len, num_heads, head_dim).
HND: (num_heads, seq_len, head_dim).

NHD是最自然的形式,HND在low-precision KV Cache上比較友善,但在fp16時,兩者不會差太多。

Ragged Tensor

在prefill stage會用Ragged Tensor把不同長度的QKV pack成單一個data tensor,省去padding帶來的空間消耗。

indptr array用來記錄一個data tensor裡不同長度seq的訊息(indptr[i+1]-indptr[i] is the sequence length of request i)。
data tensor size = (indptr[-1], num_heads, head_dim) under NHD.

Page Table Layout

FlashInfer把Paged KV中的Page table當作一個block sparse matrix,並使用CSR Format去index每個page。

When length = num_requests+1, kv_indptr = [0, len(page_indices[0]), len(page_indices[0])+len(page_indices[1]), …].

1
2
kv_cache_nhd = torch.empty(max_num_pages, 2, page_size, num_heads, head_dim, dtype=torch.bfloat16) # NHD layout
kv_cache_hnd = torch.empty(max_num_pages, 2, num_heads, page_size, head_dim, dtype=torch.bfloat16) # HND layout

Multi-level Cascade Inference Data Layout

若把這些layout實際應用在cascade attention上面的話,我們可以做multi-level的prefix reuse: 下圖levels=3

Conclusion

FlashInfer提供了多樣的inference kernel,在attention和kv cache上有良好的效率基礎,利於後序Inference Engine(sglang, vllm)進行開發,接下來我會深入Flash Attention三代的實作與計算,之後就來看LLM Inference Engine是如何schedule request, kv cache達成高效throughtput。

Comments