在人工智能领域,处理长上下文序列一直是大型语言模型面临的核心挑战之一。传统的密集注意力机制虽然功能强大,但其计算复杂度随序列长度呈二次方增长,这严重限制了模型处理长文本、代码或多轮对话的能力。今年2月,月之暗面提出的MoBA(Mixture of Block Attention)机制为这一难题提供了创新解决方案。MoBA将混合专家(MoE)原理引入注意力机制,允许查询(Query)仅稀疏关注少量关键-值(Key-Value)块,理论上可大幅降低计算成本。然而,这一创新在实际应用中却遭遇了硬件实现效率低下的瓶颈。

MIT与NVIDIA的研究团队通过深入的理论分析发现,MoBA性能的核心在于路由器能否基于查询-关键相似度准确区分相关块与无关块。他们建立了一个统计模型,推导出信噪比公式,将架构参数与检索准确率形式化关联。分析揭示了两条关键改进路径:一是采用更小的块尺寸,二是在关键上应用短卷积以增强块内语义信号聚集。理论模型明确显示,较小的块尺寸能带来显著的质量提升——当块尺寸减小时,路由器需要处理的块数量增加,这迫使模型进行更精细的语义区分,从而提高注意力分配的准确性。
然而,理论优势在现有GPU实现中却转化为实际障碍。小块尺寸导致严重的内存访问碎片化:当查询需要从不同位置收集稀疏、不连续的键值块时,GPU无法进行高效的合并内存读取,大量时间浪费在从高带宽内存(HBM)中随机获取数据上。同时,块数量增加使路由器评分和Top-k选择的开销急剧膨胀——原始实现需要显式生成巨大的分数矩阵,产生不可承受的内存开销。更严重的是,每个块的工作量减少导致GPU占用率低下,大量独立内核的启动开销进一步恶化了并行度。



面对这一矛盾,研究团队提出了FlashMoBA——一种硬件感知的CUDA内核,专门为小块MoBA场景优化。FlashMoBA的核心创新在于三个深度融合的内核设计,最大限度地减少了HBM往返次数,使计算模式与GPU架构特性对齐。
首先,分块Top-K选择机制彻底重构了路由过程。原始实现中,显式生成完整分数矩阵并串行处理批次序列是主要瓶颈。FlashMoBA将其替换为高度优化的三阶段流水线:第一步,Triton内核计算键块的质心,生成紧凑的矩阵表示;第二步,受FlashAttention-2启发的分块内核直接为每个查询找到Top-k键块,完全避免将完整分数矩阵写入HBM;第三步,高效后处理将查询中心索引重新格式化为键块中心的变长布局。整个流水线在批次和注意力头间完全并行化,消除了原始性能瓶颈。





前向传播采用创新的“收集并致密化”策略处理MoBA的不规则稀疏性。内核设计区分逻辑块与物理块:逻辑块是外层循环迭代的大型连续查询块和键块,而物理块是加载到SRAM中进行矩阵乘法的小图块。内核将逻辑查询块分配给线程块,遍历所有逻辑键块,使用变长索引查找相关查询,然后将这些子集分批处理成稠密物理块。这种两级方法的关键在于,SRAM中缓存的查询数据可在逻辑键块的所有物理图块间复用,通过高效的稠密GEMM计算分摊不规则内存访问的成本。







反向传播设计同样精妙,采用三个内核序列实现。主内核在键维度上并行化,每个线程块处理一个键块,镜像前向传播的“收集并致密化”策略。遵循FlashAttention-2的内存高效原则,研究者在反向传播期间重计算注意力分数,避免存储完整注意力矩阵。虽然键和值的梯度直接写入HBM,但部分查询梯度需要跨多个键块累加,这是通过高精度全局缓冲区的原子加法高效处理的。这种设计确保反向传播在序列长度上保持线性复杂度,相对于标准注意力的二次复杂度是重大改进。考虑到反向传播通常比前向传播慢2-3倍,这种高效实现对于长序列的实际训练至关重要。


实验验证了FlashMoBA的卓越性能。从零开始预训练的模型在可控实验中显示,优化后的MoBA在性能上可与密集注意力基线匹敌。对于小块场景,FlashMoBA相比FlashAttention-2实现了最高14.7倍的加速。这一突破不仅使MoBA机制从理论创新走向实际应用,更为处理超长上下文序列开辟了新路径。随着模型规模不断扩大和序列处理需求日益增长,FlashMoBA所代表的硬件感知优化将成为未来大模型发展的关键方向。
论文地址:https://arxiv.org/pdf/2511.11571
项目地址:https://github.com/mit-han-lab/flash-moba
论文标题:OPTIMIZING MIXTURE OF BLOCK ATTENTION
— 图片补充 —












关注“鲸栖”小程序,掌握最新AI资讯
本文由鲸栖原创发布,未经许可,请勿转载。转载请注明出处:http://www.itsolotime.com/archives/6779
