论文信息
- 论文标题: MSA: Memory Sparse Attention for Efficient End-to-End Memory Model Scaling to 100M Tokens
- 作者: Yu Chen, Runkai Chen, Sheng Yi 等 (EverMind, Shanda Group, Peking University)
- GitHub: https://github.com/EverMind-AI/MSA
摘要
长期记忆是人类智能的基石。使AI能够处理终身规模的信息,达到数亿token的处理能力,一直是该领域的长期追求。由于全注意力架构的约束,大语言模型(LLM)的有效上下文长度通常限制在1M token以内。
本文提出Memory Sparse Attention (MSA),一种端到端可训练、高效且可大规模扩展的记忆模型框架。通过可扩展稀疏注意力架构和文档级RoPE等核心创新,MSA在训练和推理中均实现线性复杂度,同时保持卓越的精度稳定性——从16K扩展到100M tokens时,性能衰减小于9%。
此外,结合KV缓存压缩和Memory Parallel技术,MSA仅需2×A800 GPU即可完成100M tokens的推理。通过Memory Interleaving机制,MSA还能有效支持跨分散记忆段的多跳推理。
核心挑战
现有方法的局限性
| 方法类型 | 代表方案 | 优势 | 劣势 |
|---|---|---|---|
| 参数化记忆 | LoRA, Continual Pre-training | 语义整合度高 | 容量受限,遗忘严重 |
| 外部存储 | RAG, MemAgent | 容量可扩展 | 非端到端,精度受限 |
| 潜在状态 | Sparse Attention, Linear Attention | 高精度 | O(L²)复杂度或容量受限 |
现有方法面临两大根本限制:
- 高精度记忆的有限可扩展性:高精度方法受限于固定上下文或状态容量
- 缺乏端到端可训练性:无法同时保持架构兼容性、高精度和对灾难性遗忘的鲁棒性
人类记忆规模估算
认知科学研究估计人类记忆的功能信息容量约为10⁹ bits。按每token 3-5 bits的有效语义密度计算,这相当于约200-300 million tokens的终身容量。要真正弥合与人类规模记忆的差距,模型必须能够处理扩展到数亿token的上下文。
方法:Memory Sparse Attention
稀疏注意力机制
MSA将标准密集自注意力替换为基于文档的检索稀疏注意力机制。核心设计:
- Router K Projector:为每个文档生成专门的路由键矩阵
- 分块压缩:将每个文档分割为多个固定长度块,通过mean pooling压缩状态
- Top-k检索:基于余弦相似度计算查询与记忆的相关性分数,选择Top-k最相关文档
- 选择性应用:仅在后半层应用MSA路由,保留层次表示对齐
文档级RoPE
为确保跨不同记忆规模的稳健泛化,MSA为每个文档独立应用RoPE。这解决了训练-推理上下文不匹配问题:模型通常在有限文档数量上训练(train-on-short),但推理时需操作大规模文档库(infer-on-long)。
通过为每个文档分配独立的起始位置ID(从0开始),MSA将位置语义与文档总数解耦,使模型能够有效外推到训练范围之外。
训练策略
连续预训练阶段(158.95B tokens):
- 标准生成损失 $\mathcal{L}{LLM}$ + 辅助路由损失 $\mathcal{L}{aux}$
- $\mathcal{L}_{aux}$:监督对比损失,分离相关与无关文档块
两阶段后训练:
- 阶段一:8K上下文长度的SFT,建立基础指令遵循能力
- 阶段二:扩展到64K上下文长度,提高长度外推鲁棒性
三阶段推理流程
- 全局记忆编码(离线):预计算所有文档的K、V、K^R,压缩并缓存
- 路由与上下文组装(在线):计算查询路由向量,匹配Top-k文档
- 稀疏生成(在线):在组装好的稀疏上下文上自回归生成
Memory Interleaving
为处理需要多跳推理的复杂查询,MSA引入Memory Interleaving机制。推理过程在”生成检索”和”上下文扩展”之间交替进行:
- 模型基于查询自回归生成文档ID序列
- 系统获取对应原文并追加到原始查询
- 重复直到模型判断累积文档足够
Memory Parallel
为在标准单节点上支持极长上下文推理,MSA设计了专门的推理引擎:
分层内存存储:
- GPU驻留路由键(K̄^R):分布在多卡VRAM中,确保低延迟检索
- CPU卸载内容KV(K̄, V̄):存储在主机DRAM,按需异步获取
分布式评分:查询hidden states广播到所有GPU,每卡独立计算相似度分数。
实验结果
QA任务
MSA在9个问答基准上显著优于基于相同backbone的RAG系统:
| 基准 | MSA vs 最佳RAG基线 | 相对提升 |
|---|---|---|
| 平均分数 | 3.760 | +16.0% |
对比SOTA RAG系统(KaLMv2 + Qwen3-235B),MSA在9个数据集中的4个取得最佳性能,平均提升7.2%。
NIAH任务
在RULER “大海捞针”基准上,MSA展现出从32K到1M tokens的卓越稳定性:
| 模型 | 32K | 256K | 1M |
|---|---|---|---|
| Qwen3-4B-Instruct | 99% | 48% | 25% |
| Qwen3-Next-80B-A3B | 100% | 97% | 81% |
| RL-MemoryAgent-14B | 98% | 95% | 93% |
| MSA (Ours) | 99% | 98% | 95% |
100M Token可扩展性
MSA在MS MARCO数据集上展示前所未有的可扩展性:从16K扩展到100M tokens,性能衰减小于9%。
结论
本文提出Memory Sparse Attention (MSA),一种端到端可训练的稀疏注意力架构,实现了接近人类记忆规模的100M tokens处理能力:
- 核心创新:可扩展稀疏注意力 + 文档级RoPE + KV缓存压缩 + Memory Interleaving
- 关键突破:从16K到100M tokens仅<9%性能衰减
- 实用部署:仅需2×A800 GPU即可完成100M tokens推理
- 性能领先:在长上下文QA和NIAH基准上显著超越前沿LLM和SOTA RAG系统
MSA证明了通过将记忆容量与推理解耦,可以为通用模型提供内在的终身规模记忆,标志着向人类认知尺度记忆迈出的重要一步。