编程 Stanford CS336 深度实战:从零实现大语言模型——数据清洗、Transformer 架构、FlashAttention 系统优化到 RL 对齐的完全指南(2026)

2026-06-02 20:14:38 +0800 CST views 9

Stanford CS336 深度实战:从零实现大语言模型——数据清洗、Transformer 架构、FlashAttention 系统优化到 RL 对齐的完全指南(2026)

"像操作系统课让你写一个 OS 一样,这门课让你从零写一个 LLM。"—— Percy Liang

引言:为什么你需要从零实现一次 LLM?

2026 年,大语言模型已经成为基础设施。每个人都用过 ChatGPT、Claude、Gemini,但真正理解它们为什么 work 的人,远少于使用它们的人

绝大多数工程师对 LLM 的理解停留在调用 API 的层面——知道 prompt 怎么写,知道 temperature 调多少,知道哪个模型"更聪明"。但当你问"Transformer 的 FFN 层为什么是 4d?"、"FlashAttention 到底省了什么?"、"数据去重为什么对 pretraining 这么关键?"——大部分人只能含糊其辞。

Stanford CS336: Language Modeling from Scratch 的出现,就是要解决这个问题。这不是一门"调包课",而是一门让你从原始 Common Crawl 网页开始,一步步写出一个有意义的 Transformer LM,并让它跑在分布式 GPU 上的课程。

本文是对这门课的完整深度解读——不仅介绍课程本身,更重要的是:把课程每一条主线的技术本质讲清楚,配可运行的代码示例,让你读完不仅能跟上课程,还能真正理解 LLM 的工程全貌。


第一部分:课程全景——你到底要写什么?

CS336 共 5 个 Assignment,覆盖了 LLM 的完整生命周期:

Assignment核心任务技术深度
A1: Basics实现 Tokenizer + Transformer + Optimizer,训练一个可运行的 LM⭐⭐⭐
A2: SystemsFlashAttention-2(Triton)、内存优化、分布式 DDP/FSDP⭐⭐⭐⭐⭐
A3: Scaling消融实验 + 拟合 Scaling Law,预测大模型表现⭐⭐⭐⭐
A4: Data从 Common Crawl 原始 HTML 清洗出高质量语料,去重⭐⭐⭐⭐
A5: AlignmentSFT + RL(GRPO/PPO),训练模型做数学推理⭐⭐⭐⭐⭐

课程的官方建议 GPU 规格(2026 年定价):B200 单卡,Modal $6.25/h,Lambda $6.69/h,RunPod $4.99/h。

先说结论:完成这门课,你写过的代码量会比你上过的绝大多数 AI 课都多——这正是它的价值所在。


第二部分:A1 深度实战——从 Tokenizer 到 Transformer,搭建第一个 LM

2.1 Tokenizer:BPE 从原理到实现

现代 LLM 不使用字符级 tokenization(词表太大、序列太长),也不使用单词级(OOV 问题严重)。Byte-Pair Encoding(BPE) 是 GPT 系列、Llama、Claude 的事实标准。

BPE 的核心思想极其简单:

  1. 从字符级词表开始
  2. 统计所有相邻 token 对的频率
  3. 合并最高频的 pair,加入词表
  4. 重复直到词表达到目标大小(通常 32K–100K)

关键工程细节bytes vs str。GPT-4 之前的模型使用 str 级别的 BPE(UTF-8 感知),但这样对多语言不友好。现代实现(包括 GPT-4 和 Llama)直接在 byte 空间做 BPE——词表固定 256 个初始 token(每个 byte 值一个),合并的是 byte pair。

# cs336_basics/tokenizer.py 核心逻辑(课程风格实现)
from collections import Counter
import re

class BytePairEncoding:
    def __init__(self, vocab_size: int):
        self.vocab_size = vocab_size
        self.merges = {}        # (int, int) -> int,合并规则
        self.vocab = {}         # int -> bytes,最终词表

    def train(self, corpus: list[str]):
        # 1. 将文本转为 byte 序列,每个 byte 是一个 token
        data = [list(t.encode("utf-8")) for t in corpus]
        
        # 2. 迭代合并,直到词表达到 vocab_size
        num_merges = self.vocab_size - 256  # 256 个初始 byte token
        next_token_id = 256
        
        for i in range(num_merges):
            # 统计所有相邻 pair 的频率
            pairs = Counter()
            for seq in data:
                for j in range(len(seq) - 1):
                    pairs[(seq[j], seq[j+1])] += 1
            
            if not pairs:
                break
            
            # 找到最高频的 pair
            best_pair = max(pairs, key=pairs.get)
            self.merges[best_pair] = next_token_id
            
            # 合并所有该 pair
            new_data = []
            for seq in data:
                new_seq = []
                j = 0
                while j < len(seq):
                    if j < len(seq) - 1 and (seq[j], seq[j+1]) == best_pair:
                        new_seq.append(next_token_id)
                        j += 2
                    else:
                        new_seq.append(seq[j])
                        j += 1
                new_data.append(new_seq)
            data = new_data
            next_token_id += 1
        
        # 构建最终 vocab
        for i in range(256):
            self.vocab[i] = bytes([i])
        for (a, b), tid in self.merges.items():
            self.vocab[tid] = self.vocab[a] + self.vocab[b]

    def encode(self, text: str) -> list[int]:
        # 推理时使用训练好的 merges 进行编码
        tokens = list(text.encode("utf-8"))
        # 用与训练相同的合并顺序应用 merges
        for (a, b), tid in self.merges.items():
            i = 0
            new_tokens = []
            while i < len(tokens):
                if i < len(tokens) - 1 and tokens[i] == a and tokens[i+1] == b:
                    new_tokens.append(tid)
                    i += 2
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            tokens = new_tokens
        return tokens

    def decode(self, tokens: list[int]) -> str:
        byte_seq = b"".join(self.vocab[t] for t in tokens)
        return byte_seq.decode("utf-8", errors="replace")

这门课要求你自己实现上述逻辑,而不是调用 tiktokensentencepiece。当你自己写过一遍 BPE,你会真正理解为什么 "hello world" 在某些模型里被分成 ["hello", " world"] 而不是 ['h','e','l','l','o',' ','w','o','r','l','d']

2.2 Transformer 架构:从「能跑」到「正确」

A1 的核心是实现完整 Transformer LM(Decoder-only)。课程使用 PyTorch,但几乎不给脚手架代码——你需要自己实现:

  • MultiHeadAttention(含 causal mask)
  • FeedForward(通常 4d hidden dim,SwiGLU 激活)
  • TransformerBlock(Pre-LN 还是 Post-LN?CS336 用的是 Pre-LN)
  • TransformerLM(完整模型,含 embedding、position encoding)

Causal Mask 的实现细节(这是最容易出 bug 的地方):

# cs336_basics/model.py 核心片段
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        x: (batch, seq_len, d_model)
        mask: (seq_len, seq_len) 的 causal mask,上三角为 -inf
        """
        B, T, D = x.shape
        q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        k = self.W_k(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        # q,k,v: (B, n_heads, T, d_head)
        
        # 注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
        # scores: (B, n_heads, T, T)
        
        if mask is not None:
            scores = scores + mask  # mask 上三角为 -inf
        
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        out = torch.matmul(attn, v)  # (B, n_heads, T, d_head)
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        return self.W_o(out)

Pre-LN vs Post-LN:原始 Transformer 论文使用 Post-LN(残差连接在外面,LayerNorm 在里面),但现代 LLM(GPT-2 之后)几乎全部使用 Pre-LN

# Post-LN (原始 Transformer)
# x -> LayerNorm -> SubLayer -> dropout -> + x (残差)

# Pre-LN (现代 LLM 标准,训练更稳定)
# x -> SubLayer -> dropout -> + x (残差),然后 LayerNorm 在残差连接之后
# 实际上 PyTorch 的实现通常是:
# x = x + dropout(sublayer(F.layer_norm(x)))

课程作业中,你需要自己搞清楚这个问题,并解释为什么 Pre-LN 训练更稳定(hint:梯度流动的 analysis)。

2.3 优化器:AdamW 与权重衰减

CS336 要求你实现自己的 AdamW 优化器(而不是直接用 torch.optim.AdamW)。原因在于:你需要完全理解 Adam 的 bias correctionweight decay 与 L2 regularization 的区别

# cs336_basics/optimizer.py
class AdamW(torch.optim.Optimizer):
    def __init__(self, params, lr=3e-4, betas=(0.9, 0.95),
                 eps=1e-8, weight_decay=0.1):
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            lr = group['lr']
            beta1, beta2 = group['betas']
            eps = group['eps']
            wd = group['weight_decay']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad
                
                # 初始化 momentum 和 variance 状态
                state = self.state[p]
                if 'step' not in state:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                
                state['step'] += 1
                step = state['step']
                
                # Weight decay (解耦!不是 L2 reg)
                p.mul_(1 - lr * wd)
                
                # Adam 更新
                state['exp_avg'].mul_(beta1).add_(grad, alpha=1-beta1)
                state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=1-beta2)
                
                # Bias correction
                bias_correction1 = 1 - beta1 ** step
                bias_correction2 = 1 - beta2 ** step
                
                denom = (state['exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) + eps
                step_size = lr / bias_correction1
                
                p.addcdiv_(state['exp_avg'], denom, value=-step_size)

AdamW 的核心要点:Weight Decay 是直接作用在参数上的 p *= (1 - lr * wd),而不是加到 loss 里。这是与 L2 regularization 的本质区别——L2 会让大梯度的参数受到更小的衰减,而 Weight Decay 对所有参数一视同仁。

2.4 训练你的第一个 LM

A1 最后一步:在 TinyStories 数据集上训练一个小型 Transformer LM(~1.25 亿参数),验证 loss 下降。

# 下载数据
mkdir -p data
cd data
wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt
wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-valid.txt
cd ..

# 训练(使用你自己的实现)
uv run python train.py \
  --train_data data/TinyStoriesV2-GPT4-train.txt \
  --val_data data/TinyStoriesV2-GPT4-valid.txt \
  --d_model 768 \
  --n_layers 12 \
  --n_heads 12 \
  --batch_size 32 \
  --lr 3e-4 \
  --weight_decay 0.1 \
  --max_seq_len 1024

TinyStories 是专为小模型训练设计的数据集(由 GPT-3.5/4 生成的高质量短故事),~2M 条训练样本。在这上面训练,一个 768d 的 12 层 Transformer 可以在单卡 GPU 上约 4-6 小时内收敛到有意义的生成效果。


第三部分:A2 深度实战——系统优化,让模型跑得更快

A1 的模型能跑,但慢得离谱。A2 的目标是把 A1 的模型优化到能在大规模 GPU 上高效训练。

3.1 FlashAttention-2:用 Triton 手写注意力

FlashAttention 的核心思想:Tiling + Recomputation

标准 Attention 的计算是:

Q, K, V ∈ (B, n_heads, T, d_head)
Attention = softmax(QK^T / √d) V

问题是 QK^T 的大小是 (T, T),当 T=2048 时是 4M 个元素,T=8192 时是 67M——这个矩阵放不到 SRAM 里,必须存在 HBM(高带宽内存)里,而 HBM 的带宽比 SRAM 慢 10-50 倍

FlashAttention 的解法:分块计算(Tiling)。把 Q、K、V 分成小块,每次只加载一块到 SRAM,增量更新 softmax(在线算法),永远不把完整的 (T, T) 注意力矩阵存下来

课程要求你用 Triton(Python-like GPU 编程语言)实现 FlashAttention-2 的前向传播:

# cs336_systems/flash_attention.py (Triton 实现核心逻辑)
import triton
import triton.language as tl

@triton.jit
def _fwd_kernel(
    Q, K, V, Out,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vn, stride_vk,
    sm_scale,  # 1/√d
    Z, H, N_CTX, D_HEAD,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    # 每个 program 处理一个 (batch, head, query_block) 组合
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)  # batch * head 展平
    
    # 计算 Q 块的指针
    q_offset = off_hz * stride_qh + start_m * BLOCK_M * stride_qm
    Q_block = tl.load(Q + q_offset + ...)
    # ... 核心:分块加载 K、V,增量计算 softmax
    # 关键技巧:online softmax(避免存储完整行)

为什么用 Triton 而不是 CUDA C++? Triton 是 OpenAI 开发的 Python-embedded GPU 编程语言,比 CUDA C++ 简洁 10 倍,同时性能接近手写 CUDA。PyTorch 2.x 的 scaled_dot_product_attention 底层就是 Triton 实现的 FlashAttention

3.2 内存优化:Gradient Checkpointing

Transformer 的显存占用:假设 batch=16, T=2048, d_model=768, n_layers=12:

  • 激活值(每层的输入/输出):O(T * D * layers) ≈ 非常大
  • 梯度:与参数同大小
  • Optimizer 状态(Adam):参数大小的 2 倍(momentum + variance)

Gradient Checkpointing(梯度检查点) 的核心 trade-off:用 30% 的额外计算,节省 50-70% 的显存

原理:训练时只保存每层输入的激活值,反向传播时重新计算每层的激活值(而不是从前向传播保存的里面读)。CS336 要求你实现这个逻辑:

# cs336_systems/checkpoint.py
import torch
from torch.utils.checkpoint import checkpoint

# 方法1:PyTorch 内置(课程要求你自己实现逻辑)
def checkpoint_sequential_blocks(blocks, x):
    """
    手动实现 gradient checkpointing:
    前向时不保存 blocks 中间的激活值,反向时重新前向一次
    """
    for block in blocks:
        # 使用 torch.utils.checkpoint.checkpoint
        # 它会在前向时"忘记"中间激活,反向时重新运行前向
        x = checkpoint(block, x, use_reentrant=False)
    return x

3.3 分布式训练:DDP 与 FSDP

当模型放不进单卡 GPU 时,需要分布式训练。CS336 覆盖两种主流范式:

DDP(Distributed Data Parallel):每卡放完整模型副本,数据切分,梯度 All-Reduce。

  • 适用场景:单卡能放下模型(比如 < 30B 参数,FP16)
  • 通信开销:每步需要 All-Reduce 梯度,通信量是 O(参数总量)

FSDP(Fully Sharded Data Parallel):每卡只放模型的一部分参数,前向/反向时临时加载需要的参数。

  • 适用场景:单卡放不下模型(70B+ 参数)
  • 通信开销:All-Gather 参数 + Reduce-Scatter 梯度
# cs336_systems/distributed.py
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# FSDP 包装策略:每个 TransformerBlock 作为一个 sharding unit
def get_transformers_wrapper(model):
    auto_wrap_policy = transformer_auto_wrap_policy(
        transformer_layer_cls={TransformerBlock},
    )
    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=MixedPrecision(param_dtype=torch.bfloat16),
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        device_id=torch.cuda.current_device(),
    )
    return model

FSDP 的核心理解:它不是"把模型切分到不同卡上"那么简单。它的运作方式是:前向传播时,对当前 layer 做 All-Gather(把所有卡上该 layer 的参数碎片聚合到每张卡),计算完立即释放。反向传播同理。这样每张卡只需要存当前 layer 的参数 + 激活值,而不是全部参数。


第四部分:A3 深度实战——Scaling Law,预测下一个突破

训练一个大模型(比如 70B 参数)需要数百万美元。在花钱之前,你需要知道:模型多大、数据多少、算力多少,才能达到目标性能?

这就是 Scaling Law 要回答的问题。

4.1 Kaplan 定律 vs Chinchilla 定律

2020 年 Kaplan 等人(OpenAI)发现:对于 decoder-only Transformer,Validation Loss 与模型参数量 N、数据量 D、算力 C 分别呈幂律关系

L(N) ∝ N^{-α}  (固定 D 和 C)
L(D) ∝ D^{-β}  (固定 N 和 C)
L(C) ∝ C^{-γ}  (固定 N 和 D)

但 2022 年 DeepMind 的 Chinchilla 论文推翻了 Kaplan 的结论:他们发现,之前的研究低估了数据的重要性。正确的 scaling 应该是:模型参数量和数据量应该 1:1 增长(每 1 个参数对应约 1-2 个 training token)。

Chinchilla 的核心发现:

  • GPT-3(175B 参数,300B token)→ 数据不够!
  • Chinchilla(70B 参数,1.4T token)→ 同等算力下,比 GPT-3 好得多

4.2 在 CS336 中拟合 Scaling Law

A3 的核心任务:在多个不同大小的模型上做短期训练,测量 loss,拟合 scaling curve,然后预测某个更大模型的最终 loss。

# cs336_basics/scaling.py
import numpy as np
from scipy.optimize import curve_fit

def power_law(N, A, alpha):
    """L(N) = A * N^{-alpha} + E_min"""
    return A * (N ** (-alpha)) + E_min

# 在不同模型大小上训练,记录 final loss
model_sizes = np.array([125e6, 350e6, 760e6, 1.3e9, 2.7e9])  # 参数数量
val_losses = np.array([3.2, 2.8, 2.5, 2.3, 2.1])

# 拟合
popt, _ = curve_fit(power_law, model_sizes, val_losses)
A_fit, alpha_fit = popt[0], popt[1]

# 预测 7B 模型的 loss
predicted_loss = power_law(7e9, A_fit, alpha_fit)
print(f"Predicted loss for 7B model: {predicted_loss:.3f}")

课程的深层目标:让你亲手验证 scaling law 在你的训练设置下是否成立。绝大多数论文里的 scaling law 是在特定数据分布、特定模型架构下拟合的——换一个数据集,斜率可能完全不同


第五部分:A4 深度实战——数据工程,LLM 的"隐形一半"

"Data is the new oil" 这句话在 LLM 时代格外正确。GPT-4、Claude、Gemini 的性能差异,有很大一部分来自数据质量和数据处理 pipeline,而不是架构创新。

5.1 Common Crawl:原始网页的泥潭

Common Crawl 是公开的网络爬虫快照,约 30-50TB 压缩数据,包含数百亿个网页。但它的质量……惨不忍睹

  • 大量重复内容(同一篇文章被几十个网站转载)
  • 垃圾内容(SEO 农场、机器生成文本、色情、多语言混杂)
  • 格式混乱(HTML 标签、JavaScript 代码、CSS)
  • 个人信息(需要过滤)

CS336 A4 要求你实现一个完整的数据处理 pipeline:从原始 WARC 文件 → 清洗 → 去重 → 高质量语料。

5.2 数据清洗:Quality Filtering

课程介绍几种主流质量过滤方法:

1. Perplexity Filtering(困惑度过滤)
用一个小模型(比如 125M 参数的 GPT-2)对每段文本计算 perplexity。高质量文本(维基百科、书籍)的 perplexity 低;垃圾文本(机器生成、语法混乱)的 perplexity 高。设定阈值,过滤高 perplexity 的文本。

2. Classifier-based Filtering
训练一个二元分类器("高质量" vs "低质量"),用维基百科、书籍作为正例,垃圾网页作为负例。用这个分类器给 Common Crawl 文本打分。

3. 启发式规则

  • 文本长度(太短的通常是菜单、导航)
  • 特殊字符比例(JavaScript 代码片段)
  • 语言识别(只用目标语言,比如英语)
  • "诅咒词"比例(垃圾网站通常更多)
# cs336_basics/data_filter.py(课程风格)
import re
from langdetect import detect

def is_high_quality(text: str) -> bool:
    # 1. 长度过滤
    if len(text) < 100 or len(text) > 100_000:
        return False
    
    # 2. 语言检测(只保留英语)
    try:
        if detect(text) != "en":
            return False
    except:
        return False
    
    # 3. 特殊字符比例
    special_char_ratio = len(re.findall(r'[^\w\s]', text)) / max(len(text), 1)
    if special_char_ratio > 0.15:
        return False
    
    # 4. 重复 n-gram 检测(垃圾内容特征)
    words = text.lower().split()
    if len(words) < 50:
        return False
    bigrams = zip(words, words[1:])
    bigram_counts = Counter(bigrams)
    if bigram_counts.most_common(1)[0][1] > len(words) * 0.05:
        return False  # 太多重复 bigram,可能是垃圾
    
    return True

5.3 去重:MinHash + LSH

即使经过质量过滤,Common Crawl 中仍有大量近似重复内容(同一篇文章略有改写的多个版本)。精确去重(比较 hash)不够,需要近似去重

MinHash + LSH(Locality-Sensitive Hashing) 是工业标准方案:

  1. 对每段文本,提取 k-shingles(k 个连续词/字符的滑动窗口)
  2. 计算 MinHash 签名(用多个哈希函数,取每个函数的最小值)
  3. 用 LSH 将相似签名分到同一个 bucket
  4. 在同一个 bucket 内做精确去重
# cs336_basics/dedup.py(简化版 MinHash)
import hashlib
import numpy as np

def minhash_signature(text: str, num_hashes: int = 128) -> np.ndarray:
    """计算文本的 MinHash 签名"""
    shingles = set(text.lower().split())  # 简化:用 word shingle
    signature = np.inf * np.ones(num_hashes)
    
    for i in range(num_hashes):
        hash_func = lambda s: int(hashlib.md5(f"{i}_{s}".encode()).hexdigest(), 16)
        for shingle in shingles:
            h = hash_func(shingle)
            signature[i] = min(signature[i], h)
    
    return signature

def jaccard_estimate(sig1: np.ndarray, sig2: np.ndarray) -> float:
    """用 MinHash 签名估计 Jaccard 相似度"""
    return np.mean(sig1 == sig2)

# 去重:如果两篇文章的 MinHash 相似度 > 0.5,认为是重复

实际工业应用:GPT-4 的训练数据去重使用了更精细的 scheme(包括 exact substring matchingMinHash LSH),并且对每个 domain 单独去重(避免跨 domain 误删)。


第六部分:A5 深度实战——对齐与推理 RL,让模型"会思考"

预训练后的模型只是一个"互联网文本压缩器"——它能续写文本,但不会"回答用户问题",也不会"逐步推理"。对齐(Alignment) 就是解决这个问题。

6.1 Supervised Fine-Tuning(SFT):从续写到对话

SFT 的核心:用"指令-回答"格式的数据集(比如 ShareGPT、OpenAssistant)继续训练预训练模型,让它学会对话格式

# SFT 的训练格式(简化)
sft_format = """<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What is the capital of France?<|im_end|>
<|im_start|>assistant
The capital of France is Paris.<|im_end|>"""

# 在预训练 checkpoint 上继续训练,loss 只计算 assistant 部分的 token
# (system/user 部分的 token 不参与 loss 计算)

关键细节:SFT 阶段通常只训练 1-3 个 epoch,学习率比预训练小 10-100 倍。过长的 SFT 会导致灾难性遗忘(模型忘记预训练时学到的世界知识)。

6.2 RL for Reasoning:GRPO 与 PPO

让模型做数学推理(比如 GSM8K 数据集)需要 Reinforcement Learning。核心思路:

  1. 模型生成多个候选答案(采样,temperature > 0)
  2. 用 verifier(或者人工标注)给每个答案打分
  3. 用 RL 更新模型参数,让高分答案的概率增大

PPO(Proximal Policy Optimization) 是 RLHF 的标准算法,但计算开销大(需要单独的 value network)。

GRPO(Group Relative Policy Optimization):DeepSeek 提出的改进方案,不需要 value network。核心思想:对同一 prompt 生成 G 个回答,用这 G 个回答的相对排名来计算 advantage。

# cs336_basics/grpo.py(极度简化版)
def grpo_loss(model, ref_model, prompts, G=8, eps=0.2):
    """
    GRPO loss 计算
    prompts: 一批数学问题
    G: 每个 prompt 生成 G 个回答
    """
    loss = 0.0
    for prompt in prompts:
        # 1. 生成 G 个回答
        generations = []
        for g in range(G):
            gen = model.generate(prompt, temperature=0.8, max_new_tokens=512)
            generations.append(gen)
        
        # 2. 计算每个回答的 reward(用 verifier)
        rewards = [math_verifier(prompt, gen) for gen in generations]
        # rewards: 比如 [0, 1, 0, 1, 1, 0, 0, 1]
        
        # 3. 计算相对 advantage(关键!)
        mean_reward = np.mean(rewards)
        std_reward = np.std(rewards) + 1e-8
        advantages = [(r - mean_reward) / std_reward for r in rewards]
        
        # 4. 计算 policy gradient(带 clipping,类似 PPO)
        for gen, adv in zip(generations, advantages):
            log_prob_new = model.log_prob(prompt + gen)
            log_prob_old = ref_model.log_prob(prompt + gen)  # 停止梯度
            ratio = torch.exp(log_prob_new - log_prob_old)
            
            clipped_ratio = torch.clamp(ratio, 1 - eps, 1 + eps)
            loss += -torch.min(ratio * adv, clipped_ratio * adv)
    
    return loss / (len(prompts) * G)

GRPO 的核心优势:不需要训练 value network,节省了约 50% 的 GPU 显存和计算。这也是 DeepSeek-Math 和 DeepSeek-R1 能够高效训练的重要原因。

6.3 推理时计算:Chain-of-Thought 与 Verifier

对齐后的模型可以"思考"——但这不是自发产生的,而是通过 Chain-of-Thought(CoT) 训练出来的。

CoT 的核心:在 SFT 数据中,让 assistant 的回答包含推理步骤,而不仅仅是最终答案。比如:

User: 小明有 3 个苹果,他又买了 2 倍的小红有的苹果数。小红有 4 个苹果。小明现在有几个苹果?

Assistant(无 CoT): 11个。

Assistant(有 CoT):
我们来逐步分析:
1. 小红有 4 个苹果
2. 小明买了 2 倍的小红有的苹果数 = 2 × 4 = 8 个
3. 小明原有 3 个,加上买的 8 个
4. 3 + 8 = 11
答案:11个。

Verifier(验证器):在推理时,让模型生成多个 CoT 回答,然后用一个训练好的 verifier 模型选出"最可信"的答案。这是 Self-ConsistencyProcess Reward Model(PRM) 的核心思想。


第七部分:课程之外——2026 年的 LLM 工程全貌

完成 CS336 后,你已经具备了从零实现 LLM 的能力。但 2026 年的工业界,还有一些课程未覆盖但极其重要的话题:

7.1 推理优化:KV Cache 与 Speculative Decoding

KV Cache:自回归生成时,每个新 token 的注意力计算只需要 Q 是新的,K 和 V 可以复用之前所有 token 的——这就是 KV Cache。没有它,生成速度会随着序列长度线性下降。

Speculative Decoding:用一个小的 draft 模型快速生成 K 个候选 token,然后用大模型并行验证这 K 个 token,接受其中前缀正确的部分。可以将推理速度提升 2-3 倍。

7.2 量化:INT4/INT8 推理

2026 年,几乎所有生产级 LLM 推理都使用 INT4 或 INT8 量化(GPTQ、AWQ、GPT-OSS 等方案)。原理:把 FP16 的权重压缩到 4-bit 整数,推理时反量化为 FP16 进行计算。

量化带来的精度损失通常很小(尤其是 INT8),但显存占用可以减少 50-75%。

7.3 多模态:LLM 的下一步

2026 年的前沿模型(GPT-4o、Claude Opus 4、Gemini 2.5)都是多模态模型——能同时理解图像、音频和文本。核心架构:用视觉 encoder(比如 ViT)把图像转为 embedding,然后把这些 embedding 当成"特殊 token"拼接到文本 token 序列里,一起送进 Transformer。


总结:从零实现 LLM,你学到了什么?

CS336 不是一门轻松的课程。你需要:

  • Tokenization(BPE 从零实现)
  • Transformer(包括因果注意力、位置编码、Pre-LN)
  • AdamW(理解 bias correction 和 weight decay)
  • Triton 实现 FlashAttention-2
  • 实现 Gradient CheckpointingFSDP
  • 处理 Common Crawl(清洗、去重、质量过滤)
  • 实现 SFT 和 GRPO(对齐与推理 RL)

但完成它之后,你对 LLM 的理解将从"会调 API"升级到"知道每个字节在干什么"。这种深度理解,才是你在 AI 时代的核心竞争力。


参考资源


本文基于 Stanford CS336 Spring 2026 课程资料撰写,所有代码示例为教学目的的简化实现,生产环境请参考课程官方代码和主流开源实现(litgpt、nanochat、llama2.c)。

推荐文章

Manticore Search:高性能的搜索引擎
2024-11-19 03:43:32 +0800 CST
前端项目中图片的使用规范
2024-11-19 09:30:04 +0800 CST
PHP 允许跨域的终极解决办法
2024-11-19 08:12:52 +0800 CST
gin整合go-assets进行打包模版文件
2024-11-18 09:48:51 +0800 CST
windows下mysql使用source导入数据
2024-11-17 05:03:50 +0800 CST
PHP设计模式:单例模式
2024-11-18 18:31:43 +0800 CST
程序员茄子在线接单