เราได้เผยแพร่บทความเกี่ยวกับ Mega Kernel หลายครั้งก่อนหน้านี้ วันนี้เราจะมาสำรวจบทความนี้: 《ไม่ต้องสร้าง MegaKernels ด้วยมือ! Luminal คอมไพล์สร้าง MegaKernels: แก้ปัญหา GPU SM โหลดไม่สมดุล กำจัดค่าใช้จ่ายในการเริ่มเคอร์เนลและฟองหน่วยความจำ รองรับทุกสถาปัตยกรรม!》。ผู้เขียน เจิ้ง ฉีหัง ได้วิเคราะห์เชิงลึกเกี่ยวกับคอมไพเลอร์โอเพนซอร์ส Luminal และรวมกับการทดสอบจริงของ gemma-3-4b ที่ทำงานบน H200 เพื่อจัดระเบียบการออกแบบ IR และกลไกการค้นหา
ขั้นตอนการคอมไพล์แบ่งออกเป็นหกขั้นตอน: ส่วนหน้าใช้ GraphTensor เพื่ออธิบายกระบวนการคำนวณ ใช้ ShapeTracker เพื่อบันทึกข้อมูลเค้าโครงของเทนเซอร์ ซึ่งช่วยขจัดการดำเนินการเกี่ยวกับรูปร่างที่ชัดเจนจำนวนมาก และสุดท้ายสร้าง IR ระดับสูงที่มีเพียง 20 การดำเนินการพื้นฐาน (primop) ที่เรียกว่า HLIR กราฟการคำนวณทั้งหมดถูกแบ่งออกเป็นหลาย chunk ตาม graph_break และ chunk ที่มีโครงสร้างเหมือนกันจะถูกรวมเป็น group ต่อจากนั้น การค้นหา egraph saturation จะดำเนินการในหน่วยของ group เพื่อสร้างโซลูชันตัวเลือก รวมถึงการเรียก CUDA kernel และ cuBLAS จากนั้นจึงคัดเลือกการใช้งานที่ดีที่สุดโดยการวัดเวลาแฝงจริง IR ระดับต่ำ (LLIR) ที่สกัดออกมาจะถูกแทนที่ด้วยพารามิเตอร์เทมเพลตและคอมไพล์ทันทีด้วย NVRTC เพื่อสร้างโค้ดที่ปฏิบัติการได้บน GPU สุดท้าย Runtime จะสร้าง CUDA Graph เพื่อดำเนินการอนุมาน
ผลการทดสอบจริงแสดงให้เห็นว่า ปริมาณงานอนุมาน fp32 ของ Luminal ต่ำกว่า vLLM มาก และในปัจจุบันยังไม่ได้ใช้งานการรวม “การส่งออก FlashAttention อัตโนมัติ” ตามที่โฆษณาไว้ ผู้เขียนชี้ให้เห็นว่าคอมไพเลอร์นี้ขาดคำอธิบายลำดับชั้นหน่วยความจำและการปรับแต่ง tiling และเชื่อว่ามีช่องว่างระหว่างเป้าหมายที่โฆษณาและความคืบหน้าที่แท้จริง และตั้งคำถามเกี่ยวกับตำแหน่ง “คอมไพเลอร์” ของมัน
ฉันสังเกตเห็นคอมไพเลอร์ Luminal ตั้งแต่ปีที่แล้ว มันอ้างว่าสามารถบรรลุประสิทธิภาพสูงสุด 80% ผ่านการคอมไพล์อัตโนมัติเต็มรูปแบบ และสามารถค้นหา FlashAttention ได้ หลังจากนั้นก็ได้รับเงินทุน เมื่อเร็ว ๆ นี้ ฉันได้รันตัวอย่าง gemma-3-4b จริง ๆ และใช้โอกาสนี้จัดระเบียบการออกแบบ IR และกลไกการค้นหาของมัน
หนึ่ง ภาพรวม
ขั้นตอนการคอมไพล์ของ Luminal แบ่งออกเป็นหกขั้นตอนโดยประมาณ:
- Frontend: ผู้ใช้เขียนโอเปอเรเตอร์ (เช่น
matmul,softmax) ผ่าน APIGraphTensorการดำเนินการเช่นexpand_dimหรือpermuteในส่วนหน้าจะแก้ไขเฉพาะข้อมูลเมตาShapeTrackerที่แนบมากับเทนเซอร์เท่านั้น และจะไม่สร้าง op ใหม่ ดังนั้น จำนวน op ในกราฟ HLIR สุดท้ายจึงน้อยกว่าโหนดในนิพจน์ของผู้ใช้มาก - HLIR: การดำเนินการส่วนหน้าจะรวมตัวกันเป็น DAG ของเทนเซอร์ที่ประกอบด้วย primop 20 ตัว ซึ่งก็คือ IR ระดับสูงของ Luminal เอง
- Partition / Group: ตาม
graph_breakที่แทรกโดยส่วนหน้า HLIR ทั้งหมดจะถูกตัดเป็น chunk หลายชิ้น จากนั้น chunk ที่มีโครงสร้างเหมือนกันจะถูกรวมเป็น group เดียวที่ไม่ซ้ำกัน ขั้นตอนต่อ ๆ ไปจะดำเนินการในหน่วยของ group - Egglog saturation: ทำให้แต่ละ group เป็น serialized เป็นโปรแกรม egglog และดำเนินการค้นหา saturation ความสัมพันธ์ที่เท่าเทียมกัน สำหรับโมเดล 4B กระบวนการนี้บน CPU แกนเดียวใช้เวลาประมาณ 30 นาที ซึ่งเป็นสาเหตุหลักของค่าใช้จ่ายในการคอมไพล์
- Extraction / LLIR: ขั้นแรกให้แยกโซลูชันตัวเลือกจาก egraph ที่อิ่มตัวแล้ว จากนั้นลดระดับ (lower) ลงเป็น LLIR
- Codegen / Runtime: แต่ละโหนด LLIR จะสร้าง CUDA kernel (หรือการเรียก cuBLAS) ก่อนผ่าน Codegen จากนั้น Runtime จะเชื่อมต่อเข้าด้วยกันเป็น CUDA Graph เพื่อดำเนินการอนุมาน โดยรวมแล้วมันคล้ายกับกระบวนการ JIT มากกว่า: การคอมไพล์ kernel เกิดขึ้นทันทีในขั้นตอน Codegen การจัดสรร buffer จะเสร็จสิ้นเมื่อ Runtime เริ่มทำงาน ไม่ใช่วิธี AOT ที่คอมไพล์ล่วงหน้าทั้งหมด
สอง HLIR
HLIR คือ IR เทนเซอร์ระดับสูงของ Luminal ซึ่งมีเพียง 20 primop ซึ่งแสดงถึงการดำเนินการอะตอมที่เล็กที่สุด HLIR ของโมเดล Gemma 3 4B หนึ่งตัวมี primop ประมาณ 5000 ตัว
primop 20 ตัวนี้สามารถแบ่งออกเป็นเจ็ดประเภท:
| ประเภท | การดำเนินการ |
|---|---|
| I/O | Input, Output, Constant |
| DType / Range | Cast, Iota |
| Unary | Exp2, Log2, Sin, Recip, Sqrt |
| Binary | Add, Mul, Mod, LessThan |
| Reduction | SumReduce, MaxReduce, Softmax |
| Indexing | Gather, Scatter |
| Fallback | CustomOpKind |
ยกตัวอย่างการคูณเมทริกซ์: a: [M, K] @ b: [K, N] -> [M, N] ไม่มีลูป for k ที่ชัดเจนใน HLIR โค้ดส่วนหน้าที่สอดคล้องกันมีดังนี้:
// src/frontend/matmul.rs
let mul = self.expand_dim(1, n) * rhs.permute((1, 0)).expand_dim(0, m);
let ret = mul.sum(2);
เมื่อแทนที่ shape จริงแล้ว HLIR ที่สร้างขึ้นจะมีเพียง 5 โหนด
แนวคิดการออกแบบนี้คล้ายกับ Jittor อย่างมาก โดยทั้งคู่ขยายเลย์เอาต์เพื่อแสดงพื้นที่ลูป สังเกตโหนด Mul และ SumReduce ข้างต้น: โหนด input มี rank 2 มิติ ในขณะที่ Mul ใช้ dims=[2, 4, 3] และ strides ของอินพุตทั้งสองคือ [(z*3), 0, z] และ [0, z, (z*4)] (โดยที่ z แทน sizeof(dtype)) 0 ใน stride เกิดจากมิติ broadcast ที่สร้างโดย expand_dim ไม่มี op การดำเนินการ Shape อิสระในระบบ ฟังก์ชันที่เกี่ยวข้องส่วนใหญ่จะดำเนินการผ่าน ShapeTracker
สิ่งที่ควรสังเกตคือ Softmax ไม่ได้ถูกแยกย่อยเป็นชุดของ Exp2 + SumReduce + Div ซึ่งน่าจะเป็นการออกแบบที่เลือกเพื่อให้การ rewrite และ pattern match ในภายหลังสะดวกยิ่งขึ้น
2.1 ShapeTracker
หน้าที่หลักของ
ShapeTrackerคือการแทนที่การดำเนินการExpand/Reshape/Permuteที่ชัดเจน สามารถเข้าใจได้ดังนี้: ขั้นแรกจะบันทึกข้อมูล Layout จากนั้นจึงนำไปใช้จริงในการคำนวณครั้งต่อไป เพื่อแสดงการดำเนินการเปลี่ยนรูปร่างเหล่านี้ ขั้นตอนการทำงานโดยประมาณมีดังนี้:
- แต่ละ
GraphTensorจะมาพร้อมกับShapeTrackerซึ่งบันทึกข้อมูลdims,strides,offset,maskปัจจุบันที่ส่งผลต่อลำดับการเข้าถึงข้อมูล - ฟังก์ชันส่วนหน้าเช่น
expand_dim,permute,reshape,sliceจะแก้ไขเฉพาะShapeTrackerเท่านั้น และจะไม่แทรกโหนดใหม่ลงในกราฟ HLIR - เมื่อสร้าง op การคำนวณจริง (
Mul,Add,SumReduce)ShapeTrackerปัจจุบันจะถูกอ่านและทำให้คงที่ในลายเซ็นอินพุตของ op นั้น
ดังนั้น ในตัวอย่างข้างต้น HLIR จึงมี Mul ที่มีข้อมูล shape/stride ไม่ใช่ห่วงโซ่การดำเนินการเช่น Expand -> Permute -> Mul
สำหรับการดำเนินการทั่วไปหลายประเภท:
expand_dim: แทรกหนึ่งมิติในdimsและตั้งค่า stride ที่สอดคล้องกันเป็น0ซึ่งแสดงถึงการดำเนินการ broadcastpermute: จัดเรียงdimsและstridesใหม่ ซึ่งหมายถึงการเปลี่ยนลำดับการสังเกตเท่านั้น ไม่เกี่ยวข้องกับการย้ายข้อมูลreshape/slice: อัปเดตข้อมูลมุมมองเช่นdims,offset,maskและไม่สร้าง HLIR op ใหม่เช่นกัน
สาม Partition / Group
หลังจากสร้างกราฟ HLIR เสร็จสมบูรณ์แล้ว การค้นหา egg โดยตรงบนกราฟทั้งหมดมีค่าใช้จ่ายสูงเกินไป โดยเฉพาะอย่างยิ่งสำหรับโมเดลที่มีโครงสร้างซ้ำซ้อนสูง เช่น Transformer และไม่จำเป็น ดังนั้น ขั้นตอนนี้จึงดำเนินการสองงาน:
- Partition: แบ่ง HLIR ทั้งหมดออกเป็นหลาย chunk โดยแต่ละ chunk คือ “กราฟย่อยที่สมบูรณ์ ซึ่งจะถูกค้นหา/คอมไพล์ภายใน” จุดแบ่งถูกระบุอย่างชัดเจนโดยส่วนหน้า (
graph_break) ตำแหน่งทั่วไปรวมถึงขอบเขตของแต่ละเลเยอร์ของ transformer หรือจุดที่อัปเดต KV cache ซึ่งเป็นจุดแบ่งตามธรรมชาติ - Group: จากนั้นรวม chunk ที่มีโครงสร้างเหมือนกันทุกประการเป็น group เดียวกัน แต่ละ group ต้องทำการค้นหา egraph เพียงครั้งเดียว และผลลัพธ์สามารถแชร์กับ chunk สมาชิกทั้งหมดได้
จากกรณีการทำงานของ Gemma 3 4B บน H200 ข้อมูลขนาดที่เกี่ยวข้องมีดังนี้:
| ระดับ | จำนวน | คำอธิบาย |
|---|---|---|
| chunk | 35 | กราฟทั้งหมดถูกแบ่งเป็น 35 ชิ้น แต่ละชิ้นมี HLIR op ประมาณ 140 ตัว |
| group | 5 | หลังจากลบ chunk ที่ซ้ำกันตามโครงสร้างจาก 35 ชิ้นแล้ว เหลือเทมเพลต 5 ประเภท |
โครงสร้างโมเดลที่สอดคล้องกับ 5 group นี้คือ:
- 1 decoder layer group: 34 เลเยอร์ decoder ทั้งหมดใช้เทมเพลตชุดนี้ร่วมกัน ซึ่งเป็นแหล่งที่มาหลักของผลประโยชน์จากการลบข้อมูลซ้ำซ้อน
- 1 embedding group: จัดการส่วน token lookup
- 1 final norm + logits group: จัดการส่วนหัวเอาต์พุตสุดท้ายของโมเดล
- 2 auxiliary groups: สอดคล้องกับจุดเข้า prefill / decode, โมดูล RoPE / mask ฯลฯ ที่ไม่ได้อยู่ในเลเยอร์ decoder หลัก
สี่ Egglog saturation
ขั้นตอนนี้ใช้เทคนิค egraph saturation เพื่อทำการแปลงที่เท่าเทียมกันและการปรับให้เหมาะสมบน HLIR โดยสร้างตัวเลือกการใช้งานที่เทียบเท่าจำนวนมาก กระบวนการค้นหาส่วนใหญ่ทำงานสี่อย่าง:
4.1 การจับคู่ HLIR primop เดี่ยวกับ kernel op
HLIR op แต่ละตัว (Add, Mul, SumReduce, Exp2…) มีกฎ kernel_rewrite<HLIR, Kernel> ที่สอดคล้องกัน ซึ่งใช้ในการขยายเป็น KernelOp ระดับภาษา (เช่น KernelAdd ของ CUDA, op ที่สอดคล้องกันของ Metal ฯลฯ) HLIR op การคำนวณ 17 ตัวแต่ละตัวมีกฎ rewrite ดังกล่าวหนึ่งข้อ (อยู่ใน crates/luminal_cuda_lite/src/kernel/hlir.rs) ขั้นตอนนี้จะแปลง HLIR op บริสุทธิ์เป็น “ตัวเลือกที่สามารถดำเนินการได้จริง”
กฎ rewrite ที่ง่ายที่สุดประเภทนี้ ในโค้ดจริง ๆ แล้วเป็น helper ทั่วไป:
pub fn kernel_rewrite<H: Default + EgglogOp, L: Default + EgglogOp>() -> Rule {
...
rule(union(hlir_op.clone(), llir_op)).fact(eq(dt, dtype(hlir_op)))
}
วิธีการนั้นตรงไปตรงมา: เมื่อเห็น HLIR op จะรวมมันกับ KernelOp ที่สอดคล้องกันใน eclass เดียวกัน ตัวอย่างเช่น Mul สามารถ rewrite เป็น KernelMul, Add สามารถ rewrite เป็น KernelAdd
4.2 HLIR primop หลายตัวที่สอดคล้องกับการเรียกใช้ไลบรารีที่มีประสิทธิภาพ
รูปแบบเช่น Mul + SumReduce (นั่นคือการคูณเมทริกซ์) จะถูกระบุแยกต่างหากและลดระดับลงเป็นตัวแปร sgemm ของ cuBLAS / cuBLASLt ตัวอย่างการตั้งชื่อกฎที่เกี่ยวข้องรวมถึง: cublas sgemm row-major x column-major, cublaslt batched column-major x row-major (อยู่ในไดเรกทอรี crates/luminal_cuda_lite/src/host/cublas/ และ cublaslt/) รูปแบบเดียวกันสามารถจับคู่กับตัวแปรไลบรารีที่แตกต่างกันตาม shape / stride ที่แตกต่างกัน
กฎการเขียนซ้ำรูปแบบระดับสูงประเภทนี้มีลักษณะดังนี้:
(rewrite
(Op (SumReduce ...) (ICons (Op (Mul ...) ...) (INil)))
(Op (CuBlasSGemm ...) ...)
:name "cublas sgemm row-major × column-major")
4.3 การทำให้ Batch และ Shape แบนราบ
ดำเนินการทำให้ง่ายขึ้นตาม Layout กฎที่เกี่ยวข้องอยู่ใน src/egglog_utils/matmul_flattening/*.egg (รวมสามกฎ):
batch_merge_a_contig.egg/batch_merge_b_contig.egg: เมื่อในการดำเนินการ batch × matmul ด้านหนึ่งเป็นเค้าโครง contiguous และอีกด้านหนึ่งเป็น broadcast ให้ทำให้แบนราบเป็นการคูณเมทริกซ์สองมิติsqueeze.egg: ลบมิติที่ไม่ถูกต้อง
4.4 ตัวเลือก In-place และการตรวจสอบนามแฝง
การดำเนินการ Scatter จะถูกเขียนใหม่เป็น ScatterNoCopy(ConsumedBuffer(dest), ...) โดยที่ ConsumedBuffer ไม่ใช่การดำเนินการที่ดำเนินการจริง แต่เป็นตัวระบุที่ใช้ในขั้นตอนการค้นหาเพื่อทำเครื่องหมายความเป็นเจ้าของ เนื่องจากโหนดใน egraph อาจมีความสัมพันธ์แบบพึ่งพาวงกลม และเป็นการยากที่จะนับจำนวนผู้ใช้ ดังนั้นจุดประสงค์ของการแนะนำ ConsumedBuffer คือการรวมการวิเคราะห์การใช้งานลงในพื้นที่การค้นหาอย่างชัดเจน: หากบัฟเฟอร์เป้าหมาย dest ไม่ถูกอ่านโดยการดำเนินการอื่นในภายหลัง ก็สามารถดำเนินการเขียนในที่เดิมได้
ชุดกฎ cleanup / base_cleanup ที่ตามมามีหน้าที่ตรวจสอบสิ่งนี้:
- หาก
destไม่มี ผู้อ่านอื่นหลังจากนั้น ให้คงConsumedBuffer(dest)ไว้ และสุดท้ายอนุญาตให้ใช้ScatterNoCopyซึ่งก็คือการดำเนินการเขียนในที่เดิม - หาก
destยังมี ผู้อ่านอื่นหลังจากนั้น ให้ลบโซลูชันตัวเลือกนั้นออก และถอยกลับไปใช้การดำเนินการScatterปกติ
4.5 Saturation
Luminal ไม่ได้ผสมกฎการเขียนซ้ำทั้งหมดเข้าด้วยกัน แต่แบ่งออกเป็น 4 ชุดกฎ และค้นหาเป็นระยะ เพื่อลดขอบเขตการจับคู่ของการเขียนซ้ำแต่ละรอบ และลดค่าใช้จ่ายในการคอมไพล์:
expr: ชุดกฎการเขียนซ้ำหลัก ครอบคลุมการสร้างตัวเลือก kernel ที่สอดคล้องกับ HLIR, การทำให้ batch matmul แบนราบ, การฉีด ConsumedBuffer และการดำเนินการอื่น ๆ ทั้งหมดdtype_prop: กฎของฟังก์ชันเสริม(function dtype (IR) DType :merge new)ที่เผยแพร่ dtype ไปตามการไหลของข้อมูลcleanup: หากdestถูกอ่านโดยการดำเนินการอื่น ให้ลบConsumedBufferและล้างตัวเลือก ScatterNoCopy แบบเรียงซ้อนbase_cleanup: ชุดกฎอิสระ วางไว้ที่ส่วนท้ายสุด จัดการกับการดำเนินการที่ไม่สามารถย้อนกลับได้เช่น(union ?cb ?dest)โดยเฉพาะ ซึ่งต้องรอให้ชุดกฎก่อนหน้าทั้งหมดอิ่มตัวก่อนจึงจะดำเนินการได้อย่างปลอดภัย ในโค้ดมี TODO ที่ระบุไว้แล้ว โดยยอมรับว่านี่คือจุดอ่อนของระบบ
ลำดับการดำเนินการจริงมีดังนี้:
(repeat 10 (saturate expr) (saturate dtype_prop) (run))
(saturate expr)
(saturate cleanup)
(saturate base_cleanup)
ในการทดลองของฉัน (Gemma 3 4B, H200, 34 เลเยอร์ transformer) 34 เลเยอร์ถูกแบ่งเป็น 35 chunk และรวมเป็น 5 group ที่มีโครงสร้างเทียบเท่ากัน egraph saturation ของแต่ละ group สร้าง enode ประมาณ 5076 ตัวและ eclass 3633 ตัว ใช้เวลาประมาณ 30 นาทีบน CPU แกนเดียว
ห้า Extraction
หลังจาก saturation เสร็จสมบูรณ์ Luminal จะได้รับค่าใช้จ่ายจริงโดยตรงจากเวลาแฝงของการดำเนินการจริง:
- การสุ่มเลือก: สำหรับแต่ละ eclass ให้สุ่มเลือก enode หนึ่งตัว ลดระดับเป็น LLIR คอมไพล์ด้วย NVRTC ดำเนินการจริงและวัดเวลาแฝง (ค่าเริ่มต้นทำซ้ำ 10 ครั้งแล้วหาค่าเฉลี่ย) พร้อมตรวจสอบว่าผลลัพธ์มี NaN หรือไม่ หากการคอมไพล์ล้มเหลวหรือมี NaN ให้เปลี่ยนโซลูชันตัวเลือก สูงสุดลองใหม่ 100 ครั้ง หากทั้งหมดล้มเหลวให้ panic (ดู
src/graph.rs:653) - การกลายพันธุ์: ใช้โซลูชันตัวเลือกที่เร็วที่สุดในปัจจุบันเป็นเมล็ด (ค่าเริ่มต้นเก็บ 1 ชุด) สร้างเวอร์ชันกลายพันธุ์ 30 ตัวต่อรุ่น: ใน eclass ที่มี enode ที่เลือกได้หลายตัว ให้สุ่มเลือกโหนดบางส่วนแทนที่ด้วยตัวเลือกอื่น และใช้การแฮชเพื่อลบข้อมูลซ้ำซ้อนเพื่อหลีกเลี่ยงการวัดซ้ำ
- การประเมินผล: การกลายพันธุ์แต่ละตัวจะผ่านกระบวนการ lower + คอมไพล์ + ดำเนินการ + วัดผลเช่นกัน หากเร็วกว่าโซลูชันเมล็ด ก็จะแทนที่เมล็ด
- งบประมาณ: แต่ละ group จะประเมินโซลูชันตัวเลือกสูงสุด
options.limitตัว (Gemma 3 4B มี 5 group ตั้งค่าGEMMA_SEARCH_GRAPHS=3หมายถึงแต่ละ group ประเมิน 3 ตัวเลือก ทั้งโมเดลรวม 5 × 3 = 15 ครั้ง NVRTC + profile) ค่าเริ่มต้นอย่างเป็นทางการคือ 500 ซึ่งจะทำให้เวลาในการค้นหาเพิ่มขึ้นอย่างมาก
วิธีนี้ใช้การวัดจริงแทนโมเดลต้นทุนที่ยากต่อการคาดการณ์อย่างแม่นยำในการสร้างแบบจำลองการวิเคราะห์แบบดั้งเดิม แต่ใช้ได้เฉพาะในสภาพแวดล้อมฮาร์ดแวร์ที่เสถียรเท่านั้น และไม่สามารถใช้ในขั้นตอนการออกแบบได้
หก LLIR
ในโค้ด LLIR ถูกกำหนดเป็น:
pub type LLIRGraph = StableGraph<LLIROp, ()>;
StableGraph สามารถเข้าใจง่าย ๆ ว่าเป็นคอนเทนเนอร์กราฟที่มีหมายเลขโหนดคงที่ LLIROp คือเนื้อหาของโหนด และขอบแสดงถึงความสัมพันธ์แบบพึ่งพา เนื้อหาที่ dump ออกมาประมาณดังนี้:
LLIROp(DialectOp(KernelMul { out_shape: [4, s, 256], ... }))
LLIROp(DialectOp(KernelSumReduce { out_shape: [s], ... }))
LLIROp(DialectOp(CuBlasLt { m: 1024, n: s, k: 2560, ... }))
โดยที่แต่ละโหนดสอดคล้องโดยตรงกับหน่วยการดำเนินการเฉพาะ เช่น:
- ซอร์สโค้ด CUDA kernel (ต่อมาคอมไพล์แบบเรียลไทม์โดย NVRTC เป็นโค้ดที่ปฏิบัติการได้บน GPU)
- Metal kernel (แบ็กเอนด์ Apple)
- การเรียกใช้ไลบรารีฝั่ง host (เช่น sgemm สำเร็จรูปของ cuBLAS, cuBLASLt)
พร้อมกันนั้นยังมีข้อมูลเมตาที่ Runtime ใช้:
- ขนาดของบัฟเฟอร์เอาต์พุต (นิพจน์สัญลักษณ์ รองรับ shape แบบไดนามิก)
- จำนวนไบต์ที่อ่าน/เขียน, FLOPs การคำนวณ
- ว่าเอาต์พุตใช้บัฟเฟอร์อินพุตบางตัวซ้ำหรือไม่ (การเขียนในที่เดิม) ฯลฯ
ในระดับ LLIR ถือว่าข้อมูลถูกส่งผ่านระหว่างโหนดโดย global memory ในขณะที่ระดับที่ละเอียดกว่าเช่น shared memory, register จะไม่สะท้อนใน LLIR
6.1 LLIR ของ Gemma 3 4B
โอเค ตามคำแนะนำของคุณ ฉันจะทำการเขียนใหม่เชิงลึกและลดความซ้ำซ้อนสำหรับส่วนของบทความที่ให้มา โดยปฏิบัติตามกฎทั้งหมดอย่างเคร่งครัด
หลังจากโมเดล Gemma 3 4B ถูกคอมไพล์ LLIR ที่สร้างขึ้นจะมีประมาณ 7250 โหนด โดยมีการกระจายเฉพาะดังนี้:
KernelMul 2043 KernelGather 205
KernelAdd 810 KernelSin 68
KernelIota 648 KernelScatter 66
KernelCast 438 KernelLessThan 63
KernelConstant 409 KernelExp2 35
KernelRecip 378 KernelExp 35
KernelSumReduce 375 KernelMaxReduce 34
KernelSqrt 205 KernelSigmoid 32
KernelScatterNoCopy 2
จากข้อมูลจะเห็นได้ชัดเจนว่าการดำเนินการแบบ Elementwise ครอบงำอย่างสมบูรณ์ สิ่งที่น่าสังเกตคือโหนด KernelScatterNoCopy เพียงสองโหนดในโค้ดนั้นใช้สำหรับการดำเนินการเขียนในที่เดิมของ KV cache ทั้งหมด นี่คือผลลัพธ์ของกลไก ConsumedBuffer ที่กล่าวถึงก่อนหน้านี้: เฉพาะเมื่อบัฟเฟอร์หนึ่งมีผู้ใช้เพียงรายเดียว ระบบ egglog จะคงการดำเนินการ Scatter ปกติและปรับให้เหมาะสมเป็นเวอร์ชัน ScatterNoCopy
เจ็ด การสร้างโค้ด (Code Generation)
โดยพื้นฐานแล้ว LLIR เป็นเพียงชุดคำอธิบายข้อมูล GPU ไม่สามารถดำเนินการได้โดยตรง กระบวนการจัดการของ Luminal สำหรับสิ่งนี้มีดังนี้:
7.1 เทมเพลตและการกำหนดพารามิเตอร์
การดำเนินการ kernel แต่ละประเภทจะรักษาเทมเพลต C++ kernel ของตัวเอง ในขั้นตอนการสร้างโค้ด ระบบจะเติมพารามิเตอร์เช่น shape, stride, dtype จากโหนดลงในเทมเพลตเหล่านี้ เพื่อสร้างซอร์สโค้ด CUDA ที่เป็นรูปธรรม จากนั้นซอร์สโค้ดนี้จะถูกส่งไปยัง NVRTC เพื่อทำการคอมไพล์ JIT และสุดท้ายสร้าง kernel ที่ GPU สามารถดำเนินการได้จริง ในกระบวนการทั้งหมดไม่มี IR ระดับลูป, schedule pass หรือ tiling เป็นขั้นตอนกลาง รูปแบบของ kernel ถูกกำหนดโดยเทมเพลตทั้งหมด
ตัวอย่างเช่น สำหรับโหนด KernelAdd งานจริงของการสร้างโค้ดคือการแทนที่เทมเพลต และสุดท้ายประกอบเป็นซอร์สโค้ดที่สมบูรณ์ดังนี้:
extern "C" {
__global__ void add_k(float *C, const float *A, const float *B, const int* dyn_dims) {
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (const_z >= /* n_elements */) return;
C[/* out_idx */] = A[/* a_idx */] + B[/* b_idx */];
}
}
เพื่อหลีกเลี่ยงการคอมไพล์ kernel ที่เหมือนกันซ้ำ ๆ Luminal จะใช้ซอร์สโค้ดที่สร้างขึ้นเป็นคีย์สำหรับแคชภายในกระบวนการ เมื่อซอร์สโค้ดที่ประกอบขึ้นสุดท้ายของสองโหนดเหมือนกันทุกประการ ระบบจะนำฟังก์ชันที่คอมไพล์แล้วกลับมาใช้ใหม่โดยตรง
7.2 การเรียกใช้ไลบรารี
แน่นอนว่าไม่ใช่ทุกโหนด LLIR ที่ต้องผ่านเส้นทางการสร้างซอร์สโค้ด ดังที่กล่าวไว้ในส่วนการค้นหาก่อนหน้านี้ ชุดค่าผสม Mul + SumReduce จะถูกเขียนใหม่เป็นการดำเนินการ matmul และสุดท้ายจะสอดคล้องกับการห่อหุ้มจุดเข้าไลบรารี cuBLAS / cuBLASLt สำหรับโหนดประเภทนี้ งานของการสร้างโค้ดเป็นเพียงการเลือกจุดเข้าใช้งานฟังก์ชันไลบรารีที่เหมาะสม และปรับพารามิเตอร์เช่น stride, leading dimension ให้เป็นรูปแบบที่ cuBLAS รองรับ เมื่อดำเนินการ ก็สามารถเรียกใช้ฟังก์ชันฝั่ง host เช่น cublasSgemm ได้โดยตรง
แปด รันไทม์ (Runtime)
เนื่องจากโค้ดถูกคอมไพล์เสร็จสิ้นในขั้นตอนก่อนหน้าแล้ว สิ่งที่ส่งมอบให้กับรันไทม์คือชุดของ kernel และการเรียกใช้ไลบรารีที่เป็นอิสระต่อกัน งานของรันไทม์แบ่งออกเป็นสองขั้นตอนหลัก:
load_llir: ขั้นแรกให้ประกอบ LLIR ของแต่ละ group ตั้งค่าพอยน์เตอร์อินพุต/เอาต์พุต จัดสรรบัฟเฟอร์กลาง และจับภาพเป็น CUDA Graphexecute: ในการอนุมานแต่ละครั้ง ให้เล่น CUDA Graph ที่สอดคล้องกันตามลำดับ chunk และดึงผลลัพธ์เอาต์พุตออกมา
8.1 ขั้นตอนการโหลด
ขั้นตอนการโหลดจะอ่าน LLIR ของแต่ละ group ก่อน จากนั้นดำเนินการดังต่อไปนี้:
1. จัดสรร Buffer:
รันไทม์จะสำรวจแต่ละโหนดใน LLIR ตามนิพจน์ขนาดเอาต์พุตของโหนด รวมกับ dyn_map ปัจจุบัน (เช่น M=1024, N=4096) เพื่อคำนวณจำนวนไบต์ที่ต้องการ จากนั้นจะเปรียบเทียบกับบัฟเฟอร์ที่มีอยู่: หากบัฟเฟอร์ที่มีอยู่มีความจุเพียงพอก็จะนำกลับมาใช้ใหม่โดยตรง มิฉะนั้นจะเรียก cudaMalloc เพื่อจัดสรรบัฟเฟอร์ใหม่
นอกจากนี้ โหนดอาจมีเครื่องหมายระบุว่าเอาต์พุตสามารถใช้ที่อยู่ของบัฟเฟอร์อินพุตบางตัวซ้ำได้ ในกรณีนี้ รันไทม์จะชี้พอยน์เตอร์เอาต์พุตไปยังพอยน์เตอร์อินพุตนั้นโดยตรง โดยไม่ต้องจัดสรรหน่วยความจำเพิ่มเติม KernelScatterNoCopy (ใช้สำหรับการเขียนในที่เดิมของ KV cache) ที่กล่าวถึงก่อนหน้านี้ได้รับการจัดการด้วยวิธีนี้
2. แพ็ค CUDA Graph:
แต่ละ group จะถูกประมวลผลแยกกัน รันไทม์จะจัดเรียง kernel ทั้งหมดภายใน group นั้นตามลำดับ LLIR จากนั้นเรียกใช้ CUDA Graph API เพื่อจับภาพลำดับการ launch ทั้งหมดนี้เป็นกราฟเดียว บนโมเดล Gemma 3 4B มีการสร้าง CUDA Graph ทั้งหมด 5 กราฟ (หนึ่งกราฟต่อ group) แต่ละกราฟภายในประกอบด้วย kernel 12 ถึง 180 ตัว ในการดำเนินการ เพียงแค่เรียก cuGraphLaunch เพียงครั้งเดียวก็สามารถปล่อยลำดับทั้งหมดได้ ซึ่งช่วยลดค่าใช้จ่ายในการ launch ได้อย่างมาก
8.2 ขั้นตอนการดำเนินการ
การดำเนินการในขั้นตอนนี้ง่ายมาก:
- ระบุพอยน์เตอร์ของข้อมูลอินพุต
- ตามลำดับของ chunk ให้เริ่ม CUDA Graph ของ group ที่ chunk นั้นสังกัดอยู่ตามลำดับ
- หากจำเป็น ให้อ่านข้อมูลในบัฟเฟอร์
⚠️ หมายเหตุ: เนื้อหาได้รับการแปลโดย AI และตรวจสอบโดยมนุษย์ หากมีข้อผิดพลาดโปรดแจ้ง
☕ สนับสนุนค่ากาแฟทีมงาน
หากคุณชอบบทความนี้ สามารถสนับสนุนเราได้ผ่าน PromptPay
SCAN TO PAY WITH ANY BANK本文来自网络搜集,不代表คลื่นสร้างอนาคต立场,如有侵权,联系删除。转载请注明出处:https://www.itsolotime.com/th/archives/32385
