经过一年的开发,FlashAttention-4 正式发布。
作为深度学习领域一项关键的底层优化技术,FlashAttention 迎来了重大版本更新。其核心作者、普林斯顿大学助理教授 Tri Dao 表示,在 Blackwell GPU 上,注意力机制的执行速度现已几乎与矩阵乘法相当,尽管两者的瓶颈截然不同。

当前,Tensor Core 的速度已变得极快,以至于注意力前向传播的瓶颈转移到了指数运算单元,而注意力反向传播的瓶颈则在于共享内存带宽。
新算法针对这些瓶颈进行了重新设计,引入了多项优化机制,包括使用多项式进行指数近似、新的在线 softmax 算法以减少 90% 的重缩放操作,以及利用 2CTA MMA 指令让两个线程块共享操作数以降低共享内存流量等。

- 论文地址:https://github.com/Dao-AILab/flash-attention/blob/main/assets/fa4_paper.pdf
- 代码链接:https://github.com/Dao-AILab/flash-attention
接下来,我们将详细解析 FlashAttention-4 的技术细节。
硬件趋势:非对称扩展
长期以来,作为 Transformer 架构的核心层,注意力机制一直是大型语言模型和长上下文应用的主要性能瓶颈。此前,FlashAttention-3 通过异步执行和 warp 专门化等技术,针对 Hopper GPU(如 H100)架构进行了优化。
然而,AI 行业正迅速转向部署 Blackwell 架构系统(如 B200 和 GB200)。现代加速器如 Blackwell GPU 延续了一种趋势:硬件的非对称扩展。在这种趋势下,张量核心的吞吐量增长速度远快于其他硬件资源,例如共享内存带宽、用于指数等超越函数运算的特殊函数单元,以及通用整数与浮点 ALU。
举例来说,从 Hopper H100 到 Blackwell B200,BF16 张量核心的吞吐量提升了 2.25 倍,但 SFU 数量和共享内存带宽基本保持不变。这种扩展的不对称性对注意力这类复杂内核的优化产生了深远影响。
具体而言,注意力机制的核心包含两个通用矩阵乘法运算,中间夹着 softmax 操作。但在实际应用中,注意力还涉及大量辅助工作,如数据搬运、同步、布局转换、元素级运算、调度和掩码处理等。
传统观点认为,注意力的性能完全由 GEMM 的速度决定。然而,对 B200 的“速度与馈送”分析显示,主要瓶颈并非张量核心,而是:
1. 前向传播中用于 Softmax 指数运算的 SFU 单元;
2. 反向传播中受共享内存带宽限制的共享内存流量。
为此,研究团队推出了 FlashAttention-4,这是一种算法与内核的协同设计方案。其核心目标是通过最大化矩阵乘法与其他瓶颈资源之间的重叠,在 B200(BF16)上实现高达 1605 TFLOPs/s 的性能(利用率为 71%),相比 cuDNN 9.13 快 1.3 倍,相比 Triton 快 2.7 倍。
协同设计的核心思路如下:
* 新型流水线:为前向和反向传播分别设计了新的软件流水线,利用 Blackwell 的全异步 MMA 和更大的分块尺寸,最大化张量核心计算、softmax 计算与内存操作之间的重叠执行。
* 前向传播优化:在 FMA 单元上通过多项式近似实现指数函数的软件仿真,以提升指数计算吞吐量;同时引入条件式 softmax 重缩放,跳过不必要的重缩放操作,从而缓解 SFU 瓶颈。
* 反向传播优化:利用张量内存存储中间结果,以缓解共享内存流量压力;结合 Blackwell 新增的 2-CTA MMA 模式,进一步降低共享内存访问,并将原子归约操作次数减少一半;此外,还支持确定性执行模式,以实现可复现的训练。
* 调度优化:引入新的分块调度器,以解决因果掩码和变长序列导致的负载不均衡问题。
Blackwell 的新硬件特性
- 张量内存:在 B200 上,每个流式多处理器都配备了 256 KB 的 TMEM,与张量核心直接连接,用于存储 warp 同步的中间结果。
- 完全异步的第五代张量核心:
tcgen05.mma指令支持异步执行,并将累加结果存储在 TMEM 中。对于 BF16 和 FP16,单个 CTA 可使用的最大 UMMA 分块尺寸为 128×256×16,约为 Hopper 架构中最大 WGMMA 原子块的 2 倍。UMMA 由单个线程发起,减轻了寄存器压力,使得在不出现寄存器溢出的情况下,更容易使用更大的分块和更深的流水线。这同时也使 warp 专门化更具可行性,部分 warp 负责搬运数据,另一些 warp 负责发起 MMA,从而实现计算与内存访问的重叠。tcgen05.mma还可以直接从 TMEM 中读取操作数 A。 - 2-CTA MMA:Blackwell 支持在同一集群中由一对 CTA 共同执行一个 UMMA 运算,并跨越两个 CTA 的 TMEM。由 leader CTA 中的一个线程发起 MMA,但执行期间两个 CTA 都必须保持活跃。通过在这对 CTA 之间拆分 M 和 N 维度,可以将 MMA 的分块尺寸扩展到 256×256×16,从而减少冗余数据传输并降低每个 CTA 的资源占用。在一个内核中,CTA 组的大小(1 或 2)在 TMEM 操作和张量核心运算之间必须保持一致。

编程语言与框架:CuTe-DSL
FlashAttention-4 完全使用 CuTe-DSL 实现,这是 CUTLASS 提供的 Python 内核领域特定语言。内核代码使用 Python 编写,随后 DSL 会将其降级为 PTX,再由 CUDA 工具链编译为 GPU 机器代码。
该编程模型在抽象层面与 CuTe / CUTLASS 保持一致,同时提供了 PTX 级别的底层控制接口。与使用 C++ 模板相比,这种方式可以将编译时间缩短约 20–30 倍。Tri Dao 对此表示兴奋,因为这使得安装和“编译”过程仅需几秒钟,而非过去的几分钟甚至几小时。

Attention 性能基准测试
研究团队展示了 FlashAttention-4 在 B200(BF16)上的性能结果,并将其与 FlashAttention-2 以及 Triton、Gluon 和 cuDNN 的实现进行了对比。结果显示:
- 前向传播性能:在 Blackwell GPU 上,FlashAttention-4 的前向传播速度比 cuDNN 9.13 快 1.1–1.3 倍,比 Triton 实现快 2.1–2.7 倍。
- 反向传播性能:在处理长序列时,FlashAttention-4 的反向传播性能持续超越其他基准方法。




FlashAttention-4 的发布引发了广泛关注。PyTorch 团队宣布,其 FlexAttention 功能现已支持 FlashAttention-4 后端。

FlexAttention 长期以来帮助研究人员快速原型化各种自定义注意力机制变体,已被上千个代码库采用并获数十篇论文引用,但用户常面临性能瓶颈。随着 FlashAttention-4 的推出,PyTorch 团队在 Hopper 和 Blackwell GPU 上为 FlexAttention 集成了该后端。PyTorch 现在能够自动生成 CuTeDSL 的 score/mask 修改代码,并通过 JIT 编译为自定义注意力变体实例化 FlashAttention-4。
测试结果显示,在算力受限的工作负载下,相比 Triton 实现,此举仍能带来 1.2 倍到 3.2 倍的性能提升,使研究人员无需在“灵活性”与“高性能”之间做出妥协。
有评论指出,FlashAttention-4 是一个里程碑。在 Blackwell 架构上,注意力计算的速度已接近矩阵乘法,这意味着计算瓶颈将完全转移到内存带宽与通信上。其约 1600 TFLOPs 的注意力性能相比 FlashAttention-3 提升了 2–3 倍,这将直接惠及所有前沿大模型,因为它意味着更长的有效上下文窗口、更低的推理成本以及更强的规模化推理能力。

关注“鲸栖”小程序,掌握最新AI资讯
本文来自网络搜集,不代表鲸林向海立场,如有侵权,联系删除。转载请注明出处:https://www.itsolotime.com/archives/24598
