คำสำคัญ: TileKernels, TileLang, MoE Routing, Low-Precision Quantization, Operator Fusion
ในทางปฏิบัติของการฝึกอบรมและการอนุมานโมเดลขนาดใหญ่ ประสิทธิภาพของโอเปอเรเตอร์มักเป็นปัจจัยสำคัญที่กำหนดประสิทธิภาพโดยรวมของระบบ
โปรเจกต์ TileKernels ที่ DeepSeek เปิดเผยเป็นโอเพนซอร์สในเดือนเมษายน 2026 ได้ตอบสนองต่อความท้าทายนี้ในวิธีที่คาดไม่ถึง นั่นคือ ไม่ได้ใช้ CUDA C++ เลย แต่ใช้ภาษาเฉพาะทางในโดเมน Python อย่าง TileLang เพียงอย่างเดียว ก็สามารถทำให้ประสิทธิภาพของโอเปอเรเตอร์บนเส้นทางสำคัญของโมเดลขนาดใหญ่ เช่น MoE Routing, Multi-Precision Quantization (FP8/FP4/E5M6), SwiGLU Fusion, Engram Gating, Manifold HyperConnection ใกล้เคียงหรือถึงขีดจำกัดทางทฤษฎีของการคำนวณและแบนด์วิธของ GPU

- ที่อยู่โปรเจกต์: deepseek-ai/TileKernels (ไลบรารีเคอร์เนลที่เขียนด้วย tilelang)
- เวลาในการอ่าน: ประมาณ 4000 คำ / 20 นาที พร้อมเวอร์ชันพอดแคสต์ 19 นาที
ที่สำคัญกว่านั้น โอเปอเรเตอร์เหล่านี้ไม่ใช่ต้นแบบในห้องปฏิบัติการ พวกมันถูกใช้งานจริงในไปป์ไลน์การฝึกอบรมและการอนุมานภายในของ DeepSeek สิ่งนี้ทำให้เกิดคำถามสำคัญ: เมื่อ “การเขียนโอเปอเรเตอร์ GPU ด้วย Python” ไม่ใช่การประนีประนอมด้านประสิทธิภาพอีกต่อไป แต่เป็นทางเลือกทางวิศวกรรมที่เข้าใกล้ขีดจำกัดแล้ว กระบวนทัศน์การพัฒนาโครงสร้างพื้นฐานของโมเดลขนาดใหญ่กำลังถูกนิยามใหม่หรือไม่?
ผู้อ่านที่สังเกตดีอาจสังเกตเห็นว่าภายใต้องค์กรทางการของ TileLang อย่าง tile-ai ก็มีไลบรารีโอเปอเรเตอร์ที่คล้ายกันชื่อ TileOPs[1] แล้ว ความสัมพันธ์ระหว่างทั้งสองคืออะไร?
ความสัมพันธ์ระหว่างทั้งสามสามารถสรุปได้ดังนี้: TileLang (DSL Compiler) → TileOPs (Official General Operator Library) / TileKernels (DeepSeek-Specific Operator Library)
- TileOPs คือ “ไลบรารีตัวอย่างทางการ” ที่ทีม TileLang สร้างขึ้นเอง มีตำแหน่งคล้ายกับ
torchvisionในระบบนิเวศ PyTorch โดยให้บริการโอเปอเรเตอร์พื้นฐานทั่วไป เช่น GEMM, elementwise เน้นการออกแบบที่เป็นมาตรฐานแบบ Spec-driven (โอเปอเรเตอร์แต่ละตัวประกาศลายเซ็น โหลดงาน และสูตร roofline ผ่านops_manifest.yaml) มุ่งเป้าไปที่ความต้องการในการสร้างอัตโนมัติของนักพัฒนาชุมชนและ AI Agent - TileKernels นั้นแตกต่างอย่างสิ้นเชิง มันคือการที่ DeepSeek ในฐานะผู้ใช้งานหนักของ TileLang นำโอเปอเรเตอร์ที่สำคัญที่สุดและไวต่อประสิทธิภาพมากที่สุดในโมเดลของตนเองมาใช้และเปิดเผยเป็นโอเพนซอร์ส โอเปอเรเตอร์เหล่านี้มี “เอกลักษณ์ของ DeepSeek” อย่างชัดเจน เช่น MoE Routing แบบครบวงจร, Engram Gating, Manifold HyperConnection, SwiGLU+FP8 Fused Quantization ล้วนเป็นส่วนประกอบเฉพาะของสถาปัตยกรรมโมเดล DeepSeek ยึดหลักปฏิบัติจริงเป็นอันดับแรก Kernel คืออินเทอร์เฟซ ไม่มีเลเยอร์นามธรรมกลาง
| โปรเจกต์ | tile-ai/TileOPs | deepseek-ai/TileKernels |
| :— | :— | :— |
| ตำแหน่ง | โอเปอเรเตอร์ทั่วไป, Spec-driven | เฉพาะสำหรับโมเดล DeepSeek, ปฏิบัติจริงเป็นอันดับแรก |
| ประเภทโอเปอเรเตอร์ | โอเปอเรเตอร์พื้นฐาน เช่น GEMM, elementwise | MoE Routing, Engram Gating, mHC, Fused Quantization |
| แนวคิดการออกแบบ | แบ่งเป็น 2 ชั้น Op/Kernel, ประกาศด้วย manifest | Kernel คืออินเทอร์เฟซ, ใช้งานจริงในระบบผลิตแล้ว |
พูดง่ายๆ: tile-ai คือ “ผู้สร้างเครื่องยนต์” ในขณะที่ deepseek-ai คือ “ผู้สร้างรถยนต์ทั้งคัน” TileKernels เปรียบเสมือน DeepSeek เปิดเผยเทอร์โบชาร์จเจอร์ที่พัฒนาขึ้นเอง และยังเป็นการยืนยันความสามารถระดับการผลิตของเครื่องยนต์ TileLang ในทางกลับกัน
unsetunsetสารบัญunsetunset
- เริ่มต้นใช้งานอย่างรวดเร็ว
- หนึ่ง ภาพรวมสถาปัตยกรรมและปรัชญาการออกแบบ
- 1.1 โครงสร้างโปรเจกต์และการแบ่งโมดูล
- 1.2 แนวคิดการออกแบบหลัก: Declarative Tile Programming
- สอง โอเปอเรเตอร์ MoE Routing: จากการเลือก Top-K ไปจนถึง Fused Expand
- 2.1 Top-K Gating: การนำวิธีการหาค่าสูงสุดซ้ำๆ มาใช้อย่างชาญฉลาด
- 2.2 Fused Expand: การย้าย Token และ Scaling Factor ใน Kernel เดียว
- สาม โอเปอเรเตอร์ Quantization: Per-Token FP8 Casting และ SwiGLU Fusion
- 3.1 โครงสร้างพื้นฐาน Quantization: Config Abstraction แบบรวมศูนย์
- 3.2 กลยุทธ์การแบ่งบล็อกของ Per-Token Quantization Kernel
- 3.3 การผสาน SwiGLU และ Quantization อย่างถึงที่สุด
- สี่ Engram Gating: Asynchronous Pipeline ที่ได้รับการปรับแต่งอย่างสูง
- 4.1 Dual-Pass Asynchronous Pipeline
- 4.2 การ复用 Buffer ข้าม Pass
- 4.3 วิศวกรรมสุดยอดของ Backward Kernel
- ห้า Manifold HyperConnection: Sinkhorn Normalization บน GPU
- หก โอเปอเรเตอร์ Transpose: ตัวอย่างตำราเรียนในการกำจัด Bank Conflict
- เจ็ด สรุปและแนวโน้มในอนาคต

unsetunsetเริ่มต้นใช้งานอย่างรวดเร็วunsetunset
ข้อกำหนดของระบบ: Python ≥ 3.10, PyTorch ≥ 2.10, TileLang ≥ 0.1.9 และ GPU สถาปัตยกรรม NVIDIA SM90/SM100 (เช่น H100/B200) พร้อม CUDA Toolkit ≥ 13.1
bash
ติดตั้งเวอร์ชันเผยแพร่
pip install tile-kernels
หรือติดตั้งเวอร์ชันพัฒนา (รวม dependencies สำหรับทดสอบ)
pip install -e “.[dev]”
หลังจากติดตั้งเสร็จ สามารถเรียกใช้โดยตรงใน Python:
python
import torch
from tile_kernels.moe import topk_gate
from tile_kernels.quant import per_token_cast
from tile_kernels.transpose import transpose
MoE Top-K Gating Selection
scores = torch.randn(1024, 256, dtype=torch.float32, device=’cuda’)
topk_idx = topk_gate(scores, num_topk=8)
Per-Token FP8 Quantization
x = torch.randn(1024, 4096, dtype=torch.bfloat16, device=’cuda’)
out, out_sf = per_token_cast(x, fmt=’e4m3′, num_per_channels=128)
High-Performance Transpose
mat = torch.randn(2048, 4096, dtype=torch.bfloat16, device=’cuda’)
mat_t = transpose(mat)
รันการทดสอบและ benchmark:bash
pytest tests/transpose/test_transpose.py -n 4 # ความถูกต้อง
pytest tests/transpose/test_transpose.py --run-benchmark # ความถูกต้อง + ประสิทธิภาพ
รายละเอียดเพิ่มเติมเกี่ยวกับการใช้งานโมดูล MoE, Quantization, Engram ฯลฯ อ้างอิงจาก README.md
หนึ่ง ภาพรวมสถาปัตยกรรมและปรัชญาการออกแบบ
1.1 โครงสร้างโปรเจกต์และการแบ่งโมดูล
การจัดระเบียบโค้ดของ TileKernels สะอาดมาก โค้ดหลักทั้งหมดอยู่ในไดเรกทอรี tile_kernels/:tile_kernels/
├── moe/ # MoE Routing: Top-K Gating, Token-Expert Mapping, Fused Expand/Reduce
├── quant/ # Multi-Precision Quantization: FP8/FP4/E5M6 casting, รวมถึง SwiGLU Fusion
├── transpose/ # High-Performance Batched Transpose
├── engram/ # Engram Gating: Forward/Backward แบบ Fused RMSNorm
├── mhc/ # Manifold HyperConnection: Sinkhorn Normalization, Mix Split
├── modeling/ # ชั้น封装 PyTorch autograd.Function
├── torch/ # การอ้างอิงการใช้งาน PyTorch (ใช้เปรียบเทียบความถูกต้อง)
└── testing/ # เครื่องมือทดสอบและ benchmark
1.2 แนวคิดการออกแบบหลัก: Declarative Tile Programming
TileKernels สร้างขึ้นบน TileLang[2] ทั้งหมด
TileLang คือชุด DSL แบบฝังตัวใน Python ซึ่งนามธรรมหลักคือ Tile (บล็อกข้อมูล) และ Fragment (ส่วนของรีจิสเตอร์) นักพัฒนาอธิบายการเคลื่อนย้ายและการคำนวณข้อมูลระหว่าง shared memory และ register file ด้วยไวยากรณ์แบบประกาศ โดยคอมไพเลอร์จะจัดการการแมปเธรด การทำ vectorization การแทรก pipeline และการปรับแต่งระดับต่ำอื่นๆ
จากไฟล์การกำหนดค่าระดับโลกจะเห็นแนวคิดการออกแบบที่ “รับรู้ฮาร์ดแวร์” นี้:
python
ที่มา: tile_kernels/config.py
@functools.lru_cache(maxsize=None)
def get_device_num_sms() -> int:
prop = torch.cuda.get_device_properties(torch.cuda.current_device())
return prop.multi_processor_count
def get_num_sms() -> int:
global _num_sms
if _num_sms == 0:
return get_device_num_sms()
return _num_sms
จำนวน SM ถูกใช้เพื่อกำหนดจำนวนบล็อกของ persistent kernel การจัดสรรงบประมาณ shared memory และพารามิเตอร์สำคัญอื่นๆ แบบไดนามิก ซึ่งทำให้โค้ดชุดเดียวกันสามารถปรับให้เข้ากับ GPU ที่มีสเปกต่างกันได้โดยอัตโนมัติ
สอง โอเปอเรเตอร์ MoE Routing: จากการเลือก Top-K ไปจนถึง Fused Expand
MoE (Mixture of Experts) เป็นส่วนประกอบสถาปัตยกรรมหลักของโมเดล DeepSeek
tile_kernels/moe/ประกอบด้วยห่วงโซ่โอเปอเรเตอร์การจัดเส้นทางที่สมบูรณ์: Top-K Gating → Group Counting → Token-Expert Mapping → Fused Expand/Reduce → Weight Normalization
2.1 Top-K Gating: การนำวิธีการหาค่าสูงสุดซ้ำๆ มาใช้อย่างชาญฉลาด
ในฐานะบรรณาธิการเทคนิคมืออาชีพ ผมได้เขียนส่วนที่ระบุใหม่ตามที่คุณร้องขอ ต่อไปนี้คือเนื้อหาในรูปแบบ Markdown หลังจากล้างโฆษณา/คิวอาร์โค้ดและคงไว้ซึ่งตัวยึดตำแหน่ง [[IMAGE_X]]
ภารกิจหลักของ Top-K Gating คือการเลือก num_topk ผู้เชี่ยวชาญที่มีคะแนนสูงสุดสำหรับแต่ละ token จาก num_experts ผู้เชี่ยวชาญ TileKernels ใช้กลยุทธ์แบบวนซ้ำที่ใช้งานง่ายมาก: ทำซ้ำ K ครั้งด้วยการดำเนินการ “หาค่าสูงสุด → ทำเครื่องหมายเป็นค่าลบอนันต์”
python
ที่มา: tile_kernels/moe/topk_gate_kernel.py
@T.prim_func
def topk_gate_kernel(
scores: T.Tensor[(num_tokens, num_experts), T.float32],
topk_idx: T.Tensor[(num_tokens, num_topk), T.int64],
):
with T.Kernel(num_tokens, threads=num_threads) as pid:
scores_fragment = T.alloc_fragment((num_aligned_experts,), T.float32)
idx_reducer = T.alloc_reducer((1,), T.int32, ‘min’, replication=’all’)
# โหลดคะแนน ตำแหน่งที่เกินขอบเขตให้เติมค่าลบอนันต์
for i in T.Parallel(num_aligned_experts):
if i < num_experts:
scores_fragment[i] = scores[pid, i]
else:
scores_fragment[i] = -T.infinity(T.float32)
# ทำซ้ำ K ครั้ง: หาค่าสูงสุด → ใช้ดัชนีต่ำสุด (เสถียรเมื่อเสมอ) → ตั้งเป็นค่าลบอนันต์
for k in T.unroll(num_topk):
T.reduce_max(scores_fragment, amax_fragment)
T.fill(idx_reducer, T.max_value(T.int32))
for i in T.Parallel(num_aligned_experts):
if scores_fragment[i] == amax_fragment[0]:
idx_reducer[0] = T.min(idx_reducer[0], idx_fragment[i])
T.finalize_reducer(idx_reducer)
topk_idx_shared[k] = idx_reducer[0]
# ผู้เชี่ยวชาญที่ถูกเลือกแล้วตั้งเป็นค่าลบอนันต์
for i in T.Parallel(num_aligned_experts):
if idx_fragment[i] == idx_reducer[0]:
scores_fragment[i] = -T.infinity(T.float32)
มีสองจุดที่ชาญฉลาดที่นี่:
- ประการแรก
T.unroll(num_topk)จะคลี่ลูปออกอย่างสมบูรณ์ กำจัดค่าใช้จ่ายในการทำนายสาขา - ประการที่สอง เมื่อมีคะแนนเท่ากัน “เสมอ” การใช้
T.alloc_reducer('min')จะทำให้แน่ใจว่าจะเลือกผู้เชี่ยวชาญที่มีดัชนีน้อยที่สุด เสมอ ทำให้ผลลัพธ์มีเสถียรภาพและสามารถทำซ้ำได้
โอเปอเรเตอร์ทั้งหมดใช้เพียง warp เดียว (32 เธรด) ซึ่งสอดคล้องกับ threads=32 และสามารถใส่ลงใน register file ได้อย่างสมบูรณ์แบบเมื่อจำนวนผู้เชี่ยวชาญไม่เกินสองสามร้อย
2.2 Fused Expand: การย้าย Token และ Scaling Factor ใน Kernel เดียว
หลังจากเลือกผู้เชี่ยวชาญ Top-K แล้ว จำเป็นต้อง “ขยาย” ค่า activation ของแต่ละ token ไปยังช่องของผู้เชี่ยวชาญที่เกี่ยวข้องตามผลการจัดเส้นทาง expand_to_fused_kernel ดำเนินการนี้และรองรับการย้าย scaling factor หลัง quantization พร้อมกัน:
python
ที่มา: tile_kernels/moe/expand_to_fused_kernel.py
for k in T.serial(num_topk):
T.assume(pos_local[k] < num_expanded_tokens)
if pos_local[k] >= 0:
for i in T.Parallel(hidden_aligned):
expanded_x[pos_local[k], i] = x_fragment[i]
if num_per_channels is not None:
for i in T.Parallel(hidden_sf_aligned):
if use_tma_aligned_col_major_sf:
expanded_x_sf[i, pos_local[k]] = x_sf_fragment[i]
else:
expanded_x_sf[pos_local[k], i] = x_sf_fragment[i]
T.assume() คือคำใบ้คอมไพเลอร์ที่ TileLang จัดให้ ซึ่งบอกแบ็กเอนด์ว่า “เงื่อนไขนี้เป็นจริงเสมอ” เพื่อให้คอมไพเลอร์สามารถกำจัดการตรวจสอบขอบเขตได้ ในขณะเดียวกัน โค้ดใช้ T.Kernel(T.max(num_tokens, num_expanded_tokens)) เพื่อรวม “การเติมตำแหน่งที่ไม่ถูกต้องด้วยศูนย์” และ “การคัดลอกข้อมูลที่ถูกต้อง” ไว้ในชุดบล็อกเดียวกัน หลีกเลี่ยงการเปิดใช้เคอร์เนลเพิ่มเติม
สาม โอเปอเรเตอร์ Quantization: Per-Token FP8 Casting และ SwiGLU Fusion
3.1 โครงสร้างพื้นฐาน Quantization: Config Abstraction แบบรวมศูนย์
ในฐานะบรรณาธิการเทคนิคมืออาชีพ ผมได้เขียนต้นฉบับใหม่ตามที่คุณร้องขอ ต่อไปนี้คือส่วนที่ 4/6 หลังจากล้างเนื้อหาโฆษณา/คิวอาร์โค้ดแล้ว และคงไว้ซึ่งตัวยึดตำแหน่ง [[IMAGE_X]]
แกนหลักการออกแบบของโมดูล Quantization
แกนหลักการออกแบบของโมดูล Quantization คือ dataclass สองตัวคือ CastInputConfig และ CastOutputConfig ซึ่งอธิบายรูปแบบต่างๆ ทั้งหมดอย่างเป็นเอกภาพ เช่น ชนิดข้อมูลอินพุต/เอาต์พุต ขนาด scaling block, การใช้ SF แบบ column-major ที่สอดคล้องกับ TMA หรือไม่, การใช้รูปแบบ packed UE8M0 หรือไม่:
python
ที่มา: tile_kernels/quant/common.py
@dataclass(frozen=True)
class CastOutputConfig(BaseCastConfig):
round_sf: bool = False
custom_clamp_min_value: Optional[float] = None
@property
def clamp_min_value(self) -> float:
if self.custom_clamp_min_value is not None:
return self.custom_clamp_min_value
elif self.dtype == T.float8_e4m3fn:
return 1e-4
elif self.dtype == T.float4_e2m1fn:
return T.max_value(self.dtype) * (2**-126)
clamp_min_value ทำให้แน่ใจว่า scaling factor จะไม่เล็กเกินไปจนทำให้ค่า quantization ล้น สำหรับ FP4 (E2M1) ขอบเขตล่างนี้ถูกตั้งค่าอย่างแม่นยำเป็น max_value * 2^(-126) ซึ่งเป็นขอบเขตของ FP32 denorm แสดงให้เห็นถึงความเข้าใจอย่างลึกซึ้งเกี่ยวกับการแทนค่าจุดลอยตัว
3.2 กลยุทธ์การแบ่งบล็อกของ Per-Token Quantization Kernel
per_token_cast_kernel เป็นโอเปอเรเตอร์หลักที่สุดในโมดูล quantization มันแปลงเมทริกซ์ BF16/FP32 ขนาด [num_tokens, hidden] เป็น FP8/FP4 ในที่ พร้อมกับสร้าง scaling factor แบบ per-group
python
ที่มา: tile_kernels/quant/per_token_cast_kernel.py
with T.Kernel(T.ceildiv(num_tokens, block_m), T.ceildiv(hidden, block_k),
threads=num_threads) as (pid_token, pid_hidden):
x_fragment = T.alloc_fragment((block_m, block_k), in_config.dtype)
T.annotate_layout({
x_fragment: T.Fragment(
(block_m, block_k),
forward_fn=x_layout_fn,
)
})
# 1. โหลดข้อมูลไปยัง register
T.copy(x[pid_token * block_m, pid_hidden * block_k], x_fragment, disable_tma=True)
# 2. Reduce หา absmax
amax_fragment = T.alloc_fragment((block_m, num_groups), in_config.dtype)
x_fragment_reshaped = T.reshape(x_fragment, [block_m, num_groups, num_per_channels])
T.reduce_absmax(x_fragment_reshaped, amax_fragment, dim=2)
# 3. คำนวณ SF และจัดเก็บ
for i, j in T.Parallel(block_m, num_groups):
sf, sf_inv = get_sf_and_inv(amax, out_config)
store_sf(out_sf, sf, m_idx, k_idx, out_config)
sf_inv_fragment[i, j] = sf_inv
# 4. คูณด้วย inverse ของ SF และเขียนออก
for i, j in T.Parallel(block_m, block_k):
out_shared[i, j] = x_fragment[i, j] * sf_inv_fragment[i, j // num_per_channels]
การ T.annotate_layout ที่นี่กำหนดวิธีการแมป fragment ไปยังเธรดใน register file แบบกำหนดเอง x_layout_fn โดยมีจุดประสงค์เพื่อให้ 128 เธรดแต่ละเธรดรับผิดชอบการโหลดแบบ vectorized ของ 32 องค์ประกอบที่ต่อเนื่องกัน T.reshape จะตีความ fragment ใหม่เป็นมุมมองสามมิติในเวลาคอมไพล์ ทำให้ T.reduce_absmax(dim=2) สามารถลดขนาดตาม粒度 num_per_channels ได้โดยตรง ทั้งหมดนี้เสร็จสิ้นภายใน register โดยไม่มีการเข้าถึงหน่วยความจำเพิ่มเติม
3.3 การผสาน SwiGLU และ Quantization อย่างถึงที่สุด
เลเยอร์ FFN ของโมเดลขนาดใหญ่มักใช้ฟังก์ชัน激活 SwiGLU การใช้งานแบบพื้นฐานต้องใช้การเปิดใช้เคอร์เนลสามครั้ง (SwiGLU → คำนวณ SF → Quantization) แต่ละครั้งต้องอ่านและเขียนข้อมูลทั้งหมดหนึ่งรอบ TileKernels ผสานสามขั้นตอนนี้เข้าเป็นเคอร์เนลเดียว:
python
ที่มา: tile_kernels/quant/swiglu_forward_and_per_token_cast_kernel.py
SwiGLU + clamp + การคูณน้ำหนักแบบเลือกได้
val_l = T.float32(xl_fragment[i, j])
val_r = T.float32(xr_fragment[i, j])
if use_clamp:
val_l = T.min(val_l, swiglu_clamp_value)
val_r = T.max(T.min(val_r, swiglu_clamp_value), -swiglu_clamp_value)
if with_weight:
val = val_l / (1 + T.exp(-val_l)) * val_r * topk_weights_fragment[i]
else:
val = val_l / (1 + T.exp(-val_l)) * val_r
จากนั้นทำ per-group absmax → SF → quantization เขียนออกในที่
T.reduce_absmax(x_fragment_reshaped, sf_inv_fragment, dim=2)

《DeepSeek เปิดเผย TileKernels แบบโอเพนซอร์ส: โอเปอเรเตอร์ GPU ที่เขียนด้วย Python เข้าใกล้ขีดจำกัดประสิทธิภาพฮาร์ดแวร์》—— ส่วนที่ 5/6
สังเกตว่า val_l / (1 + T.exp(-val_l)) คือคำจำกัดความทางคณิตศาสตร์ของ SiLU (Swish) การดำเนินการ SwiGLU ทั้งหมด, การนับ clamp แบบเลือกได้, การคูณน้ำหนัก Top-K, การลดขนาด absmax, การคำนวณ SF และการ cast FP8 ทั้งหมดเสร็จสิ้นในชุด register เดียวกันแบบ pipeline สำหรับสถานการณ์การฝึกที่ต้องการนับจำนวนครั้งของการ clamp ยังมีการรวมการนับข้ามบล็อกผ่าน T.alloc_reducer('sum') และ T.atomic_add โดยใช้กลยุทธ์ persistent kernel เพื่อหลีกเลี่ยงการเปิดใช้เพิ่มเติม
สี่ Engram Gating: Asynchronous Pipeline ที่ได้รับการปรับแต่งอย่างสูง
Engram Gating เป็นโอเปอเรเตอร์ที่มีความซับซ้อนทางวิศวกรรมมากที่สุดใน TileKernels โดย forward kernel
engram_gate_fwd_kernelแสดงให้เห็นการออกแบบ GPU pipeline ระดับตำราเรียน
4.1 Dual-Pass Asynchronous Pipeline
กระบวนการ forward แบ่งออกเป็นสอง pass: Pass 1 คำนวณ gate score (密集型 Reduction) และ Pass 2 ส่งออก x + gate * v (密集型 Memory Access) ทั้งสองซ้อนทับกันผ่าน cp.async Asynchronous Copy:
python
ที่มา: tile_kernels/engram/engram_gate_kernel.py
Pass 1: โหลด x และ k แบบ double buffer ผ่าน cp.async pipeline
for i_b in T.Serial(1, num_blk):
phase = i_b % 2
T.async_copy(hidden_states[i_s, pid_h, i_b * blk_d:(i_b+1) * blk_d],
x_smem[i_b * blk_d:(i_b+1) * blk_d])
T.async_copy(k[i_s, pid_h, i_b * blk_d:(i_b+1) * blk_d],
kv_smem[phase, :])
T.ptx_wait_group(2) # รอให้กลุ่ม async สูงสุด 2 กลุ่มเสร็จ
# คำนวณ rstd_x, rstd_k, gate_score (dot product + weighted)
for i_k in T.serial(vec_size):
rstd_x_local[0] += x_local[i_k] * x_local[i_k]
rstd_k_local[0] += k_local[i_k] * k_local[i_k]
gate_score_local[0] += x_local[i_k] * w_local[i_k] * k_local[i_k]
รายละเอียดสำคัญอยู่ที่การใช้ T.ptx_wait_group(2) ซึ่งแมปโดยตรงกับคำสั่ง cp.async.wait_group ของ PTX ทำให้สตรีมการคำนวณสามารถคงการคัดลอกแบบอะซิงโครนัสที่กำลังดำเนินการอยู่ได้สูงสุด 2 รายการ ทำให้การเคลื่อนย้ายข้อมูลและการดำเนินการ multiply-add ซ้อนทับกันอย่างสมบูรณ์
4.2 การ复用 Buffer ข้าม Pass
ในช่วงท้ายของ Pass 1 และต้นของ Pass 2 โค้ด复用 kv_smem เพื่อดึงข้อมูลเวกเตอร์ v ล่วงหน้า:
python
ท้าย Pass 1: 复用 bank kv_smem ที่ถูกปล่อยเพื่อดึง v[0] ล่วงหน้า
T.async_copy(v[i_s, 0:blk_d], kv_smem[v_start_phase, :])
Pass 2: ใช้ x_smem (ยังใช้ได้) และ kv_smem (บรรจุ v แล้ว) เพื่อเขียน output
for i_k in T.vectorized(vec_size):
output[i_s, pid_h, sub_base + thread_idx * vec_size + i_k] =
x_local[i_k] + gate_score_reducer[0] * v_local[i_k]
กลยุทธ์การสลับบัฟเฟอร์แบบ “ใช้ shared memory เพิ่มเติมเป็นศูนย์” นี้ ช่วยลดการใช้ shared memory ให้เหลือน้อยที่สุด ทำให้แต่ละ SM สามารถเก็บบล็อกได้มากขึ้นพร้อมกัน เพิ่ม occupancy
4.3 วิศวกรรมสุดยอดของ Backward Kernel
Backward kernel engram_gate_bwd_kernel ใช้ 8 warp (256 เธรด) โดยทุก 2 warp ทำงานร่วมกันเพื่อประมวลผลหนึ่ง head (hc_mult=4) สิ่งที่โดดเด่นที่สุดคือ กลยุทธ์การสะสม grad_w ใน register: warp pair แต่ละคู่จะรักษาองค์ประกอบ grad_w จำนวน hidden_size / threads_per_head ไว้ใน register อย่างสมบูรณ์ จะสะสมจนครบทุก token แล้วจึงเขียนกลับไปยัง global memory เพียงครั้งเดียว ช่วยลดความต้องการแบนด์วิธการเขียนให้เหลือน้อยที่สุด
ห้า Manifold HyperConnection: Sinkhorn Normalization บน GPU
tile_kernels/mhc/sinkhorn_kernel.py ใช้การแพร่กระจายไปข้างหน้าและย้อนกลับของ Sinkhorn Normalization โดยพื้นฐานแล้ว Sinkhorn Normalization คือการสลับกันทำ row normalization และ column normalization ซึ่งสามารถเข้าใจได้ว่าเป็นการ “ทำให้เมทริกซ์กลายเป็น double stochastic matrix”
python
ที่มา: tile_kernels/mhc/sinkhorn_kernel.py
Initial softmax + eps
T.reduce_max(comb_frag, row_max, dim=2)
for i, j, k in T.Parallel(token_block_size, hidden_size, hidden_size):
comb_frag[i, j, k] = T.exp(comb_frag[i, j, k] – row_max[i, j])
T.reduce_sum(comb_frag, row_sum, dim=2)
for i, j, k in T.Parallel(token_block_size, hidden_size, hidden_size):
comb_frag[i, j, k] = comb_frag[i, j, k] / row_sum[i, j] + eps
สลับ normalization ซ้ำ repeat ครั้ง
for _ in T.serial(repeat – 1):
T.reduce_sum(comb_frag, row_sum, dim=2) # Row normalization
…
T.reduce_sum(comb_frag, col_sum, dim=1) # Column normalization
…
จุดเด่นของ backward kernel คือ: มันเก็บผลลัพธ์กลางของแต่ละขั้นตอนในกระบวนการ forward ไว้ในอาร์เรย์ xs และ sums ของ shared memory จากนั้นจึงวนซ้ำ “checkpoint” เหล่านี้ในลำดับย้อนกลับเพื่อคำนวณเกรเดียนต์ กลยุทธ์การคำนวณใหม่แบบ all-register + shared memory นี้หลีกเลี่ยงการเขียนสถานะกลางกลับไปยัง HBM ทำให้ forward + backward ทั้งหมดอยู่ใน chip เมื่อ hidden_size มีขนาดเล็ก (เช่น
⚠️ หมายเหตุ: เนื้อหาได้รับการแปลโดย AI และตรวจสอบโดยมนุษย์ หากมีข้อผิดพลาดโปรดแจ้ง
☕ สนับสนุนค่ากาแฟทีมงาน
หากคุณชอบบทความนี้ สามารถสนับสนุนเราได้ผ่าน PromptPay
本文来自网络搜集,不代表คลื่นสร้างอนาคต立场,如有侵权,联系删除。转载请注明出处:https://www.itsolotime.com/th/archives/31750
