GPU 编程领域长期存在一个令人尴尬的困境:一方面,追求极致性能必须依赖 CUDA C++ 甚至直接编写 PTX 汇编;另一方面,为了提升开发效率,开发者往往不得不接受 Triton、Pallas 等编译器自动生成代码时的黑盒优化与不可控性。
当 Triton 生成的指令调度不符合预期,当你需要精细调控 mbarrier 时序、TMA 的 multicast 模式,或是 tcgen05 的寄存器分配时,几乎唯一的出路就是退回到 C++ 内联汇编的原始状态。
pyptx 项目给出了一个截然不同的答案:利用 Python 的语法糖,精准封装 PTX 中的每一条指令,不进行任何优化,也不引入任何中间 IR 变换——你在 Python 中写下的内容,PTX 中就原样生成。一个 ptx.inst.add.f32() 调用,恰好对应一条 add.f32 指令。
在 Blackwell B200 上,它实现了 1240 TFLOPS 的 GEMM 性能(达到 cuBLAS 的 77%);在 Hopper H100 上,它达到了 815 TFLOPS,并超越了 cuBLAS。这绝非一个玩具——它是一个真正的工程工具,让你能用 Python 的交互式体验去编写真实的 GPU 汇编。
Blackwell (B200, bf16) 性能表现
| Kernel | Shape | pyptx | cuBLAS | best / cuBLAS |
|---|---|---|---|---|
| GEMM (tcgen05.mma, 4-stage pipeline, 1SM) | 1240 TFLOPS | 1610 | 77% | |
| GEMM (1SM) | 1194 TFLOPS | 1532 | 78% | |
| GEMM 2SM (cta_group::2, 5-stage) | 649 TFLOPS (beats 1SM) | 1006 | 64% | |
| Grouped GEMM (tcgen05, MoE) | G=4 M=2048 N=256 K=2048 | 401 TFLOPS | torch ref | ~10.0× |
| RMS norm / Layer norm / SwiGLU | maintained Blackwell ports | benchmarked | torch ref | see kernel suite |
Hopper (H100 SXM5, bf16 / f32) 性能表现
| Kernel | Shape | pyptx | vs reference |
|---|---|---|---|
| GEMM (wgmma, warp-specialized) | 815 TFLOPS | beats cuBLAS ≥ 6K | |
| Grouped GEMM (bf16→f32) | G=8 M=K=2048 | 104 TFLOPS | — |
| RMS norm (f32) | B=2048 N=8192 | 2.6 TB/s (88% HBM) | 3.9× torch |
| Layer norm (f32) | B=2048 N=8192 | 2.5 TB/s (83% HBM) | 1.5× F.layer_norm |
| SwiGLU (f32) | M=2048 F=8192 | 2.8 TB/s (94% HBM) | 1.6× F.silu(g)*u |
| Softmax (f32, row-wise) | B=2048 N=8192 | 2.8 TB/s (95% HBM) | 1.16× torch.softmax |
| Flash attention (bf16) | M=N=4096, HD=64 | 88 µs | 3.0× naive torch |
注:815 TFLOPS 的 GEMM 实现位于 examples/hopper/gemm_highperf_hopper.py,详见:Fastest kernels written from scratch[1]。
本文目录
- 快速上手
- 一、架构总览与设计哲学
- 1.1 “一调用一指令”的核心契约
- 1.2 整体模块拓扑
- 二、追踪式代码生成:从 Python 到 PTX 的魔法
- 2.1 TraceContext:线程局部的指令收集器
- 2.2 Kernel._trace():追踪的完整生命周期
- 三、寄存器系统:算术运算符即 PTX 指令
- 3.1 Reg 类的运算符重载
- 3.2 RegArray 与 copy propagation 优化
- 四、共享内存与硬件原语的精确建模
- 4.1 SMEM 分配的双模态设计
- 4.2 GMMA Swizzle 的 3 指令实现
- 五、运行时分发:一个内核,三条路径
- 5.1 统一入口与分发逻辑
- 5.2 Turbo 快速路径
- 六、一个完整的实战示例:RMS Norm
- 七、PTX 转译器:逆向工程的利器
- 总结
快速上手
# 安装核心 DSL(无需 GPU 即可生成 PTX)
pip install pyptx
# 如需在 PyTorch 中启动内核
pip install 'pyptx[torch]'
# 如需在 JAX 中启动内核
pip install 'pyptx[jax]'
# 可选:加速 PyTorch C++ 扩展的 JIT 编译
pip install ninja
完成安装后,即可编写并查看内核代码:
from pyptx import kernel, reg, ptx, Tile
from pyptx.types import f32, u32
@kernel(
in_specs=(Tile(4, 64, f32),),
out_specs=(Tile(4, 64, f32),),
grid=(4, 1, 1),
block=(64, 1, 1),
arch="sm_90a",
)
def identity(X, Y):
px, py = ptx.global_ptrs(X, Y)
tid = reg.scalar(u32)
ptx.inst.mov.u32(tid, ptx.special.tid.x())
val = reg.scalar(f32)
ptr = px + tid * 4
ptx.inst.ld.global_.f32(val, ptx.addr(ptr))
ptx.inst.st.global_.f32(ptx.addr(py + tid * 4), val)
ptx.ret()
# 查看生成的 PTX
print(identity.ptx())
更多完整示例请参考 examples/hopper/[2] 和 examples/blackwell/[3],性能数据可查阅 pyptx.dev/performance[4]。
一、架构总览与设计哲学
1.1 “一调用一指令”的核心契约
pyptx 的设计理念可以用一句话概括:Python 函数中的每个
ptx.*调用,恰好对应生成的 PTX 文本中的一条指令。没有优化器,没有调度器,没有中间表示变换。这是一个”所见即所得”的 PTX 编写工具。
这种设计的价值在于——当你调试一个高性能 GEMM 内核时,你不需要猜测编译器做了什么变换。print(my_kernel.ptx()) 输出的内容,就是你写的内容。
1.2 整体模块拓扑
项目核心代码约 36 万行 Python,结构分层清晰:
| 模块 | 职责 |
|---|---|
kernel.py |
@kernel 装饰器、追踪、特化缓存、运行时分发 |
_trace.py |
线程局部追踪上下文,所有 IR 节点在此汇聚 |
ptx.py |
全部 PTX 指令的 DSL 接口(166KB,覆盖完整 ISA) |
reg.py |
寄存器分配 + 算术运算符语法糖 |
smem.py |
共享内存分配、mbarrier、GMMA swizzle |
emitter/ |
IR → PTX 文本的序列化 |
parser/ |
PTX 文本 → IR 的反序列化 |
codegen/ |
PTX → Python 的转译器(逆向工程工具) |
jax_support.py |
JAX typed FFI 集成 |
torch_support.py |
PyTorch eager + torch.compile 集成 |
二、追踪式代码生成:从 Python 到 PTX 的魔法
2.1 TraceContext:线程局部的指令收集器
pyptx 的代码生成并非静态分析或 AST 变换——它是运行时追踪。
当 @kernel 装饰的函数被调用时,一个 TraceContext 被激活,此后所有 reg.*、smem.*、ptx.* 的调用都会将 IR 节点”录入”到这个上下文中。
# 来源:pyptx/_trace.py
class TraceContext:
def __init__(self, *, ptx_version: tuple[int, int] | None = None) -> None:
self.reg_decls: list[RegDecl] = []
self.var_decls: list[VarDecl] = []
self.statements: list[Statement] = []
self.dyn_smem_offset: int = 0
self.force_dynamic_smem: bool = False
def emit(self, stmt: Statement) -> None:
"""Record a statement (instruction, label, etc.)."""
self.statements.append(stmt)
def body(self) -> tuple[Statement, ...]:
"""Return the full function body: decls then statements."""
parts: list[Statement] = []
parts.extend(self.reg_decls)
parts.extend(self.var_decls)
parts.extend(self.statements)
return tuple(parts)
这个设计的精妙之处在于:Python 的控制流(for、if)在追踪时原样展开为线性指令序列。循环 for i in range(4) 不会生成 PTX 循环——它生成 4 份展开的指令。这与 Triton 的 tile-level 编程模型截然不同,pyptx 给你的是指令级的完全控制。
2.2 Kernel._trace():追踪的完整生命周期
Kernel._trace()是整个系统的核心枢纽。它的工作流程是:
参数解析与追踪流程
在 pyptx 的 _trace 方法中,内核的追踪过程被分解为六个清晰的步骤:
- 参数分离:将
kwargs字典拆解为两类——模板参数(例如BM=128)和形状变量(例如M=4096)。 - 占位符构建:为每一个位置参数生成对应的
TensorSpec对象,其中携带了张量的形状与数据类型信息。 - 追踪上下文激活:通过调用
trace_scope()来建立一个线程局部的上下文环境。 - 用户函数执行:执行
self._fn(*positional, **resolved)这一行代码。此时,所有 DSL 内的操作都会被自动记录到当前的追踪上下文中。 - Module 组装:将追踪过程中收集到的寄存器声明、共享内存声明以及指令序列,统一封装成一个 IR Module 对象。
- 动态共享内存处理:若检测到总共享内存使用量超过 48KB,则系统会切换为动态模式,并重新执行一遍追踪流程。
以下是来自 pyptx/kernel.py 的简化实现代码:
def _trace(self, _shape_env=None, **kwargs):
# ... 参数解析逻辑 ...
while True:
with trace_scope(ptx_version=self._version) as ctx:
if force_dynamic_trace:
ctx.force_dynamic_smem = True
self._fn(*positional, **resolved) # 用户代码在此执行!
total_smem = ctx.dyn_smem_offset
if total_smem > 48 * 1024 and not ctx.force_dynamic_smem:
force_dynamic_trace = True
continue # 重新追踪
break
module = Module(
version=Version(...),
target=Target((self._arch,)),
address_size=AddressSize(64),
directives=(..., func),
)
return module
请注意这里的 while True + continue 组合——这是一个非常巧妙的“发现-重试”机制。第一次追踪时,系统可能发现共享内存超出了 48KB 的硬件限制。此时,它会自动切换到 extern 动态共享内存模式,然后重新执行一遍追踪,从而确保所有地址计算的一致性。
三、寄存器系统:算术运算符即 PTX 指令
3.1 Reg 类的运算符重载
位于 reg.py 中的 Reg 类,本质上是一个精心设计的“薄包装”。它重载的算术运算符 +、*、<<、& 等,每一次调用都恰好发射一条 PTX 指令:
# 来源:pyptx/reg.py
class Reg:
def __add__(self, other: Any) -> "Reg":
return _emit_int_add(self, other)
def __mul__(self, other: Any) -> "Reg":
return _emit_int_mul(self, other)
以 _emit_int_add 为例,当你在内核代码中写下 px + offset 时,实际发生的事情是这样的:
# 来源:pyptx/reg.py(简化)
def _emit_int_add(left: Reg, right: Any) -> Reg:
ctx = get_ctx()
kind = _int_dtype_kind(left.dtype)
if kind == "64":
result = scalar(left.dtype)
right_op = _widen_to_64(ctx, right)
ctx.emit(Instruction(
opcode="add", modifiers=(".s64",),
operands=(
RegisterOperand(result.name),
RegisterOperand(left.name),
right_op,
),
))
return result
换句话说:一条 Python 加法 → 一条 add.s64 PTX 指令。整个过程中没有任何优化 pass,没有公共子表达式消除(CSE),也没有常量折叠——唯一的例外是 Python 的 int 常量天然会被当作立即数处理。
3.2 RegArray 与 copy propagation 优化
RegArray.__setitem__ 方法中包含了一个小而精的优化技巧。当你写出 r[90] = (r[89] << 7) 这样的代码时,<< 运算符已经发射了一条 shl.b32 %r_fresh, %r89, 7 指令。此时,__setitem__ 不会再额外发射一条 mov 指令,而是直接将最后一条指令的目标寄存器改写为 %r90:
# 来源:pyptx/reg.py
def __setitem__(self, idx: int, value: "Reg") -> None:
ctx = get_ctx()
if ctx.statements:
last = ctx.statements[-1]
if (isinstance(last, Instruction)
and last.operands[0].name == value.name):
# 直接把目标寄存器名改为 r[idx] 的名字
new_dst = RegisterOperand(name=f"{self._base}{idx}")
ctx.statements[-1] = replace(last, operands=(new_dst,) + last.operands[1:])
return
# fallback: emit mov
这是整个项目中唯一存在的“优化”,而且它的作用仅仅是消除 DSL 语法糖带来的冗余 mov 指令,完全不会改变用户的语义意图。
四、共享内存与硬件原语的精确建模
4.1 SMEM 分配的双模态设计
smem.alloc() 提供了两种分配策略。当总的共享内存需求不超过 48KB 时,系统会采用静态命名的 .shared 声明;一旦超出 48KB 的限制,它会自动切换到统一的 extern .shared .b8 dyn_smem[] 模式,此时所有内存分配都通过计算偏移量来完成寻址:
# 来源:pyptx/smem.py
def alloc(dtype, shape, ...):
ctx = get_ctx()
off = ctx.dyn_smem_offset
# 对齐处理
if align > 0 and off % align != 0:
off = ((off + align - 1) // align) * align
this_offset = off
ctx.dyn_smem_offset = off + byte_count
if ctx.force_dynamic_smem:
return SharedAlloc("dyn_smem", dtype, shape, swizzle, byte_offset=this_offset)
else:
ctx.var_decls.append(VarDecl(...))
return SharedAlloc(name, dtype, shape, swizzle, byte_offset=this_offset)
4.2 用三条指令实现 GMMA Swizzle
Hopper 和 Blackwell 架构中的 WGMMA 指令要求共享内存中的数据必须遵循特定的 swizzle 排列格式。为了在 pyptx 中复现 CuTe 的 Swizzle<B,4,3> 变换,开发者仅使用了三条 ALU 指令:
# 来源:pyptx/smem.py
_SWIZZLE_PARAMS = {
"32B": (0x080, 3), # Swizzle<1,4,3>
"64B": (0x180, 3), # Swizzle<2,4,3>
"128B": (0x380, 3), # Swizzle<3,4,3>
}
def apply_swizzle(logical_offset, swizzle):
mask, shift = _SWIZZLE_PARAMS[swizzle]
# xor_bits = (logical_offset & mask) >> shift
# physical = logical_offset ^ xor_bits
# 恰好 3 条指令:and, shr, xor
这段实现与 CUTLASS 源代码中 Swizzle<B, M, S> 的计算公式 physical = logical XOR ((logical & yyy_msk) >> S) 完全对应。
五、运行时分发:一个内核,三条路径
5.1 统一入口与分发逻辑
Kernel.__call__()是整个系统的运行时统一入口。它会根据输入张量的具体类型,自动选择合适的分发路径:
- PyTorch 张量 →
torch_support.call_kernel_via_torch_compile() - JAX 数组 →
jax_support.call_kernel_via_ffi()
核心的 JIT 编译流程包含以下几个步骤:
- 从输入张量的形状信息中提取 shape_env
- 追踪并生成 PTX 文本
- 通过
cuModuleLoadData将 PTX 代码 JIT 编译为 cubin - 将编译结果注册并缓存到 cubin registry 中
5.2 Turbo 快速路径
针对 PyTorch 在推理循环中反复调用相同形状输入的场景,pyptx 设计了一个“Turbo 快速路径”。该路径会跳过所有 Python 层面的形状检查和分发逻辑,直接调用预编译好的 C++ 扩展:
# 来源:pyptx/kernel.py
if not kwargs and any_torch_tensors(input_arrays):
turbo = getattr(self, '_turbo_torch', None)
if turbo is not None:
shapes_key = tuple(a.shape for a in input_arrays)
if turbo[0] == shapes_key:
# 直接走 C++ 扩展发射内核,~14µs
return _ext.launch_kernel(_handle, stream_ptr, ...)
通过这种优化,PyTorch eager dispatch 的开销从大约 34µs 被压缩到了约 14µs。
六、一个完整的实战示例:RMS Norm
examples/hopper/rms_norm.py文件展示了一个典型的 pyptx 工作流——这是一个能够在 H100 上达到 88% HBM 带宽利用率的 RMS 归一化内核:
“`python
来源:examples/hopper/rms_norm.py(核心片段)
@kernel(
in_specs=(Tile(B, N, f32), Tile(N, f32)),
out_specs=(Tile(B, N, f32),),
grid=(B, 1, 1), block=(block, 1, 1), arch=”sm_90a”,
)
def rms_norm(X, W, Y):
partials = smem.alloc(f32, (num_warps, 1))
px, pw, py = ptx.global_ptrs(X, W, Y)
tid = reg.scalar(u32)
ptx.inst.mov.u32(tid, ptx.special.tid.x())
v4 向量化加载 + fma 累加平方和
sum_sq = reg.scalar(f32, init=0.0)
for j in range(v4_iters):
ptx.inst.ld.global_.v4.f32([x_vals[j4], …], ptx.addr(ptr))
for sub in range(4):
ptx.inst.fma.rn.f32(sum_sq, x_vals[j4+sub], x_vals[j*4+sub], sum_sq)
warp 级蝶式归约
ptx.warp.reduce_sum(sum_sq)
通过观察可以发现,Python 的 for 循环在执行追踪时会被自动展开,形成向量化的 ld.global.v4.f32 和 fma.rn.f32 指令序列。而 ptx.warp.reduce_sum 底层实现则是标准的蝶式归约操作(shfl.sync.bfly),每一次调用都会对应生成固定数量的 PTX 指令。
七、PTX 转译器:逆向工程的利器
pyptx 不仅支持从 Python 代码生成 PTX,还具备将现有 PTX 反向转译为 pyptx Python 代码的能力:
python -m pyptx.codegen kernel.ptx –sugar –name my_kernel > my_kernel.py
在 --sugar 模式下,该工具能够自动识别 spin-loop 并转换为 ptx.loop(...),将 mbarrier-wait 序列折叠为高层函数调用,同时整理表达式链。这个工具已在 218 多个来自 CUTLASS、Triton、DeepGEMM、ThunderKittens 等项目的 PTX 语料上通过了字节级别的往返一致性验证。
这意味着你可以获取 Triton 编译产生的 PTX,将其转译为 Python,在其中进行精确调整(例如修改流水线深度、变更 barrier 时序),然后直接执行——完全无需了解 Triton 内部的编译流程。
总结
pyptx 体现了 GPU 编程领域一种“反编译器”式的哲学立场:与其依赖日益复杂的编译器抽象来隐藏硬件细节,不如借助更强大的 DSL 工具,让程序员能够舒适地直接面对硬件。这套方案证明了以下几点:
- Python 的表达能力足以精确描述 PTX 级别的语义,无需依赖 C++ 或内联汇编
- 追踪式代码生成(而非 AST 变换)是实现“零抽象开销 DSL”的最简路径
- 一个内核对象同时服务于 JAX 和 PyTorch 是可行的工程实践
对于需要在 Hopper/Blackwell 架构上榨取极致性能的 AI 基础设施团队而言,pyptx 提供了一条独特的路径:既保留了 Python 生态的便利性,又不牺牲任何硬件控制精度。
参考资料[1] Fastest kernels written from scratch: https://github.com/pranjalssh/fast.cu
[2] examples/hopper: https://github.com/patrick-toulme/pyptx/tree/main/examples/hopper
[3] examples/blackwell: https://github.com/patrick-toulme/pyptx/tree/main/examples/blackwell
[4] pyptx.dev/performance: https://pyptx.dev/performance/
关注“鲸栖”小程序,掌握最新AI资讯
本文来自网络搜集,不代表鲸林向海立场,如有侵权,联系删除。转载请注明出处:https://www.itsolotime.com/archives/32399

