我们之前推送过多篇关于 Mega Kernel 的文章,今天来探讨这篇:《无需手动构建MegaKernels!Luminal 编译生成 MegaKernels:解决 GPU SM 负载不均,消除内核启动开销与内存气泡,适配任意架构!》。作者郑启航深入分析了开源编译器 Luminal,并结合其在 H200 上运行 gemma-3-4b 的实际测试,梳理了其 IR 设计与搜索机制。
编译流程分为六个步骤:前端利用 GraphTensor 描述计算过程,通过 ShapeTracker 记录张量的布局信息,从而消除了大量显式的形状操作,最终生成仅包含 20 个基本操作(primop)的高层 IR——HLIR。整个计算图按照 graph_break 切分成多个 chunk,结构相同的 chunk 会被合并为 group。随后,以 group 为单位进行 egraph 饱和搜索(egraph saturation),生成包括 CUDA kernel 和 cuBLAS 调用在内的候选方案,再通过实际测量延迟来筛选出最优实现。提取出的低级 IR(LLIR)经过模板参数替换和 NVRTC 即时编译,生成 GPU 可执行代码,最后由 Runtime 构建 CUDA Graph 来执行推理。
实测结果显示,Luminal 的 fp32 推理吞吐量远低于 vLLM,且当前并未实现宣传中提到的“自动导出 FlashAttention”的融合。作者指出,该编译器缺少内存层级描述与 tiling 优化,认为其宣传目标与实际进展之间存在落差,并对其“编译器”定位提出了质疑。
我去年就已经关注到了 Luminal 编译器。它宣称通过全自动编译可以达到 80% 的峰值性能,并且能搜索出 FlashAttention,之后还获得了投资。最近,我实际运行了它的 gemma-3-4b 示例,并借此机会梳理了它的 IR 设计和搜索机制。
一、整体概览
Luminal 的编译流程大致分为六步:
- Frontend:用户通过
GraphTensorAPI 来编写算子(如matmul、softmax)。前端中的expand_dim或permute等操作,仅修改张量附带的ShapeTracker元数据,不会生成新的 op。因此,最终 HLIR 图中的 op 数量远少于用户表达式的节点数。 - HLIR:前端操作最终凝聚成一张由 20 个 primop 组成的张量 DAG,这就是 Luminal 自身的高层 IR。
- Partition / Group:按照前端插入的
graph_break,将整张 HLIR 切成若干 chunk,再将结构相同的 chunk 合并成唯一的 group。后续的步骤都以 group 为单位推进。 - Egglog saturation:将每个 group 序列化为 egglog 程序,并执行等价关系饱和搜索。对于 4B 模型,在单核 CPU 上这个过程大约需要 30 分钟,这是编译开销的主要来源。
- Extraction / LLIR:从饱和后的 egraph 中先提取候选方案,然后将其降级(lower)到 LLIR。
- Codegen / Runtime:每个 LLIR 节点先通过 Codegen 生成 CUDA kernel(或 cuBLAS 调用),再由 Runtime 将它们串联进 CUDA Graph 来执行推理。整体上更像一个 JIT 过程:kernel 编译在 Codegen 阶段即时发生,buffer 分配在 Runtime 启动时完成,而不是提前全编译好的 AOT 方式。
二、HLIR
HLIR 是 Luminal 的高层张量 IR,只包含 20 个 primop,代表最小的原子运算。一个 Gemma 3 4B 模型的 HLIR 大约包含 5000 个 primop。
这 20 个 primop 可以分为七类:
| 分类 | 操作 |
|---|---|
| I/O | Input, Output, Constant |
| DType / Range | Cast, Iota |
| Unary | Exp2, Log2, Sin, Recip, Sqrt |
| Binary | Add, Mul, Mod, LessThan |
| Reduction | SumReduce, MaxReduce, Softmax |
| Indexing | Gather, Scatter |
| Fallback | CustomOpKind |
以矩阵乘法为例:a: [M, K] @ b: [K, N] -> [M, N]。HLIR 中不存在显式的 for k 循环,其对应的前端代码如下:
// src/frontend/matmul.rs
let mul = self.expand_dim(1, n) * rhs.permute((1, 0)).expand_dim(0, m);
let ret = mul.sum(2);
当代入具体 shape 后,构造出的 HLIR 仅有 5 个节点。
该设计思路与 Jittor 高度相似,都是通过扩展 layout 来表征循环区域。观察上述 Mul 和 SumReduce 节点:input 节点的 rank 为 2 维,而 Mul 使用 dims=[2, 4, 3],两个输入的 strides 分别是 [(z*3), 0, z] 和 [0, z, (z*4)](其中 z 代表 sizeof(dtype))。stride 中的 0 正是由 expand_dim 产生的 broadcast 维度。系统中没有独立的 Shape 操作 op,相关功能基本都通过 ShapeTracker 来实现。
值得留意的是,Softmax 并未被拆解为 Exp2 + SumReduce + Div 的组合,这很可能是为了后续的 rewrite 和 pattern match 操作更加便捷而做出的设计取舍。
2.1 ShapeTracker
ShapeTracker的核心职责是替代显式的Expand/Reshape/Permute等操作。可以这样理解:它先记录 Layout 信息,然后在后续计算中实际应用,从而表达这些形状变换操作。其工作流程大致如下:
- 每个
GraphTensor都附带一个ShapeTracker,其中记录了当前的dims、strides、offset、mask等影响数据访问顺序的信息。 expand_dim、permute、reshape、slice这类前端函数仅修改ShapeTracker,不会向 HLIR 图中插入新节点。- 当真正创建计算 op(
Mul、Add、SumReduce)时,当前的ShapeTracker会被读取并固化到该 op 的输入签名中。
因此,在上述例子中,HLIR 包含的是一个携带 shape/stride 信息的 Mul,而不是 Expand -> Permute -> Mul 这样的操作链。
具体到几种常见操作:
expand_dim:在dims中插入一维,对应的 stride 设为0,表示广播操作。permute:重新排列dims和strides,表示仅改变观察顺序,不涉及数据搬移。reshape/slice:更新dims、offset、mask等视图信息,同样不新建 HLIR op。
三、Partition / Group
HLIR 构建完整图后,直接对整个图进行 egg 搜索代价过高,尤其对于 Transformer 这类结构高度重复的模型而言,也缺乏必要性。因此,这一步执行两项任务:
- Partition:将整张 HLIR 划分为多个 chunk,每个 chunk 是一个“完整的子图,内部统一进行搜索/编译”。切分点由前端显式指定(
graph_break),典型位置包括 transformer 每层的边界或 KV cache 更新处这类天然分界点。 - Group:再将结构完全一致的 chunk 合并为同一个 group。每个 group 只需进行一次 egraph 搜索,结果可供所有成员 chunk 共享。
以 Gemma 3 4B 在 H200 上的运行情况为例,相关规模数据如下:
| 层级 | 数量 | 说明 |
|---|---|---|
| chunk | 35 | 整张图切分为 35 块,每块约含 140 个 HLIR op |
| group | 5 | 35 块按结构去重后剩下 5 类模板 |
这 5 个 group 对应的模型结构分别为:
- 1 个 decoder layer group:34 层 decoder layer 全部共享这一套模板,这是去重收益的主要来源。
- 1 个 embedding group:处理 token lookup 部分。
- 1 个 final norm + logits group:处理模型最后的输出头部。
- 2 个辅助 group:对应 prefill / decode 入口、RoPE / mask 等不属于主干 decoder layer 的模块。
四、Egglog saturation
该步骤利用 egraph saturation 技术,对 HLIR 进行等价变换与优化,生成大量等价的实现候选。搜索过程主要完成四项工作:
4.1 单个 HLIR primop 匹配 kernel op
每个 HLIR op(Add、Mul、SumReduce、Exp2……)都有对应的 kernel_rewrite<HLIR, Kernel> 规则,用于将其扩展为 dialect 级 KernelOp(如 CUDA 的 KernelAdd、Metal 的对应 op 等)。17 个 HLIR 计算 op 各有一条这样的 rewrite 规则(位于 crates/luminal_cuda_lite/src/kernel/hlir.rs)。这一步将纯 HLIR op 转化为“可真正执行的候选”。
最简单的这类 rewrite 规则,在代码中实际上是一个通用 helper:
pub fn kernel_rewrite<H: Default + EgglogOp, L: Default + EgglogOp>() -> Rule {
...
rule(union(hlir_op.clone(), llir_op)).fact(eq(dt, dtype(hlir_op)))
}
其做法非常直接:当看到一个 HLIR op 时,就将其与对应的 KernelOp union 到同一个 eclass 中。例如,Mul 可以被 rewrite 为 KernelMul,Add 可以被 rewrite 为 KernelAdd。
4.2 多个 HLIR primop 对应高效库调用
像 Mul + SumReduce 这种模式(即矩阵乘法)会被单独识别,并 lower 到 cuBLAS / cuBLASLt 的 sgemm 变体。相关规则的命名示例包括:cublas sgemm row-major x column-major、cublaslt batched column-major x row-major(位于 crates/luminal_cuda_lite/src/host/cublas/ 和 cublaslt/ 目录下)。同一个模式可根据 shape / stride 的不同匹配到不同的库变体。
这类高阶模式的改写规则大致如下:
(rewrite
(Op (SumReduce ...) (ICons (Op (Mul ...) ...) (INil)))
(Op (CuBlasSGemm ...) ...)
:name "cublas sgemm row-major × column-major")
4.3 Batch 与 Shape 展平
基于 Layout 进行简化操作,相关规则位于 src/egglog_utils/matmul_flattening/*.egg(共三条规则):
batch_merge_a_contig.egg/batch_merge_b_contig.egg:当 batch × matmul 运算中,一侧为 contiguous 布局,另一侧为 broadcast 时,将其展平为二维 matmul。squeeze.egg:移除无效维度。
4.4 In-place 候选与别名检查
Scatter 操作会被改写为 ScatterNoCopy(ConsumedBuffer(dest), ...)。其中 ConsumedBuffer 并非实际执行的操作,而是搜索阶段用于标记所有权的标识符。由于 egraph 中的节点可能存在环形依赖关系,且难以统计使用者数量,因此引入 ConsumedBuffer 的目的是将使用情况分析显式纳入搜索空间:如果目标 buffer dest 后续不再被其他操作读取,就可以实现原地写入。
后续的 cleanup / base_cleanup 规则集正是负责检查这一点:
- 如果
dest之后没有其他读取者,则保留ConsumedBuffer(dest),最终允许采用ScatterNoCopy,即执行原地写入。 - 如果
dest之后还有其他读取者,则移除该候选方案,回退到普通的Scatter操作。
4.5 Saturation
Luminal 并未将所有重写规则混合执行,而是划分为 4 个规则集,分阶段进行搜索,以缩小每轮重写的匹配范围,降低编译开销:
expr:主重写规则集,涵盖 HLIR 对应的 kernel 候选生成、batch matmul 展平、ConsumedBuffer 注入等全部操作。dtype_prop:辅助函数(function dtype (IR) DType :merge new)沿数据流传播 dtype 的规则。cleanup:若dest被其他操作读取,则删除ConsumedBuffer,并级联清除 ScatterNoCopy 候选。base_cleanup:独立的规则集,放在最后执行,专门处理(union ?cb ?dest)这类不可逆操作,必须等待前面所有规则集饱和后才能安全执行。代码中已标注 TODO,承认这是系统的脆弱点。
实际执行顺序如下:
(repeat 10 (saturate expr) (saturate dtype_prop) (run))
(saturate expr)
(saturate cleanup)
(saturate base_cleanup)
在我的实验中(Gemma 3 4B,H200,34 层 transformer),34 层被切分为 35 个 chunk,合并为 5 个结构等价的 group。每个 group 的 egraph saturation 产生了约 5076 个 enode 和 3633 个 eclass,单 CPU 核耗时约 30 分钟。
五、Extraction
saturation 完成后,Luminal 直接通过实际执行时的延迟来获取真实开销:
- 随机选择:对于每个 eclass,随机选取一个 enode,将其 lower 为 LLIR,通过 NVRTC 编译,实际执行并测量延迟(默认重复 10 次取平均值),同时检查结果中是否出现 NaN。若编译失败或出现 NaN,则更换候选方案,最多重试 100 次;若全部失败则直接 panic(参见
src/graph.rs:653)。 - 变异:以当前最快的候选方案作为种子(默认保留 1 套),每代生成 30 个变异版本:在存在多个可选 enode 的 eclass 中,随机选取若干节点替换为其他选择,并通过哈希去重避免重复测量。
- 评估:每个变异同样经历 lower + 编译 + 执行 + 测量的流程。若比种子方案更快,则取代种子。
- 预算:每个 group 最多评估
options.limit个候选方案(Gemma 3 4B 有 5 个 group,设置GEMMA_SEARCH_GRAPHS=3即每个 group 评估 3 个候选,全模型共 5 × 3 = 15 次 NVRTC + profile)。官方默认值为 500,搜索时间会显著延长。
这种方法用实际测量替代了传统分析建模中难以精确预测的成本模型,但仅适用于稳定的硬件环境,在设计阶段无法使用。
六、LLIR
代码中将 LLIR 定义为:
pub type LLIRGraph = StableGraph<LLIROp, ()>;
StableGraph 可简单理解为节点编号稳定的图容器。LLIROp 是节点内容,边表示依赖关系。dump 出来的内容大致如下:
LLIROp(DialectOp(KernelMul { out_shape: [4, s, 256], ... }))
LLIROp(DialectOp(KernelSumReduce { out_shape: [s], ... }))
LLIROp(DialectOp(CuBlasLt { m: 1024, n: s, k: 2560, ... }))
其中每个节点直接对应一个具体的执行单元,例如:
- CUDA kernel 源码(后续由 NVRTC 实时编译为 GPU 可执行代码)
- Metal kernel(Apple 后端)
- host 端的库调用(如 cuBLAS、cuBLASLt 等现成的 sgemm)
同时包含供 Runtime 使用的元信息:
- 输出 buffer 的大小(符号表达式,支持动态 shape)
- 读写字节数、计算 FLOPs
- 输出是否复用某个输入 buffer(in-place 写入)等
在 LLIR 这一层,可以认为节点间均通过 global memory 传递数据,而 shared memory、register 等更细粒度的层次不会反映在 LLIR 中。
6.1 Gemma 3 4B 的 LLIR
好的,遵照您的指示,我将对提供的文章片段进行深度重写与降重,严格遵守所有规则。
Gemma 3 4B 模型经过编译后,生成的 LLIR 大约包含 7250 个节点,其具体分布如下:
KernelMul 2043 KernelGather 205
KernelAdd 810 KernelSin 68
KernelIota 648 KernelScatter 66
KernelCast 438 KernelLessThan 63
KernelConstant 409 KernelExp2 35
KernelRecip 378 KernelExp 35
KernelSumReduce 375 KernelMaxReduce 34
KernelSqrt 205 KernelSigmoid 32
KernelScatterNoCopy 2
从数据中可以清晰地看到,Elementwise 类型的操作占据了绝对主导地位。值得注意的是,代码中仅有的两个 KernelScatterNoCopy 节点,全部用于实现 KV cache 的原地写入操作。这正是前文所述 ConsumedBuffer 机制发挥作用的结果:只有当某个 buffer 只有一个使用者时,egglog 系统才会将常规的 Scatter 操作保留并优化为 ScatterNoCopy 版本。
七、代码生成(Code Generation)
LLIR 本质上只是一组数据描述,GPU 无法直接执行。Luminal 对此的处理流程如下:
7.1 模板与参数化
每一种 kernel 操作都维护着自身的 C++ kernel 模板。在代码生成阶段,系统会将节点中的 shape、stride、dtype 等参数填入这些模板,从而生成一段具体的 CUDA 源码。接着,这段源码会被提交给 NVRTC 进行 JIT 编译,最终生成 GPU 可以实际执行的 kernel。整个流程中没有 loop-level IR、schedule pass 或 tiling 等中间环节,kernel 的形态完全由模板决定。
例如,对于一个 KernelAdd 节点,代码生成的实际工作就是进行模板替换,最终拼接成类似下面的完整源码:
extern "C" {
__global__ void add_k(float *C, const float *A, const float *B, const int* dyn_dims) {
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (const_z >= /* n_elements */) return;
C[/* out_idx */] = A[/* a_idx */] + B[/* b_idx */];
}
}
为了避免对相同的 kernel 进行重复编译,Luminal 会以生成的源码为键值进行进程内缓存。当两个节点最终拼接出的源码完全一致时,系统会直接复用已经编译好的函数。
7.2 库调用
当然,并非所有 LLIR 节点都需要走源码生成这条路。如前文搜索部分所述,Mul + SumReduce 的组合会被重写为 matmul 操作,最终对应的是 cuBLAS / cuBLASLt 库的入口封装。对于这类节点,代码生成的任务仅仅是选择一个合适的库函数入口,并将 stride、leading dimension 等参数调整为 cuBLAS 所支持的格式。在执行时,直接调用诸如 cublasSgemm 这样的 host 端函数即可。
八、运行时(Runtime)
由于代码已在上一阶段编译完成,交付给运行时的是一组相互独立的 kernel 和库调用。运行时的工作主要分为两个阶段:
load_llir:首先装配好每个 group 的 LLIR,设置输入/输出指针,分配中间 buffer,并将其捕获为 CUDA Graph。execute:每次推理时,按 chunk 顺序回放对应的 CUDA Graph,并取出输出结果。
8.1 加载阶段
加载阶段首先读取每个 group 的 LLIR,然后执行以下操作:
1. 分配 Buffer:
运行时遍历 LLIR 中的每一个节点,根据节点的输出大小表达式,结合当前的 dyn_map(例如 M=1024, N=4096)计算出所需的字节数。然后,它会与已有的 buffer 进行比较:如果现有 buffer 容量足够,则直接复用;否则,会调用 cudaMalloc 来分配新的 buffer。
此外,节点可能携带一个标记,表明其输出可以复用某个输入 buffer 的地址。在这种情况下,运行时会直接将 output pointer 指向该 input pointer,而不再进行额外的内存分配。前文提到的 KernelScatterNoCopy(用于 KV cache 原地写入)正是通过这种方式处理的。
2. 打包 CUDA Graph:
每个 group 会被单独处理。运行时按照 LLIR 的顺序排列好该 group 内的所有 kernel,然后调用 CUDA Graph API 将这一整段 launch 序列捕获成一张图。在 Gemma 3 4B 模型上,总共构建了 5 张 CUDA Graph(每个 group 一张),每张图内部封装了 12 到 180 个不等的 kernel。在执行时,只需一次 cuGraphLaunch 调用即可发出整段序列,从而显著减少 launch overhead。
8.2 执行阶段
此阶段的操作非常简单:
- 提供输入数据的指针。
- 按照 chunk 的顺序,依次启动每个 chunk 所属 group 的 CUDA Graph。
- 如果需要,将 output buffer 中的数据读取回 host 端。
总结
首先,将最终测试得到的数据汇总如下:
| 框架 | dtype | TTFT | TPOT | TPS |
|---|---|---|---|---|
| vLLM | bf16 | — | 3.71 ms | 269 |
| vLLM | fp32 | — | 5.81 ms | 172 |
| Luminal main | fp32 | 202 ms | 37.42 ms | 26.7 |
| Luminal fusion | fp32 | 250 ms | 48.13 ms | 20.8 |
再对比 Luminal 官方在 README.md 中的宣传:
- 在性能方面,他们声称 Q8 Llama 3 8B 在 H100 上能达到约 80% 的理论峰值性能。
- 在技术方面,他们表示这套搜索系统 可以自动导出 FlashAttention。
然而,根据我在 H200 上对 Gemma 3 4B fp32 模型的实测结果来看,这两点都难以成立:
- 实际的 TPS 与 vLLM 相比仍有巨大差距。
- 当前代码中也没有一条能够真正将 attention 融合成 FlashAttention 的路径。其 egraph saturation 过程中,没有任何一条规则能够跨越
SoftmaxOp 并将其重写为 FlashAttention。
我的几点看法如下:
- 缺少对 Bufferization / Memory Hierarchy 的描述。
-
缺少 fusion / tiling / scheduling 等方面的优化。
-
早期宣传中声称能自动生成 FlashAttention,然而其输入和规则实际上经过了精心设计[1]。此外,之前的 IR 设计与当前版本差异巨大,原先包含
LoopOut, Let等元素,能够表达复杂程序,但规则编写困难且搜索空间更大。如今又回退到类似 linalg 的纯算子 IR,这使得之前宣传的自动生成 FlashAttention 难以实现。
至少截至本文撰写时,最新的 PR 仍在进行 elementwise fusion[2]。以目前的开发进度来看,显然难以匹配其宣传目标和投资规模[3],我怀疑这是在“编译器”的幌子下,进行手动优化的实质操作。
参考资料
[1] flash_attention_demo/src/code.lisp: https://github.com/luminal-ai/luminal/blob/0ccd344a69226205f1992f43f0dc3ef590bd56b2/flash_attention_demo/src/code.lisp
[2] luminal-ai/luminal/pull/274: https://github.com/luminal-ai/luminal/pull/274
[3] Luminal raises $5.3 million to build a better GPU code framework: https://techcrunch.com/2025/11/17/luminal-raises-5-3-million-to-build-a-better-gpu-code-framework/
关注“鲸栖”小程序,掌握最新AI资讯
本文来自网络搜集,不代表鲸林向海立场,如有侵权,联系删除。转载请注明出处:https://www.itsolotime.com/archives/32384

