ในแวดวงการเขียนโปรแกรม GPU มีภาวะกลืนไม่เข้าคายไม่ออกที่น่าอึดอัดใจมานาน ด้านหนึ่ง การไล่ตามประสิทธิภาพสูงสุดต้องพึ่งพา CUDA C++ หรือแม้แต่การเขียน PTX Assembly โดยตรง อีกด้านหนึ่ง เพื่อเพิ่มประสิทธิภาพการพัฒนา นักพัฒนามักต้องยอมรับการปรับแต่งแบบกล่องดำและความไม่สามารถควบคุมได้ของโค้ดที่สร้างโดยอัตโนมัติจากคอมไพเลอร์อย่าง Triton, Pallas
เมื่อการจัดตารางคำสั่งที่สร้างโดย Triton ไม่เป็นไปตามที่คาดหวัง เมื่อคุณต้องปรับจังหวะ mbarrier, โหมด multicast ของ TMA หรือการจัดสรร register ของ tcgen05 อย่างละเอียด ทางออกเดียวแทบจะหนีไม่พ้นการถอยกลับไปสู่สถานะดั้งเดิมของการเขียน Inline Assembly ใน C++
โปรเจกต์ pyptx มอบคำตอบที่แตกต่างอย่างสิ้นเชิง นั่นคือการใช้ Syntactic Sugar ของ Python เพื่อห่อหุ้มคำสั่งแต่ละคำสั่งใน PTX อย่างแม่นยำ โดยไม่ทำการปรับแต่งใดๆ และไม่นำเสนอการแปลง IR ใดๆ ระหว่างทาง สิ่งที่คุณเขียนใน Python จะถูกสร้างขึ้นใน PTX ตามต้นฉบับทุกประการ การเรียกใช้ ptx.inst.add.f32() หนึ่งครั้ง จะตรงกับคำสั่ง add.f32 หนึ่งคำสั่งพอดี
บน Blackwell B200 มันทำประสิทธิภาพ GEMM ได้ถึง 1240 TFLOPS (คิดเป็น 77% ของ cuBLAS) บน Hopper H100 มันทำได้ถึง 815 TFLOPS และเหนือกว่า cuBLAS นี่ไม่ใช่ของเล่น มันเป็นเครื่องมือทางวิศวกรรมที่แท้จริง ที่ให้คุณใช้ประสบการณ์แบบโต้ตอบของ Python เพื่อเขียน GPU Assembly จริง
ประสิทธิภาพบน Blackwell (B200, bf16)
| Kernel | Shape | pyptx | cuBLAS | best / cuBLAS |
|---|---|---|---|---|
| GEMM (tcgen05.mma, 4-stage pipeline, 1SM) | 1240 TFLOPS | 1610 | 77% | |
| GEMM (1SM) | 1194 TFLOPS | 1532 | 78% | |
| GEMM 2SM (cta_group::2, 5-stage) | 649 TFLOPS (beats 1SM) | 1006 | 64% | |
| Grouped GEMM (tcgen05, MoE) | G=4 M=2048 N=256 K=2048 | 401 TFLOPS | torch ref | ~10.0× |
| RMS norm / Layer norm / SwiGLU | maintained Blackwell ports | benchmarked | torch ref | see kernel suite |
ประสิทธิภาพบน Hopper (H100 SXM5, bf16 / f32)
| Kernel | Shape | pyptx | vs reference |
|---|---|---|---|
| GEMM (wgmma, warp-specialized) | 815 TFLOPS | beats cuBLAS ≥ 6K | |
| Grouped GEMM (bf16→f32) | G=8 M=K=2048 | 104 TFLOPS | — |
| RMS norm (f32) | B=2048 N=8192 | 2.6 TB/s (88% HBM) | 3.9× torch |
| Layer norm (f32) | B=2048 N=8192 | 2.5 TB/s (83% HBM) | 1.5× F.layer_norm |
| SwiGLU (f32) | M=2048 F=8192 | 2.8 TB/s (94% HBM) | 1.6× F.silu(g)*u |
| Softmax (f32, row-wise) | B=2048 N=8192 | 2.8 TB/s (95% HBM) | 1.16× torch.softmax |
| Flash attention (bf16) | M=N=4096, HD=64 | 88 µs | 3.0× naive torch |
หมายเหตุ: การใช้งาน GEMM 815 TFLOPS อยู่ใน examples/hopper/gemm_highperf_hopper.py ดูรายละเอียดได้ที่: Fastest kernels written from scratch[1]
สารบัญบทความนี้
- เริ่มต้นใช้งานอย่างรวดเร็ว
- หนึ่ง ภาพรวมสถาปัตยกรรมและปรัชญาการออกแบบ
- 1.1 ข้อตกลงหลัก “หนึ่งการเรียกใช้ หนึ่งคำสั่ง”
- 1.2 โครงสร้างโมดูลโดยรวม
- สอง การสร้างโค้ดแบบติดตาม: เวทมนตร์จาก Python สู่ PTX
- 2.1 TraceContext: ตัวรวบรวมคำสั่งระดับเธรด
- 2.2 Kernel._trace(): วงจรชีวิตที่สมบูรณ์ของการติดตาม
- สาม ระบบรีจิสเตอร์: ตัวดำเนินการทางคณิตศาสตร์คือคำสั่ง PTX
- 3.1 การโอเวอร์โหลดตัวดำเนินการของคลาส Reg
- 3.2 RegArray และการปรับแต่ง copy propagation
- สี่ การสร้างแบบจำลองที่แม่นยำของหน่วยความจำที่ใช้ร่วมกันและดั้งเดิมของฮาร์ดแวร์
- 4.1 การออกแบบสองโหมดของการจัดสรร SMEM
- 4.2 การใช้งาน GMMA Swizzle ด้วย 3 คำสั่ง
- ห้า การกระจายขณะรันไทม์: หนึ่งเคอร์เนล สามเส้นทาง
- 5.1 จุดเข้าใช้งานรวมและตรรกะการกระจาย
- 5.2 เส้นทางด่วน Turbo
- หก ตัวอย่างการปฏิบัติจริงที่สมบูรณ์: RMS Norm
- เจ็ด ตัวแปล PTX: เครื่องมืออันทรงพลังสำหรับวิศวกรรมย้อนกลับ
- สรุป
เริ่มต้นใช้งานอย่างรวดเร็ว
# ติดตั้ง DSL หลัก (ไม่จำเป็นต้องใช้ GPU เพื่อสร้าง PTX)
pip install pyptx
# หากต้องการเปิดใช้เคอร์เนลใน PyTorch
pip install 'pyptx[torch]'
# หากต้องการเปิดใช้เคอร์เนลใน JAX
pip install 'pyptx[jax]'
# ไม่บังคับ: เร่งการคอมไพล์ JIT ของส่วนขยาย C++ PyTorch
pip install ninja
หลังจากติดตั้งเสร็จ คุณสามารถเขียนและดูโค้ดเคอร์เนลได้:
from pyptx import kernel, reg, ptx, Tile
from pyptx.types import f32, u32
@kernel(
in_specs=(Tile(4, 64, f32),),
out_specs=(Tile(4, 64, f32),),
grid=(4, 1, 1),
block=(64, 1, 1),
arch="sm_90a",
)
def identity(X, Y):
px, py = ptx.global_ptrs(X, Y)
tid = reg.scalar(u32)
ptx.inst.mov.u32(tid, ptx.special.tid.x())
val = reg.scalar(f32)
ptr = px + tid * 4
ptx.inst.ld.global_.f32(val, ptx.addr(ptr))
ptx.inst.st.global_.f32(ptx.addr(py + tid * 4), val)
ptx.ret()
# ดู PTX ที่สร้างขึ้น
print(identity.ptx())
ตัวอย่างที่สมบูรณ์เพิ่มเติมโปรดดูที่ examples/hopper/[2] และ examples/blackwell/[3] ข้อมูลประสิทธิภาพสามารถตรวจสอบได้ที่ pyptx.dev/performance[4]
หนึ่ง ภาพรวมสถาปัตยกรรมและปรัชญาการออกแบบ
1.1 ข้อตกลงหลัก “หนึ่งการเรียกใช้ หนึ่งคำสั่ง”
แนวคิดการออกแบบของ pyptx สามารถสรุปได้ในประโยคเดียว: ทุกการเรียกใช้
ptx.*ในฟังก์ชัน Python จะตรงกับหนึ่งคำสั่งในข้อความ PTX ที่สร้างขึ้นทุกประการ ไม่มี Optimizer, ไม่มี Scheduler, ไม่มีการแปลง IR ระดับกลาง นี่คือเครื่องมือเขียน PTX แบบ “เห็นอะไรได้อย่างนั้น”
คุณค่าของการออกแบบนี้คือ เมื่อคุณดีบักเคอร์เนล GEMM ประสิทธิภาพสูง คุณไม่จำเป็นต้องเดาว่าคอมไพเลอร์ทำการแปลงอะไรไปบ้าง สิ่งที่ print(my_kernel.ptx()) แสดงออกมา คือสิ่งที่คุณเขียน
1.2 โครงสร้างโมดูลโดยรวม
โค้ดหลักของโปรเจกต์มีประมาณ 360,000 บรรทัด Python แบ่งชั้นโครงสร้างอย่างชัดเจน:
| โมดูล | หน้าที่รับผิดชอบ |
|---|---|
kernel.py |
ตัวตกแต่ง @kernel, การติดตาม, แคชเฉพาะทาง, การกระจายขณะรันไทม์ |
_trace.py |
บริบทการติดตามระดับเธรด โหนด IR ทั้งหมดมารวมกันที่นี่ |
ptx.py |
อินเทอร์เฟซ DSL สำหรับคำสั่ง PTX ทั้งหมด (166KB ครอบคลุม ISA ที่สมบูรณ์) |
reg.py |
การจัดสรร register + Syntactic Sugar สำหรับตัวดำเนินการทางคณิตศาสตร์ |
smem.py |
การจัดสรรหน่วยความจำที่ใช้ร่วมกัน, mbarrier, GMMA swizzle |
emitter/ |
การทำให้ IR → ข้อความ PTX เป็นอนุกรม |
parser/ |
การดีซีเรียลไลซ์ข้อความ PTX → IR |
codegen/ |
ตัวแปล PTX → Python (เครื่องมือวิศวกรรมย้อนกลับ) |
jax_support.py |
การรวม JAX typed FFI |
torch_support.py |
การรวม PyTorch eager + torch.compile |
สอง การสร้างโค้ดแบบติดตาม: เวทมนตร์จาก Python สู่ PTX
2.1 TraceContext: ตัวรวบรวมคำสั่งระดับเธรด
การสร้างโค้ดของ pyptx ไม่ใช่การวิเคราะห์แบบคงที่หรือการแปลง AST มันคือ การติดตามขณะรันไทม์
เมื่อฟังก์ชันที่ตกแต่งด้วย @kernel ถูกเรียกใช้ TraceContext จะถูกเปิดใช้งาน หลังจากนั้นทุกการเรียกใช้ reg.*, smem.*, ptx.* จะ “บันทึก” โหนด IR ลงในบริบทนี้
# ที่มา: pyptx/_trace.py
class TraceContext:
def __init__(self, *, ptx_version: tuple[int, int] | None = None) -> None:
self.reg_decls: list[RegDecl] = []
self.var_decls: list[VarDecl] = []
self.statements: list[Statement] = []
self.dyn_smem_offset: int = 0
self.force_dynamic_smem: bool = False
def emit(self, stmt: Statement) -> None:
"""บันทึกคำสั่ง (instruction, label, ฯลฯ)"""
self.statements.append(stmt)
def body(self) -> tuple[Statement, ...]:
"""ส่งคืนเนื้อหาฟังก์ชันเต็ม: decls ตามด้วย statements"""
parts: list[Statement] = []
parts.extend(self.reg_decls)
parts.extend(self.var_decls)
parts.extend(self.statements)
return tuple(parts)
ความชาญฉลาดของการออกแบบนี้คือ: โฟลว์ควบคุมของ Python (for, if) จะถูก ขยายตามต้นฉบับ เป็นลำดับคำสั่งเชิงเส้นระหว่างการติดตาม ตัวอย่างเช่น ลูป for i in range(4) จะไม่สร้างลูป PTX แต่มันจะสร้างคำสั่งที่ถูกขยาย 4 ชุด ซึ่งแตกต่างอย่างสิ้นเชิงจากโมเดลการเขียนโปรแกรมระดับ Tile ของ Triton pyptx มอบการควบคุมระดับคำสั่งอย่างสมบูรณ์ให้กับคุณ
2.2 Kernel._trace(): วงจรชีวิตที่สมบูรณ์ของการติดตาม
Kernel._trace()เป็นศูนย์กลางหลักของทั้งระบบ ขั้นตอนการทำงานของมันคือ:
การแยกวิเคราะห์พารามิเตอร์และขั้นตอนการติดตาม
ในเมธอด _trace ของ pyptx กระบวนการติดตามเคอร์เนลถูกแบ่งออกเป็นหกขั้นตอนที่ชัดเจน:
- การแยกพารามิเตอร์: แยกพจนานุกรม
kwargsออกเป็นสองประเภท ได้แก่ พารามิเตอร์เทมเพลต (เช่นBM=128) และตัวแปรรูปร่าง (เช่นM=4096) - การสร้าง Placeholder: สร้างออบเจ็กต์
TensorSpecสำหรับพารามิเตอร์ตำแหน่งแต่ละตัว ซึ่งมีข้อมูลรูปร่างและประเภทข้อมูลของเทนเซอร์ - การเปิดใช้งานบริบทการติดตาม: สร้างสภาพแวดล้อมบริบทระดับเธรดโดยการเรียกใช้
trace_scope() - การดำเนินการฟังก์ชันผู้ใช้: ดำเนินการบรรทัดโค้ด
self._fn(*positional, **resolved)ในเวลานี้ การดำเนินการทั้งหมดภายใน DSL จะถูกบันทึกลงในบริบทการติดตามปัจจุบันโดยอัตโนมัติ - การประกอบ Module: ห่อหุ้มการประกาศ register, การประกาศหน่วยความจำที่ใช้ร่วมกัน และลำดับคำสั่งที่รวบรวมได้ระหว่างการติดตาม ให้เป็นออบเจ็กต์ IR Module ที่เป็นหนึ่งเดียว
- การจัดการหน่วยความจำที่ใช้ร่วมกันแบบไดนามิก: หากตรวจพบว่าการใช้หน่วยความจำที่ใช้ร่วมกันทั้งหมดเกิน 48KB ระบบจะสลับไปยังโหมดไดนามิกและดำเนินการติดตามอีกครั้ง
ต่อไปนี้คือโค้ดการใช้งานแบบย่อจาก pyptx/kernel.py:
def _trace(self, _shape_env=None, **kwargs):
# ... ตรรกะการแยกวิเคราะห์พารามิเตอร์ ...
while True:
with trace_scope(ptx_version=self._version) as ctx:
if force_dynamic_trace:
ctx.force_dynamic_smem = True
self._fn(*positional, **resolved) # โค้ดผู้ใช้ทำงานที่นี่!
total_smem = ctx.dyn_smem_offset
if total_smem > 48 * 1024 and not ctx.force_dynamic_smem:
force_dynamic_trace = True
continue # ติดตามใหม่
break
module = Module(
version=Version(...),
target=Target((self._arch,)),
address_size=AddressSize(64),
directives=(..., func),
)
return module
โปรดสังเกตการรวมกันของ while True + continue ที่นี่ ซึ่งเป็นกลไก “ค้นพบ-ลองใหม่” ที่ชาญฉลาดมาก ในการติดตามครั้งแรก ระบบอาจพบว่าหน่วยความจำที่ใช้ร่วมกันเกินขีดจำกัดฮาร์ดแวร์ 48KB ในเวลานี้ มันจะสลับไปยังโหมดหน่วยความจำที่ใช้ร่วมกันแบบไดนามิก extern โดยอัตโนมัติ จากนั้นดำเนินการติดตามอีกครั้ง เพื่อให้แน่ใจถึงความสอดคล้องของการคำนวณที่อยู่ทั้งหมด
สาม ระบบรีจิสเตอร์: ตัวดำเนินการทางคณิตศาสตร์คือคำสั่ง PTX
3.1 การโอเวอร์โหลดตัวดำเนินการของคลาส Reg
คลาส Reg ที่อยู่ใน reg.py โดยพื้นฐานแล้วคือ “การห่อหุ้มแบบบาง” ที่ออกแบบมาอย่างพิถีพิถัน ตัวดำเนินการทางคณิตศาสตร์ที่โอเวอร์โหลด +, *, <<, & ฯลฯ ทุกครั้งที่เรียกใช้จะปล่อยคำสั่ง PTX หนึ่งคำสั่งพอดี:
# ที่มา: pyptx/reg.py
class Reg:
def __add__(self, other: Any) -> "Reg":
return _emit_int_add(self, other)
def __mul__(self, other: Any) -> "Reg":
return _emit_int_mul(self, other)
ยกตัวอย่าง _emit_int_add เมื่อคุณเขียน px + offset ในโค้ดเคอร์เนล สิ่งที่เกิดขึ้นจริงคือ:
# ที่มา: pyptx/reg.py (แบบย่อ)
def _emit_int_add(left: Reg, right: Any) -> Reg:
ctx = get_ctx()
kind = _int_dtype_kind(left.dtype)
if kind == "64":
result = scalar(left.dtype)
right_op = _widen_to_64(ctx, right)
ctx.emit(Instruction(
opcode="add", modifiers=(".s64",),
operands=(
RegisterOperand(result.name),
RegisterOperand(left.name),
right_op,
),
))
return result
กล่าวอีกนัยหนึ่ง: การบวก Python หนึ่งครั้ง → คำสั่ง PTX add.s64 หนึ่งคำสั่ง ในกระบวนการทั้งหมดไม่มี Optimization Pass, ไม่มีการกำจัดนิพจน์ย่อยร่วม (CSE) และไม่มีการพับค่าคงที่ (Constant Folding) ข้อยกเว้นเพียงอย่างเดียวคือค่าคงที่ int ของ Python จะถูกจัดการเป็น Immediate Value โดยธรรมชาติ
3.2 RegArray และการปรับแต่ง copy propagation
เมธอด RegArray.__setitem__ มีเทคนิคการปรับแต่งเล็กๆ แต่ชาญฉลาด เมื่อคุณเขียนโค้ดเช่น r[90] = (r[89] << 7) ตัวดำเนินการ << ได้ปล่อยคำสั่ง shl.b32 %r_fresh, %r89, 7 ไปแล้ว ในเวลานี้ __setitem__ จะไม่ปล่อยคำสั่ง mov เพิ่มเติมอีก แต่จะ เขียนทับรีจิสเตอร์เป้าหมายของคำสั่งสุดท้าย โดยตรงเป็น %r90:
# ที่มา: pyptx/reg.py
def __setitem__(self, idx: int, value: "Reg") -> None:
ctx = get_ctx()
if ctx.statements:
last = ctx.statements[-1]
if (isinstance(last, Instruction)
and last.operands[0].name == value.name):
# เปลี่ยนชื่อรีจิสเตอร์เป้าหมายเป็นชื่อของ r[idx] โดยตรง
new_dst = RegisterOperand(name=f"{self._base}{idx}")
ctx.statements[-1] = replace(last, operands=(new_dst,) + last.operands[1:])
return
# fallback: emit mov
นี่คือ “การปรับแต่ง” เพียงอย่างเดียวที่มีอยู่ ในทั้งโปรเจกต์ และบทบาทของมันก็แค่กำจัดคำสั่ง mov ที่ซ้ำซ้อนซึ่งเกิดจาก Syntactic Sugar ของ DSL โดยไม่เปลี่ยนแปลงเจตนาทางความหมายของผู้ใช้แต่อย่างใด
สี่ การสร้างแบบจำลองที่แม่นยำของหน่วยความจำที่ใช้ร่วมกันและดั้งเดิมของฮาร์ดแวร์
4.1 การออกแบบสองโหมดของการจัดสรร SMEM
smem.alloc() มีกลยุทธ์การจัดสรรสองแบบ เมื่อความต้องการหน่วยความจำที่ใช้ร่วมกันทั้งหมดไม่เกิน 48KB ระบบจะใช้การประกาศ .shared แบบตั้งชื่อแบบคงที่ เมื่อเกินขีดจำกัด 48KB มันจะสลับไปยังโหมด extern .shared .b8 dyn_smem[] แบบรวมศูนย์โดยอัตโนมัติ ในเวลานี้ การจัดสรรหน่วยความจำทั้งหมดจะทำการระบุที่อยู่โดยการคำนวณออฟเซ็ต:
# ที่มา: pyptx/smem.py
def alloc(dtype, shape, ...):
ctx = get_ctx()
off = ctx.dyn_smem_offset
# การจัดการการจัดตำแหน่ง
if align > 0 and off % align != 0:
off = ((off + align - 1) // align) * align
this_offset = off
ctx.dyn_smem_offset = off + byte_count
if ctx.force_dynamic_smem:
return SharedAlloc("dyn_smem", dtype, shape, swizzle, byte_offset=this_offset)
else:
ctx.var_decls.append(VarDecl(...))
return SharedAlloc(name, dtype, shape, swizzle, byte_offset=this_offset)
4.2 การใช้งาน GMMA Swizzle ด้วยสามคำสั่ง
คำสั่ง WGMMA ในสถาปัตยกรรม Hopper และ Blackwell กำหนดให้ข้อมูลในหน่วยความจำที่ใช้ร่วมกันต้องเป็นไปตามรูปแบบการจัดเรียง Swizzle ที่เฉพาะเจาะจง เพื่อจำลองการแปลง Swizzle<B,4,3> ของ CuTe ใน pyptx นักพัฒนาใช้คำสั่ง ALU เพียงสามคำสั่ง:
# ที่มา: pyptx/smem.py
_SWIZZLE_PARAMS = {
"32B": (0x080, 3), # Swizzle<1,4,3>
"64B": (0x180, 3), # Swizzle<2,4,3>
"128B": (0x380, 3), # Swizzle<3,4,3>
}
def apply_swizzle(logical_offset, swizzle):
mask, shift = _SWIZZLE_PARAMS[swizzle]
# xor_bits = (logical_offset & mask) >> shift
# physical = logical_offset ^ xor_bits
# 3 คำสั่งพอดี: and, shr, xor
การใช้งานนี้สอดคล้องกับสูตรการคำนวณ physical = logical XOR ((logical & yyy_msk) >> S) ของ Swizzle<B, M, S> ในซอร์สโค้ด CUTLASS อย่างสมบูรณ์
ห้า การกระจายขณะรันไทม์: หนึ่งเคอร์เนล สามเส้นทาง
5.1 จุดเข้าใช้งานรวมและตรรกะการกระจาย
Kernel.__call__()เป็นจุดเข้าใช้งานรวมขณะรันไทม์ของทั้งระบบ มันจะเลือกเส้นทางการกระจายที่เหมาะสมโดยอัตโนมัติตามประเภทเฉพาะของเทนเซอร์อินพุต:
- PyTorch Tensor →
torch_support.call_kernel_via_torch_compile() - JAX Array →
jax_support.call_kernel_via_ffi()
ขั้นตอนการคอมไพล์ JIT หลักประกอบด้วยขั้นตอนต่อไปนี้:
- แยก shape_env จากข้อมูลรูปร่างของเทนเซอร์อินพุต
- ติดตามและสร้างข้อความ PTX
- คอมไพล์ JIT โค้ด PTX เป็น cubin ผ่าน
cuModuleLoadData - ลงทะเบียนและแคชผลลัพธ์การคอมไพล์ใน cubin registry
5.2 เส้นทางด่วน Turbo
สำหรับสถานการณ์ที่ PyTorch เรียกใช้เคอร์เนลซ้ำๆ ด้วยรูปร่างอินพุตเดียวกันในลูปการอนุมาน pyptx ได้ออกแบบ “เส้นทางด่วน Turbo” เส้นทางนี้จะข้ามการตรวจสอบรูปร่างและตรรกะการกระจายทั้งหมดในระดับ Python และเรียกใช้ส่วนขยาย C++ ที่คอมไพล์ไว้ล่วงหน้าโดยตรง:
# ที่มา: pyptx/kernel.py
if not kwargs and any_torch_tensors(input_arrays):
turbo = getattr(self, '_turbo_torch', None)
if turbo is not None:
shapes_key = tuple(a.shape for a in input_arrays)
if turbo[0] == shapes_key:
# เรียกใช้เคอร์เนลผ่านส่วนขยาย C++ โดยตรง ~14µs
return _ext.launch_kernel(_handle, stream_ptr, ...)
ด้วยการปรับแต่งนี้ โอเวอร์เฮดของ PyTorch eager dispatch ถูกบีบอัดจากประมาณ 34µs เหลือประมาณ 14µs
หก ตัวอย่างการปฏิบัติจริงที่สมบูรณ์: RMS Norm
ไฟล์
examples/hopper/rms_norm.pyแสดงเวิร์กโฟลว์ทั่วไปของ pyptx ซึ่งเป็นเคอร์เนลการทำให้เป็นมาตรฐาน RMS ที่สามารถบรรลุการใช้แบนด์วิดท์ HBM 88% บน H100:
“`python
ที่มา: examples/hopper/rms_norm.py (ส่วนย่อยหลัก)
@kernel(
in_specs=(Tile(B, N, f32), Tile(N, f32)),
out_specs=(Tile(B, N, f32),),
grid=(B, 1, 1), block=(block, 1, 1), arch=”sm_90a”,
)
def rms_norm(X, W, Y):
partials = smem.alloc(f32, (num_warps, 1))
px, pw, py = ptx.global_ptrs(X, W, Y)
tid = reg.scalar(u32)
ptx.inst.mov.u32(tid, ptx.special.tid.x())
v4 vectorized load + fma สะสมผลรวมกำลังสอง
sum_sq = reg.scalar(f32, init=0.0)
for j in range(v4_iters):
ptx.inst.ld.global_.v4.f32([x_vals[j4], …], ptx.addr(ptr))
for sub in range(4):
ptx.inst.fma.rn.f32(sum_sq, x_vals[j4+sub], x_vals[j*4+sub], sum_sq)
warp-level butterfly reduction
ptx.warp.reduce_sum(sum_sq)
จะสังเกตได้ว่า ลูป for ของ Python จะถูกขยายโดยอัตโนมัติระหว่างการติดตาม ทำให้เกิดลำดับคำสั่ง ld.global.v4.f32 และ fma.rn.f32 แบบเวกเตอร์ ในขณะที่การใช้งานเบื้องหลังของ ptx.warp.reduce_sum คือการดำเนินการลดแบบผีเสื้อมาตรฐาน (shfl.sync.bfly) ซึ่งทุกครั้งที่เรียกใช้จะสร้างจำนวนคำสั่ง PTX ที่คงที่
เจ็ด ตัวแปล PTX: เครื่องมืออันทรงพลังสำหรับวิศวกรรมย้อนกลับ
pyptx ไม่เพียงแต่สนับสนุนการสร้าง PTX จากโค้ด Python เท่านั้น แต่ยังมีความสามารถในการ แปลย้อนกลับ PTX ที่มีอยู่เป็นโค้ด Python ของ pyptx:
python -m pyptx.codegen kernel.ptx –sugar –name my_kernel > my_kernel.py
ในโหมด --sugar เครื่องมือนี้สามารถจดจำ Spin-loop โดยอัตโนมัติและแปลงเป็น ptx.loop(...), ย่อลำดับ mbarrier-wait เป็นการเรียกใช้ฟังก์ชันระดับสูง และจัดระเบียบห่วงโซ่นิพจน์ เครื่องมือนี้ผ่านการตรวจสอบ ความสอดคล้องแบบไปกลับในระดับไบต์ กับคลัง PTX มากกว่า 218 รายการจากโปรเจกต์ต่างๆ เช่น CUTLASS, Triton, DeepGEMM, ThunderKittens
ซึ่งหมายความว่าคุณสามารถรับ PTX ที่คอมไพล์โดย Triton, แปลเป็น Python, ปรับแต่งอย่างแม่นยำในนั้น (เช่น แก้ไขความลึกของไปป์ไลน์, เปลี่ยนจังหวะ Barrier) จากนั้นดำเนินการโดยตรง โดยไม่จำเป็นต้องเข้าใจขั้นตอนการคอมไพล์ภายในของ Triton เลย
สรุป
pyptx แสดงให้เห็นถึงจุดยืนทางปรัชญาแบบ “ต่อต้านคอมไพเลอร์” ในแวดวงการเขียนโปรแกรม GPU นั่นคือ แทนที่จะพึ่งพานามธรรมของคอมไพเลอร์ที่ซับซ้อนมากขึ้นเพื่อซ่อนรายละเอียดฮ
⚠️ หมายเหตุ: เนื้อหาได้รับการแปลโดย AI และตรวจสอบโดยมนุษย์ หากมีข้อผิดพลาดโปรดแจ้ง
☕ สนับสนุนค่ากาแฟทีมงาน
หากคุณชอบบทความนี้ สามารถสนับสนุนเราได้ผ่าน PromptPay
本文来自网络搜集,不代表คลื่นสร้างอนาคต立场,如有侵权,联系删除。转载请注明出处:https://www.itsolotime.com/th/archives/32400
