用Python写GPU汇编?pyptx在Blackwell上实现1240 TFLOPS,性能超越cuBLAS

GPU 编程领域长期存在一个令人尴尬的困境:一方面,追求极致性能必须依赖 CUDA C++ 甚至直接编写 PTX 汇编;另一方面,为了提升开发效率,开发者往往不得不接受 Triton、Pallas 等编译器自动生成代码时的黑盒优化与不可控性。

当 Triton 生成的指令调度不符合预期,当你需要精细调控 mbarrier 时序、TMA 的 multicast 模式,或是 tcgen05 的寄存器分配时,几乎唯一的出路就是退回到 C++ 内联汇编的原始状态。

pyptx 项目给出了一个截然不同的答案:利用 Python 的语法糖,精准封装 PTX 中的每一条指令,不进行任何优化,也不引入任何中间 IR 变换——你在 Python 中写下的内容,PTX 中就原样生成。一个 ptx.inst.add.f32() 调用,恰好对应一条 add.f32 指令。

在 Blackwell B200 上,它实现了 1240 TFLOPS 的 GEMM 性能(达到 cuBLAS 的 77%);在 Hopper H100 上,它达到了 815 TFLOPS,并超越了 cuBLAS。这绝非一个玩具——它是一个真正的工程工具,让你能用 Python 的交互式体验去编写真实的 GPU 汇编。

Blackwell (B200, bf16) 性能表现

Kernel Shape pyptx cuBLAS best / cuBLAS
GEMM (tcgen05.mma, 4-stage pipeline, 1SM) 1240 TFLOPS 1610 77%
GEMM (1SM) 1194 TFLOPS 1532 78%
GEMM 2SM (cta_group::2, 5-stage) 649 TFLOPS (beats 1SM) 1006 64%
Grouped GEMM (tcgen05, MoE) G=4 M=2048 N=256 K=2048 401 TFLOPS torch ref ~10.0×
RMS norm / Layer norm / SwiGLU maintained Blackwell ports benchmarked torch ref see kernel suite

Hopper (H100 SXM5, bf16 / f32) 性能表现

Kernel Shape pyptx vs reference
GEMM (wgmma, warp-specialized) 815 TFLOPS beats cuBLAS ≥ 6K
Grouped GEMM (bf16→f32) G=8 M=K=2048 104 TFLOPS
RMS norm (f32) B=2048 N=8192 2.6 TB/s (88% HBM) 3.9× torch
Layer norm (f32) B=2048 N=8192 2.5 TB/s (83% HBM) 1.5× F.layer_norm
SwiGLU (f32) M=2048 F=8192 2.8 TB/s (94% HBM) 1.6× F.silu(g)*u
Softmax (f32, row-wise) B=2048 N=8192 2.8 TB/s (95% HBM) 1.16× torch.softmax
Flash attention (bf16) M=N=4096, HD=64 88 µs 3.0× naive torch

注:815 TFLOPS 的 GEMM 实现位于 examples/hopper/gemm_highperf_hopper.py,详见:Fastest kernels written from scratch[1]。

本文目录

  • 快速上手
  • 一、架构总览与设计哲学
  • 1.1 “一调用一指令”的核心契约
  • 1.2 整体模块拓扑
  • 二、追踪式代码生成:从 Python 到 PTX 的魔法
  • 2.1 TraceContext:线程局部的指令收集器
  • 2.2 Kernel._trace():追踪的完整生命周期
  • 三、寄存器系统:算术运算符即 PTX 指令
  • 3.1 Reg 类的运算符重载
  • 3.2 RegArray 与 copy propagation 优化
  • 四、共享内存与硬件原语的精确建模
  • 4.1 SMEM 分配的双模态设计
  • 4.2 GMMA Swizzle 的 3 指令实现
  • 五、运行时分发:一个内核,三条路径
  • 5.1 统一入口与分发逻辑
  • 5.2 Turbo 快速路径
  • 六、一个完整的实战示例:RMS Norm
  • 七、PTX 转译器:逆向工程的利器
  • 总结

快速上手

# 安装核心 DSL(无需 GPU 即可生成 PTX)
pip install pyptx

# 如需在 PyTorch 中启动内核
pip install 'pyptx[torch]'

# 如需在 JAX 中启动内核
pip install 'pyptx[jax]'

# 可选:加速 PyTorch C++ 扩展的 JIT 编译
pip install ninja

完成安装后,即可编写并查看内核代码:

from pyptx import kernel, reg, ptx, Tile
from pyptx.types import f32, u32

@kernel(
in_specs=(Tile(4, 64, f32),),
out_specs=(Tile(4, 64, f32),),
grid=(4, 1, 1),
block=(64, 1, 1),
arch="sm_90a",
)
def identity(X, Y):
px, py = ptx.global_ptrs(X, Y)
tid = reg.scalar(u32)
ptx.inst.mov.u32(tid, ptx.special.tid.x())
val = reg.scalar(f32)
ptr = px + tid * 4
ptx.inst.ld.global_.f32(val, ptx.addr(ptr))
ptx.inst.st.global_.f32(ptx.addr(py + tid * 4), val)
ptx.ret()

# 查看生成的 PTX
print(identity.ptx())

更多完整示例请参考 examples/hopper/[2] 和 examples/blackwell/[3],性能数据可查阅 pyptx.dev/performance[4]。

一、架构总览与设计哲学

1.1 “一调用一指令”的核心契约

pyptx 的设计理念可以用一句话概括:Python 函数中的每个 ptx.* 调用,恰好对应生成的 PTX 文本中的一条指令。没有优化器,没有调度器,没有中间表示变换。这是一个”所见即所得”的 PTX 编写工具。

这种设计的价值在于——当你调试一个高性能 GEMM 内核时,你不需要猜测编译器做了什么变换。print(my_kernel.ptx()) 输出的内容,就是你写的内容。

1.2 整体模块拓扑

项目核心代码约 36 万行 Python,结构分层清晰:

模块 职责
kernel.py @kernel 装饰器、追踪、特化缓存、运行时分发
_trace.py 线程局部追踪上下文,所有 IR 节点在此汇聚
ptx.py 全部 PTX 指令的 DSL 接口(166KB,覆盖完整 ISA)
reg.py 寄存器分配 + 算术运算符语法糖
smem.py 共享内存分配、mbarrier、GMMA swizzle
emitter/ IR → PTX 文本的序列化
parser/ PTX 文本 → IR 的反序列化
codegen/ PTX → Python 的转译器(逆向工程工具)
jax_support.py JAX typed FFI 集成
torch_support.py PyTorch eager + torch.compile 集成

二、追踪式代码生成:从 Python 到 PTX 的魔法

2.1 TraceContext:线程局部的指令收集器

pyptx 的代码生成并非静态分析或 AST 变换——它是运行时追踪

@kernel 装饰的函数被调用时,一个 TraceContext 被激活,此后所有 reg.*smem.*ptx.* 的调用都会将 IR 节点”录入”到这个上下文中。

# 来源:pyptx/_trace.py
class TraceContext:
def __init__(self, *, ptx_version: tuple[int, int] | None = None) -> None:
self.reg_decls: list[RegDecl] = []
self.var_decls: list[VarDecl] = []
self.statements: list[Statement] = []
self.dyn_smem_offset: int = 0
self.force_dynamic_smem: bool = False

def emit(self, stmt: Statement) -> None:
"""Record a statement (instruction, label, etc.)."""
self.statements.append(stmt)

def body(self) -> tuple[Statement, ...]:
"""Return the full function body: decls then statements."""
parts: list[Statement] = []
parts.extend(self.reg_decls)
parts.extend(self.var_decls)
parts.extend(self.statements)
return tuple(parts)

这个设计的精妙之处在于:Python 的控制流(forif)在追踪时原样展开为线性指令序列。循环 for i in range(4) 不会生成 PTX 循环——它生成 4 份展开的指令。这与 Triton 的 tile-level 编程模型截然不同,pyptx 给你的是指令级的完全控制。

2.2 Kernel._trace():追踪的完整生命周期

Kernel._trace() 是整个系统的核心枢纽。它的工作流程是:

参数解析与追踪流程

在 pyptx 的 _trace 方法中,内核的追踪过程被分解为六个清晰的步骤:

  1. 参数分离:将 kwargs 字典拆解为两类——模板参数(例如 BM=128)和形状变量(例如 M=4096)。
  2. 占位符构建:为每一个位置参数生成对应的 TensorSpec 对象,其中携带了张量的形状与数据类型信息。
  3. 追踪上下文激活:通过调用 trace_scope() 来建立一个线程局部的上下文环境。
  4. 用户函数执行:执行 self._fn(*positional, **resolved) 这一行代码。此时,所有 DSL 内的操作都会被自动记录到当前的追踪上下文中。
  5. Module 组装:将追踪过程中收集到的寄存器声明、共享内存声明以及指令序列,统一封装成一个 IR Module 对象。
  6. 动态共享内存处理:若检测到总共享内存使用量超过 48KB,则系统会切换为动态模式,并重新执行一遍追踪流程。

以下是来自 pyptx/kernel.py 的简化实现代码:

def _trace(self, _shape_env=None, **kwargs):
# ... 参数解析逻辑 ...
while True:
with trace_scope(ptx_version=self._version) as ctx:
if force_dynamic_trace:
ctx.force_dynamic_smem = True
self._fn(*positional, **resolved)  # 用户代码在此执行!

total_smem = ctx.dyn_smem_offset
if total_smem > 48 * 1024 and not ctx.force_dynamic_smem:
force_dynamic_trace = True
continue  # 重新追踪
break

module = Module(
version=Version(...),
target=Target((self._arch,)),
address_size=AddressSize(64),
directives=(..., func),
)
return module

请注意这里的 while True + continue 组合——这是一个非常巧妙的“发现-重试”机制。第一次追踪时,系统可能发现共享内存超出了 48KB 的硬件限制。此时,它会自动切换到 extern 动态共享内存模式,然后重新执行一遍追踪,从而确保所有地址计算的一致性。


三、寄存器系统:算术运算符即 PTX 指令

3.1 Reg 类的运算符重载

位于 reg.py 中的 Reg 类,本质上是一个精心设计的“薄包装”。它重载的算术运算符 +*<<& 等,每一次调用都恰好发射一条 PTX 指令

# 来源:pyptx/reg.py
class Reg:
def __add__(self, other: Any) -> "Reg":
return _emit_int_add(self, other)

def __mul__(self, other: Any) -> "Reg":
return _emit_int_mul(self, other)

_emit_int_add 为例,当你在内核代码中写下 px + offset 时,实际发生的事情是这样的:

# 来源:pyptx/reg.py(简化)
def _emit_int_add(left: Reg, right: Any) -> Reg:
ctx = get_ctx()
kind = _int_dtype_kind(left.dtype)
if kind == "64":
result = scalar(left.dtype)
right_op = _widen_to_64(ctx, right)
ctx.emit(Instruction(
opcode="add", modifiers=(".s64",),
operands=(
RegisterOperand(result.name),
RegisterOperand(left.name),
right_op,
),
))
return result

换句话说:一条 Python 加法 → 一条 add.s64 PTX 指令。整个过程中没有任何优化 pass,没有公共子表达式消除(CSE),也没有常量折叠——唯一的例外是 Python 的 int 常量天然会被当作立即数处理。

3.2 RegArray 与 copy propagation 优化

RegArray.__setitem__ 方法中包含了一个小而精的优化技巧。当你写出 r[90] = (r[89] << 7) 这样的代码时,<< 运算符已经发射了一条 shl.b32 %r_fresh, %r89, 7 指令。此时,__setitem__ 不会再额外发射一条 mov 指令,而是直接将最后一条指令的目标寄存器改写%r90

# 来源:pyptx/reg.py
def __setitem__(self, idx: int, value: "Reg") -> None:
ctx = get_ctx()
if ctx.statements:
last = ctx.statements[-1]
if (isinstance(last, Instruction)
and last.operands[0].name == value.name):
# 直接把目标寄存器名改为 r[idx] 的名字
new_dst = RegisterOperand(name=f"{self._base}{idx}")
ctx.statements[-1] = replace(last, operands=(new_dst,) + last.operands[1:])
return
# fallback: emit mov

这是整个项目中唯一存在的“优化”,而且它的作用仅仅是消除 DSL 语法糖带来的冗余 mov 指令,完全不会改变用户的语义意图。


四、共享内存与硬件原语的精确建模

4.1 SMEM 分配的双模态设计

smem.alloc() 提供了两种分配策略。当总的共享内存需求不超过 48KB 时,系统会采用静态命名的 .shared 声明;一旦超出 48KB 的限制,它会自动切换到统一的 extern .shared .b8 dyn_smem[] 模式,此时所有内存分配都通过计算偏移量来完成寻址:

# 来源:pyptx/smem.py  
def alloc(dtype, shape, ...):  
ctx = get_ctx()  
off = ctx.dyn_smem_offset  
# 对齐处理  
if align > 0 and off % align != 0:  
off = ((off + align - 1) // align) * align  
this_offset = off  
ctx.dyn_smem_offset = off + byte_count  

if ctx.force_dynamic_smem:  
return SharedAlloc("dyn_smem", dtype, shape, swizzle, byte_offset=this_offset)  
else:  
ctx.var_decls.append(VarDecl(...))  
return SharedAlloc(name, dtype, shape, swizzle, byte_offset=this_offset)  

4.2 用三条指令实现 GMMA Swizzle

Hopper 和 Blackwell 架构中的 WGMMA 指令要求共享内存中的数据必须遵循特定的 swizzle 排列格式。为了在 pyptx 中复现 CuTe 的 Swizzle<B,4,3> 变换,开发者仅使用了三条 ALU 指令:

# 来源:pyptx/smem.py  
_SWIZZLE_PARAMS = {  
"32B":  (0x080, 3),   # Swizzle<1,4,3>  
"64B":  (0x180, 3),   # Swizzle<2,4,3>  
"128B": (0x380, 3),   # Swizzle<3,4,3>  
}  

def apply_swizzle(logical_offset, swizzle):  
mask, shift = _SWIZZLE_PARAMS[swizzle]  
# xor_bits = (logical_offset & mask) >> shift  
# physical = logical_offset ^ xor_bits  
# 恰好 3 条指令:and, shr, xor  

这段实现与 CUTLASS 源代码中 Swizzle<B, M, S> 的计算公式 physical = logical XOR ((logical & yyy_msk) >> S) 完全对应。

五、运行时分发:一个内核,三条路径

5.1 统一入口与分发逻辑

Kernel.__call__() 是整个系统的运行时统一入口。它会根据输入张量的具体类型,自动选择合适的分发路径:

  • PyTorch 张量torch_support.call_kernel_via_torch_compile()
  • JAX 数组jax_support.call_kernel_via_ffi()

核心的 JIT 编译流程包含以下几个步骤:

  1. 从输入张量的形状信息中提取 shape_env
  2. 追踪并生成 PTX 文本
  3. 通过 cuModuleLoadData 将 PTX 代码 JIT 编译为 cubin
  4. 将编译结果注册并缓存到 cubin registry 中

5.2 Turbo 快速路径

针对 PyTorch 在推理循环中反复调用相同形状输入的场景,pyptx 设计了一个“Turbo 快速路径”。该路径会跳过所有 Python 层面的形状检查和分发逻辑,直接调用预编译好的 C++ 扩展:

# 来源:pyptx/kernel.py  
if not kwargs and any_torch_tensors(input_arrays):  
turbo = getattr(self, '_turbo_torch', None)  
if turbo is not None:  
shapes_key = tuple(a.shape for a in input_arrays)  
if turbo[0] == shapes_key:  
# 直接走 C++ 扩展发射内核,~14µs  
return _ext.launch_kernel(_handle, stream_ptr, ...)  

通过这种优化,PyTorch eager dispatch 的开销从大约 34µs 被压缩到了约 14µs。

六、一个完整的实战示例:RMS Norm

examples/hopper/rms_norm.py 文件展示了一个典型的 pyptx 工作流——这是一个能够在 H100 上达到 88% HBM 带宽利用率的 RMS 归一化内核:

“`python

来源:examples/hopper/rms_norm.py(核心片段)

@kernel(
in_specs=(Tile(B, N, f32), Tile(N, f32)),
out_specs=(Tile(B, N, f32),),
grid=(B, 1, 1), block=(block, 1, 1), arch=”sm_90a”,
)
def rms_norm(X, W, Y):
partials = smem.alloc(f32, (num_warps, 1))
px, pw, py = ptx.global_ptrs(X, W, Y)
tid = reg.scalar(u32)
ptx.inst.mov.u32(tid, ptx.special.tid.x())

v4 向量化加载 + fma 累加平方和

sum_sq = reg.scalar(f32, init=0.0)
for j in range(v4_iters):
ptx.inst.ld.global_.v4.f32([x_vals[j4], …], ptx.addr(ptr))
for sub in range(4):
ptx.inst.fma.rn.f32(sum_sq, x_vals[j
4+sub], x_vals[j*4+sub], sum_sq)

warp 级蝶式归约

ptx.warp.reduce_sum(sum_sq)

通过观察可以发现,Python 的 for 循环在执行追踪时会被自动展开,形成向量化的 ld.global.v4.f32fma.rn.f32 指令序列。而 ptx.warp.reduce_sum 底层实现则是标准的蝶式归约操作(shfl.sync.bfly),每一次调用都会对应生成固定数量的 PTX 指令。

七、PTX 转译器:逆向工程的利器

pyptx 不仅支持从 Python 代码生成 PTX,还具备将现有 PTX 反向转译为 pyptx Python 代码的能力:

python -m pyptx.codegen kernel.ptx –sugar –name my_kernel > my_kernel.py

--sugar 模式下,该工具能够自动识别 spin-loop 并转换为 ptx.loop(...),将 mbarrier-wait 序列折叠为高层函数调用,同时整理表达式链。这个工具已在 218 多个来自 CUTLASS、Triton、DeepGEMM、ThunderKittens 等项目的 PTX 语料上通过了字节级别的往返一致性验证。

这意味着你可以获取 Triton 编译产生的 PTX,将其转译为 Python,在其中进行精确调整(例如修改流水线深度、变更 barrier 时序),然后直接执行——完全无需了解 Triton 内部的编译流程。

总结

pyptx 体现了 GPU 编程领域一种“反编译器”式的哲学立场:与其依赖日益复杂的编译器抽象来隐藏硬件细节,不如借助更强大的 DSL 工具,让程序员能够舒适地直接面对硬件。这套方案证明了以下几点:

  1. Python 的表达能力足以精确描述 PTX 级别的语义,无需依赖 C++ 或内联汇编
  2. 追踪式代码生成(而非 AST 变换)是实现“零抽象开销 DSL”的最简路径
  3. 一个内核对象同时服务于 JAX 和 PyTorch 是可行的工程实践

对于需要在 Hopper/Blackwell 架构上榨取极致性能的 AI 基础设施团队而言,pyptx 提供了一条独特的路径:既保留了 Python 生态的便利性,又不牺牲任何硬件控制精度。

参考资料[1] Fastest kernels written from scratch: https://github.com/pranjalssh/fast.cu

[2] examples/hopper: https://github.com/patrick-toulme/pyptx/tree/main/examples/hopper

[3] examples/blackwell: https://github.com/patrick-toulme/pyptx/tree/main/examples/blackwell

[4] pyptx.dev/performance: https://pyptx.dev/performance/


关注“鲸栖”小程序,掌握最新AI资讯

本文来自网络搜集,不代表鲸林向海立场,如有侵权,联系删除。转载请注明出处:https://www.itsolotime.com/archives/32399

(0)
上一篇 1小时前
下一篇 1小时前

相关推荐

  • Chandra OCR:重塑文档AI新标杆,以结构感知开启OCR 2.0时代

    OCR技术已历经长期发展,关于“文档智能”的愿景也层出不穷。然而,当面对真正复杂的文档材料时,大多数OCR系统的表现往往不尽如人意: 📄 模糊的PDF文件🧮 老旧数学作业纸的扫描件🗂️ 多栏版式的报纸扫描件✍️ 数十年前的手写表格 现有的一些OCR方案在页面干净规整时表现尚可,但一旦涉及文档结构、上下文理解或内容意图,就显得力不从心。 Chandra OCR…

    2025年12月24日
    37300
  • 4款惊艳AI开源项目盘点:从图表重建到桌面助手,解锁智能新体验

    01 图片、PDF转为可编辑 Edit Banana 是一个由北京理工大学开发的开源项目。它能够将不可编辑的图片或PDF格式的统计图表、流程图,转换为可完全编辑的格式,例如 DrawIO 的 XML 或 PPTX。 该项目并非简单的OCR工具,而是基于计算机视觉模型,对图表中的逻辑关系、形状组件和文本进行深度重建,实现高保真还原。生成的图形元素可以独立选中和…

    2026年2月21日
    56900
  • 百度GenFlow 4.0全面升级:一句话搞定PPT、Excel、Word,办公效率炸裂

    百度GenFlow 4.0全面升级:一句话搞定PPT、Excel、Word,办公效率炸裂 打工人最想甩给AI干的活儿,无非就是Office三件套:PPT、Excel、Word。 百度GenFlow这次大升级,直接把这些活儿全包圆了。 4月27日,百度文库与百度网盘联合发布了通用智能体GenFlow 4.0,对Office Agent进行了全面升级。 同时,还…

    AI产品库 1小时前
    1200
  • 阿里万相2.6发布:国内首个声画一致角色定制模型,将专业影棚搬入手机

    2025年,视频生成技术迎来突破性进展,行业范式正在重塑。9月,OpenAI发布的Sora 2通过“客串”功能攻克了长期困扰行业的角色一致性难题,使AI视频从随机生成转向可控创作。商业应用同步加速:B端AI短剧与漫剧批量上线,显著降低制作成本;C端社交平台涌现“粘土滤镜”等爆款特效,众多博主开始常态化使用AI制作剧情短片。 然而,顶尖技术对普通用户而言仍存在…

    2025年12月17日
    51800
  • LibTV震撼发布:首款人+Agent双视角AI视频创作平台,无限画布+节点工作流颠覆传统

    LibTV:首款人+Agent双视角AI视频创作平台 传统AI视频工具往往仅聚焦于内容生成本身,而LibTV提出了一个全新的设计理念:将人类创作者与AI智能体(Agent)置于同等地位,提供“手动”与“自动”两种并行的创作范式。 这一设计在AI创作社区引发了广泛关注。其核心在于,它并非单一的生成工具,而是一个整合了无限画布、节点式工作流与丰富专业功能的一站式…

    2026年3月20日
    1.4K00