เขียน GPU Assembly ด้วย Python? pyptx ทำ 1240 TFLOPS บน Blackwell แซงหน้า cuBLAS

ในแวดวงการเขียนโปรแกรม GPU มีภาวะกลืนไม่เข้าคายไม่ออกที่น่าอึดอัดใจมานาน ด้านหนึ่ง การไล่ตามประสิทธิภาพสูงสุดต้องพึ่งพา CUDA C++ หรือแม้แต่การเขียน PTX Assembly โดยตรง อีกด้านหนึ่ง เพื่อเพิ่มประสิทธิภาพการพัฒนา นักพัฒนามักต้องยอมรับการปรับแต่งแบบกล่องดำและความไม่สามารถควบคุมได้ของโค้ดที่สร้างโดยอัตโนมัติจากคอมไพเลอร์อย่าง Triton, Pallas

เมื่อการจัดตารางคำสั่งที่สร้างโดย Triton ไม่เป็นไปตามที่คาดหวัง เมื่อคุณต้องปรับจังหวะ mbarrier, โหมด multicast ของ TMA หรือการจัดสรร register ของ tcgen05 อย่างละเอียด ทางออกเดียวแทบจะหนีไม่พ้นการถอยกลับไปสู่สถานะดั้งเดิมของการเขียน Inline Assembly ใน C++

โปรเจกต์ pyptx มอบคำตอบที่แตกต่างอย่างสิ้นเชิง นั่นคือการใช้ Syntactic Sugar ของ Python เพื่อห่อหุ้มคำสั่งแต่ละคำสั่งใน PTX อย่างแม่นยำ โดยไม่ทำการปรับแต่งใดๆ และไม่นำเสนอการแปลง IR ใดๆ ระหว่างทาง สิ่งที่คุณเขียนใน Python จะถูกสร้างขึ้นใน PTX ตามต้นฉบับทุกประการ การเรียกใช้ ptx.inst.add.f32() หนึ่งครั้ง จะตรงกับคำสั่ง add.f32 หนึ่งคำสั่งพอดี

บน Blackwell B200 มันทำประสิทธิภาพ GEMM ได้ถึง 1240 TFLOPS (คิดเป็น 77% ของ cuBLAS) บน Hopper H100 มันทำได้ถึง 815 TFLOPS และเหนือกว่า cuBLAS นี่ไม่ใช่ของเล่น มันเป็นเครื่องมือทางวิศวกรรมที่แท้จริง ที่ให้คุณใช้ประสบการณ์แบบโต้ตอบของ Python เพื่อเขียน GPU Assembly จริง

ประสิทธิภาพบน Blackwell (B200, bf16)

Kernel Shape pyptx cuBLAS best / cuBLAS
GEMM (tcgen05.mma, 4-stage pipeline, 1SM) 1240 TFLOPS 1610 77%
GEMM (1SM) 1194 TFLOPS 1532 78%
GEMM 2SM (cta_group::2, 5-stage) 649 TFLOPS (beats 1SM) 1006 64%
Grouped GEMM (tcgen05, MoE) G=4 M=2048 N=256 K=2048 401 TFLOPS torch ref ~10.0×
RMS norm / Layer norm / SwiGLU maintained Blackwell ports benchmarked torch ref see kernel suite

ประสิทธิภาพบน Hopper (H100 SXM5, bf16 / f32)

Kernel Shape pyptx vs reference
GEMM (wgmma, warp-specialized) 815 TFLOPS beats cuBLAS ≥ 6K
Grouped GEMM (bf16→f32) G=8 M=K=2048 104 TFLOPS
RMS norm (f32) B=2048 N=8192 2.6 TB/s (88% HBM) 3.9× torch
Layer norm (f32) B=2048 N=8192 2.5 TB/s (83% HBM) 1.5× F.layer_norm
SwiGLU (f32) M=2048 F=8192 2.8 TB/s (94% HBM) 1.6× F.silu(g)*u
Softmax (f32, row-wise) B=2048 N=8192 2.8 TB/s (95% HBM) 1.16× torch.softmax
Flash attention (bf16) M=N=4096, HD=64 88 µs 3.0× naive torch

หมายเหตุ: การใช้งาน GEMM 815 TFLOPS อยู่ใน examples/hopper/gemm_highperf_hopper.py ดูรายละเอียดได้ที่: Fastest kernels written from scratch[1]

สารบัญบทความนี้

  • เริ่มต้นใช้งานอย่างรวดเร็ว
  • หนึ่ง ภาพรวมสถาปัตยกรรมและปรัชญาการออกแบบ
  • 1.1 ข้อตกลงหลัก “หนึ่งการเรียกใช้ หนึ่งคำสั่ง”
  • 1.2 โครงสร้างโมดูลโดยรวม
  • สอง การสร้างโค้ดแบบติดตาม: เวทมนตร์จาก Python สู่ PTX
  • 2.1 TraceContext: ตัวรวบรวมคำสั่งระดับเธรด
  • 2.2 Kernel._trace(): วงจรชีวิตที่สมบูรณ์ของการติดตาม
  • สาม ระบบรีจิสเตอร์: ตัวดำเนินการทางคณิตศาสตร์คือคำสั่ง PTX
  • 3.1 การโอเวอร์โหลดตัวดำเนินการของคลาส Reg
  • 3.2 RegArray และการปรับแต่ง copy propagation
  • สี่ การสร้างแบบจำลองที่แม่นยำของหน่วยความจำที่ใช้ร่วมกันและดั้งเดิมของฮาร์ดแวร์
  • 4.1 การออกแบบสองโหมดของการจัดสรร SMEM
  • 4.2 การใช้งาน GMMA Swizzle ด้วย 3 คำสั่ง
  • ห้า การกระจายขณะรันไทม์: หนึ่งเคอร์เนล สามเส้นทาง
  • 5.1 จุดเข้าใช้งานรวมและตรรกะการกระจาย
  • 5.2 เส้นทางด่วน Turbo
  • หก ตัวอย่างการปฏิบัติจริงที่สมบูรณ์: RMS Norm
  • เจ็ด ตัวแปล PTX: เครื่องมืออันทรงพลังสำหรับวิศวกรรมย้อนกลับ
  • สรุป

เริ่มต้นใช้งานอย่างรวดเร็ว

# ติดตั้ง DSL หลัก (ไม่จำเป็นต้องใช้ GPU เพื่อสร้าง PTX)
pip install pyptx

# หากต้องการเปิดใช้เคอร์เนลใน PyTorch
pip install 'pyptx[torch]'

# หากต้องการเปิดใช้เคอร์เนลใน JAX
pip install 'pyptx[jax]'

# ไม่บังคับ: เร่งการคอมไพล์ JIT ของส่วนขยาย C++ PyTorch
pip install ninja

หลังจากติดตั้งเสร็จ คุณสามารถเขียนและดูโค้ดเคอร์เนลได้:

from pyptx import kernel, reg, ptx, Tile
from pyptx.types import f32, u32

@kernel(
in_specs=(Tile(4, 64, f32),),
out_specs=(Tile(4, 64, f32),),
grid=(4, 1, 1),
block=(64, 1, 1),
arch="sm_90a",
)
def identity(X, Y):
px, py = ptx.global_ptrs(X, Y)
tid = reg.scalar(u32)
ptx.inst.mov.u32(tid, ptx.special.tid.x())
val = reg.scalar(f32)
ptr = px + tid * 4
ptx.inst.ld.global_.f32(val, ptx.addr(ptr))
ptx.inst.st.global_.f32(ptx.addr(py + tid * 4), val)
ptx.ret()

# ดู PTX ที่สร้างขึ้น
print(identity.ptx())

ตัวอย่างที่สมบูรณ์เพิ่มเติมโปรดดูที่ examples/hopper/[2] และ examples/blackwell/[3] ข้อมูลประสิทธิภาพสามารถตรวจสอบได้ที่ pyptx.dev/performance[4]

หนึ่ง ภาพรวมสถาปัตยกรรมและปรัชญาการออกแบบ

1.1 ข้อตกลงหลัก “หนึ่งการเรียกใช้ หนึ่งคำสั่ง”

แนวคิดการออกแบบของ pyptx สามารถสรุปได้ในประโยคเดียว: ทุกการเรียกใช้ ptx.* ในฟังก์ชัน Python จะตรงกับหนึ่งคำสั่งในข้อความ PTX ที่สร้างขึ้นทุกประการ ไม่มี Optimizer, ไม่มี Scheduler, ไม่มีการแปลง IR ระดับกลาง นี่คือเครื่องมือเขียน PTX แบบ “เห็นอะไรได้อย่างนั้น”

คุณค่าของการออกแบบนี้คือ เมื่อคุณดีบักเคอร์เนล GEMM ประสิทธิภาพสูง คุณไม่จำเป็นต้องเดาว่าคอมไพเลอร์ทำการแปลงอะไรไปบ้าง สิ่งที่ print(my_kernel.ptx()) แสดงออกมา คือสิ่งที่คุณเขียน

1.2 โครงสร้างโมดูลโดยรวม

โค้ดหลักของโปรเจกต์มีประมาณ 360,000 บรรทัด Python แบ่งชั้นโครงสร้างอย่างชัดเจน:

โมดูล หน้าที่รับผิดชอบ
kernel.py ตัวตกแต่ง @kernel, การติดตาม, แคชเฉพาะทาง, การกระจายขณะรันไทม์
_trace.py บริบทการติดตามระดับเธรด โหนด IR ทั้งหมดมารวมกันที่นี่
ptx.py อินเทอร์เฟซ DSL สำหรับคำสั่ง PTX ทั้งหมด (166KB ครอบคลุม ISA ที่สมบูรณ์)
reg.py การจัดสรร register + Syntactic Sugar สำหรับตัวดำเนินการทางคณิตศาสตร์
smem.py การจัดสรรหน่วยความจำที่ใช้ร่วมกัน, mbarrier, GMMA swizzle
emitter/ การทำให้ IR → ข้อความ PTX เป็นอนุกรม
parser/ การดีซีเรียลไลซ์ข้อความ PTX → IR
codegen/ ตัวแปล PTX → Python (เครื่องมือวิศวกรรมย้อนกลับ)
jax_support.py การรวม JAX typed FFI
torch_support.py การรวม PyTorch eager + torch.compile

สอง การสร้างโค้ดแบบติดตาม: เวทมนตร์จาก Python สู่ PTX

2.1 TraceContext: ตัวรวบรวมคำสั่งระดับเธรด

การสร้างโค้ดของ pyptx ไม่ใช่การวิเคราะห์แบบคงที่หรือการแปลง AST มันคือ การติดตามขณะรันไทม์

เมื่อฟังก์ชันที่ตกแต่งด้วย @kernel ถูกเรียกใช้ TraceContext จะถูกเปิดใช้งาน หลังจากนั้นทุกการเรียกใช้ reg.*, smem.*, ptx.* จะ “บันทึก” โหนด IR ลงในบริบทนี้

# ที่มา: pyptx/_trace.py
class TraceContext:
def __init__(self, *, ptx_version: tuple[int, int] | None = None) -> None:
self.reg_decls: list[RegDecl] = []
self.var_decls: list[VarDecl] = []
self.statements: list[Statement] = []
self.dyn_smem_offset: int = 0
self.force_dynamic_smem: bool = False

def emit(self, stmt: Statement) -> None:
"""บันทึกคำสั่ง (instruction, label, ฯลฯ)"""
self.statements.append(stmt)

def body(self) -> tuple[Statement, ...]:
"""ส่งคืนเนื้อหาฟังก์ชันเต็ม: decls ตามด้วย statements"""
parts: list[Statement] = []
parts.extend(self.reg_decls)
parts.extend(self.var_decls)
parts.extend(self.statements)
return tuple(parts)

ความชาญฉลาดของการออกแบบนี้คือ: โฟลว์ควบคุมของ Python (for, if) จะถูก ขยายตามต้นฉบับ เป็นลำดับคำสั่งเชิงเส้นระหว่างการติดตาม ตัวอย่างเช่น ลูป for i in range(4) จะไม่สร้างลูป PTX แต่มันจะสร้างคำสั่งที่ถูกขยาย 4 ชุด ซึ่งแตกต่างอย่างสิ้นเชิงจากโมเดลการเขียนโปรแกรมระดับ Tile ของ Triton pyptx มอบการควบคุมระดับคำสั่งอย่างสมบูรณ์ให้กับคุณ

2.2 Kernel._trace(): วงจรชีวิตที่สมบูรณ์ของการติดตาม

Kernel._trace() เป็นศูนย์กลางหลักของทั้งระบบ ขั้นตอนการทำงานของมันคือ:

การแยกวิเคราะห์พารามิเตอร์และขั้นตอนการติดตาม

ในเมธอด _trace ของ pyptx กระบวนการติดตามเคอร์เนลถูกแบ่งออกเป็นหกขั้นตอนที่ชัดเจน:

  1. การแยกพารามิเตอร์: แยกพจนานุกรม kwargs ออกเป็นสองประเภท ได้แก่ พารามิเตอร์เทมเพลต (เช่น BM=128) และตัวแปรรูปร่าง (เช่น M=4096)
  2. การสร้าง Placeholder: สร้างออบเจ็กต์ TensorSpec สำหรับพารามิเตอร์ตำแหน่งแต่ละตัว ซึ่งมีข้อมูลรูปร่างและประเภทข้อมูลของเทนเซอร์
  3. การเปิดใช้งานบริบทการติดตาม: สร้างสภาพแวดล้อมบริบทระดับเธรดโดยการเรียกใช้ trace_scope()
  4. การดำเนินการฟังก์ชันผู้ใช้: ดำเนินการบรรทัดโค้ด self._fn(*positional, **resolved) ในเวลานี้ การดำเนินการทั้งหมดภายใน DSL จะถูกบันทึกลงในบริบทการติดตามปัจจุบันโดยอัตโนมัติ
  5. การประกอบ Module: ห่อหุ้มการประกาศ register, การประกาศหน่วยความจำที่ใช้ร่วมกัน และลำดับคำสั่งที่รวบรวมได้ระหว่างการติดตาม ให้เป็นออบเจ็กต์ IR Module ที่เป็นหนึ่งเดียว
  6. การจัดการหน่วยความจำที่ใช้ร่วมกันแบบไดนามิก: หากตรวจพบว่าการใช้หน่วยความจำที่ใช้ร่วมกันทั้งหมดเกิน 48KB ระบบจะสลับไปยังโหมดไดนามิกและดำเนินการติดตามอีกครั้ง

ต่อไปนี้คือโค้ดการใช้งานแบบย่อจาก pyptx/kernel.py:

def _trace(self, _shape_env=None, **kwargs):
# ... ตรรกะการแยกวิเคราะห์พารามิเตอร์ ...
while True:
with trace_scope(ptx_version=self._version) as ctx:
if force_dynamic_trace:
ctx.force_dynamic_smem = True
self._fn(*positional, **resolved)  # โค้ดผู้ใช้ทำงานที่นี่!

total_smem = ctx.dyn_smem_offset
if total_smem > 48 * 1024 and not ctx.force_dynamic_smem:
force_dynamic_trace = True
continue  # ติดตามใหม่
break

module = Module(
version=Version(...),
target=Target((self._arch,)),
address_size=AddressSize(64),
directives=(..., func),
)
return module

โปรดสังเกตการรวมกันของ while True + continue ที่นี่ ซึ่งเป็นกลไก “ค้นพบ-ลองใหม่” ที่ชาญฉลาดมาก ในการติดตามครั้งแรก ระบบอาจพบว่าหน่วยความจำที่ใช้ร่วมกันเกินขีดจำกัดฮาร์ดแวร์ 48KB ในเวลานี้ มันจะสลับไปยังโหมดหน่วยความจำที่ใช้ร่วมกันแบบไดนามิก extern โดยอัตโนมัติ จากนั้นดำเนินการติดตามอีกครั้ง เพื่อให้แน่ใจถึงความสอดคล้องของการคำนวณที่อยู่ทั้งหมด


สาม ระบบรีจิสเตอร์: ตัวดำเนินการทางคณิตศาสตร์คือคำสั่ง PTX

3.1 การโอเวอร์โหลดตัวดำเนินการของคลาส Reg

คลาส Reg ที่อยู่ใน reg.py โดยพื้นฐานแล้วคือ “การห่อหุ้มแบบบาง” ที่ออกแบบมาอย่างพิถีพิถัน ตัวดำเนินการทางคณิตศาสตร์ที่โอเวอร์โหลด +, *, <<, & ฯลฯ ทุกครั้งที่เรียกใช้จะปล่อยคำสั่ง PTX หนึ่งคำสั่งพอดี:

# ที่มา: pyptx/reg.py
class Reg:
def __add__(self, other: Any) -> "Reg":
return _emit_int_add(self, other)

def __mul__(self, other: Any) -> "Reg":
return _emit_int_mul(self, other)

ยกตัวอย่าง _emit_int_add เมื่อคุณเขียน px + offset ในโค้ดเคอร์เนล สิ่งที่เกิดขึ้นจริงคือ:

# ที่มา: pyptx/reg.py (แบบย่อ)
def _emit_int_add(left: Reg, right: Any) -> Reg:
ctx = get_ctx()
kind = _int_dtype_kind(left.dtype)
if kind == "64":
result = scalar(left.dtype)
right_op = _widen_to_64(ctx, right)
ctx.emit(Instruction(
opcode="add", modifiers=(".s64",),
operands=(
RegisterOperand(result.name),
RegisterOperand(left.name),
right_op,
),
))
return result

กล่าวอีกนัยหนึ่ง: การบวก Python หนึ่งครั้ง → คำสั่ง PTX add.s64 หนึ่งคำสั่ง ในกระบวนการทั้งหมดไม่มี Optimization Pass, ไม่มีการกำจัดนิพจน์ย่อยร่วม (CSE) และไม่มีการพับค่าคงที่ (Constant Folding) ข้อยกเว้นเพียงอย่างเดียวคือค่าคงที่ int ของ Python จะถูกจัดการเป็น Immediate Value โดยธรรมชาติ

3.2 RegArray และการปรับแต่ง copy propagation

เมธอด RegArray.__setitem__ มีเทคนิคการปรับแต่งเล็กๆ แต่ชาญฉลาด เมื่อคุณเขียนโค้ดเช่น r[90] = (r[89] << 7) ตัวดำเนินการ << ได้ปล่อยคำสั่ง shl.b32 %r_fresh, %r89, 7 ไปแล้ว ในเวลานี้ __setitem__ จะไม่ปล่อยคำสั่ง mov เพิ่มเติมอีก แต่จะ เขียนทับรีจิสเตอร์เป้าหมายของคำสั่งสุดท้าย โดยตรงเป็น %r90:

# ที่มา: pyptx/reg.py
def __setitem__(self, idx: int, value: "Reg") -> None:
ctx = get_ctx()
if ctx.statements:
last = ctx.statements[-1]
if (isinstance(last, Instruction)
and last.operands[0].name == value.name):
# เปลี่ยนชื่อรีจิสเตอร์เป้าหมายเป็นชื่อของ r[idx] โดยตรง
new_dst = RegisterOperand(name=f"{self._base}{idx}")
ctx.statements[-1] = replace(last, operands=(new_dst,) + last.operands[1:])
return
# fallback: emit mov

นี่คือ “การปรับแต่ง” เพียงอย่างเดียวที่มีอยู่ ในทั้งโปรเจกต์ และบทบาทของมันก็แค่กำจัดคำสั่ง mov ที่ซ้ำซ้อนซึ่งเกิดจาก Syntactic Sugar ของ DSL โดยไม่เปลี่ยนแปลงเจตนาทางความหมายของผู้ใช้แต่อย่างใด


สี่ การสร้างแบบจำลองที่แม่นยำของหน่วยความจำที่ใช้ร่วมกันและดั้งเดิมของฮาร์ดแวร์

4.1 การออกแบบสองโหมดของการจัดสรร SMEM

smem.alloc() มีกลยุทธ์การจัดสรรสองแบบ เมื่อความต้องการหน่วยความจำที่ใช้ร่วมกันทั้งหมดไม่เกิน 48KB ระบบจะใช้การประกาศ .shared แบบตั้งชื่อแบบคงที่ เมื่อเกินขีดจำกัด 48KB มันจะสลับไปยังโหมด extern .shared .b8 dyn_smem[] แบบรวมศูนย์โดยอัตโนมัติ ในเวลานี้ การจัดสรรหน่วยความจำทั้งหมดจะทำการระบุที่อยู่โดยการคำนวณออฟเซ็ต:

# ที่มา: pyptx/smem.py  
def alloc(dtype, shape, ...):  
ctx = get_ctx()  
off = ctx.dyn_smem_offset  
# การจัดการการจัดตำแหน่ง  
if align > 0 and off % align != 0:  
off = ((off + align - 1) // align) * align  
this_offset = off  
ctx.dyn_smem_offset = off + byte_count  

if ctx.force_dynamic_smem:  
return SharedAlloc("dyn_smem", dtype, shape, swizzle, byte_offset=this_offset)  
else:  
ctx.var_decls.append(VarDecl(...))  
return SharedAlloc(name, dtype, shape, swizzle, byte_offset=this_offset)  

4.2 การใช้งาน GMMA Swizzle ด้วยสามคำสั่ง

คำสั่ง WGMMA ในสถาปัตยกรรม Hopper และ Blackwell กำหนดให้ข้อมูลในหน่วยความจำที่ใช้ร่วมกันต้องเป็นไปตามรูปแบบการจัดเรียง Swizzle ที่เฉพาะเจาะจง เพื่อจำลองการแปลง Swizzle<B,4,3> ของ CuTe ใน pyptx นักพัฒนาใช้คำสั่ง ALU เพียงสามคำสั่ง:

# ที่มา: pyptx/smem.py  
_SWIZZLE_PARAMS = {  
"32B":  (0x080, 3),   # Swizzle<1,4,3>  
"64B":  (0x180, 3),   # Swizzle<2,4,3>  
"128B": (0x380, 3),   # Swizzle<3,4,3>  
}  

def apply_swizzle(logical_offset, swizzle):  
mask, shift = _SWIZZLE_PARAMS[swizzle]  
# xor_bits = (logical_offset & mask) >> shift  
# physical = logical_offset ^ xor_bits  
# 3 คำสั่งพอดี: and, shr, xor  

การใช้งานนี้สอดคล้องกับสูตรการคำนวณ physical = logical XOR ((logical & yyy_msk) >> S) ของ Swizzle<B, M, S> ในซอร์สโค้ด CUTLASS อย่างสมบูรณ์

ห้า การกระจายขณะรันไทม์: หนึ่งเคอร์เนล สามเส้นทาง

5.1 จุดเข้าใช้งานรวมและตรรกะการกระจาย

Kernel.__call__() เป็นจุดเข้าใช้งานรวมขณะรันไทม์ของทั้งระบบ มันจะเลือกเส้นทางการกระจายที่เหมาะสมโดยอัตโนมัติตามประเภทเฉพาะของเทนเซอร์อินพุต:

  • PyTorch Tensortorch_support.call_kernel_via_torch_compile()
  • JAX Arrayjax_support.call_kernel_via_ffi()

ขั้นตอนการคอมไพล์ JIT หลักประกอบด้วยขั้นตอนต่อไปนี้:

  1. แยก shape_env จากข้อมูลรูปร่างของเทนเซอร์อินพุต
  2. ติดตามและสร้างข้อความ PTX
  3. คอมไพล์ JIT โค้ด PTX เป็น cubin ผ่าน cuModuleLoadData
  4. ลงทะเบียนและแคชผลลัพธ์การคอมไพล์ใน cubin registry

5.2 เส้นทางด่วน Turbo

สำหรับสถานการณ์ที่ PyTorch เรียกใช้เคอร์เนลซ้ำๆ ด้วยรูปร่างอินพุตเดียวกันในลูปการอนุมาน pyptx ได้ออกแบบ “เส้นทางด่วน Turbo” เส้นทางนี้จะข้ามการตรวจสอบรูปร่างและตรรกะการกระจายทั้งหมดในระดับ Python และเรียกใช้ส่วนขยาย C++ ที่คอมไพล์ไว้ล่วงหน้าโดยตรง:

# ที่มา: pyptx/kernel.py  
if not kwargs and any_torch_tensors(input_arrays):  
turbo = getattr(self, '_turbo_torch', None)  
if turbo is not None:  
shapes_key = tuple(a.shape for a in input_arrays)  
if turbo[0] == shapes_key:  
# เรียกใช้เคอร์เนลผ่านส่วนขยาย C++ โดยตรง ~14µs  
return _ext.launch_kernel(_handle, stream_ptr, ...)  

ด้วยการปรับแต่งนี้ โอเวอร์เฮดของ PyTorch eager dispatch ถูกบีบอัดจากประมาณ 34µs เหลือประมาณ 14µs

หก ตัวอย่างการปฏิบัติจริงที่สมบูรณ์: RMS Norm

ไฟล์ examples/hopper/rms_norm.py แสดงเวิร์กโฟลว์ทั่วไปของ pyptx ซึ่งเป็นเคอร์เนลการทำให้เป็นมาตรฐาน RMS ที่สามารถบรรลุการใช้แบนด์วิดท์ HBM 88% บน H100:

“`python

ที่มา: examples/hopper/rms_norm.py (ส่วนย่อยหลัก)

@kernel(
in_specs=(Tile(B, N, f32), Tile(N, f32)),
out_specs=(Tile(B, N, f32),),
grid=(B, 1, 1), block=(block, 1, 1), arch=”sm_90a”,
)
def rms_norm(X, W, Y):
partials = smem.alloc(f32, (num_warps, 1))
px, pw, py = ptx.global_ptrs(X, W, Y)
tid = reg.scalar(u32)
ptx.inst.mov.u32(tid, ptx.special.tid.x())

v4 vectorized load + fma สะสมผลรวมกำลังสอง

sum_sq = reg.scalar(f32, init=0.0)
for j in range(v4_iters):
ptx.inst.ld.global_.v4.f32([x_vals[j4], …], ptx.addr(ptr))
for sub in range(4):
ptx.inst.fma.rn.f32(sum_sq, x_vals[j
4+sub], x_vals[j*4+sub], sum_sq)

warp-level butterfly reduction

ptx.warp.reduce_sum(sum_sq)

จะสังเกตได้ว่า ลูป for ของ Python จะถูกขยายโดยอัตโนมัติระหว่างการติดตาม ทำให้เกิดลำดับคำสั่ง ld.global.v4.f32 และ fma.rn.f32 แบบเวกเตอร์ ในขณะที่การใช้งานเบื้องหลังของ ptx.warp.reduce_sum คือการดำเนินการลดแบบผีเสื้อมาตรฐาน (shfl.sync.bfly) ซึ่งทุกครั้งที่เรียกใช้จะสร้างจำนวนคำสั่ง PTX ที่คงที่

เจ็ด ตัวแปล PTX: เครื่องมืออันทรงพลังสำหรับวิศวกรรมย้อนกลับ

pyptx ไม่เพียงแต่สนับสนุนการสร้าง PTX จากโค้ด Python เท่านั้น แต่ยังมีความสามารถในการ แปลย้อนกลับ PTX ที่มีอยู่เป็นโค้ด Python ของ pyptx:

python -m pyptx.codegen kernel.ptx –sugar –name my_kernel > my_kernel.py

ในโหมด --sugar เครื่องมือนี้สามารถจดจำ Spin-loop โดยอัตโนมัติและแปลงเป็น ptx.loop(...), ย่อลำดับ mbarrier-wait เป็นการเรียกใช้ฟังก์ชันระดับสูง และจัดระเบียบห่วงโซ่นิพจน์ เครื่องมือนี้ผ่านการตรวจสอบ ความสอดคล้องแบบไปกลับในระดับไบต์ กับคลัง PTX มากกว่า 218 รายการจากโปรเจกต์ต่างๆ เช่น CUTLASS, Triton, DeepGEMM, ThunderKittens

ซึ่งหมายความว่าคุณสามารถรับ PTX ที่คอมไพล์โดย Triton, แปลเป็น Python, ปรับแต่งอย่างแม่นยำในนั้น (เช่น แก้ไขความลึกของไปป์ไลน์, เปลี่ยนจังหวะ Barrier) จากนั้นดำเนินการโดยตรง โดยไม่จำเป็นต้องเข้าใจขั้นตอนการคอมไพล์ภายในของ Triton เลย

สรุป

pyptx แสดงให้เห็นถึงจุดยืนทางปรัชญาแบบ “ต่อต้านคอมไพเลอร์” ในแวดวงการเขียนโปรแกรม GPU นั่นคือ แทนที่จะพึ่งพานามธรรมของคอมไพเลอร์ที่ซับซ้อนมากขึ้นเพื่อซ่อนรายละเอียดฮ


⚠️ หมายเหตุ: เนื้อหาได้รับการแปลโดย AI และตรวจสอบโดยมนุษย์ หากมีข้อผิดพลาดโปรดแจ้ง

☕ สนับสนุนค่ากาแฟทีมงาน

หากคุณชอบบทความนี้ สามารถสนับสนุนเราได้ผ่าน PromptPay

PromptPay QR
SCAN TO PAY WITH ANY BANK

本文来自网络搜集,不代表คลื่นสร้างอนาคต立场,如有侵权,联系删除。转载请注明出处:https://www.itsolotime.com/th/archives/32400

Like (0)
Previous 1 hour ago
Next 1 hour ago

相关推荐