Axe异构布局编译器:跨GPU/TPU/Trainium的统一编程模型,开启机器学习编译新纪元

Axe Layout 的提出,是机器学习系统领域向统一抽象迈进的重要一步。这种统一抽象的威力,在于让开发者能够以接近手工调优代码的性能,轻松编写出高效利用最新 GPU 特性、实现通信计算重叠、并能跨 GPU 和 AI 加速器移植的复杂内核
Axe 不仅仅是一个编译器或 DSL,它更是一种思维范式。它试图弥合高层分布式编程与底层硬件微架构之间的语义鸿沟,为下一代机器学习编译器和框架奠定了坚实的基础。

在当今大模型时代,训练和部署像 GPT-4、DeepSeek-R1、Llama 这样的巨型模型,已不仅仅是算法创新的比拼,更是对底层计算系统的终极考验。模型参数动辄千亿,数据量浩如烟海,计算任务需要分布在成千上万的 GPU 或专用 AI 加速器上协同完成。

然而,如何高效地将数据和计算任务“摆放”在从设备集群到芯片内部寄存器的每一层硬件上,成为了一个极其复杂且分裂的系统难题。传统的解决方案往往“头痛医头,脚痛医脚”:
* 分布式框架(如 PyTorch DTensor、GSPMD)负责设备间的数据分片与复制;
* 设备级 DSL(如 Triton、CuTe)则专注于 GPU 线程块和寄存器级别的数据排布与计算映射。

这种割裂的抽象导致跨层优化困难,程序难以移植,更无法应对 GPU、TPU、Trainium 等异构硬件带来的挑战。

Axe异构布局编译器:跨GPU/TPU/Trainium的统一编程模型,开启机器学习编译新纪元

今天,我们解读一篇来自学术界和工业界合作者的重要论文《AXE: A Simple Unified Layout Abstraction for Machine Learning Compilers》。这项工作提出了一种名为 Axe Layout 的革命性抽象,旨在用一套统一的“语言”,描述从设备集群、GPU 内存层次到 AI 加速器片上存储的完整数据布局与计算映射,并在此基础上构建了一个强大的编译栈。

Axe异构布局编译器:跨GPU/TPU/Trainium的统一编程模型,开启机器学习编译新纪元
图 1 | Axe Layout 组成要素示意图。该图展示了 Axe Layout 的三个组成部分:分片部分 D(由多个 Iter 组成的列表)、副本部分 R(一组 Iter)、偏移量 O。一个 Iter 定义了在某个轴上的线性带步长访问(范围,步长,轴)。Axe 布局以 Iter 为核心构建逻辑索引与硬件资源的映射,D、R、O 协同实现多场景适配。D 将逻辑张量拆解到硬件轴(如 GPU 的 lane、warp),像在张量核心计算中,可把矩阵行、列维度分别映射到 lane 和 warp 轴,保证线程级数据分配合理;R 解决数据复用,如在多 warp 协作时,让不同 warp 持有相同数据块,减少通信开销;O 能避开硬件资源冲突,例如将数据部署到特定编号 warp。这种设计打破传统布局局限,为跨设备、跨内存层级映射提供统一框架,是实现硬件感知能力的关键。

实验表明,基于 Axe 生成的代码,性能可逼近手工调优的高性能内核,同时在编程效率和硬件覆盖率上实现了巨大飞跃。更多性能数据见后文

Axe异构布局编译器:跨GPU/TPU/Trainium的统一编程模型,开启机器学习编译新纪元
图 11 | Qwen3-30B MoE 层延迟对比。在不同输入 token 数量下,Axe(橙色线)的延迟始终低于 FlashInfer(蓝色线)和 SGLang(深蓝线)。MoE 层是 LLM 关键组件,计算密集且数据依赖复杂,传统编译器难优化。Axe 通过细粒度流水线优化和高效算子复用提升性能,在 MoE 层 “门控选择 – 专家计算 – 结果聚合” 流程中,部分专家计算完成后即可启动下部分,减少等待时间,且统一布局复用算子。实验显示,Axe 在所有输入 token 数量下均优于 FlashInfer 和 SGLang,短输入场景比 FlashInfer 快 1.20 – 1.36 倍,长输入场景优势更显著,因 Axe 合理分配张量到 GPU 内存层级,预加载长输入数据,实现 “计算 – 传输” 重叠。Axe 开发成本低,对 LLM 推理服务快速迭代部署意义重大。

Axe异构布局编译器:跨GPU/TPU/Trainium的统一编程模型,开启机器学习编译新纪元
图 12 | 多 GPU GEMM+Reduce-Scatter 延迟对比。该图显示,在不同问题规模下,Axe(红色)的延迟远低于 cuBLAS+NCCL(橙色)和 Triton-distributed(蓝色)。GEMM + Reduce – Scatter 是 LLM 分布式训练核心任务,传统方案将计算与通信串行执行,延迟高;Triton – distributed 因 GEMM 性能优化不足,延迟也较高。Axe 通过 “计算 – 通信融合” 策略,在 GEMM 计算中同步传输部分结果,实现细粒度重叠,大幅减少总延迟;且统一分布式张量表示,编译器自动推断数据分布规则,减少开发成本与错误风险。实验中 Axe 在所有权重形状下延迟最低,比 cuBLAS + NCCL 提速 1.20 – 1.40 倍,比 Triton – distributed 提速 1.08 – 1.32 倍。TP=8 的设置模拟中大规模 LLM 分布式训练场景,Axe 低延迟性能提升训练吞吐量,缩短周期,且开发成本低,为大规模 LLM 分布式训练提供高效方案。

问题一:统一的代价——Axe 的跨硬件抽象是否牺牲了特定架构的极致性能?

Axe 声称实现从线程到设备的“统一布局抽象”,展示了在 NVIDIA B200、Trainium 等硬件上的性能数据,但实际异构硬件(如 NVIDIA Tensor Core、TPU/训练芯片)中,未深入探讨 Axe 布局【是否】在【某些特定硬件上】【无法】表达其最优内存布局 如 Tensor Core 的特定 swizzle 模式、TPU 的脉动阵列布局, 换句话说,其“统一”是否以牺牲特定硬件的最优布局为代价? 若统一抽象不得不引入冗余或对齐约束,则可能在高性能计算场景中带来不可忽略的开销,论文中的性能对比是否充分证明了其在各类硬件上均能接近手工优化内核?能否澄清“统一”的真实代价与边界?

Axe 的统一抽象【并未】以显著牺牲硬件最优性为代价,论文中的实验数据表明其在多种硬件上能接近或达到手工优化内核的性能,但其“统一”能力仍依赖于布局表达式的正确构造, 且在某些极端场景下【可能存在】表达性或优化空间上的妥协 。具体来说:

| 主条目 | 描述 |
| :— | :— |
| 布局表达能力的完备性 | 1. 核心组件:通过 D(分片)、R(复制)、O(偏移) 三个组件,支持对线程、线程束、内存库、设备网格 等多个维度的映射进行编码。
2. 统一语法覆盖场景
– NVIDIA Tensor Core 的寄存器布局(跨 lane/warp/reg 轴)
– 分布式 GPU 网格上的分片与复制(跨 gpuid 轴)
– AI 加速器中的多维片上内存(跨 P/F 轴)
3. 核心结论:在语法层面 具备跨硬件表达主流布局模式的能力。 |
| 性能证据支持“接近手工优化” | 1. NVIDIA B200(GPU):FP16 GEMM 达到 cuBLAS 的 97%~100%,MoE 层相比 FlashInfer 有 1.20–1.36 倍加速。
2. 多 GPU 场景:GEMM+ReduceScatter 比 cuBLAS+NCCL 快达 1.40 倍。
3. Trainium 1(AI 加速器):FP16 GEMM 匹配手工 NKI 库,MHA 达 1.44 倍加速。
4. 核心结论在主流 GPU 与 AI 加速器上,Axe 生成的内核性能与手工优化库相当或更优。 |
| 潜在的代价与边界 | 1. 非幂二次形状的限制:线性布局抽象对非幂二次形状支持有限,Axe 整数步幅模型支持更灵活形状,但极端不规则布局仍可能需额外转换。
2. 布局构造的复杂性:统一抽象下,为特定硬件(如 Tensor Core 的 swizzle 模式)构造最优布局,仍需开发者/编译器理解硬件约束;未与极度特化、汇编级手工微调内核对比,存在未覆盖边界。 |

Axe 在保持统一抽象的同时,通过其灵活的轴映射机制避免了严重的性能损失,且在论文涵盖的硬件和负载中表现出竞争力。然而,“统一”并不意味着在所有场景下都能自动达到硬件绝对最优,它依赖编译器或开发者正确构造布局,并且在面对未来新型硬件时,其抽象可能需进一步扩展。

问题二:抽象的边界——Axe 是否将硬件映射的复杂性重新交给了开发者?

Axe 的“多粒度、分布感知”编程模型依赖于开发者正确设置布局与执行作用域,这是否将复杂的硬件映射责任重新转移给了程序员?这与 Axe 试图降低开发成本的初衷是否矛盾?在实际使用中的“易用性”,论文示例显示开发者需显式指定Layout 中的轴映射、复制、偏移等细节,并正确使用 with warp()with cta() 等作用域。 这实质上要求程序员对硬件架构有深入理解,而非完全依赖编译器自动化 。这与 Triton 等更高级的抽象(以线程块为集体单位)形成对比。看起来 Axe 在“抽象层次”上的定位模糊性, 它究竟是面向编译器开发者的底层抽象,还是面向算法工程师的高层工具?

Axe 确实将部分硬件映射责任交给了程序员, 但这是一种“可控的暴露”,旨在通过统一的抽象降低跨平台、跨层级的开发成本,而非完全隐藏硬件细节。其目标用户是系统开发者与高性能库作者,而非终端算法工程师。从下面几个角度来说:

目标用户与设计定位

| 主条目 | 细分说明 |
| :— | :— |
| 目标用户定位 | 1. 设计要求:其 DSL 和布局 API 要求显式指定轴映射、分片因子和执行作用域(如 with warp()with cta())。
2. 面向人群:主要服务于编译器开发者、内核库作者和框架集成者,而非普通的机器学习研究者
3. 核心用途:用于构建机器学习编译器和框架,提供可重用、声明式的操作符,以替代冗余的样板代码。 |
| 与 Triton/CuTe 的对比揭示其定位 | 1. CuTe:暴露线程级的循环变换与绑定,追求峰值硬件效率,但编程复杂度高。
2. Triton:提供线程块级的集体语义,开发生产力更高,但会限制部分底层优化的空间。
3. Axe 定位(中间道路):允许在同一内核中混合线程局部控制与集体操作,开发者可自由选择控制粒度。
4. 表达能力:既可通过集体语义实现 Triton 风格的拷贝,也可通过线程级绑定实现 CuTe 风格的地址计算,不强制固定抽象范式。 |
| 降低的开发成本体现在何处 | 1. 代码复用:同一套布局描述,可适配跨设备、跨内存层级的数据映射。
2. 跨平台移植:仅需更换轴绑定,即可将内核从 GPU 迁移至 AI 加速器(参考 Trainium 代码生成示例)。
3. 编译器辅助:自带布局规范化、分组、平铺、切片等代数操作,可自动匹配硬件指令(如 TMA 异步拷贝、脉动阵列 GEMM),大幅减少手动地址计算工作。 |
| 与完全自动化的对比 | 1. 同类路线:Halide/TVM 的自动调度代表了完全自动化的布局优化路线。
2. Axe 设计理念:不追求完全隐藏硬件细节,而是提供精确、可组合的语义工具,让系统开发者在可控的复杂度下实现高性能。
3. 生态互补:可与 PyTorch DTensor、Alpa 等高级分布式抽象形成互补,为其提供底层内核生成支撑。 |

Axe 并未将硬件责任“重新转移”给程序员,而是针对其目标用户(系统开发者)提供了一套比纯手写内核更高效、比全自动编译器更可控的抽象。其“降低开发成本”体现在跨硬件的一致性和布局驱动的代码生成上,而非完全免除硬件知识。这种设计选择与构建可维护、可移植的高性能内核库的需求是一致的。

核心贡献与设计哲学总结
* Axe 的核心贡献在于提供了一套表达力强、跨层级统一、编译器可推理的布局抽象,并在实验中证明其能在多种硬件上实现接近手工优化的性能。
* 它的设计哲学是暴露必要的硬件控制,但通过统一抽象降低跨平台与跨层级的开发复杂度,适用于需要兼顾性能与可移植性的系统级开发场景。

一、 挑战:跨越尺度的布局迷宫

要理解 Axe 的价值,首先需看清当前深度学习系统在布局(Layout)与映射(Mapping)上面临的三重挑战,如图 1 所示。

Axe异构布局编译器:跨GPU/TPU/Trainium的统一编程模型,开启机器学习编译新纪元
图 1 | Axe Layout 组成要素示意图。该图展示了 Axe Layout 的三个组成部分:分片部分 D(由多个 Iter 组成的列表)、副本部分 R(一组 Iter)、偏移量 O。一个 Iter 定义了在某个轴上的线性带步长访问(范围,步长,轴)。Axe 布局以 Iter 为核心构建逻辑索引与硬件资源的映射,D、R、O 协同实现多场景适配。D 将逻辑张量拆解到硬件轴(如 GPU 的 lane、warp),像在张量核心计算中,可把矩阵行、列维度分别映射到 lane 和 warp 轴,保证线程级数据分配合理;R 解决数据复用,如在多 warp 协作时,让不同 warp 持有相同数据块,减少通信开销;O 能避开硬件资源冲突,例如将数据部署到特定编号 warp。这种设计打破传统布局局限,为跨设备、跨内存层级映射提供统一框架,是实现硬件感知能力的关键。

1.1 分布式执行(Inter-Device)

当模型太大,单设备无法容纳时,我们必须将模型或数据分割到多个设备上。这涉及到数据分片(Sharding)复制(Replication)设备网格(Device Mesh) 上的放置策略。不同的策略(如数据并行、模型并行、专家混合并行)对应不同的通信模式,如 All-Reduce、All-Gather,需要框架或编译器做出明确选择并优化通信与计算的重叠。

1.2 内存与线程层次(Intra-Device)

在单个 GPU 或加速器内部,硬件具有复杂的层次结构:
* 内存层次:全局内存 -> 共享内存 -> 寄存器。
* 线程层次:线程网格(Grid) -> 线程块(Block/CTA) -> 线程束(Warp) -> 线程(Lane)。

高效的内核必须精心设计数据如何在各级内存间分块(Tiling)、搬运,以及计算任务如何映射到线程层次上。特别是像 Tensor Core 这样的专用计算单元,要求特定线程组以特定格式协同读取寄存器中的数据。

1.3 硬件异构性(Heterogeneity)

硬件的世界并非只有 GPU。Google 的 TPU、AWS 的 Trainium 等 AI 加速器有着与 GPU 截然不同的内存架构(如多维暂存器、存储体约束)。即便在 NVIDIA 家族内部,从 Ampere 到 Hopper 再到 Blackwell,Tensor Core 的片上数据格式和要求也在不断变化。编译器必须为每种硬件生成定制化代码,同时为程序员提供相对统一的体验。

现有的工作往往只聚焦于某一层:
* GSPMDAlpa 擅长分布式;
* TritonCuTe 专注于 GPU 内核开发;
* PallasNKI 则针对特定加速器。

缺乏一个贯穿多层的统一抽象,是系统优化道路上的一道鸿沟。

二、 核心创新:Axe 布局抽象——命名轴的统一映射

Axe Layout 的核心理念非常优雅:它将逻辑张量的索引,通过一组命名的轴(Named Axes),映射到一个多维的物理空间坐标上。这个物理空间可以涵盖设备 ID、GPU 线程束、内存存储体等任何硬件资源。

一个 Axe 布局 L 由一个三元组定义:L = (D, R, O)。让我们通过论文中的图 1 和示例来拆解它。

Axe异构布局编译器:跨GPU/TPU/Trainium的统一编程模型,开启机器学习编译新纪元
图 1 | Axe Layout 组成要素示意图。该图展示了 Axe Layout 的三个组成部分:分片部分 D(由多个 Iter 组成的列表)、副本部分 R(一组 Iter)、偏移量 O。一个 Iter 定义了在某个轴上的线性带步长访问(范围,步长,轴)。

2.1 分片(D – Shard)

  • 是什么:一个有序的 Iter 列表。每个 Iter 是一个三元组 (范围, 步长, 轴),例如 (8, 4@lane) 表示在 lane 轴上,有 8 个连续元素,每个元素在物理空间中间隔 4 个单位。
  • 作用D 将逻辑索引空间划分到多个硬件轴上,产生一个基础坐标。它是对传统“形状-步长(shape-stride)”模型的泛化,允许步长与命名的硬件轴(如 thread, warp, gpuid, sram_bank)绑定。
  • 示例:将一个形状为 (8, 16) 的逻辑块映射到 GPU 线程和寄存器。可以表示为:
    D = ( (8, 4@lane), (2, 1@warp), (4, 1@lane), (2, 1@reg) )
    这表示逻辑索引被分解为 8, 2, 4, 2 四个因子,分别分布在 lane, warp, lane, reg 轴上。

2.2 副本(R – Replica)

  • 是什么:一个无序的 Iter 集合。这些 Iter 的枚举独立于逻辑索引
  • 作用:将 D 产生的基础坐标进行复制或广播每个副本 Iter 定义了一组偏移量,加到基础坐标上,从而在物理空间创建多个数据副本。
  • 示例:在 warp 轴上复制 2 份,副本间间隔 4 个 warp:R = [ (2, 4@warp) ]

2.3 偏移(O – Offset)

  • 是什么:一个固定的坐标偏移向量(每个轴一个整数值)。
  • 作用将所有坐标整体平移,用于数据对齐、预留资源或实现特殊的放置策略。
  • 示例:在 warp 轴上整体偏移 5 个单位:O = 5@warp

完整的映射公式:对于一个逻辑索引 x,Axe 布局产生一个物理坐标的集合:L(x) = { D(x) + r + O | r ∈ R }

如果 R 为空,则 L(x) 是单点集;否则,其大小等于 R 中所有 Iter 范围的乘积。

统一性的体现:让我们通过图 2 中的例子,感受 Axe 如何统一不同场景。

Axe异构布局编译器:跨GPU/TPU/Trainium的统一编程模型,开启机器学习编译新纪元 图 2 | 不同场景下的 Axe 布局示例。该图从左至右展示了三个场景:左) 将 8×16 逻辑块映射到 4 个 GPU warp(32 lane/线程)和 2 个寄存器;中) 在 2×2 GPU 网格上分布式共享一个 64×128 矩阵;右) 映射到 AI 加速器的 2D 分区 SRAM 和 NVIDIA Blackwell 的 2D 张量内存。所有场景均使用同一套(D, R, O)语法描述。

上图左列、中列、右列分别代表不同场景:

  • 场景 A(GPU Tensor Core):描述一个适配 Tensor Core 指令的寄存器布局,涉及 lanewarpreg 轴。该场景 聚焦单设备细粒度资源分配,8×16 张量块适配 GPU 的 warp – lane – reg 层级,D 拆解维度、R 实现数据复用、O 调整位置,充分发挥线程级并行性,减少交互开销。
  • 场景 B(分布式 GPU 网格):描述一个矩阵在 4 个 GPU 间的分片与复制,涉及 gpuid_xgpuid_ym(内存)轴。这可以直接对应像 S(0)S(1)(完全分片)或 S(0)R(行分片并复制)这样的高层分布式策略。针对多 GPU 分布式场景,Axe 布局灵活实现全分片或分片 + 复制策略,全分片减少跨设备通信,分片 + 复制平衡通信与计算效率,适配 Alpa 等分布式框架并行策略。
  • 场景 C(AI 加速器内存):描述数据在加速器专用内存(如片内 SRAM)中的排布,涉及分区轴 P 和自由轴 F体现对异构加速器的适配,AI 加速器 SRAM 通过布局避免内存 bank 冲突,Blackwell GPU 张量内存布局适配专用存取指令,验证 Axe 可贯穿全硬件栈,为统一编译优化奠基。

可以看到,无论是设备间的分片,还是设备内寄存器的排布,亦或是加速器特殊内存的约束, 都可以用同一套(D, R, O)语言来形式化地描述。 这为编译器进行跨层、跨硬件的统一分析和优化提供了可能。

三、 Axe 编译器:多粒度感知的编程与编译

有了强大的布局抽象,Axe 团队在此基础上构建了一个多粒度、分布式感知的编译器。其核心思想是:允许程序员在一个内核中,自由混合线程级别的细粒度控制和线程块/设备级别的集体操作语义, 由编译器基于 Axe 布局自动推导出高效的硬件原生调度。

编译器的工作流程概览如下:

Axe异构布局编译器:跨GPU/TPU/Trainium的统一编程模型,开启机器学习编译新纪元 图 3 | Axe 编译器概述。该图展示了 Axe 编译器的工作流程。左侧是用 Axe DSL 编写的 GEMM 内核代码,其中高亮显示了执行作用域、带 Axe 布局的张量以及算子。程序使用了加载和 GEMM 宏,以及一个三阶段流水线。右侧展示了张量在共享内存和寄存器中的布局,以及copy.async算子如何根据布局被 lowering 为线程绑定的循环,发出具体的硬件指令。

  • 左侧:GEMM 内核代码通过执行作用域明确硬件层级,kernel 覆盖所有线程,cta 限定线程块内操作,warp 针对线程束级计算。带 Axe 布局的张量定义是关键,共享内存张量确保线程块内数据高效共享,寄存器张量适配 warp 内线程访问模式,让编译器自动推断数据位置。
  • 右侧上方:共享内存(shared memory)和寄存器(register)中的张量携带 Axe 布局信息,标注不同布局下迭代器的范围、步长与轴(如 “16@reg、4@lane、2@reg、1@lane、16@reg” 等)。
  • 右侧中间:“tile” 工具通过规范化、分组、切片实现布局变换与匹配,将寄存器分片与 lane 轴组合形成 warp 级视图,识别适配张量核心指令的硬件原生布局,如 “8@reg、1@lane” 等布局的组合。
  • 右侧下方:代码生成中,编译器依据 Axe 布局推导地址,生成绑定线程的循环和专用指令:copy.async 算子被下转为绑定线程的循环,该循环发出 cp.async.cg.shared.global 指令,指令中的地址由 Axe 布局推导得出。开发者仅需关注高层逻辑,兼顾手写优化内核性能与开发效率,解决 “性能与效率难兼顾” 问题。

3.1 多粒度执行作用域(Execution Scopes)

Axe DSL 引入了显式的作用域概念,来界定一组线程(或设备)共同执行一个操作:

  • kernel: 内核启动的所有线程。
  • cta (或 block): 一个线程块。
  • warp: 一个线程束。
  • thread: 单个线程。
  • device: 跨设备的集合。

程序员可以在不同作用域内编写代码,编译器会理解其语义。例如,在 cta 作用域内的一个 copy 操作,意味着整个线程块协同完成一次数据搬运。

3.2 携带布局的张量抽象(Tensor with Layout)

在 Axe 中,张量是一个一等公民,它携带了形状、数据类型、内存指针和至关重要的 Axe 布局信息。这使得编译器能够精确地知道每个张量元素在物理硬件上的位置。

“`python

示例:定义一个分布在4个GPU上,并在每个GPU内按特定方式排布的分布式张量

input = Tensor(shape=(4, 64, 64), layout=((4, 64, 64), (1@gpuid, 64, 1)))
“`

3.3 高阶算子与调度(Operators and Schedules)

Axe 提供了一组高阶算子(如 copygemmreduce),类似于嵌入在原生内核语言中的集体库(如 CUB),但更加通用。关键创新在于算子的具体实现(调度)是根据操作数张量的 Axe 布局和当前执行作用域,由编译器自动分派的

  • 同一 copy 算子,用在 thread 作用域操作寄存器张量,可能被编译为简单的寄存器移动指令。
  • 用在 cta 作用域从全局内存拷贝到共享内存,可能被分派为向量化加载指令或异步拷贝指令(如 cp.async)。
  • 如果源和目的张量分布在不同的设备上,copy 可能会在底层被翻译为一个 all-gatherbroadcast 集合通信操作。

3.4 布局操作与编译器分析

布局是编译器进行分析和优化的基石。Axe 定义了一组核心布局操作:

  • 规范化(Canonicalize):将布局转换为唯一的标准形式,用于判断两个布局在语义上是否等价。这涉及消除大小为1的迭代维度、合并相同轴上的相邻迭代维度等规则。
  • 分组(Group):给定一个逻辑形状 S,将布局中的迭代维度列表分割或融合成连续的块,使得每个块的维度乘积等于 S 的对应维度。这是进行分块、切片等操作的前提。
  • 分块(Tile / Kronecker Product):这是支持分块计算和利用 SIMD/张量核心指令的关键。给定两个布局 A 和 B,分块操作 A ⊗ B 产生一个新布局,其中 B 布局作为“内部块”,A 布局作为“外部块”并按 B 的跨度(span)进行缩放,以确保内部块互不重叠。公式如下:f_{A⊗B}(x || y) = f_A(x) ⊙ span(f_B) + f_B(y)
  • 切片(Slice):给定一个张量布局和它的一个逻辑子区域 R,推导出该子区域对应的布局 L[R:S],使得其映射与原布局在该区域上完全一致。这允许编译器只对感兴趣的数据区域生成高效代码。

通过这些布局操作,编译器能够:
* 匹配硬件指令:判断某个张量(或切片)的布局是否符合特定硬件指令(如 Tensor Core、TMA 异步拷贝)的要求。
* 推导地址计算:自动生成从逻辑索引到复杂物理地址(可能涉及设备 ID、线程 ID、内存体偏移)的计算代码。
* 优化数据搬运:为 copy 等算子选择最合适的实现方式。

四、 效果评估:性能与生产力的双重胜利

Axe 的实现基于 Apache TVM 的 TensorIR。评估围绕三个核心问题展开:

4.1 在最新 GPU 上能否达到接近最优性能?

测试平台为 NVIDIA B200 GPU。对比基线为行业标杆 cuBLAS 和流行的 Triton。
* FP16 GEMM:在多种来自真实模型(如 LLaMA-3.1, Qwen3)的权重形状上,Axe 达到了 cuBLAS 97%到 100%的吞吐量,而 Triton 约为 90%。Axe 的成功在于其 DSL 能轻松表达Warp 专业化线程块集群(Thread Block Cluster)等先进特性。例如,在 Blackwell 架构上,Axe 内核可以显式指定两个 SM 协同处理一个 GEMM 块,而 Triton 编译器自动生成的计划则未能利用此特性。
* MoE 层:在 Qwen3-30B MoE 层推理中,对比 FlashInfer 和 SGLang(基于 Triton),Axe 获得了最高达 1.36 倍的加速。Axe 能够精细地编排第一组和第二组 GEMM 之间的流水线,实现计算重叠。

Axe异构布局编译器:跨GPU/TPU/Trainium的统一编程模型,开启机器学习编译新纪元
图 11 | Qwen3-30B MoE 层延迟对比。在不同输入 token 数量下,Axe(橙色线)的延迟始终低于 FlashInfer(蓝色线)和 SGLang(深蓝线)。

4.2 多设备执行能否提升?

测试多 GPU GEMM+Reduce-Scatter 工作负载。Axe 将分布式张量、求和算子与计算融合在单个内核中,由编译器自动分派到 multinem.ld_reduce 等底层原语。
* 对比非融合的 cuBLAS+NCCL 基线以及 Triton-distributedAxe 实现了最高 1.40 倍的加速。关键在于Axe 在单个内核内实现了通信与计算的细粒度重叠,从而提高了内存带宽和 Tensor Core 利用率。

Axe异构布局编译器:跨GPU/TPU/Trainium的统一编程模型,开启机器学习编译新纪元
图 12 | 多 GPU GEMM+Reduce-Scatter 延迟对比。在不同问题规模下,Axe(红色)的延迟远低于 cuBLAS+NCCL(橙色)和 Triton-distributed(蓝色)。

4.3 能否支持异构硬件后端?

测试平台为 AWS Trainium-1 AI 加速器。对比基线为供应商手工优化的 NKI 库。
* FP16 GEMM:Axe 生成的代码性能与手工 NKI 库完全匹配
* 多头注意力(MHA)Axe 实现了平均 1.26 倍,最高 1.44 倍的加速。通过更优的软件流水线和内存分配计划超越了手工实现。
* 生产力:手工 NKI 实现需要 120 行代码(GEMM)和 1188 行代码(MHA),而Axe 仅需 78 行和 228 行。高级的 DSL 极大简化了调度和地址计算。

Axe异构布局编译器:跨GPU/TPU/Trainium的统一编程模型,开启机器学习编译新纪元
图 13 | FP16 GEMM 和多头注意力的测试结果。左图显示 Axe(橙色)在 FP16 GEMM 上性能与手工 NKI(蓝色)持平;右图显示 Axe 在 MHA 上性能显著优于手工 NKI。

五、 相关工作对比:Axe 的独特定位

5.1 布局系统

5.1 底层布局抽象

  • CuTe:Axe 继承了 CuTe 的形状-步长代数,并将其泛化。核心区别在于,CuTe 的映射是单值的,主要用于 GPU 内核内的工作划分和 TMA 地址计算;而 Axe 引入了命名轴和(R, O),支持多值映射(副本),并天然覆盖分布式和异构硬件。
  • 线性布局(Linear Layouts):采用基于 F2 的位线性函数,对形状有 2 的幂次限制,在处理非 2 幂次形状(如某些分布式场景)时受限。 Axe 的整数线性形式更为通用。

5.2 深度学习编译器与 DSL

  • Halide/TVM:算法与调度分离的开创者。Axe 更侧重于为异构、分布式环境定义统一的底层数据映射原语。
  • Triton:提供线程块级别的集体编程模型,隐藏线程级细节。 Axe 则允许在同一内核中混合集体语义和线程级控制,兼具生产力和对尖端硬件特性的控制力。
  • CuTeDSL / Mojo / TileLang:这些 DSL 在不同层次上抽象数据布局和计算。Axe 的布局抽象可以作为它们底层的一个通用中间表示,增强其跨硬件和分布式的能力。

5.3 分布式机器学习框架

  • GSPMD / Alpa / PyTorch DTensor:这些框架在高层定义张量在设备网格上的分片。Axe 可以看作是其向设备内部的延伸,用同一套语言描述了分片张量在每个设备内部的具体内存和线程布局,使得跨层联合优化成为可能。
  • TileLink / Triton-Distributed:将集合通信引入内核。Axe 的分布式感知能力与之类似,但其基于统一布局抽象的设计更具扩展性和硬件无关性。

六、 总结与展望:迈向统一的软硬件协同栈

Axe Layout 的提出,是机器学习系统领域向统一抽象迈进的重要一步。它通过(D, R, O)这一简洁而富有表达力的三元组,为数据与计算在跨越设备、内存层次和异构单元的物理空间中的放置,提供了一套共享的词汇表。

基于此构建的 Axe 编译器,验证了这种统一抽象的威力:它让开发者能够以接近手工调优代码的性能, 轻松编写出高效利用最新 GPU 特性、实现通信计算重叠、并能跨 GPU 和 AI 加速器移植的复杂内核。

Axe 不仅仅是一个编译器或 DSL,它更是一种思维范式。它试图弥合高层分布式编程与底层硬件微架构之间的语义鸿沟,为下一代机器学习编译器和框架奠定了坚实的基础。未来,我们有理由期待 Axe 或类似的思想被更广泛地采纳,从而真正实现“编写一次,高效运行在任何 AI 硬件之上”的愿景。


关注“鲸栖”小程序,掌握最新AI资讯

本文来自网络搜集,不代表鲸林向海立场,如有侵权,联系删除。转载请注明出处:http://www.itsolotime.com/archives/20035

(0)
上一篇 2天前
下一篇 1天前

相关推荐

  • AdaptCLIP:西门子与腾讯优图联合打造零样本工业异常检测新框架,无需微调实现精准定位

    AdaptCLIP:无需微调的零样本工业异常检测新框架 当前,视觉模型在工业“缺陷检测”等领域的应用已相对成熟。然而,广泛使用的传统模型在训练时对数据要求极高,需要大量精细标注的数据才能达到理想效果。 大模型则有望在“零样本/少样本识别” 条件下,达到与传统模型相当的性能。CLIP 是 OpenAI 于 2021 年发布的开源视觉-语言基础模型。本研究在其基…

    2026年1月19日
    6700
  • Python进阶之路:避开6个常见陷阱,从中级迈向高级开发者

    这已经不再是语法的问题。 如果到了 2026 年你还在学新的 Python 语法,你不是卡住了——你是在拖延。 刻薄吗?也许。 是真的吗?绝对。 大多数中级 Python 开发者不是因为不够懂 Python 而失败。 他们失败,是因为还在用新手的思维……只是写得更快。 过去 4 年多里,我审阅过上百个 Python 代码库——创业项目、内部工具、“在我机器上…

    2026年1月11日
    5000
  • 华为诺亚&港中文发布SCOPE框架:让LLM Agent从错误中学习,实现Prompt自我进化

    在 LLM Agent 领域,一个常见的问题是:Agent 明明“看到了”错误信息,却总是重蹈覆辙。 当 Agent 遇到工具调用错误时,错误日志里往往已经包含了解决方案——正确的参数格式、有效的 API 用法、甚至是直接可用的替代方案。然而,静态的 Prompt 无法让 Agent 从这些反馈中“学到教训”,导致它们陷入“错误循环”:承认失败,却重复同样的…

    2025年12月26日
    9700
  • 跨越模态边界:构建真正理解图像、表格与文本的多模态RAG系统

    构建多模态 RAG 系统的终极指南 三个月前,我们新开发的 AI 应用在诸多看似简单的问题上频频“翻车”。问题根源并非 AI 不够智能或数据不足,而是因为答案蕴含在一张图片里,而当时的系统仅能处理文本。 这一时刻迫使我直面一个在构建 RAG 系统时长期回避的核心问题:我们花费数年时间教 AI “阅读”文字,却忽略了人类同样通过图像、表格、公式和流程图来“表达…

    2025年12月16日
    8100
  • Context7架构革命:子代理架构如何将AI上下文消耗降低65%?

    VibeCoding 必备的 MCP 工具之一 Context7 刚完成了一次重要的架构重构,旨在解决上下文臃肿问题,让 AI 更高效地获取项目文档。此前,用户的一个简单问题,系统就会拉取大量文档,平均上下文大小达到 3000 tokens。这不仅拖慢了响应速度,还增加了不必要的成本。 新架构细节 针对这一问题,团队对产品做了一个关键改进:子代理架构。开发者…

    2025年12月27日
    12800