MELT架构深度解析:高通如何让AI"深度思考"不再耗尽内存——循环Transformer的内存革命
当AI模型开始"反复推敲"时,内存消耗会随着思考轮数线性增长。高通AI研究院提出的MELT架构,在不牺牲推理能力的前提下,让内存消耗保持恒定。本文将从架构原理、数学推导、训练策略到性能实测,完整剖析这一突破性工作。
一、问题背景:当AI越想越费内存
1.1 从"一次性回答"到"反复推敲"
手机导航会在出发前规划好整条路线,而一个真正聪明的向导则会边走边思考——遇到路障时随机应变,反复斟酌哪条弯路最省时。现在的AI大模型正在经历类似的转变:从"一次性给出答案"走向"反复推敲、深度思考"。
这种转变的代表性工作是循环Transformer(Looped Transformer),也被称为Ouro或LoopLM。其核心思想是:把同一个思维状态反复交给同一批神经网络层处理,每轮处理后再传给下一层,直到"质量满意"为止。
研究者发现,这种"多想几遍"的策略能让小模型打败大模型——Ouro能够媲美甚至超越两倍参数量的普通模型。换句话说,通过增加思考深度,可以弥补参数规模的不足。
1.2 循环思考的隐藏代价
然而,这种策略有一个致命缺陷。
AI在处理文字时,需要把之前看过的内容存储在KV缓存(Key-Value Cache)结构中,方便后续回头参考。你可以把它理解为AI的"便条纸"——每次阅读一段文字,就在便条纸上记一条笔记。
在普通模型里,每个词语只需要记一条笔记。但在循环模型里,同一个词语每循环一次就要记一条新笔记:
普通模型:词语数量 = N → KV缓存条目 = N
循环模型:词语数量 = N,循环次数 = L → KV缓存条目 = N × L
循环10次就有10条笔记,循环20次就有20条笔记。这导致内存消耗随着思考轮数线性增长。
具体数据有多惊人?生成32000个词的内容时:
- Ouro(循环模型):需要约 28GB 内存
- 普通模型(同等规模):只需要约 7GB 内存
这个4倍的差距,在实际部署时往往是致命的——尤其是当目标设备是手机或边缘计算节点时。
1.3 高通的解决方案:MELT
高通AI研究院的研究团队提出了记忆高效循环Transformer(Memory-Efficient Looped Transformer,简称MELT),专门解决这个问题。
核心突破:在不牺牲推理能力的前提下,让内存消耗保持恒定,无论模型思考多少轮。
这项研究发表于2026年5月,论文编号为arXiv:2605.07721。
二、核心思路:便条纸不用越记越多,更新就够了
2.1 从"堆纸"到"擦写"
MELT的核心思想可以用一个生活场景来理解:
普通做法(Ouro):你是一名侦探,正在审查一份证词。每次重读后都拿一张新便条纸,把最新的理解写下来。结果桌上堆满了纸——循环多少次,就有多少张纸。
MELT做法:只用一张便条纸。每次重新理解之后,把上面的内容擦掉一部分,写上更新的认识。这张便条纸始终只有一张,无论你重新思考了多少遍。
2.2 潜在状态与门控机制
具体来说,MELT为每一层神经网络维护一个潜在状态(latent state),可以理解为那张会被不断更新的便条纸。
每次循环时,模型不是新添一条笔记,而是通过一个可学习的门控机制(gating mechanism)来决定:
- 旧的理解保留多少
- 新产生的认识写入多少
这个门控机制就像一个滑块:
- 完全向左 → 完全保留旧认识
- 完全向右 → 用新认识完全替换旧认识
- 停在中间 → 按比例混合
2.3 数学表达
更新规则写成公式:
$$h_t = g_t \odot h_{t-1} + (1 - g_t) \odot x_t$$
其中:
- $h_t$ 是第 $t$ 时刻的潜在状态
- $g_t$ 是门控值,由当前输入和上一状态共同计算得出
- $x_t$ 是当前输入状态
- $\odot$ 表示元素级乘法
关键设计:每个维度都有自己独立的门控值(元素级门控),而不是所有维度共用一个值。这种精细化设计让模型能对不同类型的信息采取不同的保留策略。
门控值的计算:
$$g_t = \sigma(W_g \cdot [x_t; h_{t-1}] + b_g)$$
其中 $\sigma$ 是sigmoid函数,将输出压缩到 $(0, 1)$ 区间。
2.4 从潜在状态到KV缓存
更新后的潜在状态,会通过两个学习得到的投影矩阵(分别叫做 $W_K$ 和 $W_V$)转化为注意力机制所需的"键"和"值":
$$K_t = W_K \cdot h_t$$
$$V_t = W_V \cdot h_t$$
然后替换当前这个词的缓存位置——而不是追加新的。
2.5 内存效益的数学证明
Ouro的内存复杂度:
$$\text{KV Cache Size} = O(N \times L)$$
其中 $N$ 是词数,$L$ 是循环次数。
MELT的内存复杂度:
$$\text{KV Cache Size} = O(N)$$
与循环次数 $L$ 无关!
这意味着,无论模型循环思考4次还是40次,内存占用都一样。
三、为什么简单方案行不通
3.1 四种朴素替代方案
在设计MELT之前,高通团队评估了更简单的替代方案:
- 只保留最后一轮缓存:直接丢弃之前所有轮次的信息
- 取所有轮次平均值:$K_{final} = \frac{1}{L}\sum_{i=1}^{L} K_i$
- 指数移动平均(EMA):$K_t = \alpha K_{t-1} + (1-\alpha) x_t$,固定 $\alpha$
- 只保留第一轮缓存:后续轮次不更新缓存
3.2 实验结果:全部失效
研究者把Ouro模型直接套用这些策略,在多个推理基准测试上测试。
结果令人咋舌:这四种策略的得分全部为零。
不是差一点,是完全失效。
3.3 失效原因:累积漂移
研究团队分析发现,这种失效不是随机的,而是一种累积漂移现象:
- 在靠近提示词的位置,缓存替换带来的误差还不明显
- 但随着生成的文字越来越长、越来越远离原始提示,错误会不断叠加
典型失败案例:模型开始还在认真推导数学题,后来思路越来越混乱,最后输出的文字完全是无意义的重复。
这就好比侦探走出案发现场太远之后,完全忘记了最初的线索,开始胡乱猜测。
3.4 为什么门控机制有效
简单的共享或复用缓存并不可行,必须通过训练让模型学会如何在单张便条纸上有效地整合信息。
这正是MELT门控机制存在的价值:
- 门控值是可学习的,不是固定的规则
- 每个维度独立学习保留策略
- 通过反向传播,模型自动发现哪些信息需要长期保留、哪些可以覆盖
四、训练策略:让老模型穿上新衣
4.1 架构迁移的挑战
MELT的架构改变是相当剧烈的:
- 原本往缓存里"追加"内容,现在变成"覆盖更新"
- 引入了全新的门控参数
如果直接用这套新架构从零开始训练,代价极大。
高通团队的思路:从已经训练好的Ouro模型出发,通过精心设计的两阶段过渡流程,让MELT"继承"Ouro的知识,同时适应新的架构。
4.2 第一阶段:分块训练与插值过渡
问题:顺序依赖导致无法并行
MELT有一个让训练变复杂的特性:因为每个词的KV缓存依赖于前一个词处理完成后的潜在状态,所以无法像普通模型那样对整个序列并行计算。
这就像一条流水线,后面的工人必须等前面的工人完工才能开始。
解决:分块训练(Chunk-wise Training)
把一个长序列切成若干段(每段500个词):
- 同一段内部:并行计算
- 不同段之间:按顺序传递状态
块大小的权衡:
- 块越小 → 训练越接近真实推理,但越慢
- 块越大 → 训练越快,但与推理时的行为差异越大
500个词的块大小是实验中找到的较好平衡点。
插值过渡机制
即使有了分块训练,另一个问题仍然存在:MELT的架构改变太剧烈,从Ouro的权重出发直接训练会导致模型一开始表现得像一个完全没有训练过的网络。
解决方案:插值过渡(Interpolated Transition)
在训练早期,同时计算两套KV缓存:
- 一套按照Ouro的原始方式
- 另一套按照MELT的新方式
实际使用的是两套缓存的加权混合:
$$\text{KV}{actual} = (1 - \alpha) \cdot \text{KV}{Ouro} + \alpha \cdot \text{KV}_{MELT}$$
混合系数 $\alpha$ 从0线性增长到1,耗时500步:
- 最开始:完全使用Ouro的缓存(模型行为等同于Ouro)
- 随着训练推进:MELT缓存的比重越来越大
- 最终:完全切换到MELT的行为
知识蒸馏
在第一阶段,还额外加入了知识蒸馏(Knowledge Distillation):
- 以原始Ouro为"教师"
- 让MELT的每一个循环步骤的输出都去学习教师的输出
这种密集的监督信号帮助模型收敛得更快、更稳定。
4.3 第二阶段:注意力对齐蒸馏
当 $\alpha$ 达到1之后,MELT已经完全在自己的架构下运行了。
但实验发现,如果就这样不加约束地继续训练,模型会逐渐"忘记"Ouro的推理风格,性能开始下滑。
解决方案:注意力对齐蒸馏(Attention-Aligned Distillation)
- 把Ouro模型完全冻结,作为固定的教师
- MELT在每一层、每一个循环步骤的注意力机制之后产生的"中间表示",必须尽量贴近教师在同等位置产生的中间表示
对齐损失:
$$\mathcal{L}{align} = \beta \sum{l,i} | \text{Attn}^{MELT}{l,i} - \text{Attn}^{Ouro}{l,i} |^2$$
其中 $\beta = 0.1$ 控制强度,$l$ 是层索引,$i$ 是循环步骤索引。
4.4 训练资源统计
| 项目 | 数值 |
|---|---|
| 第一阶段 | 500步插值 + 知识蒸馏 |
| 第二阶段 | 300步注意力对齐 |
| 总数据量 | 约2.56亿词 |
| GPU配置 | 8块H100(每块80GB显存) |
| 总时长 | 130小时 |
| 总GPU小时 | 1040 |
五、实验结果:内存与性能的全面对比
5.1 测试模型
研究团队将MELT-1.6B与多个竞争对手进行了系统性对比:
| 模型 | 参数量 | 类型 |
|---|---|---|
| Ouro-1.4B | 1.4B | 循环Transformer(MELT的前身) |
| Qwen3-1.7B | 1.7B | 普通非循环 |
| Gemma4-E2B | 2B | 普通非循环 |
| Qwen3.5-2B | 2B | 普通非循环 |
| DeepSeek-R1-1.5B | 1.5B | 普通非循环 |
5.2 评测基准
覆盖10个基准测试:
数学推理:AIME24、AIME25、AIME26、AMC23、MATH-500、OlympiadBench
通用推理与代码:GPQA、HLE、MMLU-Red、HumanEval
所有评测使用最多32000个词的完成长度,温度参数1.0,top-p为0.7。
5.3 内存对比
从vLLM工具中提取的精确数字:
| 模型 | 每词KV缓存 | 32000词KV缓存 | 模型权重 | 总内存 |
|---|---|---|---|---|
| Ouro-1.4B | 0.786 MB | ~25 GB | ~2.9 GB | ~28 GB |
| MELT-1.6B | 0.197 MB | ~6.3 GB | ~2.9 GB | ~9.5 GB |
| Qwen3-1.7B | - | - | - | ~7.1 GB |
结论:
- MELT比Ouro减少约 2.95倍 内存
- MELT比Qwen3多约2.5GB(因为Qwen使用了MQA技术,MELT尚未采用)
5.4 性能对比
数学推理综合平均 pass@1
| 模型 | 得分 |
|---|---|
| MELT-1.6B | 59.9 |
| Ouro-1.4B | 62.3 |
| Qwen3-1.7B | 56.9 |
| Gemma4-E2B | 56.0 |
| DeepSeek-R1-1.5B | 46.9 |
| Qwen3.5-2B | 40.7 |
通用推理综合平均 pass@1
| 模型 | 得分 |
|---|---|
| MELT-1.6B | 50.1 |
| Ouro-1.4B | 48.6 |
| Qwen3-1.7B | 45.9 |
| Gemma4-E2B | 45.5 |
具体测试详情
| 测试 | Ouro | MELT | 差距 |
|---|---|---|---|
| AIME24 pass@1 | 50.2% | 46.7% | -3.5% |
| AIME25 pass@1 | 36.7% | 33.3% | -3.4% |
| AIME26 pass@1 | 44.0% | 41.0% | -3.0% |
| HumanEval | 76.8% | 81.7% | +4.9% |
关键发现:
- MELT与Ouro相比有一定差距,但差距并不悬殊(约3分)
- 在代码测试HumanEval上,MELT反而超过Ouro
- 与普通非循环模型相比,MELT在相近内存下全面领先
六、消融实验:每个组件都不可或缺
6.1 门控机制的必要性
研究团队专门做了一组消融实验,验证元素级门控机制是否真的必要:
| 方案 | AIME24 | AIME25 | AMC23 | MATH-500 |
|---|---|---|---|---|
| 完整MELT | 44.8 | 32.9 | 77.7 | 92.8 |
| 均值融合 | 29.0 | 23.3 | 68.8 | 83.2 |
| EMA(固定α=0.2) | 30.2 | 21.5 | 68.6 | 84.6 |
| 只用最后一轮 | 33.7 | 24.0 | 69.7 | 84.0 |
| 标量门控 | 34.4 | 23.1 | 66.9 | 85.6 |
结论:所有简单替代方案都显著落后于完整的元素级门控机制,差距在10到16个百分点之间。
这表明:让每个维度独立学习保留比例是关键——不同类型的信息需要以不同的方式随时间演化,一刀切的规则无法满足这种需求。
6.2 训练流程的必要性
研究团队对整个训练流程做了逐步拆解的消融实验:
| 配置 | AIME24 pass@1 | AIME24 pass@10 | MATH-500 |
|---|---|---|---|
| 完整MELT | 46.7 | 79.9 | 93.4 |
| 移除注意力对齐蒸馏 | 44.8 | 78.1 | 92.8 |
| 移除插值过渡 | 35.4 | 63.7 | 86.6 |
| 移除知识蒸馏(改用SFT) | 35.8 | 63.9 | 85.2 |
| 移除分块训练(完全并行) | 0 | 0 | 0 |
结论:这几个组件不是锦上添花,而是缺一不可:
- 分块训练:运行的基础
- 知识蒸馏:收敛的保障
- 插值过渡:稳定的关键
- 注意力对齐蒸馏:性能的最后一公里
七、局限与未来方向
7.1 固定循环次数
目前MELT(和Ouro一样)在推理时使用固定的4次循环,不管问题是"1加1等于几"还是"证明黎曼猜想",消耗的计算量是一样的。
理想情况:简单的问题少思考几轮,复杂的问题多思考几轮。
好消息:MELT的常数内存设计实际上为动态调整循环深度提供了更好的基础——因为内存不会随循环数增长,增加循环次数不会带来额外的内存代价。
7.2 缺少MQA支持
MQA(多查询注意力)是一种让不同注意力头共享键值数据的技术,可以进一步压缩内存。
Qwen3等模型已经使用了这项技术,这也是为什么Qwen3的KV缓存比MELT还小的主要原因。
未来方向:把MQA引入MELT,有望进一步缩小与普通模型之间的内存差距。
7.3 训练并行性受限
因为MELT的KV缓存依赖于前一个词的处理结果,无法像普通Transformer那样对整个序列完全并行处理,这让训练速度慢于普通模型。
分块训练是当前的折中方案,但开发更高效的并行化策略仍是未来需要攻克的工程难题。
7.4 Ouro复现困难
研究团队在复现Ouro原论文的性能时遇到了困难,发现论文中的部分实现细节描述不够具体,导致实验结果与原论文有出入。
此外,Ouro声称的"早退出"机制(让模型在认为不需要更多思考时提前结束循环)在实际代码中并未真正工作——默认配置实际上是禁用早退出的。
MELT的优势:常数内存设计恰好不受这个限制,为真正的早退出提供了可能。
八、代码实现示例
8.1 门控机制的核心实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class MELTGating(nn.Module):
"""MELT的门控机制实现"""
def __init__(self, hidden_dim: int):
super().__init__()
self.hidden_dim = hidden_dim
# 门控参数:根据当前输入和上一状态计算门控值
self.gate_proj = nn.Linear(hidden_dim * 2, hidden_dim)
# KV投影矩阵
self.W_K = nn.Linear(hidden_dim, hidden_dim)
self.W_V = nn.Linear(hidden_dim, hidden_dim)
def forward(
self,
x_t: torch.Tensor, # 当前输入状态 [batch, seq, dim]
h_prev: torch.Tensor # 上一时刻潜在状态 [batch, seq, dim]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
h_t: 更新后的潜在状态
K_t: 键缓存
V_t: 值缓存
"""
# 拼接当前输入和上一状态
concat = torch.cat([x_t, h_prev], dim=-1)
# 计算门控值(元素级,每个维度独立)
g_t = torch.sigmoid(self.gate_proj(concat)) # [batch, seq, dim]
# 更新潜在状态:g * h_prev + (1-g) * x
h_t = g_t * h_prev + (1 - g_t) * x_t
# 从潜在状态计算KV缓存
K_t = self.W_K(h_t)
V_t = self.W_V(h_t)
return h_t, K_t, V_t
8.2 MELT层的完整实现
class MELTLayer(nn.Module):
"""MELT的完整层实现"""
def __init__(
self,
hidden_dim: int,
num_heads: int = 8,
ffn_dim: int = None,
dropout: float = 0.1
):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
ffn_dim = ffn_dim or hidden_dim * 4
# 注意力组件
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
self.out_proj = nn.Linear(hidden_dim, hidden_dim)
# FFN
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, ffn_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(ffn_dim, hidden_dim),
nn.Dropout(dropout)
)
# MELT特有的门控机制
self.gating = MELTGating(hidden_dim)
# Layer Norm
self.ln1 = nn.LayerNorm(hidden_dim)
self.ln2 = nn.LayerNorm(hidden_dim)
def attention(
self,
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
mask: torch.Tensor = None
) -> torch.Tensor:
"""多头注意力计算"""
batch, seq, _ = Q.shape
# 重塑为多头形式
Q = Q.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
# 注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
# 重塑回来
out = out.transpose(1, 2).contiguous().view(batch, seq, -1)
return self.out_proj(out)
def forward(
self,
x: torch.Tensor,
h_prev: torch.Tensor,
K_cache: torch.Tensor = None,
V_cache: torch.Tensor = None,
mask: torch.Tensor = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
单次循环的前向传播
Args:
x: 当前输入 [batch, seq, dim]
h_prev: 上一轮的潜在状态
K_cache, V_cache: 之前的KV缓存(会被更新)
Returns:
output: 层输出
h_new: 更新后的潜在状态
K_new, V_new: 更新后的KV缓存
"""
# 1. 通过门控机制更新潜在状态和KV缓存
h_new, K_new, V_new = self.gating(x, h_prev)
# 2. 如果有之前的缓存,拼接(这里MELT的关键:替换而非追加)
if K_cache is not None:
# MELT的核心:用新的替换对应位置的旧缓存
# 而不是追加(Ouro的做法)
K_full = torch.cat([K_cache[:, :-K_new.size(1), :], K_new], dim=1)
V_full = torch.cat([V_cache[:, :-V_new.size(1), :], V_new], dim=1)
else:
K_full, V_full = K_new, V_new
# 3. 计算Query
Q = self.q_proj(self.ln1(x))
# 4. 注意力
attn_out = self.attention(Q, K_full, V_full, mask)
x = x + attn_out
# 5. FFN
x = x + self.ffn(self.ln2(x))
return x, h_new, K_full, V_full
8.3 循环推理的实现
class MELTModel(nn.Module):
"""完整的MELT模型"""
def __init__(
self,
vocab_size: int,
hidden_dim: int,
num_layers: int,
num_loops: int = 4, # 循环次数
**kwargs
):
super().__init__()
self.num_loops = num_loops
self.hidden_dim = hidden_dim
# Embedding
self.embed = nn.Embedding(vocab_size, hidden_dim)
# MELT层(会被循环使用)
self.layers = nn.ModuleList([
MELTLayer(hidden_dim, **kwargs)
for _ in range(num_layers)
])
# 输出层
self.lm_head = nn.Linear(hidden_dim, vocab_size)
def forward(
self,
input_ids: torch.Tensor,
h_states: list[torch.Tensor] = None, # 每层的潜在状态
kv_caches: list[tuple[torch.Tensor, torch.Tensor]] = None
) -> tuple[torch.Tensor, list, list]:
"""
循环推理的前向传播
"""
batch, seq = input_ids.shape
# Embedding
x = self.embed(input_ids)
# 初始化潜在状态(如果是第一次)
if h_states is None:
h_states = [torch.zeros_like(x) for _ in self.layers]
# 初始化KV缓存
if kv_caches is None:
kv_caches = [None for _ in self.layers]
# 循环 num_loops 次
for loop_idx in range(self.num_loops):
new_kv_caches = []
for i, layer in enumerate(self.layers):
K_cache, V_cache = kv_caches[i] if kv_caches[i] else (None, None)
x, h_states[i], K_new, V_new = layer(
x,
h_states[i],
K_cache,
V_cache
)
new_kv_caches.append((K_new, V_new))
kv_caches = new_kv_caches
# 输出logits
logits = self.lm_head(x)
return logits, h_states, kv_caches
8.4 内存对比实验
def compare_memory_usage():
"""对比MELT和普通Transformer的内存使用"""
# 配置
batch_size = 1
seq_len = 32000
hidden_dim = 2048
num_loops = 4
# 普通Transformer的KV缓存
# 每个token存储一对K,V
standard_kv_size = seq_len * hidden_dim * 2 * 4 # float32
print(f"普通Transformer KV缓存: {standard_kv_size / 1e9:.2f} GB")
# Ouro(循环Transformer)的KV缓存
# 每个token * 循环次数 存储K,V
ouro_kv_size = seq_len * num_loops * hidden_dim * 2 * 4
print(f"Ouro KV缓存: {ouro_kv_size / 1e9:.2f} GB")
# MELT的KV缓存
# 每个token只存储一对K,V(循环次数不影响)
melt_kv_size = seq_len * hidden_dim * 2 * 4
print(f"MELT KV缓存: {melt_kv_size / 1e9:.2f} GB")
print(f"\nMELT相比Ouro节省: {(ouro_kv_size - melt_kv_size) / ouro_kv_size * 100:.1f}%")
# 运行对比
compare_memory_usage()
输出:
普通Transformer KV缓存: 0.52 GB
Ouro KV缓存: 2.10 GB
MELT KV缓存: 0.52 GB
MELT相比Ouro节省: 75.0%
九、总结与展望
9.1 核心贡献
MELT解决的是一个实实在在的工程瓶颈:让AI模型"想更多"而不用"花更多内存"。
通过把"每想一次就记一张新便条纸"改成"不断更新同一张便条纸",MELT把Ouro那套强大的循环推理能力移植到了一个与普通模型相当的内存预算之内。
关键数据:
- 内存:从28GB降至9.5GB,减少约3倍
- 性能:数学推理59.9分,超过同等规模所有普通模型
- 训练成本:1040 GPU小时,约2.56亿词数据
9.2 对普通用户的意义
当未来这类技术落地到手机、平板或边缘计算设备时,有限的内存不再是AI深度推理的硬性门槛。
更聪明的AI助手,有可能在不升级硬件的前提下出现在你的日常设备中。
9.3 对研究者的启示
MELT展示了一条从现有循环模型出发、以轻量级后处理训练实现架构升级的可行路径,避免了从零训练的巨大代价。
这套过渡方法本身就有独立的参考价值:
- 插值过渡:平滑架构切换
- 知识蒸馏:保持教师模型的知识
- 注意力对齐:防止风格漂移
9.4 未来展望
- 动态循环次数:根据问题复杂度自动调整思考深度
- MQA集成:进一步压缩内存,追赶普通模型的内存效率
- 高效并行训练:突破顺序依赖的限制
- 早退出机制:简单问题提前结束,节省计算
参考资料
- MELT论文:arXiv:2605.07721 - Memory-Efficient Looped Transformers
- Ouro论文:循环Transformer的开创性工作
- KV Cache原理:LLM推理加速的核心技术
- vLLM:用于精确测量KV缓存大小
一句话总结:MELT通过门控机制实现潜在状态的更新而非追加,让循环Transformer的内存消耗与思考轮数解耦,在保持推理能力的同时,将内存占用降低到与普通模型相当的水平。这是AI推理效率优化的重要一步。