IndexCache-Accelerating Sparse Attention via Cross-Layer Index Reuse

 
Category: Paper

摘要

随着长上下文智能体工作流成为大语言模型的关键应用场景,注意力效率对推理速度和部署成本变得至关重要。稀疏注意力(Sparse Attention)能够有效解决这一问题,其中 DeepSeek Sparse Attention(DSA)是一个具有代表性的生产级解决方案:其轻量级索引器(lightning indexer)为每个查询选择 top-k 个最相关的 token,将核心注意力复杂度从 O(L²) 降低到 O(Lk)。然而,索引器本身仍保持 O(L²) 复杂度,且必须在每一层独立运行,尽管相邻层产生的 top-k 选择高度相似。

本文提出 IndexCache,通过跨层索引复用来解决这一效率瓶颈。该方法将层划分为少量保留索引器的 Full 层(F)和大多数复用邻近 Full 层 top-k 索引的 Shared 层(S)。作者提出了两种互补的方法来确定和优化这种配置:

  1. 无训练 IndexCache(Training-free IndexCache):使用贪心搜索算法,通过在校准集上直接最小化语言建模损失来选择保留索引器的层,无需权重更新。
  2. 有训练 IndexCache(Training-aware IndexCache):引入多层蒸馏损失,训练每个保留的索引器针对其服务的所有层的平均注意力分布,使即使是简单的均匀交错模式也能达到与完整索引器相当的精度。

在 30B DSA 模型上的实验表明,IndexCache 可以移除 75% 的索引器计算,同时几乎不损失模型质量,在 200K 上下文长度下实现高达 1.82× 的预填充加速和 1.48× 的解码加速。这些积极结果在生产级 GLM-5 模型(744B 参数)上也得到了初步验证。

Indexer 时间占比随上下文长度变化

1. 引言

自注意力机制是现代大语言模型的核心,但其序列长度的二次复杂度已成为长上下文推理的基本瓶颈。随着 LLMs 越来越多地部署在需要扩展上下文的场景中(如长链式思考推理、多步智能体工作流和基于检索增强生成的网络规模数据处理),在不牺牲模型质量的前提下降低注意力成本已成为关键研究问题。

稀疏注意力提供了一种原则性的解决方案:每个查询不再关注所有前面的 token,而是只选择最相关的子集。在近期方法中,DeepSeek Sparse Attention(DSA)作为一个生产级可训练的稀疏注意力机制脱颖而出。DSA 引入了轻量级索引器(lightning indexer),对所有前面的 token 进行评分并选择 top-k 个最高分的 token 用于后续的核心注意力。这将每层核心注意力从 O(L²) 减少到 O(Lk),同时通过持续预训练保持模型质量。

然而,索引器本身仍在每一层以 O(L²) 运行:虽然每 FLOP 比主要注意力计算便宜,但跨 N 层的总成本是 O(NL²),随上下文长度二次增长,在总注意力预算中占据不可忽视的比例。如上图所示,对 30B DSA 模型的性能分析表明,索引器在总延迟中的份额随上下文长度急剧上升,尤其是在预填充阶段。

一个关键观察是相邻层的 top-k 选择高度相似——这是全注意力模型中更广泛的跨层 token 选择稳定性的体现。虽然之前的方法利用这种稳定性复用全注意力 anchor 层的索引,但它们不直接适用于稀疏注意力,因为 DSA 中全注意力已被轻量级索引器取代。作者通过计算所有层的成对 top-k 索引重叠来验证这一点:相邻层共享 70-100% 的所选 token,热力图显示具有相互高重叠的明显层簇,表明大多数索引器计算是冗余的。

IndexCache 通过跨层索引复用消除了 DSA 中高达 75% 的索引器计算。IndexCache 将层划分为保留索引器的 F 层和从最近的前一个 F 层继承 top-k 索引的 S 层,在推理中只需添加一个条件分支。

2. 背景知识

2.1 DeepSeek Sparse Attention

DSA 将每个注意力层分解为两个阶段:选择和计算。轻量级索引器首先使用多头 ReLU 门控点积对所有前面的 token 相对于当前查询进行评分,然后选择得分最高的 top-k 位置。主注意力只在这些稀疏子集上计算,将每层核心注意力从 O(L²) 降低到 O(Lk),其中 k=2048 远小于 L。索引器设计注重效率:使用少量头、低秩投影和 FP8 算术,使其每 FLOP 比主多头潜在注意力(MLA)便宜一个数量级。

DSA 在 MLA 下实例化,并通过两阶段持续预训练引入。首先是简短的密集热身阶段,仅通过 KL 散度蒸馏针对每层的聚合全注意力分布训练索引器,同时冻结所有其他参数。然后是更长的稀疏训练阶段,激活 top-k 选择并联合优化整个模型,索引器在分离的计算图上接收蒸馏梯度。

尽管有这些效率提升,索引器本身仍以 O(L²) 运行:在每一层,它必须独立地对所有前面的 token 进行评分以确定其自己的 top-k 集合。跨 N 层模型,总索引器成本是 O(NL²),在长上下文长度下这成为注意力预算的重要部分。

2.2 跨层 Token 选择的稳定性

答案来自更广泛的实证发现:一组重要的 token 在连续的 transformer 层中非常稳定。Kascade 和 HySparse 都观察到相邻层共享其 top-k 注意力质量的绝大部分,并通过指定少数计算全注意力的 anchor 层来利用这一点,让中间层复用 anchor 的 top-k 索引。

关键的是,这两种方法都依赖于全注意力作为识别重要 token 的 oracle。在 DSA 中,全注意力已被完全消除——被轻量级索引器取代。这提出了一个尚未解决的问题:索引器的输出是否也表现出跨层稳定性?如果是这样,我们可以应用相同的共享原则来消除冗余的索引器计算,而不需要任何全注意力 oracle。

3. 方法

3.1 概述

IndexCache 通过将 N 层划分到两种角色来修改 DSA,编码为二进制模式字符串 c = c₁c₂…c_N,其中 c_ℓ ∈ {F, S}:

  • F(Full):该层保留其索引器,在所有前面的 token 上计算新的 top-k 索引,并在选择的子集上执行稀疏核心注意力,与标准 DSA 相同。
  • S(Shared):该层没有索引器。它从最近的 F 层继承索引集合,即 T^(ℓ)_t ← T^(f(ℓ))_t,其中 f(ℓ) = max{j < ℓ : c_j = F},并直接使用这些继承的索引应用稀疏核心注意力。

第一层始终为 F 以种下初始索引。在推理中,S 层只需跳过索引器前向传播,从其 F 前驱复用缓存的索引张量。

关键设计问题是如何选择模式 c。如果大多数层可以安全地共享索引,则可以消除 O(NL²) 总量索引器成本的很大一部分,而 O(NLk) 核心注意力保持不变。作者提出了两种方法:一种是无训练方法,通过贪心搜索在已训练的 DSA 模型上确定 c;另一种是有训练方法,通过多层蒸馏损失联合优化索引器参数以进行跨层共享。

3.2 无训练 IndexCache

给定预训练的 DSA 模型,目标是找到最大化 S 层数量同时最小化对模型质量影响的模式 c。首先讨论为什么最明显的方法会失败,然后提出贪心搜索算法。

3.2.1 为什么均匀交错是次优的

最简单的策略是均匀交错:保留每 r 层的一个索引器并跳过其余的(例如,r=4 时为 FSSSFSSS…)。然而,这忽略了索引器重要性在不同层之间存在显著差异这一事实。作者观察到某些层,特别是网络的早期和过渡区域,对索引器移除更为敏感。均匀交错可能移除一个关键索引器而保留一个冗余的索引器,导致明显的质量下降。

3.2.2 层选择算法

作者提出了一种贪心搜索,逐步将 F 层转换为 S 层,使用小型校准集上的语言建模损失作为下游质量的代理。

校准集:从训练数据中缓存 B 个小批次。所有候选模式都在完全相同的批次上评估,确保损失差异仅反映模式变化的影响。

搜索过程:从全 F 基线开始(c_ℓ = F 对所有 ℓ),算法进行 K 步,其中 K 是 S 层的数量(例如,K = 3N/4 以保留 1/4 的索引器)。在每一步,遍历所有当前为 F 的层(排除第一层),暂时将其翻转为 S,评估 resulting LM 损失,并提交产生最低损失的翻转。

复杂度:从全 F 到全 S 的完整搜索执行 N(N-1)/2 次前向传播。当流水线并行将模型划分为 P 个阶段时,作者通过将层分割成 P 个块来加速搜索。

贪心解的性质:虽然贪心搜索不能保证全局最优性,但作者一致观察到三个令人满意的特性:

  1. 搜索的模式在相同保留比下优于均匀交错
  2. 每步 LM 验证损失曲线显示“容易”层(前 20 步)和“关键”层(35 步之后)之间的明显分离,表明索引器重要性存在自然排序
  3. 结果在不同校准集上稳定,表明这种重要性排序是模型的固有属性而非数据伪影

贪心搜索过程

3.3 有训练 IndexCache 与多层蒸馏

无训练 IndexCache 不需要权重更新,但受到每个索引器最初只为其自己的层服务这一事实的限制。在从头训练 DSA 模型或通过持续预训练时,可以做得更好:明确训练每个保留的索引器同时为多个层服务。

从单层到多层蒸馏:在标准 DSA 训练中,每层 ℓ 的索引器通过 KL 散度针对其自己的层聚合注意力分布 p^(ℓ)_t 进行蒸馏。作者将其推广到多层目标。设层 ℓ 是保留的 F 层,层 ℓ+1, …, ℓ+m 是将复用其索引集合的后续 S 层。多层蒸馏损失为:

\[\mathcal{L}^{\mathrm{I}}_{\mathrm{multi}} = \sum_{j=0}^{m} \frac{1}{m+1}\sum_{t} D_{\mathrm{KL}}\left(\mathbf{p}^{(\ell+j)}_{t} \| \mathbf{q}^{(\ell)}_t\right)\]

直观上,这鼓励索引器预测一个对其服务的所有层共同有用的 top-k 集合,而不是仅适应单层。

梯度等价于针对平均分布的蒸馏:作者证明多层损失产生与针对单个平均目标蒸馏完全相同的梯度。定义平均目标 p̄t = Σ{j=0}^{m} (1/(m+1))p^{(\ell+j)}_t 和对应的单目标损失,命题表明两个梯度完全等价。这表明多层蒸馏不仅仅是启发式正则化——它本质上等同于将索引器蒸馏到目标层注意力分布的质心。

训练:按照标准 DSA 训练程序进行两阶段。在热身阶段,使用 L^I_multi 训练 F 层中的索引器,同时保持所有其他参数固定。在稀疏训练阶段,继续使用 L^I_multi 训练索引器,计算仅在选定的 top-k token 上的 KL 散度,并额外包含 LM 损失来训练其余参数。

4. 实验

4.1 设置

模型:DSA 模型通过两阶段训练过程获得,从 GLM-4.7-Flash(30B-A3B MoE 模型,47 层)的基座模型开始。其评估性能与原始 GLM-4.7-Flash 相当。

无训练 IndexCache:贪心模式搜索由在 SFT 数据上使用批次大小 768 和上下文长度 200K 计算的每 token 验证损失指导。

有训练 IndexCache:从 GLM-4.7-Flash 模型直接初始化,在 SFT 数据上训练为 DSA 模型,上下文长度为 200K。训练由 1,000 步密集热身阶段和 4,000 步稀疏训练阶段组成。

评估:包含五个长上下文基准:MRCR v2、GraphWalks、LongBench v2、RULER 和 AA-LCR;四个通用和推理基准:AIME 2025、GPQA-Diamond、LiveCodeBench v6 和 IFBench。

4.2 端到端推理加速

作者使用 30B DSA 模型在 SGLang 中启用 dp_attention(dp_size=8)在 NVIDIA H100 节点上进行端到端推理性能测量。将原始 DSA 基线与两种保留比的 IndexCache 进行比较:1/2(保留一半索引器层)和 1/4(保留四分之一)。

预填充:IndexCache 提供了随上下文长度增长而增加的实质性预填充加速。在 200K token 时,IndexCache(1/4)将预填充延迟从 19.5 秒减少到 10.7 秒,实现了 1.82× 的加速。即使在 10K,索引器占总计算的比例较小,也观察到 1.27× 的加速。

解码:每请求解码吞吐量改进在长上下文中显著。在 200K 时,DSA 的解码速度为 58 tok/s,而 IndexCache(1/4)达到 86 tok/s,1.48× 的加速。当 KV 缓存完全饱和时,IndexCache(1/4)在不同上下文长度下将总解码吞吐量提高 22-51%,在 200K 时最大增益(197→297 tok/s,1.51× 增加)。

相对加速对比

在更大的 GLM-5 模型(744B 参数)上也观察到类似趋势,其中 IndexCache(1/4)在超过 100K 的上下文长度下产生至少 1.3× 的预填充延迟和解码吞吐量改进。

4.3 无训练 IndexCache 结果

在 30B DSA 模型上测试无训练 IndexCache,比较三种保留比:1/2、1/4 和 1/8,每种在均匀交错基线和贪心搜索模式之间进行比较。

搜索模式弥补了长上下文任务的差距:在激进保留比下,均匀交错导致显著的长上下文退化:1/2 和 1/4 均匀交错使 Long Avg 分别下降 2.8 和 7.2 点。贪心搜索模式在很大程度上消除了这一 deficit,在 1/4 保留时将 Long Avg 恢复到 49.9,在 1/2 保留时恢复到 50.3,两者都与原始 DSA 相当。这证实了哪个索引器层被保留远比有多少被保留更重要。

长链式思考推理能力得到保留:在除 1/8 保留比的均匀交错外的所有配置中,G&R Avg 保持在基线 74.6 的 1 点以内(73.7-74.9 vs. 74.6)。值得注意的是,1/4 搜索模式在 AIME 2025(92.6 vs. 91.0)和 GPQA-Diamond(78.6 vs. 77.6)上较 DSA 有所改进,表明移除冗余索引器计算可能作为推理中的轻度正则化。

4.4 有训练 IndexCache 结果

使用多层蒸馏损失在两种保留比(1/2 和 1/4)上进行有训练 IndexCache,均使用均匀交错。

有训练 IndexCache 匹配 DSA 基线:1/2 均匀 IndexCache 实现了 51.6 的 Long Avg,超过了基线(51.0),而 G&R Avg 保持相当(74.5 vs. 74.2)。在 1/4 保留时,Long Avg 和 G&R Avg 都在基线的 0.4% 以内。这些结果确认 DSA 可以训练以适应共享模式。

训练中观察到的模式敏感性消失:与无训练 IndexCache 形成显著对比:1/2 保留时均匀交错与贪心搜索模式表现相当甚至略高(Long Avg 51.6 vs. 50.6;G&R Avg 74.5 vs. 73.6)。回想在无训练设置中,搜索模式对于在激进保留比下恢复质量至关重要。然而,当使用共享感知目标重新训练模型时,S 层学习适应继承的索引,保留的索引器同时学习产生泛化到其服务层的选择。这种联合适应完全消除了层特异性敏感性,允许简单的均匀模式匹配全索引器基线。

跨层蒸馏提供了有意义的益处:移除跨层损失使 Long Avg 从 51.6 下降到 49.8,AA-LCR 从 49.8 下降到 44.0。这确认了多层蒸馏目标实际上是有益的:通过针对其服务层注意力分布的质心训练每个索引器,它学习跨层泛化的共识 top-k,而不是过度拟合到单个层。

4.5 扩展实验

将无训练 IndexCache 应用于 GLM-5,这是一个默认使用 DSA 的 744B 参数(40B 活跃)模型。总体趋势与 30B 结果一致:均匀交错在激进保留下退化,而搜索模式恢复质量。

有趣的是,1/2 保留的均匀交错恰好保留了 Long Avg(78.1 vs. 78.4),但这可能是偶然的,其中固定交替模式恰好避开了最关键的索引器层。搜索模式提供了一致稳定的结果:在 1/2 保留时略超过基线(78.7 vs. 78.4),在 1/4 保留时保持在 0.4 点以内(78.0 vs. 78.4)。

5. 相关工作

高效注意力

减少自注意力的二次成本是中心研究主题。训练-free 稀疏方法通过固定模式、启发式驱逐策略或轻量级重要性估计在推理中引入稀疏性。然而,训练-推理不匹配可能导致长上下文设置中的错误累积。相比之下,训练稀疏方法将稀疏性直接纳入训练阶段。DSA 是本文工作的基础,它从全注意力蒸馏轻量级索引器以选择每个查询的 top-k token,将核心注意力复杂度降低到 O(Lk)。除了稀疏性,混合架构通过将昂贵的二次层与滑动窗口注意力、线性注意力或状态空间层交错来减少数量。

跨层共享

近期研究表明表示在相邻层之间表现出很强的一致性。这种结构特性常被利用来减少计算冗余和加速推理。TidalDecode、LessIsMore、OmniKV 和 DELTA 复用来自周期性 anchor 层的 top-k 索引用于稀疏解码。Kascade 通过动态规划在跨层相似性矩阵上形式化 anchor 层选择,并识别头感知重映射对保持准确性至关重要。所有这些方法都依赖 anchor 层计算精确 top-k 索引的全注意力。独立地,跨层 KV 缓存共享通过让多个层复用相同的键值张量来减少内存。然而,这些方法都需要全注意力作为 oracle,而 DSA 完全移除了这一点。

IndexCache 在两个方面不同。首先,oracle 要便宜得多,因为作者共享的是 DSA 轻量级索引器的输出,而不是全 O(L²) 注意力分数。其次,作者引入了优化共享配置的系统技术,包括无训练贪心搜索来识别最优结构布局,以及有训练多层蒸馏损失用于参数适应。虽然作者在 DSA 上实例化了 IndexCache,但核心原则延伸到任何不依赖固定稀疏模式而是涉及动态 token 选择步骤的稀疏注意力方法。

6. 结论

本文提出了 IndexCache,一种通过利用负责 token 选择的索引器的跨层冗余来加速稀疏注意力的方法。IndexCache 将层划分为少量保留索引器的 F 层和大多数复用继承的 top-k 索引的 S 层,通过单一条件分支消除了高达 75% 的 O(NL²) 总量索引器成本,而没有任何性能退化。

更广泛地说,本文的工作表明,跨层共享原则以前仅适用于全注意力作为 oracle 的地方,现在自然延伸到稀疏注意力。随着稀疏注意力成为前沿 LLMs 的默认配置(DeepSeek-V3.2、GLM-5),作者预计跨层索引复用将成为高效推理流水线的标准组成部分。