“人类直觉本质上不足以捕捉代数变换、数据布局和硬件特定调度决策之间的组合交互。” 这句话来自 Prism 论文,精准揭示了在过去十年中,ML 系统优化领域始终无法跨越的核心瓶颈。
从 TensorFlow 到 TVM,从 cuDNN 到 FlashAttention,我们一直依赖专家手工编写的规则和内核来推动 AI 模型性能的飞跃。然而,这种范式正逐渐触及天花板。手工规则只能覆盖有限的优化空间,而每当有新算子或新硬件出现,适配工作往往需要投入数月的人力。超优化技术曾被视为希望所在,它试图通过自动搜索来发现最优程序,但枚举式搜索会遭遇组合爆炸,采样式搜索又难以保证找到最优解。
当前的主流超优化器主要分为两大流派:枚举式(如 TASO、Mirage)和采样式(如 AlphaEvolve)。
- 枚举式超优化器通过穷举所有可能的程序结构来寻找最优解,虽然能提供最优性保证,但候选程序数量会随算子数量和执行层级呈组合式增长,导致其无法处理复杂的 LLM 工作负载。
- 采样式超优化器则借助大语言模型或进化算法来引导搜索,虽然能探索更大的优化空间,但它将优化景观视为无结构的,搜索行为不稳定,且无法保证覆盖到最优解。
Prism 的核心洞察在于:大多数候选程序之间存在高度的结构相似性, 它们共享相同的计算结构,只是在并行化参数和数据映射上有所不同 。与其逐一枚举每个具体程序,不如用符号变量表示这些可变部分,将整个程序族编码为一个单一的符号结构。
基于这一洞察,Prism 开创了符号化超优化的新范式,将优化过程分解为两个层级进行搜索:
- 上层构建符号图(sGraph),用于编码整个程序族;
- 下层则将符号图实例化为具体的实现。通过符号推理对搜索空间进行结构化剪枝,Prism 在保证最优性的同时,大幅提升了可扩展性。
实验结果表明,在五个主流 LLM 工作负载上,Prism 的性能比当前最优的超优化器 Mirage 最高提升了 2.2 倍,比传统编译器方法(PyTorch Compiled)最高提升了 4.9 倍,同时端到端优化时间最高缩短了 3.4 倍。更关键的是,Prism 的符号搜索只需运行一次,就能覆盖所有输入配置, 而 Mirage 则需要为每个配置单独进行搜索。
本文将深入剖析 Prism 的核心技术设计,从 sGraph 的符号表示到两级搜索的实现细节,从符号剪枝的数学原理到 e-graph 等价验证的机制,全面解读这一突破性成果如何重新定义张量程序优化的边界。
本文目录
- 一、ML 系统优化的困境与超优化的演进
- 1.1 传统优化范式的局限性
- 1.2 超优化技术的兴起与挑战
- 二、Prism 的核心设计:符号化超优化的新范式
- 2.1 核心洞察:从枚举程序到编码程序族
- 2.2 sGraph:符号化的层次化图表示
- 2.3 Prism 的整体工作流程
- 三、sGraph 生成:符号化搜索与剪枝
- 3.1 符号维度匹配
- 3.2 表达式引导的剪枝
- 3.3 映射实例化与对称性破缺
- 四、sGraph 验证:基于 e-graph 的符号等价性检查
- 4.1 表达式语言与并行化算子
- 4.2 等价公理与 e-graph 重写
- 4.3 公理适配的特殊处理
- 五、参数实例化:自动调优与性能优化
- 5.1 自动调优的搜索空间
- 5.2 随机采样与并行编译
- 六、实验评估与结果分析
- 6.1 实验设置
- 6.2 内核性能结果
- 6.3 总优化时间结果
- 6.4 搜索时间分解与图多样性
- 6.5 消融研究:符号映射的影响
- 七、相关工作
- 7.1 专家手工优化的内核
- 7.2 基于规则的编译器
- 7.3 超优化器
- 7.4 符号图表示
- 八、结论与展望
- 8.1 结论总结
- 8.2 进阶分析
- 8.3 未来工作
一、ML 系统优化的困境与超优化的演进
好的,遵照您的指示,我已对提供的文章片段进行了深度重写与降重,并清洗了所有杂质。
以下是重写后的 Markdown 文本:
现代 AI 应用的性能瓶颈已不再局限于模型本身,而是转移到了硬件执行效率上。GPU 作为 AI 训练与推理的核心硬件,其计算能力的增长远超内存带宽与延迟的改善速度。因此,如何最大化 GPU 计算资源的利用率,已成为机器学习系统优化的核心议题。
张量程序作为机器学习模型的标准描述方式,通过有向无环图(DAG)来刻画计算流程,图中的节点代表张量算子,边则代表张量数据。优化张量程序的核心,就是在确保功能等价的前提下,为特定硬件找到执行速度最快的程序实现。
1.1 传统优化范式的局限性
传统机器学习系统的优化范式主要包含两类:基于规则的编译器与手工优化的内核库。这两类方法在过去十年间推动了 AI 应用的迅猛发展,但随着模型规模与硬件复杂度的持续攀升,其局限性也愈发显著。
基于规则的编译器(如 TensorFlow XLA、PyTorch TorchInductor、TVM) 主要依赖专家预设的图重写规则和调度模板来优化张量程序。这些规则是专家经验的结晶,能覆盖诸如算子融合、常量折叠、公共子表达式消除等常见优化模式。然而,这些规则仅能捕捉人类所能预见的优化,无法发现非直观的组合优化策略。 举例来说,将多个算子融合为一个自定义内核能显著减少内存流量,但融合的顺序与方式存在无数种可能性,人类工程师无法穷举所有有效的融合模式。
手工优化的内核库(如 cuDNN、cuBLAS、FlashAttention) 针对特定算子和硬件进行了极致优化,能充分发挥硬件的计算潜力。例如,FlashAttention 通过将注意力计算分块并利用共享内存,将注意力算子的内存复杂度从 O(n²) 降至 O(n),实现了数倍的性能提升。然而,手工优化内核库所支持的算子集合是固定的,无法快速适配新的算子或模型结构。此外,适配新硬件需要巨大的工程投入,将 FlashAttention 移植到新的 GPU 架构上通常需要数月时间。
论文明确指出,传统优化范式存在两个根本性的缺陷:
- 高昂的工程成本:适配新算子与新硬件需要大量专业工程师的投入,难以跟上 AI 模型与硬件快速迭代的步伐。
- 有限的优化空间:人类直觉无法捕捉代数变换、数据布局与硬件调度之间复杂的组合交互,导致大量性能潜力被白白浪费。
1.2 超优化技术的兴起与挑战
超优化技术源于编译器领域,其核心理念是通过自动搜索候选程序空间,找到与原始程序功能等价但执行速度更快的实现。与依赖规则的优化不同,超优化无需人工编写优化规则,而是通过自动探索来发现最优程序。
枚举式超优化是最早的超优化方法,它通过穷举所有可能的程序结构来寻找最优解。 TASO 是首个将超优化应用于张量程序的系统,它通过枚举小的计算子图,利用随机测试和形式验证来识别等价变换,进而优化目标程序。TASO 证明了超优化在张量程序优化中的有效性,但其优化范围仅限于图层面,无法深入到内核内部。
Mirage 则将超优化进一步扩展到 GPU 执行的多个层级(内核、线程块、线程),通过 μGraph 表示来统一不同层级的优化。μGraph 是一种层次化表示,它在三个层级上描述张量程序:内核图、线程块图和线程图。每个层级都通过 imap(输入映射)、fmap(循环映射)和 omap(输出映射)来指定张量在并行化维度上的分区或复制方式,以及网格、块和循环维度的具体大小。Mirage 的多级搜索能够协调代数变换与调度变换,合成全新的自定义内核,其性能远超传统编译器。
然而,枚举式超优化面临着组合爆炸这一根本性瓶颈。 随着算子数量与执行层级的增加,候选程序的数量呈指数级增长。例如,对于一个仅有两种并行化维度和两种数据维度的简单算子,就可能存在数十种映射组合,每种组合都对应一个独立的 μGraph。对于复杂的 LLM 工作负载,枚举式搜索很快就会变得不可行。 论文中提到,Mirage 在处理 RMSNorm-MLP 工作负载时,搜索一小时仍未完成。
为解决组合爆炸问题,研究人员提出了采样式超优化方法。AlphaEvolve 利用大语言模型引导进化搜索,能够探索更大的优化空间。 它采用一个进化框架,让大语言模型迭代地提出和改进代码候选,再由自动评估器进行验证。AlphaEvolve 在多种定义明确的任务上证明,LLM 引导的搜索能够发现超越人类工程师和先前自动化解决方案的优化。然而,采样式超优化将优化景观视为无结构的,导致搜索行为不稳定,且无法保证能覆盖到最优解。
传统优化范式受限于人类经验的有限性,枚举式超优化受限于组合爆炸,采样式超优化受限于最优性缺失。这三大瓶颈共同构成了当前张量程序优化的核心挑战。Prism 的符号化超优化正是为了同时突破这三大瓶颈而诞生。
二、Prism 的核心设计:符号化超优化的新范式
Prism 是世界上首个针对张量程序的符号超优化器。其核心创新在于提出了符号图(sGraph)表示,将整个程序族编码为单个符号结构,从而将优化过程分解为两级搜索:上层构建符号图,下层实例化为具体实现。
这种两级搜索的设计使 Prism 能够在符号层面进行结构化剪枝,在生成具体程序之前就剔除可证明的次优区域,同时保留最优性保证,完美地结合了枚举式搜索的严谨性和采样式搜索的可扩展性。
2.1 核心洞察:从枚举程序到编码程序族
Prism 的核心洞察在于:大多数候选程序之间存在高度的结构相似性,它们共享相同的计算结构,仅在并行化参数和数据映射上有所不同。与其枚举每个具体程序,不如用符号变量表示这些可变部分,将整个程序族编码为单个符号结构。
在传统的枚举式超优化器(如 Mirage)中,每个不同的映射和维度分配都会生成一个独立的 μGraph 候选,需要单独进行生成、验证和性能分析。例如,对于一个拥有两种并行化维度和两种数据维度的简单算子,就可能存在数十种映射组合,每种组合都对应一个独立的 μGraph。如果再考虑不同的并行化参数值,候选数量会进一步激增至数千甚至数百万。
而在 Prism 中,这些共享相同计算结构的程序被编码为一个 sGraph,其中并行化维度的大小和数据映射都用符号变量表示。一个 sGraph 可以代表数千甚至数百万个具体的 μGraph,从而将搜索空间的复杂度从 O(N * M * P) 降低到 O(N + M + P),其中 N 是图结构的数量,M 是映射分配的数量,P 是并行化参数配置的数量。
这种表示方式的优势显而易见:
- 大幅减少搜索空间:一个 sGraph 可以代表整个程序族,无需枚举每个具体程序。
- 符号层面的推理与剪枝:可以在符号层面进行形状匹配和等价性检查,剔除无效的候选,而无需生成具体程序。
- 正确性与性能解耦:一旦一个 sGraph 被验证为正确,其所有参数实例化都将自动正确,无需重新验证。
2.2 sGraph:符号化的层次化图表示
sGraph 是对 Mirage 的 μGraph 的符号化扩展。它保留了 μGraph 的层次化结构(内核图、线程块图、线程图),但将并行化参数和数据映射用符号变量表示。为理解 sGraph 的设计,我们首先回顾 GPU 编程模型和 μGraph 表示。
GPU 计算以核函数为单位。核函数启动时定义一个线程块网格,每个线程块被调度到一个流式多处理器(SM)上执行,内部包含一组线程。线程通过寄存器保存私有状态,同一线程块内的线程通过共享内存进行通信,而核函数之间的数据交换则通过全局内存进行。这种层次化的执行模型要求张量程序的表示能够捕捉不同层级的并行性。
Mirage 提出的 μGraph 正是为了满足这一需求。μGraph 是一种层次化表示,在三个层级上描述张量程序:
- 内核图:描述核函数之间的依赖关系。
- 线程块图:描述每个核函数在线程块层级的计算。
- 线程图:描述每个线程块在线程层级的计算。
在每个层级,μGraph 通过三个映射来指定张量如何在并行化维度上进行处理。
2.2 sGraph:符号化的核心抽象
在 Prism 的符号化框架中,三个核心映射关系被重新定义为符号化形式:
* imap:定义了输入张量如何被分割并分配到不同的并行化维度上。
* fmap:指定了循环计算的维度如何与并行化维度相对应。
* omap:描述了输出张量如何从各个并行化维度上组合起来。
除此之外,μGraph 还精确规定了网格、块以及循环维度的具体尺寸。例如,一个 μGraph 可能设定网格维度为特定值,这意味着输入张量的行维度被划分到了 64 个线程块上进行处理。
sGraph 对 μGraph 的符号化扩展,主要体现在以下两个关键方面:
-
符号化并行化参数:在传统的 μGraph 中,网格维度、块维度和循环维度的大小都是固定不变的整数。而在 sGraph 中,这些尺寸被抽象为符号整数变量。例如,网格维度的大小不再是具体的数字,而是一个可以变化的符号变量。
-
符号化映射关系:在 μGraph 中,imap、fmap 和 omap 都是具体的映射关系(如
imap(A, row) = grid.x)。而在 sGraph 中,这些映射关系通过布尔变量来表达。针对每个张量T、每个数据维度d和每个并行化维度p,系统会引入一个布尔变量m_{T,d,p}。当m_{T,d,p} = 1时,表示数据维度d沿着并行化维度p进行分区;当所有相关的m变量都为 0 时,则表示张量T在并行化维度p上是完全复制的。
这些符号变量必须满足两个基本约束,以确保映射的有效性:
其中,第一个约束规定,对于给定的张量 T,每个并行化维度 p 最多只能对其一个数据维度进行分区。第二个约束则要求,每个数据维度 d 最多只能被一个网格维度分区。这两个约束共同确保了张量的分区方式是明确且不会产生冲突的。
基于这些符号变量,张量的形状也相应地变成了符号表达式。对于一个原始大小为 D 的数据维度 d,它在每个线程块、每次迭代中的大小可以表示为:
size = D / (∏_{p} (1 - m_{T,d,p} + m_{T,d,p} * factor_p))
这里的 factor_p 代表数据维度 d 被并行化维度 p 分区的总倍数。
- 当
m_{T,d,p} = 1时,乘积项为factor_p,表示数据维度d被并行化维度p分区。 - 当
m_{T,d,p} = 0时,乘积项为1,表示该维度不被并行化维度p分区。
由于约束 (2) 确保了每个数据维度最多能被一个网格维度和一个循环维度分区,因此最终的乘积中最多包含两个非 1 的项。
为了更直观地理解 sGraph 与 μGraph 的区别,我们来看论文中的图 1:
图 1 | 融合Softmax-矩阵乘法运算的图表示。(a) 输入计算图;(b) 带有特定映射和并行化参数的具体μGraph;(c) 本文提出的符号图(sGraph),其映射和维度以符号变量表示。该直观呈现传统具体图与Prism核心创新sGraph的本质差异,揭示符号化设计的核心价值。传统μGraph将网格维度、块维度、张量映射固化为具体数值,每个并行配置对应独立图,直接导致搜索空间随参数数量呈组合爆炸。而sGraph把并行化参数、映射关系抽象为符号变量,单张图即可代表一族功能等价的张量程序,无需逐一枚举具体配置。这种设计让Prism能在符号层面开展逻辑推理,提前剪枝次优搜索区域,既规避枚举式方法的效率瓶颈,又为后续跨配置通用优化奠定基础,是符号化超优化区别于过往方法的关键突破。表 1 | sGraph等价性验证所用的部分等价公理。符号说明:t表示张量,v表示(批处理)向量,d表示数据维度,p表示并行化维度。该表汇总Prism验证阶段的核心等价公理,是保障符号图功能正确性的形式化基石。公理覆盖矩阵乘法代数性质、并行算子交换律、抵消恒等式、并行化算子交互规则四大类,共约70条规则,精准刻画张量运算与GPU并行逻辑的数学特性。例如矩阵乘法分配律、part与repl交换律、comb抵消part恒等式等,让e-图能通过定向重写规则推导表达式等价性。公理设计秉持“实用优先、兼顾严谨”原则,不追求理论完备性但覆盖关键优化场景,既避免规则冗余拖慢效率,又保障验证结果可靠,支撑高效符号等价检查。
从图 1 中可以看到,(a) 是原始的计算图,包含了 Softmax 和 Matmul 这两个算子;(b) 是一个具体的 μGraph,其中网格维度为某个值,循环维度为另一个值,映射关系是 imap(A, row) = grid.x 和 omap(C, row) = grid.x;(c) 是对应的 sGraph,其中网格维度和循环维度的大小被符号变量 D1 和 D2 所替代,映射关系则用布尔变量 m_{A,row,grid.x}、m_{C,row,grid.x} 等来表示。在这个 sGraph 中,任何一组满足约束条件的符号变量赋值,都能生成一个有效的 μGraph。例如,对符号变量赋值 D1=64、D2=1、m_{A,row,grid.x}=1、m_{C,row,grid.x}=1 等,就会生成图 (b) 中的具体 μGraph。
通过将并行化参数和数据映射符号化,sGraph 实现了对整个程序族的紧凑编码。这种表示方式使得 Prism 能够在符号层面进行推理和剪枝,而无需逐一生成和验证每个具体的程序实例,这正是 Prism 能够突破组合爆炸瓶颈的核心所在。
2.3 Prism 的整体工作流程
Prism 的整体工作流程分为四个主要阶段:sGraph 生成、映射实例化、sGraph 验证和参数实例化,如图 2 所示:
图 2 | Prism系统流程概览。sGraph生成:穷举搜索构建含符号映射的sGraph,通过维度匹配与表达式引导剪枝剔除无效分支;映射实例化:枚举满足所有约束的候选具体映射赋值;sGraph验证:基于重写公理进行等价性检查;参数实例化:结合GPU性能剖析的随机采样调优并行化参数。该图清晰拆解Prism四级核心流程,各环节解耦又协同,兼顾优化严谨性与可扩展性。生成阶段通过符号化抽象,避免枚举式搜索的无效开销,双重剪枝技术进一步压缩搜索范围;映射实例化筛选合法赋值,过滤无意义候选;验证阶段依托e-图与等价公理,保障符号图与原图功能完全一致;参数实例化复用成熟自动调优策略,平衡效率与性能。该流程将符号推理、形式化验证、性能调优深度融合,既保留穷举搜索的正确性保障,又解决传统超优化器难以适配大模型复杂工作负载的痛点,构成Prism高效优化的核心架构支撑。
- sGraph 生成:Prism 以输入的张量程序为起点,通过逐步添加算子的方式构建候选的 sGraph。在构建过程中,系统会运用符号维度匹配和表达式引导的剪枝技术,剔除掉那些无效的部分 sGraph,从而大幅缩减搜索空间。
- 映射实例化:对于所有通过了剪枝的候选 sGraph,系统会枚举出所有满足约束条件的具体映射分配(即为布尔变量
m赋值)。同时,通过对称性破缺技术来消除冗余的映射。 - sGraph 验证:针对每个具有具体映射的 sGraph,系统会使用基于 e-graph 的等价性检查,来验证其与输入程序在功能上是否等价。由于验证过程不依赖于具体的并行化参数值,因此一次验证即可覆盖所有可能的参数实例化情况。
图 3 | 含并行化维度的张量表示。该图是理解sGraph张量建模的基础可视化工具,直观呈现GPU并行场景下张量数据维度与并行维度的关联逻辑。GPU采用SPMD并行模式,张量需沿网格、线程块等并行维度划分或复制,此图清晰展示单数据维度张量与并行维度的映射关系。在sGraph中,张量形状由符号化映射变量和并行参数共同推导,图3的抽象模型为后续定义符号化张量尺寸、推导维度匹配约束提供底层依据。该建模方式让Prism能统一处理不同并行策略下的张量形态,为符号化维度匹配、等价性验证提供标准化建模基础,是系统实现的关键前置设计。图 4 | sGraph验证所用的并行算子。该图定义了sGraph等价性验证的四类核心并行算子,是实现符号图形式化验证的核心工具集。part算子负责沿并行维度划分张量数据维度,comb算子反向合并划分后的张量块,red算子沿并行维度做元素归约,repl算子复制张量至所有并行单元,四类算子覆盖GPU并行中张量分布、重组、归约、复制的全场景核心操作。Prism将复杂并行逻辑拆解为基础算子组合,再结合等价公理验证逻辑正确性,简化了并行逻辑的形式化描述难度,让e-图能高效处理符号图等价性检查,兼顾验证效率与严谨性。
- 参数实例化:对于所有通过验证的 sGraph,系统会采用随机采样并结合 GPU 性能分析的方法,来调整并行化参数的具体数值(即为符号变量
d赋值),以找到性能最优的具体实现方案。
这种两级搜索的设计,将正确性验证与性能调优完全解耦:一旦一个 sGraph 被验证为正确,其所有可能的参数实例化都将是正确的,无需重新验证。这不仅极大地降低了验证的开销,还使得 Prism 能够轻松适应不同的硬件配置和输入规模。例如,当输入张量的大小发生变化时,系统只需重新运行参数实例化阶段,而无需再次进行符号搜索和验证。
三、sGraph 生成:符号化搜索与剪枝
sGraph 生成:符号层面的高效搜索
sGraph 的生成构成了 Prism 两级搜索架构中的顶层阶段,其核心目标是以最高效率探索所有可行的 sGraph 结构及其对应的正确映射方案。与传统的暴力枚举搜索策略截然不同,Prism 将搜索提升至符号层面,巧妙地解耦了图结构搜索、映射枚举以及参数调优这三个环节,从而在早期阶段就有效规避了组合爆炸的风险。
为了在浩瀚的符号搜索空间中实现高效导航,Prism 引入了两项相辅相成的剪枝技术:符号维度匹配与表达式引导的剪枝。这两项技术能够在具体程序被生成之前,就果断剔除大量无效的候选方案,从而显著提升整体搜索效率。
3.1 符号维度匹配
在 sGraph 的语境下,张量的形状被抽象为符号表达式。因此,算子之间的形状兼容性要求,实际上转化为了对符号变量施加的约束条件。例如,矩阵乘法运算要求左输入矩阵的列维度必须与右输入矩阵的行维度相等。
- 在具体的、非符号化的图表示中,检查形状兼容性仅需比对两个整数是否相等;
- 而在符号化的图表示中,这要求我们验证两个符号表达式是否在数学上等价。
符号维度匹配的核心思想在于:形状的兼容性必须对所有可能的并行化参数取值都成立。因此,那些需要匹配的维度表达式,作为 的函数,必须完全一致。这实际上将形状匹配问题简化为仅对映射变量 施加的约束。其根本原因在于,如果两个表达式对于 的所有可能取值都相等,那么它们的系数必然完全相同。
每当向一个部分构建的 sGraph 中添加一个新的算子时,Prism 会执行两项关键任务:
- 收集与映射变量相关的等式约束,这些约束将在后续的“映射实例化”阶段被强制执行。
- 检查新生成的维度表达式是否与现有结构兼容,并立即剔除那些不兼容的部分 sGraph。
该图展示了一个自定义算子的符号化图(sGraph)表示:左侧是高层Kernel Graph,定义了输入X(4096×4096)、W(4096×128)与输出O(4096×128);右侧是线程块级ThreadBlock Graph,通过Input Loader1/2加载数据,依次进行指数运算、矩阵乘法及累加,再由除法融合Softmax与矩阵乘逻辑,最终通过Output Saver写回结果。图中采用σ(T,d)的符号化公式刻画分块维度,并借助imap/fmap/omap映射参数,实现了算子从高层逻辑到线程块调度的抽象表达,兼顾了计算流程与并行分块策略。
以图 1(c)中的 Matmul 算子为例,其左输入来自 Exp 算子的输出(继承了 InputLoader1 的形状),右输入则来自 InputLoader2(张量 W,原始形状为[4096,128])。为了使收缩维度匹配,Exp 算子的 c 维度必须与 InputLoader2 的 r 维度相等:
等式两边同时乘以 ,得到 。由于 和 都是关于 的函数,该等式对所有 成立,当且仅当二者作为 (或相应形式)的线性表达式拥有完全相同的系数。
具体而言, 包含项 和 ,而 包含项 和 。通过比较 与 的系数,我们可以得到如下约束:
这些约束将在映射实例化阶段被强制执行,确保只有满足这些条件的映射才会被纳入考量。如果在应用这些约束后,维度表达式仍然不兼容,则该部分 sGraph 会立即被剪枝。
符号维度匹配的优势在于,它能在图生成的早期阶段就剔除大量形状不兼容的候选方案,无需等到映射实例化阶段再进行判断。这极大地减少了需要处理的候选数量,从而提升了搜索效率。
3.2 表达式引导的剪枝
符号维度匹配虽然能剔除形状不兼容的候选,但无法识别那些形状兼容、却在功能上不可能与输入程序等价的候选。例如,一个计算 的图与一个计算 的图,它们的形状可能是兼容的,但功能完全不同。为了解决这个问题,Prism 引入了表达式引导的剪枝技术。
表达式引导的剪枝基于一个核心观察:任何可行的 Graph 的完整形式,必须对所有并行化参数值 都满足表达式检查。因此,在某个特定的 赋值下检查表达式条件,可以得到该图可行性的一个必要条件。如果一个部分 Graph 在某个 赋值下无法通过表达式检查,那么它就不可能被补全成一个可行的 Graph。
Prism 选择 的赋值。在此赋值下, 对所有张量 和维度 成立,张量形状与映射变量 无关,部分 Graph 也因此简化为一个具有具体形状的非符号化图。在这个简化后的图上,Prism 应用了 Mirage 中的抽象表达式检查:检查每个中间张量的抽象表达式,是否是最终输出表达式 的子表达式(即在语法结构上构成其子树)。如果不是,则该部分 Graph 不可能被补全成一个可行的 Graph,因此会被剪枝。
抽象表达式检查的基本逻辑是:如果一个中间张量的表达式不是 的子表达式,那么它就不可能对最终输出产生语义上的贡献。因此,包含该中间张量的图不可能等价于输入程序。举例来说,若最终输出为 ,而某中间张量的表达式为 ,则 不是 的子表达式,对应的部分 Graph 将被剪枝。
这种剪枝技术是“欠剪枝”的:它永远不会丢弃那些可能导向可行 sGraph 的部分图,但可能会保留一些实际上不可行的候选方案。这些残留的不可行方案会在后续的映射实例化和验证阶段被过滤掉。在实践中,这种计算成本低廉的检查能够剪枝掉绝大部分的搜索空间。论文指出,表达式引导的剪枝可以将搜索空间缩减几个数量级。
3.3 映射实例化与对称性破缺
在生成了通过剪枝的候选 sGraph 之后,Prism 需要枚举所有满足约束的具体映射分配(即为布尔变量 m 赋值)。一个有效的映射必须满足两类约束:
-
线性约束:源自公式(1)和(2),要求每个并行化维度最多映射到一个数据维度,同时每个数据维度最多被一个网格维度映射。
-
等式约束:源自符号维度匹配过程,要求匹配的维度必须以完全相同的方式进行分区。
枚举所有可能的映射分配的复杂度为 ,其中 是张量的数量, 是每个张量的数据维度数量, 是并行化维度的数量。 对于复杂的工作负载,这个数值可能相当庞大,但由于先前的剪枝技术已经剔除了大部分无效的 sGraph,实际需要枚举的映射数量是可控的。
为了进一步减少需要验证的候选数量,Prism 引入了对称性破缺技术 。不同的映射分配,如果仅仅是在并行化维度的排列顺序上有所不同,那么它们会生成功能上等价的 sGraph。例如,交换网格维度 和 ,并相应地调整所有映射变量,会生成一个功能完全相同的内核。为了消除这种冗余的验证工作,Prism 只保留每个等价类中字典序最小的分配。这可以将候选数量减少最多 倍,其中 是网格维度的数量。
例如,对于拥有两个网格维度 和 的情况,对称性破缺会要求 对所有张量 和数据维度 成立。这确保了只有字典序最小的映射分配会被保留,从而将候选数量减少一半。
sGraph 生成阶段通过符号维度匹配和表达式引导的剪枝,在符号层面大幅缩减了搜索空间 。而映射实例化阶段则在保证正确性的前提下,借助对称性破缺技术,进一步减少了需要验证的候选数量。这一系列技术使得 Prism 能够高效地搜索到可行的 sGraph,成功突破了传统枚举式搜索所面临的组合爆炸瓶颈。
四、sGraph 验证:基于 e-graph 的符号等价性检查
sGraph 验证是 Prism 确保正确性的核心环节,其目标是验证一个拥有具体映射但参数为符号化的 sGraph,是否与输入程序在功能上等价。与传统的随机测试不同,Prism 的验证过程不依赖于具体的张量形状或并行化参数值,一次验证即可覆盖所有可能的参数实例化。
为了实现这一目标, Prism 将输入程序和候选的 sGraph 都编码为表达式,然后利用 e-graph(等价图)在一组代数公理下检查它们的等价性 。这种方法结合了形式化验证的严谨性与符号推理的可扩展性。
4.1 表达式语言与并行化算子
为了对 sGraph 的语义进行编码,Prism 定义了一种表达式语言。该语言不仅包含标准的张量算子(例如 matmul、add、exp、div 等),还引入了四个专门的并行化算子。这四个并行化算子完整地描述了张量在 GPU 并行执行过程中所涉及的分区、组合、归约和复制操作。
核心并行化算子
Prism 定义了四个基础并行化算子,用于在符号层面描述张量在 GPU 线程块间的分布与通信:
part(t, m, x):将张量t的第m个数据维度均匀切分成多个等长块,并将这些块沿并行化维度x进行分布。每个线程块仅处理分配给它的那一块数据。comb(t, m, x):作为part的逆操作,它将张量t中沿并行化维度x分布、且属于第m个数据维度的多个块拼接起来,重建出原始的完整维度。red(t, x):沿并行化维度x对张量t执行逐元素的求和归约。每个线程块先独立计算其局部和,再通过全局规约得到最终的标量或降维结果。repl(t, x):将张量t完整复制到并行化维度x的所有位置上,确保每个线程块都能访问到t的全部内容。
将 sGraph 编码为表达式
借助这四个算子,Prism 能够将任意一个 sGraph 编码成一个表达式。编码过程遵循拓扑顺序遍历整个图结构,并为其中的每个张量计算出其对应的表达式。对于大多数算子(如逐元素算子或矩阵乘法),这个过程是直接的:输出表达式就是将算子应用于输入表达式的结果。
关键的编码环节在于线程块图中的 InputLoader 和 OutputSaver:
- InputLoader:根据
imap映射,应用part或repl操作。如果映射非空(例如{row}),则应用part操作,将输入张量t的行维度row分区到并行化维度x上;如果映射为空集{},则应用repl操作,将输入张量t复制到所有线程块。 - OutputSaver:根据
omap映射,应用comb操作。如果映射非空(例如{row}),则将各线程块输出的张量t在并行化维度x上沿行维度row进行拼接,从而重建完整的行维度。
图 5 展示了一个简单的例子:
图 5 | 将 sGraph 编码为表达式。图中每个张量都标注了其对应的表达式;输入加载器根据 imap 执行 part 操作,输出保存器则根据 omap 执行 comb 操作。此图清晰地展示了从 sGraph 到表达式的编码逻辑,是连接符号图结构与形式化等价验证的关键桥梁。Prism 不直接验证图结构的等价性,而是将算子运算和并行映射转化为包含四类并行算子的表达式,从而将图等价性问题转化为表达式等价性问题。其中,输入加载器和输出保存器分别对应 part 和 comb 算子,中间算子则继承表达式的运算逻辑,最终形成一条完整的表达式链。这种编码方式能够直接复用 e-图的重写规则体系,摆脱传统随机测试对具体张量形状的依赖,在符号层面实现通用的等价验证,显著提升了验证效率与通用性。
在这个例子中,内核图包含一个执行逐元素指数运算的 CustomOp,其 imap 为 {row},omap 为 {row}。InputLoader 将输入变量 x 的行维度在并行化维度 p 上进行分区,得到 part(x, row, p)。应用 exp 算子后得到 exp(part(x, row, p))。OutputSaver 将结果在并行化维度 p 上连接起来,得到最终表达式 comb(exp(part(x, row, p)), row, p)。根据表 1 中的取消公理 comb(part(t, d, p), d, p) = t,该表达式等价于 exp(x),从而验证了该 sGraph 的正确性。
4.2 等价公理与 e-graph 重写
Prism 定义了一套包含约 70 条等价公理的系统,这些公理精确捕捉了张量算子和并行化算子的数学性质。表 1 列出了其中部分重要的公理:
表 1 | sGraph 等价性验证所用的部分等价公理。符号说明:t 表示张量,v 表示(批处理)向量,d 表示数据维度,p 表示并行化维度。该表汇总了 Prism 验证阶段的核心等价公理,是保障符号图功能正确性的形式化基石。这些公理覆盖了矩阵乘法代数性质、并行算子交换律、抵消恒等式以及并行化算子交互规则四大类,共约 70 条规则,精准刻画了张量运算与 GPU 并行逻辑的数学特性。例如,矩阵乘法的分配律、part 与 repl 的交换律、comb 抵消 part 的恒等式等,使得 e-图能够通过定向重写规则推导出表达式的等价性。公理的设计秉持“实用优先、兼顾严谨”的原则,虽不追求理论上的完备性,但覆盖了所有关键优化场景,既避免了规则冗余拖慢效率,又保障了验证结果的可靠性,是实现高效符号等价检查的基础。
这些公理可分为以下几类:
- 矩阵乘法的代数性质:包括结合律
(A * B) * C = A * (B * C)、分配律A * (B + C) = A * B + A * C、与标量乘法的兼容性(a * A) * B = a * (A * B),以及与标量除法(非零标量)的兼容性(A / a) * B = (A * B) / a。 - 并行化算子的交换性:例如
part(repl(t, d1, p), d2, p) = repl(part(t, d2, p), d1, p)(当d1 != d2)、part(part(t, d1, p), d2, p) = part(part(t, d2, p), d1, p)(当d1 != d2)等。这些公理刻画了不同并行化算子之间的可交换条件。 - 取消恒等式:如
comb(part(t, d, p), d, p) = t、red(repl(t, p), p) = sum(t)(在适当维度约束下)。这些公理刻画了分区与组合、复制与归约之间的互逆关系。 - 并行化矩阵乘法:
matmul(part(A, col, p), part(B, row, p), p) = red(matmul(A, B), p),其中col和row分别表示对左矩阵的列、右矩阵的行沿并行维度p分区,再沿p进行归约收缩。 - 并行化求和:
sum(A, dim=d) = red(part(A, d, p), p)(沿d维求和),等价于先分区后归约。 - 逐元素算子与并行化算子的交换性:
exp(part(t, d, p)) = part(exp(t), d, p)、add(part(t1, d, p), part(t2, d, p)) = part(add(t1, t2), d, p)。
Prism 使用 egg 库实现 e-graph 重写,将这些公理转换为重写规则。e-graph 是一种高效的数据结构,用于表示表达式之间的等价关系。它将每个表达式表示为一个节点,等价的节点被分组到同一个等价类中。通过应用重写规则,e-graph 不断扩展等价类,直到达到不动点。
在验证过程中,Prism 将输入程序的表达式和候选 sGraph 的表达式都添加到 e-graph 中,然后应用所有重写规则。如果两个表达式最终属于同一个等价类,则证明它们是功能等价的。这种方法的优势在于它能够自动探索所有可能的等价变换,而无需人类指定变换的顺序。
4.3 公理适配的特殊处理
e-graph 重写要求每个重写规则
LHS → RHS的右侧只能引入左侧已经出现的变量。这对于双向公理来说意味着只能在满足这个约束的方向上应用。例如,公理comb(part(t, d, p), d, p) = t只能应用为comb(part(t, d, p), d, p) → t,而不能应用为t → comb(part(t, d, p), d, p),因为后者会在右侧引入新的变量d和p。
这个约束在处理多个连续并行化算子时会带来挑战。例如,对于连续的并行化矩阵乘法 matmul(part(A, col, p1), part(B, row, p1), p1),仅使用从左到右的并行化 matmul 公理无法验证这个等价性,因为嵌套的算子无法一次剥离一个。为了解决这个问题,Prism 为与 matmul 相关的公理引入了“逆”重写规则:
red(matmul(A, B), p) → matmul(part(A, col, p), part(B, row, p), p)
其中 col 和 row 表示并行化算子的逆(如 part 是 comb 的逆,反之亦然)。这个规则满足变量子集约束,能够将并行化算子向外推,从而使 e-graph 能够建立等价性。Prism 对其他与并行化算子有类似交互的计算算子也应用了相同的技术。
基于 e-graph 的符号等价性检查是 Prism 保证正确性的核心。它不依赖于具体的输入值或并行化参数,一次验证即可覆盖所有可能的实例化。这不仅大幅减少了验证开销,还保证了生成的内核在任何硬件配置和输入大小下都是正确的。
五、参数实例化:自动调优与性能优化
参数实例化是 Prism 两级搜索的下层阶段,其目标是为通过验证的 sGraph 找到最优的并行化参数值,以最大化内核在目标硬件上的性能。这是一个标准的自动调优问题:给定一组参数化的内核模板和目标硬件平台,找到使执行时间最小的模板和参数值。
Prism 采用随机采样的方法进行参数调优,这种方法能够最大化编译并行性,避免迭代方法的长依赖链,同时在实践中能够找到接近最优的参数配置。
5.1 自动调优的搜索空间
参数实例化的搜索空间包括每个并行化参数的有效值:网格维度大小和循环迭代次数。这些参数必须满足一系列硬件约束和性能约束:
- 硬件限制:网格维度的大小不能超过 GPU 支持的最大网格维度(A100 是 2^31-1),线程块的大小不能超过 GPU 支持的最大线程块大小(通常是 1024)。
- 共享内存约束:每个线程块使用的共享内存大小不能超过 GPU 每个 SM 的共享内存容量(A100 是 16KB 每个线程块,最多 48KB 每个 SM)。如果共享内存使用量超过限制,内核将无法启动。
- 张量大小约束:并行化参数的取值必须能够整除张量的对应维度,以确保分区是均匀的。例如,如果张量的行维度大小是 4096,那么网格维度 x 的大小必须是 4096 的约数。
对于每个 sGraph,并行化参数的取值范围受到上述约束的限制。虽然理论上可能的参数配置数量仍然很大,但由于前面的阶段已经将 sGraph 的数量减少到很小的范围(每个工作负载 9-23 个),因此实际需要调优的参数配置数量是可控的。
5.2 随机采样与并行编译
在自动调优流程中,最消耗计算资源的环节当属内核编译与性能剖析。一个复杂的 GPU 内核往往需要花费数秒乃至数十秒才能完成编译,而性能剖析则必须让内核多次运行,才能获取到精确的执行耗时。因此,自动调优方法的整体效率,在很大程度上取决于它能否将编译与性能分析这两个阶段进行高效的并行化处理。
与进化搜索、模拟退火这类迭代式方法截然不同,随机采样无需等待前一批次的结果,便可直接生成下一批次的候选方案。这一特性使得随机采样能够最大化编译过程的并行性,从而充分挖掘多核 CPU 与多 GPU 的计算潜力。
Prism 的参数实例化流程如下:
- 针对每一个通过验证的 sGraph,生成所有满足硬件约束、共享内存限制以及张量尺寸要求的并行化参数配置。
- 从这些配置中,通过均匀随机采样的方式挑选出一定数量的候选方案(论文中未明确说明具体的采样数量,但指出随机采样在实践中足以找到接近最优的配置)。
- 将所有候选内核并行编译,并在目标 GPU 上执行性能剖析。每个内核均运行 1000 次,取平均执行时间作为最终指标。
- 返回执行时间最短的内核作为最终成果。
论文同时指出,集成更复杂的调优策略(例如学习成本模型、进化搜索等)是未来的研究方向。但在当前的实现中,随机采样已经能够找到性能出色的内核,同时将调优开销维持在较低水平。这是因为 Prism 的符号搜索已经将搜索空间大幅缩小至一个很小的范围,剩余的参数调优问题相对简单。
参数实例化阶段的核心任务,是将符号化的 sGraph 转化为具体的、可执行的内核。通过随机采样与并行编译的结合,Prism 能够在较短的时间内找到接近最优的参数配置,从而实现了符号推理与实际硬件性能的完美融合。
六、实验评估与结果分析
Prism 的实验评估在 NVIDIA A100 GPU 上进行,并与五个主流 LLM 工作负载进行了对比,包括 RMSNorm、RMSNorm-MLP、SwiGLU、Attention 和 QK-Attention。评估指标涵盖内核执行时间与端到端优化时间 ,对比基线包括 PyTorch Eager、PyTorch Compiled、TVM(Ansor)以及当前最优的超优化器 Mirage。
实验数据显示,Prism 在所有 10 个配置上均取得了最优的内核性能 ,并且在大多数情况下大幅缩短了优化时间,这充分验证了符号化超优化的有效性。
6.1 实验设置
- 硬件平台:采用双 Intel Xeon Platinum 8275CL CPU(48 核,96 线程)进行符号搜索,NVIDIA A100 GPU 用于内核性能剖析。
- 软件版本:PyTorch 2.5.1,Triton 3.1.0,Apache TVM 0.18.0(使用 Ansor 自动调度器,每个工作负载进行 1000 次调优试验)。
- 工作负载:选取了五个 LLM 中常见的工作负载,每个工作负载评估两种不同的输入配置:
- RMSNorm:融合归一化与线性层,,变化隐藏维度 和批量大小 。
- RMSNorm-MLP:采用融合归一化的 GLU 风格门控 MLP,,固定 ,变化 。
- SwiGLU:LLaMA 风格模型中使用的门控激活函数,,固定 ,变化 。
- Attention:解码阶段的组查询注意力(GQA),,固定批量大小 ,头数 ,查询序列长度为 1,头维度 ,变化键值序列长度 。
- QK-Attention:带有查询-键归一化的 GQA,,配置与 Attention 相同。
- 评估指标:
- 内核执行时间:每个内核运行 1000 次,取平均执行时间。
- 总优化时间:包括搜索时间与性能剖析时间。对 Prism 而言,搜索时间涵盖符号图生成、映射枚举与验证(每个工作负载运行一次,所有配置共享),实例化时间则包括参数调优与性能剖析(每个配置单独运行)。对 Mirage,搜索时间是每个配置单独的图生成与映射枚举时间,性能剖析时间是每个配置的内核编译与基准测试时间。对 TVM,则是 Ansor 的自动调优时间(每个配置 1000 次试验)。
6.2 内核性能结果
图 6 的上半部分展示了 Prism 与所有基线的内核执行时间对比:
图 6 | 5种工作负载下的内核性能与优化时间。上图:相对内核执行时间;下图:总优化时间拆解——Mirage时间含图生成与性能剖析,Prism时间含符号图生成与实例化,TVM时间为Ansor自动调度时间(1000次试调)。该图从内核性能、优化效率双维度,全面验证Prism相较于PyTorch、TVM、Mirage等基线的优势。性能层面,Prism在LLM主流工作负载中均取得最优结果,较先进超优化器Mirage提速最高2.2倍,较传统编译器提速最高4.9倍,注意力类任务提升最显著,因符号化能探索更多可行并行策略。效率层面,Prism单次搜索适配多配置,较Mirage优化时间最高减少3.4倍,虽简单任务存在固定实例化开销,但内核性能提升幅度远覆盖开销。结果充分证明,符号化超优化可兼顾高性能与高效率,适配现代大模型复杂张量程序的优化需求。
实验结果总结如下:
- Prism 在所有 10 个配置上均取得了最优的内核性能,优于 PyTorch Eager、PyTorch Compiled、TVM 以及 Mirage。这表明符号化超优化能够发现比现有方法更优的程序实现。
- 与传统编译器方法相比,Prism 的优势最为显著:以 RMSNorm-MLP(d=1024, n=8)为例,Prism 比 PyTorch Compiled 快 4.9 倍,比 TVM 快 5.4 倍。这种优势源于超优化技术能够发掘新颖的融合内核,将多个算子合并为单个 GPU 内核,从而减少内存流量与内核启动开销。传统编译器的融合启发式策略无法发现这些复杂的融合模式。
- 与当前最优的超优化器 Mirage 相比,Prism 在 8 个配置上找到了严格更优的内核,在 2 个配置(SwiGLU)上与 Mirage 持平。最大的性能提升出现在注意力工作负载上:
- 在 QK-Attention 上,Prism 比 Mirage 快 1.8 倍(h=1024)和 2.2 倍(h=2048)。
- 在标准 Attention 上,Prism 比 Mirage 快 1.2 倍(h=1024)和 1.3 倍(h=2048)。
- 在 RMSNorm-MLP 上,Prism 比 Mirage 快 1.2 倍(n=16)和 1.9 倍(n=8)。值得注意的是,在这两个配置下,Mirage 的具体搜索在一小时后超时,而 Prism 成功完成了搜索并找到了更优的内核。
- 在 RMSNorm 上,Prism 比 Mirage 快 1.2 倍(d=4096)和 1.1 倍(d=1024)。
- 在 SwiGLU 上,Prism 和 Mirage 找到了相同的内核在 RMSNorm-MLP 上,Prism 比 Mirage 快 ,这是因为其简单的图结构只有较少的映射选择,Mirage 的枚举式搜索能够覆盖整个搜索空间。
注意力工作负载上的巨大提升,源于其 3D 张量结构(批量、序列、头),这提供了大量可能的并行化策略。Mirage 使用启发式方法只探索了这些映射和并行化参数的一个子集, 而 Prism 则通过符号搜索探索了整个空间,从而发现了 Mirage 遗漏的更优策略 。QK-Attention 比标准 Attention 受益更多,因为额外的归一化算子进一步拓宽了有用并行化策略的空间。
6.3 总优化时间结果
图 6 的下半部分展示了 Prism 与 Mirage 和 TVM 的总优化时间对比:
图 6 | 五种负载下的内核性能与优化耗时对比。上图:相对内核执行时间;下图:总体优化时间分解——Mirage耗时包含图生成与性能剖析,Prism耗时涵盖符号图生成与实例化,TVM耗时则为Ansor自动调度时间(经1000次调优尝试)。该图表从内核性能与优化效率两大维度,系统性地验证了Prism相较于PyTorch、TVM及Mirage等基准方案的显著优势。在性能方面,Prism在所有LLM主流工作负载中均取得了最优结果,相比先进超优化器Mirage,提速最高达2.2倍;相较于传统编译器,提速最高达4.9倍。其中,注意力类任务的性能提升最为突出,因为符号化方法能够探索更多可行的并行策略。在效率方面,Prism通过单次搜索适配多种配置,相较于Mirage,优化时间最多可减少3.4倍。尽管在处理简单任务时存在固定的实例化开销,但内核性能的提升幅度足以完全覆盖这部分额外成本。这些结果有力地证明,符号化超优化能够同时实现高性能与高效率,完美契合现代大模型复杂张量程序的优化需求。
实验结果总结:
- Prism 在 RMSNorm-MLP 任务上实现了最大的优化时间缩减:Mirage 的搜索在一小时后超时(总耗时分别为 3713 秒和 3632 秒),而 Prism 仅用 1111 秒和 1180 秒便完成了任务,速度快了 3.1 至 3.4 倍,同时还发现了速度快 1.2 至 1.9 倍的内核。这充分验证了符号化超优化在复杂工作负载上的可扩展性优势。
- 在 QK-Attention(h=1024)任务上,Prism 耗时 128 秒,而 Mirage 耗时 199 秒,TVM 耗时 276 秒。同时,Prism 的内核比 Mirage 快 1.8 倍,比 TVM 快 2.8 倍。
- 在某些配置下,Prism 的总耗时高于 Mirage:
- RMSNorm(d=4096):Prism 耗时 135 秒,Mirage 耗时 52 秒。
- Attention(h=2048)和 QK-Attention(h=2048):Prism 耗时约 152 秒,Mirage 耗时 13 秒。
- 这种差异源于 Prism 实例化阶段存在固定开销,需要编译并分析所有已发现的图模板。当 Mirage 针对某个配置的搜索本身已经很快时,这部分固定开销就会占据主导地位。然而,即便在这些情况下,Prism 依然找到了更优的内核(RMSNorm 快 1.2 倍,Attention 和 QK-Attention 分别快 1.3 倍和 2.2 倍),因此,额外的优化时间转化为了更快的端到端推理。对于生产环境中的部署而言,这种一次性的优化开销是完全值得的,因为内核会被执行数百万次。
6.4 搜索时间分解与图多样性
表 2 展示了 Prism 和 Mirage 在仅搜索时间上的对比:
表 2 | 仅搜索耗时对比(秒)。Prism 单次搜索适配单一工作负载(所有配置共享),Mirage 则按配置独立搜索;“×”表示搜索 1 小时仍未完成。该表量化对比了 Prism 与 Mirage 的搜索效率,直观凸显了符号化搜索的可扩展性优势。Mirage 采用具体枚举策略,为每个输入配置独立开展图搜索,简单任务耗时数十秒,而复杂任务(如 RMSNorm-MLP)在 1 小时内仍无法完成搜索。相比之下,Prism 的符号化搜索单次即可覆盖同一工作负载的所有配置,耗时仅为 0.3 至 871 秒,在复杂任务上的效率提升超过 300 倍。在注意力任务上,尽管两者搜索耗时接近,但 Prism 能够探索完整的搜索空间,从而发现 Mirage 遗漏的最优内核。数据表明,符号化方法通过解耦结构搜索与参数枚举,从根本上解决了枚举式搜索所面临的组合爆炸难题。
关键发现:
- Prism 的搜索为每个工作负载运行一次,并覆盖所有输入配置,而 Mirage 则需要为每个配置单独进行搜索。这是 Prism 的一个关键优势,因为在实际应用中,一个工作负载可能对应多个不同的输入配置。
- 在 RMSNorm 和 SwiGLU 任务上,Prism 的符号搜索速度极快:分别仅需 0.3 秒和 1.0 秒,而 Mirage 每个配置则需要 11 至 46 秒。这种加速源于将图结构搜索与映射枚举解耦,从而避免了在每一步都尝试所有可能的 imap、fmap 和 omap 分配所导致的组合爆炸。
- 在 RMSNorm-MLP 任务上,Mirage 的具体搜索在两个配置上均达到了一小时超时,而 Prism 在 871 秒内完成。RMSNorm-MLP 融合了两个矩阵乘法、归一化和门控乘法,产生了多种有效的算子执行顺序,在具体搜索中,每个映射分配都需要探索这些顺序,从而引发了组合爆炸。
- 在注意力工作负载上,搜索时间呈现出不同的情况。注意力具有受限的图结构,但由于其 3D 张量结构(批量、序列、头),映射空间非常庞大。 Mirage 使用启发式方法,仅探索可能映射和并行化参数的一个子集,因此每个配置的搜索速度很快(Attention 为 10 至 42 秒,QK-Attention 为 10 至 155 秒)。 而 Prism 的符号搜索一次性为所有配置运行,耗时 41 至 42 秒,探索了整个空间,这解释了为何尽管搜索时间相当,它却能发现更优的内核。
表 3 展示了 Prism 和 Mirage 所发现的唯一图数量对比:
表 3 | Prism 与 Mirage 发现的唯一图数量对比(数值越高越好)。若算子序列或映射不同,则视为不同的图。该表统计了两大系统发现的唯一图数量,直接体现了 Prism 搜索空间覆盖的全面性。Mirage 受限于枚举式搜索的效率瓶颈,依赖启发式策略筛选候选,单个配置仅能发现 1 至 14 种有效图,Swiglu 任务仅 1 种,注意力任务仅 3 至 4 种,覆盖范围极为狭窄。而 Prism 通过符号化遍历完整的合法空间,单次搜索可发现 9 至 23 种唯一图,RMSNorm-MLP 任务达 23 种,注意力任务达 14 种,涵盖了不同的网格维度、循环划分策略以及算子执行顺序。更多有效的候选图意味着找到最优内核的概率更高,这直接解释了 Prism 在复杂任务上的性能优势。
关键发现:
- Prism 通过一次符号搜索,每个工作负载可发现 9 至 23 个唯一图,而 Mirage 每个配置仅能发现 1 至 14 个唯一图。
- 这些图在多个维度上存在差异:活动网格维度的数量(1、2 或 3)、不同的循环分区策略,以及融合内核内不同的算子执行顺序。
- 差异最为显著的是 SwiGLU 任务,Mirage 每个配置只发现 1 个唯一结构,而 Prism 发现了 12 个。在注意力工作负载上,3D 张量结构允许最多 3 个网格维度,Prism 探索了这些维度上所有有效的网格和循环分区组合,为 Attention 和 QK-Attention 都发现了 14 个唯一图。Mirage 基于启发式的映射探索,每个配置只发现 3 至 4 个结构,错过了 Prism 发现的许多策略。这种更广泛的覆盖直接转化为了在注意力工作负载上观察到的内核性能提升。
6.5 消融研究:符号映射的影响
为了理解哪种映射变量对搜索时间减少的贡献最大,Prism 进行了消融研究,选择性地将单个映射类型设为具体(即在搜索期间枚举),同时保持其他映射类型为符号。表 4 展示了 RMSNorm(d=4096, n=8)任务的结果:
表 4 | RMSNorm 任务消融实验:搜索阶段选择性枚举映射变量时的耗时。“S”=符号化(延迟至实例化),“C”=具体化(搜索阶段枚举)。该表通过消融实验量化了不同映射变量符号化的贡献,揭示了 Prism 高效搜索的核心设计关键。当所有映射均符号化时,搜索耗时仅为 0.3 秒;而当所有映射均具体化枚举时,耗时则飙升至 312 秒,差距超过千倍。在三类映射变量中,输入映射(imap)的影响最大,单独枚举耗时 20.5 秒;循环映射(fmap)和输出映射(omap)分别为 5.5 秒和 2.5 秒。当多类映射同时枚举时,耗时呈指数级增长。结果明确表明,符号化映射变量是压缩搜索空间的核心手段,尤其是输入映射,同时验证了解耦映射枚举与结构搜索的设计合理性,为后续张量超优化器的设计提供了关键指导。
关键发现:
八、结论与展望
Prism 作为全球首个面向张量程序的符号化超优化器,凭借创新的 sGraph 表示与两级搜索架构,成功突破了传统超优化器面临的组合爆炸问题,同时维持了最优性保障。实验数据显示,Prism 在 LLM 工作负载上实现了显著的性能提升与优化时间缩减,充分彰显了符号化超优化的巨大潜能。然而,任何技术都存在其固有的局限与适用范围。 本节将系统总结 Prism 的核心贡献,客观剖析其方法论的不足之处,并对未来研究方向进行展望。
8.1 结论总结
Prism 的核心贡献可归纳为以下几个方面:
由于映射搜索空间会随着并行化维度与数据维度的数量呈指数级扩张,符号化映射对于可扩展性而言至关重要。当所有三种映射类型均被具体枚举时,搜索过程耗时 312 秒;而通过符号化所有映射,这一时间骤降至 0.3 秒,实现了超过 1000 倍的加速比。
在这三种映射类型中,imap 的贡献最为显著:仅枚举 imap 就需要 20.5 秒,而仅枚举 fmap 或 omap 分别只需 5.5 秒和 2.5 秒。究其原因,imap 决定了输入张量的分区方式,对图结构产生了最为深远的影响。
当多个映射类型被一同枚举时,成本会进一步叠加:枚举全部三种映射类型需要 312 秒,远远超出它们各自成本之和(28.5 秒)。这一现象表明,映射变量之间存在着复杂的交互关系,而符号化技术能有效规避这种组合爆炸效应。
实验成果全方位验证了 Prism 符号化超优化的有效性。它不仅在所有工作负载上均取得了最优的内核性能,而且在多数场景下显著缩短了优化时间。通过符号搜索探索更为广阔的优化空间,Prism 成功发现了现有超优化器所遗漏的更优并行化策略,尤其在复杂的 LLM 工作负载上表现尤为突出。
unsetunset七、相关工作unsetunset
张量程序优化是机器学习系统领域的核心研究方向之一,过去十年间涌现了大量研究成果。这些工作大致可归为三类:专家手工优化的内核、基于规则的编译器以及超优化器。Prism 的符号化超优化在这些工作的基础上,提出了一种全新的范式,同时突破了既有方法的局限性。
7.1 专家手工优化的内核
专家手工优化的内核是当前机器学习系统中性能最高的实现方式。NVIDIA 的 cuDNN 与 cuBLAS 库为常见算子(如卷积、矩阵乘法)提供了高度优化的实现,能够充分发挥 GPU 张量核心与内存层次结构的优势。近年来,针对注意力算子的手工优化取得了重大突破,FlashAttention 系列通过将注意力计算分块并利用共享内存,将注意力算子的内存复杂度从 O(n²) 降至 O(n),实现了数倍的性能提升。
然而,手工优化的内核存在两个根本性缺陷:
- 支持的算子集合固定,无法快速适配新型算子或模型架构。随着 AI 模型的快速演进,新型算子与模型架构层出不穷,手工优化难以跟上这一发展节奏。
- 新硬件的适配需要大量工程投入,无法跟上硬件快速迭代的步伐。GPU 架构每 1-2 年便会更新一次,每次更新都会引入新的硬件特性,手工优化的内核必须重新编写才能利用这些特性。
随着 GPU 架构的快速演进(如 A100 的张量核心、H100 的线程块集群、B200 的张量内存),手工优化的内核越来越容易错过那些难以通过人类设计识别的非直观性能机会。
7.2 基于规则的编译器
基于规则的编译器通过专家编写的图重写规则与调度模板来优化张量程序。代表性工作包括 TensorFlow XLA、PyTorch TorchInductor 与 TVM。
TensorFlow XLA 是首个被广泛使用的机器学习编译器,它将 TensorFlow 计算图编译为高效的机器代码。XLA 利用线性代数优化器来优化计算图,并为每个算子生成高效的内核。然而,XLA 的优化效果受限于人类编写的规则,无法发现非直观的组合优化。
TVM 是一个端到端的深度学习编译器,它引入了张量表达式语言与自动调度器。TVM 的 Ansor 自动调度器能够自动生成高效的内核调度,无需人类编写调度模板。然而,Ansor 仍然依赖于人类定义的调度原语,无法突破这些原语的限制。此外,Ansor 的搜索空间基于人类经验设计,可能会遗漏某些最优的调度策略。
PyTorch TorchInductor 是 PyTorch 2.0 引入的新编译器,它采用 Triton 作为后端,能够生成高效的 GPU 内核。TorchInductor 的优势在于与 PyTorch 的动态图无缝集成,能够自动融合算子并生成高效的内核。然而,TorchInductor 的融合启发式算法仍然是人工编写的,无法发现复杂的融合模式。
7.3 超优化器
超优化器通过自动搜索候选程序空间来发现最优实现,是当前最具前景的优化方向之一。
- TASO:首个将超优化应用于张量程序的系统,由 Zhihao Jia 等人于 2019 年提出。TASO 枚举小的计算子图,通过随机测试与形式验证识别等价对,随后应用这些变换来优化目标程序。TASO 证明了超优化在张量程序优化中的有效性,但它仅在图层面进行优化,无法深入到内核内部。
- Mirage:由 Mengdi Wu 等人于 2025 年提出,将超优化扩展到 GPU 执行的多个层级,通过 μGraph 表示统一不同层级的优化。Mirage 的多级搜索能够协调代数变换与调度变换,合成全新的自定义内核,性能远超传统编译器。但 Mirage 的枚举式搜索面临组合爆炸问题,无法处理复杂的 LLM 工作负载。
- AlphaEvolve:由 Alexander Novikov 等人于 2025 年提出,利用大语言模型引导进化搜索,能够探索更大的优化空间。AlphaEvolve 采用进化框架,其中大语言模型迭代地提出并改进代码候选,随后由自动评估器进行验证。AlphaEvolve 在各种定义明确的任务上证明了 LLM 引导的搜索能够发现超越人类工程与先前自动化解决方案的优化。然而,采样式超优化将优化景观视为无结构的,搜索行为不稳定,无法保证覆盖最优解。
Prism 与这些超优化器的核心区别在于:
- 与枚举式超优化器(TASO、Mirage)相比,Prism 采用符号化表示编码整个程序族,将搜索空间复杂度从 O(|G||M||D|) 降至 O(|G|),突破了组合爆炸瓶颈。
- 与采样式超优化器(AlphaEvolve)相比,Prism 的符号剪枝是可靠的,不会丢弃最优解,保留了枚举式搜索的最优性保证。
7.4 符号图表示
先前的工作已引入了多级图表示来描述张量程序。例如,Welder 与 ASPEN 采用基于瓦片的多级图表示,能够捕捉不同层级的并行性与数据局部性。Mirage 引入了 μGraph 表示来捕捉 GPU 层次结构,能够统一内核、线程块与线程层级的优化。
与这些方法不同,Prism 的 sGraph 是一种符号化的层次化表示,它不是表示单个具体的张量程序,而是紧凑地编码了大量等价类的张量程序,从而大幅缩减了搜索空间。这种符号化表示是 Prism 能够突破组合爆炸瓶颈的关键所在。
Prism 的符号化超优化是张量程序优化领域的一个重要里程碑。它结合了枚举式搜索的严谨性与采样式搜索的可扩展性,同时突破了传统优化范式的局限性,为未来的机器学习系统优化指明了新的方向。
符号化超优化的核心贡献与价值
- 提出全新的符号化超优化范式:该范式将张量程序优化的核心从枚举单个具体程序,转变为对整个程序族进行编码,从而极大地压缩了搜索空间的复杂度。这一创新是张量程序优化领域的重大突破,为应对组合爆炸这一经典难题提供了崭新的解决思路。
- 设计 sGraph 符号化层次化图表示:通过引入符号变量来表征并行化参数与数据映射关系,sGraph 能够紧凑地编码整个程序族。它继承了 μGraph 的层次化结构特性,可以精确捕捉 GPU 执行过程中不同层级的并行性。
- 开发符号维度匹配与表达式引导的剪枝技术:这两种技术能够在符号层面高效地筛选并剔除无效的候选方案,从而显著提升搜索效率。它们可以在生成具体程序之前,就排除掉大量无效的候选,大幅减轻后续阶段的工作负担。
- 提出基于 e-graph 的符号等价性检查方法:该方法不依赖于具体的输入值或并行化参数,一次验证即可覆盖所有可能的实例化情况。这不仅大幅度降低了验证开销,更重要的是,它保证了所生成的内核在任何硬件配置和输入尺寸下都是正确的。
- 全面的实验评估:在五个主流的 LLM 工作负载上进行的实验表明,Prism 相比当前最优的超优化器,性能提升最高可达 2.2 倍;相比传统的编译器方法,性能提升最高可达 4.9 倍;同时,端到端的优化时间最高可缩短 3.4 倍。这些结果充分验证了符号化超优化的有效性与可扩展性。
方法论价值
Prism 的核心方法论价值在于,它有力地展示了符号推理与超优化相结合的巨大潜力。通过将人类的数学知识编码为公理和约束,Prism 能够在保证正确性的前提下,自动探索人类直觉难以触及的优化空间。这种方法不仅适用于张量程序优化,更有望推广至其他领域的程序优化问题。
8.2 进阶分析
尽管 Prism 取得了令人瞩目的成果,但我们仍需客观认识其方法论的局限性与适用范围:
- 问题解决的阶段性:Prism 并未从根本上解决张量程序优化的组合爆炸问题,而是通过符号化将组合爆炸的压力从图生成阶段转移到了映射实例化和参数实例化阶段。对于极其复杂的工作负载(例如包含数十个算子的融合内核),映射实例化阶段的枚举仍然可能面临组合爆炸。此外,参数实例化阶段采用的随机采样虽然在实践中有效,但本质上仍是一种启发式方法,无法保证找到全局最优的参数配置。
- 公理系统的不完备性:Prism 的等价性检查依赖于约 70 条手工编写的公理,这套公理系统并不完备。论文明确指出,存在一些等价的程序无法被当前的公理系统证明,例如
T+T等价于2*T。公理系统的不完备性可能导致 Prism 错过一些最优的程序实现。并且,添加新的公理需要人类专家介入,无法自动扩展。 - 实验设计的局限性:Prism 的实验评估仅在 NVIDIA A100 GPU 上进行,未评估在其他硬件平台(如 AMD GPU、Intel GPU、TPU)上的性能。不同硬件平台的架构特性各异,Prism 的符号化表示和优化策略可能需要调整才能在其他平台上取得良好效果。此外,实验仅评估了 LLM 中的常见算子,未涵盖更广泛的 ML 工作负载(如 CNN、扩散模型)。
- 固定开销的问题:Prism 的实例化阶段存在固定开销,需要编译和性能分析所有发现的图模板。对于简单的工作负载或较小的输入配置,这项固定开销可能超过性能提升带来的收益。此外,Prism 的符号搜索虽然每个工作负载只运行一次,但对于非常大的工作负载,其本身耗时可能仍然很长。
- 并行化维度的限制:当前的 Prism 实现假设只有一个 for-loop 维度。虽然论文提到扩展到多个循环维度是直接的,但尚未实现和评估。许多复杂的算子(如卷积)需要多个循环维度才能高效实现,这限制了 Prism 的适用范围。
哪些场景下 Prism 表现欠佳?
从实验结果来看,Prism 在处理简单工作负载(如 SwiGLU)时,性能表现与 Mirage 相当,但优化时间可能更长。这是因为简单工作负载的优化空间较小,符号搜索的优势无法充分发挥,而固定开销却占据了主导地位。此外,对于需要多个 for-loop 维度的工作负载,当前的 Prism 实现无法支持。对于需要快速迭代的开发场景,Prism 较长的优化时间也可能成为一个问题。
8.3 未来工作
8.3.1 原文计划
论文作者明确提出的未来工作方向包括:
- 集成更复杂的自动调优策略:当前的参数实例化阶段使用随机采样,未来可以集成学习成本模型、进化搜索等更复杂的调优策略,以进一步减少调优时间并找到更优的参数配置。学习成本模型可以预测不同参数配置的性能,从而减少需要编译和性能分析的候选数量。
- 扩展到多个 for-loop 维度:当前的实现假设只有一个 for-loop 维度,未来可以扩展到支持多个循环维度,以覆盖更广泛的优化空间。这将使 Prism 能够优化更复杂的算子,如卷积和池化。
- 完善公理系统:当前的公理系统并不完备,未来可以添加更多的公理,以覆盖更多的等价程序。此外,可以研究自动发现公理的方法,减少对人类专家的依赖。
- 支持更多的算子和硬件平台:当前的实现主要针对 LLM 中的常见算子和 NVIDIA GPU,未来可以扩展到支持更多的算子和硬件平台,如 AMD GPU、Intel GPU 和 TPU。这将使 Prism 能够应用于更广泛的场景。
8.3.2 NeuralTalk 视角
从更宏观的领域发展趋势来看,Prism 的成果可能催生以下几个有前途的研究方向:
- 符号推理与大语言模型的结合:将 Prism 的符号推理能力与大语言模型的代码生成能力相结合,可能是未来超优化器的发展方向。大语言模型可以提出新颖的程序结构,而符号推理可以验证这些结构的正确性并进行剪枝,从而实现更高效、更可靠的搜索。这种结合可以充分发挥两者的优势:大语言模型的创造力和符号推理的严谨性。
- 端到端的符号化编译:Prism 目前只针对张量程序的超优化,未来可以将符号化表示扩展到整个编译流程,从高级的计算图到低级的机器代码,实现端到端的符号化编译。这将使编译器能够在整个编译栈上进行全局优化,发现更多的性能机会。例如,符号化编译可以同时优化算子融合、并行化和内存布局,而不是分阶段进行优化。
- 跨硬件平台的符号化优化:Prism 的符号化表示与具体的硬件平台无关,这使得它非常适合跨硬件平台的优化。未来可以开发一个统一的符号化优化框架,能够自动生成针对不同硬件平台的最优内核,大幅减少新硬件的适配成本。这将使 ML 模型能够轻松部署到各种硬件平台上,而无需进行大量的工程投入。
- 形式化验证的集成:Prism 的等价性检查已经使用了形式化方法,未来可以进一步集成更强大的形式化验证技术,如 SMT 求解器、定理证明器,以验证更复杂的程序属性,如数值稳定性、内存安全性等。这将使生成的内核不仅性能优异,而且更加可靠。
- 动态符号化优化:当前的 Prism 是静态优化器,在编译时进行所有优化。未来可以开发动态符号化优化器,能够在运行时根据输入数据的特征动态调整优化策略,进一步提升性能。例如,对于注意力算子,当序列长度较小时,可以使用一种并行化策略;当序列长度较大时,可以使用另一种并行化策略。动态符号化优化可以根据实际的输入数据选择最优的策略。
Prism 的符号化超优化为张量程序优化打开了一扇新的大门。它不仅在当前的 LLM 工作负载上取得了显著的成果,还为未来的 ML 系统优化指明了新的方向。随着符号推理技术的不断发展和完善,我们有理由相信,符号化超优化将成为未来 ML 系统的核心技术之一,推动 AI 应用的性能不断提升。
关注“鲸栖”小程序,掌握最新AI资讯
本文来自网络搜集,不代表鲸林向海立场,如有侵权,联系删除。转载请注明出处:https://www.itsolotime.com/archives/35313

