Stanford CS336 深度实战:从零实现大语言模型——数据清洗、Transformer 架构、FlashAttention 系统优化到 RL 对齐的完全指南(2026)
"像操作系统课让你写一个 OS 一样,这门课让你从零写一个 LLM。"—— Percy Liang
引言:为什么你需要从零实现一次 LLM?
2026 年,大语言模型已经成为基础设施。每个人都用过 ChatGPT、Claude、Gemini,但真正理解它们为什么 work 的人,远少于使用它们的人。
绝大多数工程师对 LLM 的理解停留在调用 API 的层面——知道 prompt 怎么写,知道 temperature 调多少,知道哪个模型"更聪明"。但当你问"Transformer 的 FFN 层为什么是 4d?"、"FlashAttention 到底省了什么?"、"数据去重为什么对 pretraining 这么关键?"——大部分人只能含糊其辞。
Stanford CS336: Language Modeling from Scratch 的出现,就是要解决这个问题。这不是一门"调包课",而是一门让你从原始 Common Crawl 网页开始,一步步写出一个有意义的 Transformer LM,并让它跑在分布式 GPU 上的课程。
本文是对这门课的完整深度解读——不仅介绍课程本身,更重要的是:把课程每一条主线的技术本质讲清楚,配可运行的代码示例,让你读完不仅能跟上课程,还能真正理解 LLM 的工程全貌。
第一部分:课程全景——你到底要写什么?
CS336 共 5 个 Assignment,覆盖了 LLM 的完整生命周期:
| Assignment | 核心任务 | 技术深度 |
|---|---|---|
| A1: Basics | 实现 Tokenizer + Transformer + Optimizer,训练一个可运行的 LM | ⭐⭐⭐ |
| A2: Systems | FlashAttention-2(Triton)、内存优化、分布式 DDP/FSDP | ⭐⭐⭐⭐⭐ |
| A3: Scaling | 消融实验 + 拟合 Scaling Law,预测大模型表现 | ⭐⭐⭐⭐ |
| A4: Data | 从 Common Crawl 原始 HTML 清洗出高质量语料,去重 | ⭐⭐⭐⭐ |
| A5: Alignment | SFT + RL(GRPO/PPO),训练模型做数学推理 | ⭐⭐⭐⭐⭐ |
课程的官方建议 GPU 规格(2026 年定价):B200 单卡,Modal $6.25/h,Lambda $6.69/h,RunPod $4.99/h。
先说结论:完成这门课,你写过的代码量会比你上过的绝大多数 AI 课都多——这正是它的价值所在。
第二部分:A1 深度实战——从 Tokenizer 到 Transformer,搭建第一个 LM
2.1 Tokenizer:BPE 从原理到实现
现代 LLM 不使用字符级 tokenization(词表太大、序列太长),也不使用单词级(OOV 问题严重)。Byte-Pair Encoding(BPE) 是 GPT 系列、Llama、Claude 的事实标准。
BPE 的核心思想极其简单:
- 从字符级词表开始
- 统计所有相邻 token 对的频率
- 合并最高频的 pair,加入词表
- 重复直到词表达到目标大小(通常 32K–100K)
关键工程细节:bytes vs str。GPT-4 之前的模型使用 str 级别的 BPE(UTF-8 感知),但这样对多语言不友好。现代实现(包括 GPT-4 和 Llama)直接在 byte 空间做 BPE——词表固定 256 个初始 token(每个 byte 值一个),合并的是 byte pair。
# cs336_basics/tokenizer.py 核心逻辑(课程风格实现)
from collections import Counter
import re
class BytePairEncoding:
def __init__(self, vocab_size: int):
self.vocab_size = vocab_size
self.merges = {} # (int, int) -> int,合并规则
self.vocab = {} # int -> bytes,最终词表
def train(self, corpus: list[str]):
# 1. 将文本转为 byte 序列,每个 byte 是一个 token
data = [list(t.encode("utf-8")) for t in corpus]
# 2. 迭代合并,直到词表达到 vocab_size
num_merges = self.vocab_size - 256 # 256 个初始 byte token
next_token_id = 256
for i in range(num_merges):
# 统计所有相邻 pair 的频率
pairs = Counter()
for seq in data:
for j in range(len(seq) - 1):
pairs[(seq[j], seq[j+1])] += 1
if not pairs:
break
# 找到最高频的 pair
best_pair = max(pairs, key=pairs.get)
self.merges[best_pair] = next_token_id
# 合并所有该 pair
new_data = []
for seq in data:
new_seq = []
j = 0
while j < len(seq):
if j < len(seq) - 1 and (seq[j], seq[j+1]) == best_pair:
new_seq.append(next_token_id)
j += 2
else:
new_seq.append(seq[j])
j += 1
new_data.append(new_seq)
data = new_data
next_token_id += 1
# 构建最终 vocab
for i in range(256):
self.vocab[i] = bytes([i])
for (a, b), tid in self.merges.items():
self.vocab[tid] = self.vocab[a] + self.vocab[b]
def encode(self, text: str) -> list[int]:
# 推理时使用训练好的 merges 进行编码
tokens = list(text.encode("utf-8"))
# 用与训练相同的合并顺序应用 merges
for (a, b), tid in self.merges.items():
i = 0
new_tokens = []
while i < len(tokens):
if i < len(tokens) - 1 and tokens[i] == a and tokens[i+1] == b:
new_tokens.append(tid)
i += 2
else:
new_tokens.append(tokens[i])
i += 1
tokens = new_tokens
return tokens
def decode(self, tokens: list[int]) -> str:
byte_seq = b"".join(self.vocab[t] for t in tokens)
return byte_seq.decode("utf-8", errors="replace")
这门课要求你自己实现上述逻辑,而不是调用 tiktoken 或 sentencepiece。当你自己写过一遍 BPE,你会真正理解为什么 "hello world" 在某些模型里被分成 ["hello", " world"] 而不是 ['h','e','l','l','o',' ','w','o','r','l','d']。
2.2 Transformer 架构:从「能跑」到「正确」
A1 的核心是实现完整 Transformer LM(Decoder-only)。课程使用 PyTorch,但几乎不给脚手架代码——你需要自己实现:
MultiHeadAttention(含 causal mask)FeedForward(通常 4d hidden dim,SwiGLU 激活)TransformerBlock(Pre-LN 还是 Post-LN?CS336 用的是 Pre-LN)TransformerLM(完整模型,含 embedding、position encoding)
Causal Mask 的实现细节(这是最容易出 bug 的地方):
# cs336_basics/model.py 核心片段
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
x: (batch, seq_len, d_model)
mask: (seq_len, seq_len) 的 causal mask,上三角为 -inf
"""
B, T, D = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
# q,k,v: (B, n_heads, T, d_head)
# 注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
# scores: (B, n_heads, T, T)
if mask is not None:
scores = scores + mask # mask 上三角为 -inf
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v) # (B, n_heads, T, d_head)
out = out.transpose(1, 2).contiguous().view(B, T, D)
return self.W_o(out)
Pre-LN vs Post-LN:原始 Transformer 论文使用 Post-LN(残差连接在外面,LayerNorm 在里面),但现代 LLM(GPT-2 之后)几乎全部使用 Pre-LN:
# Post-LN (原始 Transformer)
# x -> LayerNorm -> SubLayer -> dropout -> + x (残差)
# Pre-LN (现代 LLM 标准,训练更稳定)
# x -> SubLayer -> dropout -> + x (残差),然后 LayerNorm 在残差连接之后
# 实际上 PyTorch 的实现通常是:
# x = x + dropout(sublayer(F.layer_norm(x)))
课程作业中,你需要自己搞清楚这个问题,并解释为什么 Pre-LN 训练更稳定(hint:梯度流动的 analysis)。
2.3 优化器:AdamW 与权重衰减
CS336 要求你实现自己的 AdamW 优化器(而不是直接用 torch.optim.AdamW)。原因在于:你需要完全理解 Adam 的 bias correction 和 weight decay 与 L2 regularization 的区别。
# cs336_basics/optimizer.py
class AdamW(torch.optim.Optimizer):
def __init__(self, params, lr=3e-4, betas=(0.9, 0.95),
eps=1e-8, weight_decay=0.1):
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
for group in self.param_groups:
lr = group['lr']
beta1, beta2 = group['betas']
eps = group['eps']
wd = group['weight_decay']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
# 初始化 momentum 和 variance 状态
state = self.state[p]
if 'step' not in state:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
state['step'] += 1
step = state['step']
# Weight decay (解耦!不是 L2 reg)
p.mul_(1 - lr * wd)
# Adam 更新
state['exp_avg'].mul_(beta1).add_(grad, alpha=1-beta1)
state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=1-beta2)
# Bias correction
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
denom = (state['exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) + eps
step_size = lr / bias_correction1
p.addcdiv_(state['exp_avg'], denom, value=-step_size)
AdamW 的核心要点:Weight Decay 是直接作用在参数上的 p *= (1 - lr * wd),而不是加到 loss 里。这是与 L2 regularization 的本质区别——L2 会让大梯度的参数受到更小的衰减,而 Weight Decay 对所有参数一视同仁。
2.4 训练你的第一个 LM
A1 最后一步:在 TinyStories 数据集上训练一个小型 Transformer LM(~1.25 亿参数),验证 loss 下降。
# 下载数据
mkdir -p data
cd data
wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt
wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-valid.txt
cd ..
# 训练(使用你自己的实现)
uv run python train.py \
--train_data data/TinyStoriesV2-GPT4-train.txt \
--val_data data/TinyStoriesV2-GPT4-valid.txt \
--d_model 768 \
--n_layers 12 \
--n_heads 12 \
--batch_size 32 \
--lr 3e-4 \
--weight_decay 0.1 \
--max_seq_len 1024
TinyStories 是专为小模型训练设计的数据集(由 GPT-3.5/4 生成的高质量短故事),~2M 条训练样本。在这上面训练,一个 768d 的 12 层 Transformer 可以在单卡 GPU 上约 4-6 小时内收敛到有意义的生成效果。
第三部分:A2 深度实战——系统优化,让模型跑得更快
A1 的模型能跑,但慢得离谱。A2 的目标是把 A1 的模型优化到能在大规模 GPU 上高效训练。
3.1 FlashAttention-2:用 Triton 手写注意力
FlashAttention 的核心思想:Tiling + Recomputation。
标准 Attention 的计算是:
Q, K, V ∈ (B, n_heads, T, d_head)
Attention = softmax(QK^T / √d) V
问题是 QK^T 的大小是 (T, T),当 T=2048 时是 4M 个元素,T=8192 时是 67M——这个矩阵放不到 SRAM 里,必须存在 HBM(高带宽内存)里,而 HBM 的带宽比 SRAM 慢 10-50 倍。
FlashAttention 的解法:分块计算(Tiling)。把 Q、K、V 分成小块,每次只加载一块到 SRAM,增量更新 softmax(在线算法),永远不把完整的 (T, T) 注意力矩阵存下来。
课程要求你用 Triton(Python-like GPU 编程语言)实现 FlashAttention-2 的前向传播:
# cs336_systems/flash_attention.py (Triton 实现核心逻辑)
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q, K, V, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
sm_scale, # 1/√d
Z, H, N_CTX, D_HEAD,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
# 每个 program 处理一个 (batch, head, query_block) 组合
start_m = tl.program_id(0)
off_hz = tl.program_id(1) # batch * head 展平
# 计算 Q 块的指针
q_offset = off_hz * stride_qh + start_m * BLOCK_M * stride_qm
Q_block = tl.load(Q + q_offset + ...)
# ... 核心:分块加载 K、V,增量计算 softmax
# 关键技巧:online softmax(避免存储完整行)
为什么用 Triton 而不是 CUDA C++? Triton 是 OpenAI 开发的 Python-embedded GPU 编程语言,比 CUDA C++ 简洁 10 倍,同时性能接近手写 CUDA。PyTorch 2.x 的 scaled_dot_product_attention 底层就是 Triton 实现的 FlashAttention。
3.2 内存优化:Gradient Checkpointing
Transformer 的显存占用:假设 batch=16, T=2048, d_model=768, n_layers=12:
- 激活值(每层的输入/输出):
O(T * D * layers)≈ 非常大 - 梯度:与参数同大小
- Optimizer 状态(Adam):参数大小的 2 倍(momentum + variance)
Gradient Checkpointing(梯度检查点) 的核心 trade-off:用 30% 的额外计算,节省 50-70% 的显存。
原理:训练时只保存每层输入的激活值,反向传播时重新计算每层的激活值(而不是从前向传播保存的里面读)。CS336 要求你实现这个逻辑:
# cs336_systems/checkpoint.py
import torch
from torch.utils.checkpoint import checkpoint
# 方法1:PyTorch 内置(课程要求你自己实现逻辑)
def checkpoint_sequential_blocks(blocks, x):
"""
手动实现 gradient checkpointing:
前向时不保存 blocks 中间的激活值,反向时重新前向一次
"""
for block in blocks:
# 使用 torch.utils.checkpoint.checkpoint
# 它会在前向时"忘记"中间激活,反向时重新运行前向
x = checkpoint(block, x, use_reentrant=False)
return x
3.3 分布式训练:DDP 与 FSDP
当模型放不进单卡 GPU 时,需要分布式训练。CS336 覆盖两种主流范式:
DDP(Distributed Data Parallel):每卡放完整模型副本,数据切分,梯度 All-Reduce。
- 适用场景:单卡能放下模型(比如 < 30B 参数,FP16)
- 通信开销:每步需要 All-Reduce 梯度,通信量是
O(参数总量)
FSDP(Fully Sharded Data Parallel):每卡只放模型的一部分参数,前向/反向时临时加载需要的参数。
- 适用场景:单卡放不下模型(70B+ 参数)
- 通信开销:All-Gather 参数 + Reduce-Scatter 梯度
# cs336_systems/distributed.py
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
# FSDP 包装策略:每个 TransformerBlock 作为一个 sharding unit
def get_transformers_wrapper(model):
auto_wrap_policy = transformer_auto_wrap_policy(
transformer_layer_cls={TransformerBlock},
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16),
sharding_strategy=ShardingStrategy.FULL_SHARD,
device_id=torch.cuda.current_device(),
)
return model
FSDP 的核心理解:它不是"把模型切分到不同卡上"那么简单。它的运作方式是:前向传播时,对当前 layer 做 All-Gather(把所有卡上该 layer 的参数碎片聚合到每张卡),计算完立即释放。反向传播同理。这样每张卡只需要存当前 layer 的参数 + 激活值,而不是全部参数。
第四部分:A3 深度实战——Scaling Law,预测下一个突破
训练一个大模型(比如 70B 参数)需要数百万美元。在花钱之前,你需要知道:模型多大、数据多少、算力多少,才能达到目标性能?
这就是 Scaling Law 要回答的问题。
4.1 Kaplan 定律 vs Chinchilla 定律
2020 年 Kaplan 等人(OpenAI)发现:对于 decoder-only Transformer,Validation Loss 与模型参数量 N、数据量 D、算力 C 分别呈幂律关系:
L(N) ∝ N^{-α} (固定 D 和 C)
L(D) ∝ D^{-β} (固定 N 和 C)
L(C) ∝ C^{-γ} (固定 N 和 D)
但 2022 年 DeepMind 的 Chinchilla 论文推翻了 Kaplan 的结论:他们发现,之前的研究低估了数据的重要性。正确的 scaling 应该是:模型参数量和数据量应该 1:1 增长(每 1 个参数对应约 1-2 个 training token)。
Chinchilla 的核心发现:
- GPT-3(175B 参数,300B token)→ 数据不够!
- Chinchilla(70B 参数,1.4T token)→ 同等算力下,比 GPT-3 好得多
4.2 在 CS336 中拟合 Scaling Law
A3 的核心任务:在多个不同大小的模型上做短期训练,测量 loss,拟合 scaling curve,然后预测某个更大模型的最终 loss。
# cs336_basics/scaling.py
import numpy as np
from scipy.optimize import curve_fit
def power_law(N, A, alpha):
"""L(N) = A * N^{-alpha} + E_min"""
return A * (N ** (-alpha)) + E_min
# 在不同模型大小上训练,记录 final loss
model_sizes = np.array([125e6, 350e6, 760e6, 1.3e9, 2.7e9]) # 参数数量
val_losses = np.array([3.2, 2.8, 2.5, 2.3, 2.1])
# 拟合
popt, _ = curve_fit(power_law, model_sizes, val_losses)
A_fit, alpha_fit = popt[0], popt[1]
# 预测 7B 模型的 loss
predicted_loss = power_law(7e9, A_fit, alpha_fit)
print(f"Predicted loss for 7B model: {predicted_loss:.3f}")
课程的深层目标:让你亲手验证 scaling law 在你的训练设置下是否成立。绝大多数论文里的 scaling law 是在特定数据分布、特定模型架构下拟合的——换一个数据集,斜率可能完全不同。
第五部分:A4 深度实战——数据工程,LLM 的"隐形一半"
"Data is the new oil" 这句话在 LLM 时代格外正确。GPT-4、Claude、Gemini 的性能差异,有很大一部分来自数据质量和数据处理 pipeline,而不是架构创新。
5.1 Common Crawl:原始网页的泥潭
Common Crawl 是公开的网络爬虫快照,约 30-50TB 压缩数据,包含数百亿个网页。但它的质量……惨不忍睹:
- 大量重复内容(同一篇文章被几十个网站转载)
- 垃圾内容(SEO 农场、机器生成文本、色情、多语言混杂)
- 格式混乱(HTML 标签、JavaScript 代码、CSS)
- 个人信息(需要过滤)
CS336 A4 要求你实现一个完整的数据处理 pipeline:从原始 WARC 文件 → 清洗 → 去重 → 高质量语料。
5.2 数据清洗:Quality Filtering
课程介绍几种主流质量过滤方法:
1. Perplexity Filtering(困惑度过滤)
用一个小模型(比如 125M 参数的 GPT-2)对每段文本计算 perplexity。高质量文本(维基百科、书籍)的 perplexity 低;垃圾文本(机器生成、语法混乱)的 perplexity 高。设定阈值,过滤高 perplexity 的文本。
2. Classifier-based Filtering
训练一个二元分类器("高质量" vs "低质量"),用维基百科、书籍作为正例,垃圾网页作为负例。用这个分类器给 Common Crawl 文本打分。
3. 启发式规则
- 文本长度(太短的通常是菜单、导航)
- 特殊字符比例(JavaScript 代码片段)
- 语言识别(只用目标语言,比如英语)
- "诅咒词"比例(垃圾网站通常更多)
# cs336_basics/data_filter.py(课程风格)
import re
from langdetect import detect
def is_high_quality(text: str) -> bool:
# 1. 长度过滤
if len(text) < 100 or len(text) > 100_000:
return False
# 2. 语言检测(只保留英语)
try:
if detect(text) != "en":
return False
except:
return False
# 3. 特殊字符比例
special_char_ratio = len(re.findall(r'[^\w\s]', text)) / max(len(text), 1)
if special_char_ratio > 0.15:
return False
# 4. 重复 n-gram 检测(垃圾内容特征)
words = text.lower().split()
if len(words) < 50:
return False
bigrams = zip(words, words[1:])
bigram_counts = Counter(bigrams)
if bigram_counts.most_common(1)[0][1] > len(words) * 0.05:
return False # 太多重复 bigram,可能是垃圾
return True
5.3 去重:MinHash + LSH
即使经过质量过滤,Common Crawl 中仍有大量近似重复内容(同一篇文章略有改写的多个版本)。精确去重(比较 hash)不够,需要近似去重。
MinHash + LSH(Locality-Sensitive Hashing) 是工业标准方案:
- 对每段文本,提取 k-shingles(k 个连续词/字符的滑动窗口)
- 计算 MinHash 签名(用多个哈希函数,取每个函数的最小值)
- 用 LSH 将相似签名分到同一个 bucket
- 在同一个 bucket 内做精确去重
# cs336_basics/dedup.py(简化版 MinHash)
import hashlib
import numpy as np
def minhash_signature(text: str, num_hashes: int = 128) -> np.ndarray:
"""计算文本的 MinHash 签名"""
shingles = set(text.lower().split()) # 简化:用 word shingle
signature = np.inf * np.ones(num_hashes)
for i in range(num_hashes):
hash_func = lambda s: int(hashlib.md5(f"{i}_{s}".encode()).hexdigest(), 16)
for shingle in shingles:
h = hash_func(shingle)
signature[i] = min(signature[i], h)
return signature
def jaccard_estimate(sig1: np.ndarray, sig2: np.ndarray) -> float:
"""用 MinHash 签名估计 Jaccard 相似度"""
return np.mean(sig1 == sig2)
# 去重:如果两篇文章的 MinHash 相似度 > 0.5,认为是重复
实际工业应用:GPT-4 的训练数据去重使用了更精细的 scheme(包括 exact substring matching 和 MinHash LSH),并且对每个 domain 单独去重(避免跨 domain 误删)。
第六部分:A5 深度实战——对齐与推理 RL,让模型"会思考"
预训练后的模型只是一个"互联网文本压缩器"——它能续写文本,但不会"回答用户问题",也不会"逐步推理"。对齐(Alignment) 就是解决这个问题。
6.1 Supervised Fine-Tuning(SFT):从续写到对话
SFT 的核心:用"指令-回答"格式的数据集(比如 ShareGPT、OpenAssistant)继续训练预训练模型,让它学会对话格式。
# SFT 的训练格式(简化)
sft_format = """<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What is the capital of France?<|im_end|>
<|im_start|>assistant
The capital of France is Paris.<|im_end|>"""
# 在预训练 checkpoint 上继续训练,loss 只计算 assistant 部分的 token
# (system/user 部分的 token 不参与 loss 计算)
关键细节:SFT 阶段通常只训练 1-3 个 epoch,学习率比预训练小 10-100 倍。过长的 SFT 会导致灾难性遗忘(模型忘记预训练时学到的世界知识)。
6.2 RL for Reasoning:GRPO 与 PPO
让模型做数学推理(比如 GSM8K 数据集)需要 Reinforcement Learning。核心思路:
- 模型生成多个候选答案(采样,temperature > 0)
- 用 verifier(或者人工标注)给每个答案打分
- 用 RL 更新模型参数,让高分答案的概率增大
PPO(Proximal Policy Optimization) 是 RLHF 的标准算法,但计算开销大(需要单独的 value network)。
GRPO(Group Relative Policy Optimization):DeepSeek 提出的改进方案,不需要 value network。核心思想:对同一 prompt 生成 G 个回答,用这 G 个回答的相对排名来计算 advantage。
# cs336_basics/grpo.py(极度简化版)
def grpo_loss(model, ref_model, prompts, G=8, eps=0.2):
"""
GRPO loss 计算
prompts: 一批数学问题
G: 每个 prompt 生成 G 个回答
"""
loss = 0.0
for prompt in prompts:
# 1. 生成 G 个回答
generations = []
for g in range(G):
gen = model.generate(prompt, temperature=0.8, max_new_tokens=512)
generations.append(gen)
# 2. 计算每个回答的 reward(用 verifier)
rewards = [math_verifier(prompt, gen) for gen in generations]
# rewards: 比如 [0, 1, 0, 1, 1, 0, 0, 1]
# 3. 计算相对 advantage(关键!)
mean_reward = np.mean(rewards)
std_reward = np.std(rewards) + 1e-8
advantages = [(r - mean_reward) / std_reward for r in rewards]
# 4. 计算 policy gradient(带 clipping,类似 PPO)
for gen, adv in zip(generations, advantages):
log_prob_new = model.log_prob(prompt + gen)
log_prob_old = ref_model.log_prob(prompt + gen) # 停止梯度
ratio = torch.exp(log_prob_new - log_prob_old)
clipped_ratio = torch.clamp(ratio, 1 - eps, 1 + eps)
loss += -torch.min(ratio * adv, clipped_ratio * adv)
return loss / (len(prompts) * G)
GRPO 的核心优势:不需要训练 value network,节省了约 50% 的 GPU 显存和计算。这也是 DeepSeek-Math 和 DeepSeek-R1 能够高效训练的重要原因。
6.3 推理时计算:Chain-of-Thought 与 Verifier
对齐后的模型可以"思考"——但这不是自发产生的,而是通过 Chain-of-Thought(CoT) 训练出来的。
CoT 的核心:在 SFT 数据中,让 assistant 的回答包含推理步骤,而不仅仅是最终答案。比如:
User: 小明有 3 个苹果,他又买了 2 倍的小红有的苹果数。小红有 4 个苹果。小明现在有几个苹果?
Assistant(无 CoT): 11个。
Assistant(有 CoT):
我们来逐步分析:
1. 小红有 4 个苹果
2. 小明买了 2 倍的小红有的苹果数 = 2 × 4 = 8 个
3. 小明原有 3 个,加上买的 8 个
4. 3 + 8 = 11
答案:11个。
Verifier(验证器):在推理时,让模型生成多个 CoT 回答,然后用一个训练好的 verifier 模型选出"最可信"的答案。这是 Self-Consistency 和 Process Reward Model(PRM) 的核心思想。
第七部分:课程之外——2026 年的 LLM 工程全貌
完成 CS336 后,你已经具备了从零实现 LLM 的能力。但 2026 年的工业界,还有一些课程未覆盖但极其重要的话题:
7.1 推理优化:KV Cache 与 Speculative Decoding
KV Cache:自回归生成时,每个新 token 的注意力计算只需要 Q 是新的,K 和 V 可以复用之前所有 token 的——这就是 KV Cache。没有它,生成速度会随着序列长度线性下降。
Speculative Decoding:用一个小的 draft 模型快速生成 K 个候选 token,然后用大模型并行验证这 K 个 token,接受其中前缀正确的部分。可以将推理速度提升 2-3 倍。
7.2 量化:INT4/INT8 推理
2026 年,几乎所有生产级 LLM 推理都使用 INT4 或 INT8 量化(GPTQ、AWQ、GPT-OSS 等方案)。原理:把 FP16 的权重压缩到 4-bit 整数,推理时反量化为 FP16 进行计算。
量化带来的精度损失通常很小(尤其是 INT8),但显存占用可以减少 50-75%。
7.3 多模态:LLM 的下一步
2026 年的前沿模型(GPT-4o、Claude Opus 4、Gemini 2.5)都是多模态模型——能同时理解图像、音频和文本。核心架构:用视觉 encoder(比如 ViT)把图像转为 embedding,然后把这些 embedding 当成"特殊 token"拼接到文本 token 序列里,一起送进 Transformer。
总结:从零实现 LLM,你学到了什么?
CS336 不是一门轻松的课程。你需要:
- 写 Tokenization(BPE 从零实现)
- 写 Transformer(包括因果注意力、位置编码、Pre-LN)
- 写 AdamW(理解 bias correction 和 weight decay)
- 用 Triton 实现 FlashAttention-2
- 实现 Gradient Checkpointing 和 FSDP
- 处理 Common Crawl(清洗、去重、质量过滤)
- 实现 SFT 和 GRPO(对齐与推理 RL)
但完成它之后,你对 LLM 的理解将从"会调 API"升级到"知道每个字节在干什么"。这种深度理解,才是你在 AI 时代的核心竞争力。
参考资源
- 课程官网: https://cs336.stanford.edu/
- YouTube 讲义: https://www.youtube.com/watch?v=JuoVZkPBiKk&list=PLoROMvodv4rMqXOcazWaTUHhq-yembLCV
- Assignment 1 (Basics): https://github.com/stanford-cs336/assignment1-basics
- Assignment 2 (Systems): https://github.com/stanford-cs336/assignment2-systems
- Assignment 3 (Scaling): https://github.com/stanford-cs336/assignment3-scaling
- Assignment 4 (Data): https://github.com/stanford-cs336/assignment4-data
- Assignment 5 (Alignment): https://github.com/stanford-cs336/assignment5-alignment
- FlashAttention 论文: https://arxiv.org/abs/2205.14135
- Chinchilla 论文: https://arxiv.org/abs/2203.15556
- GRPO 论文(DeepSeek-Math): https://arxiv.org/abs/2402.03300
本文基于 Stanford CS336 Spring 2026 课程资料撰写,所有代码示例为教学目的的简化实现,生产环境请参考课程官方代码和主流开源实现(litgpt、nanochat、llama2.c)。