MLsys Note (0) - Qwen3 Code Walk Through

MLsys Note (0) - Qwen3 Code Walk Through

Kai-Jie Lin Lv3

最近想要深入學習一些MLsys的開源框架,像Sglang, vLLM和Slime之類的,光閱讀code和文章記憶不深,想說自己來寫一下筆記。第一篇先以最簡單的huggingface transformers為主,順便複習一下以前沒注意到的東西。

Transformers Code Walk Through

平常都在用Transformers做事,但裡面的架構卻不是很熟悉,想用看code的方式,去解析一下不同開源LLM的實作細節,這邊以Qwen3Qwen2為主,參雜一些個人理解和對其他模型的比較,不會到超級仔細,若是說錯了,歡迎指正。
這邊主要以huggingface的實作為主。

CausalLM

“Causal” 就是因果關係的意思,CausalLM就是泛指目前最常見從左到右的LLM,最新的token看得到過去的token,但過去的token看不到未來的token。

1
2
3
class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin)
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin)
class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin)

這是第一層interface,還沒有真的進入model,主要是為了LLM生成logits之後,可以做採樣上的控制(GenerationMixin),若是有loss function的話也會在這裡計算loss。
因為這邊forward function很單純,而且大同小異,就直接貼上來:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
cache_position: torch.LongTensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)

hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

lm_head 是一個 hidden_size X vocab_size的linear layer,從hidden states生成logits

Model

主要架構層,這裡會把input text從頭到尾生成last hidden state。

embed_tokens: Token embedding,從input_ids (1, 2, 3…) mapping 成hidden vector。

layers: config.num_hidden_layers定義了LLM需要幾層layers,layers越多model就越大。

norm: Qwen 2, 3 使用了 RMSnorm 取代傳統 layer norm,後面細講。

rotary_emb: Rotary Position Embedding. 這邊等到後面細講。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Qwen3Model(Qwen3PreTrainedModel):
def __init__(self, config: Qwen3Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.has_sliding_layers = "sliding_attention" in self.config.layer_types

# Initialize weights and apply final processing
self.post_init()

Forward function接受自定義input_embeds, attention_maks, position_ids和KV cache,可以有flexible的masking和forward的操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast

# 1. Embedding forward
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

# 2. 取出過去的key value
# 關於KV Cache 之後用另一篇討論
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)

# KV Cache的position
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

# 3. masking的部分拿llama的代替比較簡潔
# 我們input的attention_mask是1D的,若要造出2D的attention mask,可以用這個function
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)

hidden_states = inputs_embeds
# 4. Position embedding
position_embeddings = self.rotary_emb(hidden_states, position_ids)

# 5. Decoder forwards
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)

# 6. Last hidden states.
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
)

Decoder layers

這邊會分開來講 [layer norm -> attention -> mlp]。

Attention

這邊就不細講 Attention 怎麼算的,之後看 FlashInfer 再一起看 Flash Attn 等計算和代碼。

看一些一般不會注意到的實作細節。
Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Qwen3Attention(nn.Module):
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads # MHA & GQA
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True

self.q_proj = nn.Linear(
config.hidden_size,
config.num_attention_heads * self.head_dim,
bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim,
config.hidden_size,
bias=config.attention_bias
)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  1. 注意Qwen3的attention bias是可以被關閉的,而Qwen2是強制開啟。

    • Attn(Q,K,V)=softmax(QK/sqrt(d)​+bias)V
    • bias會讓llm有更強的instruction following,能夠更服從prompt,相對的關掉bias可以讓llm有更多exploration,這對reasoning model特別有用,因為CoT需要更廣的探索去獲取reasoning pattern。
  2. QK norm是Qwen3新加的,在Qwen2和llama等模型都沒有。

    • 原本attention score: sij=qi⊤​kjds_{ij}​=\frac{q_{i}⊤​k_{j}}{\sqrt{d}}.
    • 在沒有bias的之後,模型更新會直接由QK決定,若是訓練過於朝向某個token,QK值會變得很大,在這裏的RMSnorm是為了訓練穩定。
    • 除了防止訓練爆掉,還有一個是為了不讓RL訓練overfit到某些token,在RL訓練中,某些token可能導致獲得高reward,這屬於某種reward hacking,這邊的normalize是為了RL訓練的穩定。
    • 為什麼只對QK做norm? 注意這邊只對head dim做norm,意思就是說為了不消除不同head之間的訊息量,所以才只對head內做norm,若是V也做同樣的事,則會失去不同head之間的語意訊息。
  3. ​Group Query Attention / Multi-head Attention

    • 在Attention裡每個head代表學習不同的語意和邏輯,就是不同的sub-space,當然也可以被共享。
    • 就像上圖一樣,最naive的MHA每個QKV都是獨立的,也消耗最多的算力(更多的KV weight和KV Cache存取)。
    • MQA是另一極端,所有V享有一個head,最有效率的同時也犧牲表現力,只在一些專門解small task的模型中使用。
    • GQA是兼顧效率和表現,享有較少算力的同時兼顧表現,現代模型大多使用這種架構。

Rotary Positional Embedding (RoPE)

比較一下Absolute PE和Relative PE

  1. Absolute: x_t → x_t + pos_embedding[t]
    • 沒有相對位置的資訊
  2. Relative: QK^T + bias(i - j)
    • bias需要另外訓練

RoPE概念是利用旋轉角度給予位置資訊,RoPE 對每個 position t,把 QK 在 2D 子空間裡旋轉一個角度θt。賦予相對位置資訊的同時,也不用另外學習。當然,當 context length 拉得超級長的時候,短距離位置之間的資訊量會被壓平,這時候會用 yarn 之類的變體來處理(Deepseek)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Qwen3RotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`

def __init__(self, config: Qwen3Config, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config

self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)

self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)

inv_freq: 決定每一個維度的旋轉頻率。這邊 register_buffer 之後就是一個固定參數,不會被train。

rope_type: 這邊定義了不同種 rope,像是有處理過長 context 的 yarn。

計算 inv_freq (簡化):對於第i個位置,分配一個角度。

1
2
3
4
5
6
def compute_default_rope_parameters(config, device, seq_len=None):
base = config.rope_parameters["rope_theta"] # 通常是10000
dim = config.head_dim
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2) / dim)
)

Forward function: (簡化)

1
2
3
4
5
6
def forward(self, x, position_ids):
freqs = (inv_freq @ position_ids).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

這邊比較簡單,就是對每個position_ids賦予cos和sin作為positional embedding。

MLP

每個layer後面都會接一個MLP,之前一直不知道為什麼後面接一個這個,稍微研究一下才知道是為了引入「非線性」。

就像linear layer後面會有ReLU一樣,Decoder block最後也會有MLP。


Qwen2,3都採用Gated MLP

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Qwen3MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
  1. 注意到 intermediate_size = 4*hidden_size
    • 查資料才知道MLP是模型容量的主要來源 我原本一直以為把decoder層數scale up就會有容量 我錯了
  2. Gated MLP
    • 選擇性開關資訊流的MLP
    • Gated MLP = token-level、連續版的 MoE

RMSNorm

最後來談談 RMSNorm。

跟一般 layer norm 的主要不同就是只對 vector 長度正規化,不會對方向正規劃。換句話說不會引入其他 bias。

  • Layer Norm: [x-mean(x)]/std(x)
  • RMS Norm: x / ||x||
1
2
3
4
5
6
7
8
9
10
11
12
class Qwen3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

Conlusion

原本想寫一下 MOE 的,但還是等下一篇好了,之後也會談談Flash Infer等 inference kernel,最後會去看 Sglang vLLM。可能也會看看 LoRA 等一些 MLsys 技巧。

Comments