Python MoE Training Framework Pith-Train: หนึ่งหมื่นบรรทัดโค้ดที่ผสานการทำงานแบบสี่มิติและการควอนไทซ์ FP8 ทำลายทางเลือกระหว่างประสิทธิภาพระดับโปรดักชั่นกับความสามารถในการอ่านโค้ด

ระบบฝึกอบรมโมเดลขนาดใหญ่มักเปรียบเสมือนโรงงานปิด: สายการผลิต โทโพโลยีการสื่อสาร การกำหนดเส้นทางผู้เชี่ยวชาญ การใช้หน่วยความจำซ้ำ การฝึกแบบผสมความแม่นยำ และการกู้คืนจุดตรวจ ล้วนทำงานด้วยความเร็วสูง แต่นักพัฒนามักมองไม่เห็นว่าเฟืองจักรเหล่านี้ทำงานประสานกันอย่างไร

เฟรมเวิร์กการผลิตมีประสิทธิภาพสูง แต่มักถูกห่อหุ้มด้วยโค้ด C++/CUDA หลายหมื่นบรรทัดและรันไทม์ที่ซับซ้อน ในขณะที่โค้ดน้ำหนักเบาอ่านง่าย แต่มักรับภาระปริมาณงานของการฝึก MoE จริงไม่ได้

Pith-Train พยายาม打破ทางเลือกสองทางนี้: มันใช้ Python ประมาณหนึ่งหมื่นบรรทัด จัดระเบียบ Pipeline, Expert, FSDP, Context แบบขนานสี่มิติ, การจัดตารางเดินหน้า-ถอยหลังซ้อนทับของ DualPipeV, การฝึก FP8 ของ DeepGEMM, และโอเปอเรเตอร์ Triton แบบฟิวชัน ให้เป็นระบบฝึกอบรมที่อ่านได้ตั้งแต่ต้นจนจบ

ความสำคัญของมันไม่ได้เป็นเพียง “เฟรมเวิร์กการฝึกอบรมอีกตัวหนึ่ง” แต่เป็นการแสดงให้เห็นว่าวิศวกรรมการฝึกอบรมประสิทธิภาพสูงสามารถถูกออกแบบใหม่ได้อย่างไรในยุคของตัวช่วยโค้ด AI

สารบัญ

  • หนึ่ง: เริ่มต้นอย่างรวดเร็ว: จากการติดตั้งจนถึงการรัน MoE Pre-training ครั้งแรก
  • สอง: ตำแหน่งของโปรเจกต์: ไม่ใช่เฟรมเวิร์กของเล่น แต่เป็น “ชุดฝึกอบรมระดับการผลิตที่อ่านได้”
  • 2.1 ปัญหาหลักที่ PithTrain ต้องการแก้ไข
  • 2.2 หน้าที่ของไดเรกทอรีหลัก
  • สาม: เส้นทางหลักการฝึก: จาก torchrun สู่การอัปเดตพารามิเตอร์ครั้งเดียว
  • 3.1 สคริปต์เริ่มต้น: ปรับอัตโนมัติสำหรับเครื่องเดียวและ SLURM
  • 3.2 ฟังก์ชัน launch: สามชั้นบริบทห่อหุ้มวงจรชีวิตการฝึก
  • 3.3 สิ่งที่เกิดขึ้นใน train_step ครั้งเดียว
  • สี่: การขนานสี่มิติ: PP, DP, CP, EP ถูกรวมเป็น DeviceMesh ได้อย่างไร
  • 4.1 การขนานสี่แบบแต่ละแบบมีหน้าที่ของตน
  • 4.2 การแบ่งส่วนพารามิเตอร์ FSDP และ MoE
  • ห้า: DualPipeV: “ศูนย์ควบคุมการจราจร” เพื่อเพิ่มปริมาณงานของ PithTrain
  • 5.1 สายไปป์ไลน์สองทิศทางรูปตัว V
  • 5.2 การจัดตารางแปดขั้นตอนในฟังก์ชัน step
  • 5.3 การควบคุม hook ของ FSDP ด้วยตนเอง
  • หก: การซ้อนทับห้าขั้นตอน: แยก Transformer หนึ่งชั้นเป็นชิ้นส่วนไปป์ไลน์ที่จัดตารางได้
  • 6.1 โปรโตคอลโมเดลมาก่อน
  • 6.2 ความหมายที่แท้จริงของห้าขั้นตอน
  • 6.3 ทำไมต้องแยกเป็นห้าส่วน
  • เจ็ด: การฝึก FP8: การหาปริมาณแบบฟิวชันของ DeepGEMM และ Triton
  • 7.1 FP8Linear ทำงานอย่างไร
  • 7.2 เคอร์เนลการหาปริมาณที่รับรู้สถาปัตยกรรม
  • แปด: การกระจายแบบขนานผู้เชี่ยวชาญ: ใช้ Triton ฟิวชันโอเปอเรเตอร์เล็กๆ กว่ายี่สิบตัวเป็นสามเคอร์เนล
  • 8.1 ทำไมการสื่อสาร MoE จึงยาก
  • 8.2 การเรียงลำดับแบบนับ O(n) แทนที่ argsort O(n log n)
  • เก้า: ข้อมูลและ checkpoint: “วิศวกรรมฐานราก” ของระบบฝึกอบรม
  • 9.1 การอ่านข้อมูล mmap และการแบ่งส่วนแบบขนานบริบท
  • 9.2 การสร้างข้อมูล: การแยกคำแบบหลายกระบวนการที่กู้คืนได้
  • 9.3 รูปแบบมาตรฐานของ checkpoint
  • สิบ: ความสำคัญทางเทคนิคและขอบเขตของ PithTrain
  • 10.1 คุณค่าสูงสุด: เปลี่ยนระบบฝึกอบรมประสิทธิภาพสูงให้เป็นวัตถุที่เรียนรู้ได้
  • 10.2 ความแตกต่างจากชุดฝึกอบรมแบบดั้งเดิม
  • 10.3 ขอบเขตปัจจุบัน
  • สิบเอ็ด: แผนผังข้อความหนึ่งภาพ: PithTrain ทำการฝึกอบรมครั้งเดียวได้อย่างไร
  • 11.1 เส้นทางการดำเนินการแบบ end-to-end
  • บทสรุป: แรงบันดาลใจที่ PithTrain มอบให้กับการออกแบบระบบฝึกอบรม

หนึ่ง: เริ่มต้นอย่างรวดเร็ว: จากการติดตั้งจนถึงการรัน MoE Pre-training ครั้งแรก

README ของ PithTrain ระบุข้อกำหนดด้านสภาพแวดล้อมอย่างตรงไปตรงมา: ต้องใช้ GPU NVIDIA Hopper (SM90) หรือ Blackwell (SM100), CUDA 13.0, Python >= 3.12 และใช้ uv จัดการ dependencies เส้นทางการติดตั้งขั้นต่ำมีดังนี้

git clone https://github.com/mlc-ai/Pith-Train.git && cd Pith-Train
uv venv
uv pip install .

# หากเป็นนักพัฒนา ต้องติดตั้ง dependencies สำหรับการพัฒนาและสภาพแวดล้อมซอร์สโค้ด:
uv sync

การ pre-training Qwen3-30B-A3B ทั่วไปตั้งแต่เริ่มต้นแบ่งเป็นสามขั้นตอน

  • ขั้นตอนแรก: ดาวน์โหลดและแยกคำคลังข้อมูล pre-training DCLM: bash examples/build_tokenized_corpus/launch.sh dclm-qwen3
  • ขั้นตอนที่สอง: แก้ไขสคริปต์การฝึก ปรับขนาดการขนาน batch size อัตราการเรียนรู้ ฯลฯ ไฟล์คอนฟิกอยู่ที่: examples/pretrain_language_model/qwen3-30b-a3b/script.py
  • ขั้นตอนที่สาม: เริ่มการฝึก: bash examples/pretrain_language_model/launch.sh qwen3-30b-a3b

สคริปต์การฝึกจะตรวจจับจำนวน GPU โดยอัตโนมัติ และรองรับทั้งการรันบนเครื่องเดียวและสภาพแวดล้อมคลัสเตอร์หลายเครื่องแบบ SLURM จุดตรวจสอบโมเดลจะถูกบันทึกไว้ในไดเรกทอรี workspace โดยค่าเริ่มต้น และมีความสามารถในการกู้คืนการฝึกจาก checkpoint ล่าสุดโดยอัตโนมัติ เมื่อการฝึกสิ้นสุดลง คุณสามารถแปลงน้ำหนักโมเดลในรูปแบบ PyTorch Distributed Checkpoint เป็นรูปแบบที่เข้ากันได้กับ Hugging Face:

bash examples/convert_checkpoint/launch.sh qwen3-30b-a3b

สำหรับรายละเอียดเพิ่มเติมเกี่ยวกับโมเดลและการแปลง โปรดดู examples/build_tokenized_corpus/README.md, examples/pretrain_language_model/ และ examples/convert_checkpoint/README.md

unsetunsetสอง: ตำแหน่งของโปรเจกต์: ไม่ใช่เฟรมเวิร์กของเล่น แต่เป็น “ชุดฝึกอบรมระดับการผลิตที่อ่านได้”unsetunset

2.1 ปัญหาหลักที่ PithTrain ต้องการแก้ไข

README ของ PithTrain ระบุตำแหน่งของตัวเองอย่างเฉียบคม: Efficient, Python-native MoE training in ~10K lines of code. มันไม่ได้มุ่งเน้นที่ “วิธีการเขียนตัวอย่างการฝึก Transformer แบบง่าย” แต่ชี้ตรงไปยังเป้าหมายสามประการที่ยากจะบรรลุพร้อมกันในการฝึกโมเดลขนาดใหญ่แบบ MoE:

  1. ประสิทธิภาพใกล้เคียงระบบระดับการผลิต: รองรับการขนาน 4 มิติ, การซ้อนทับการคำนวณและการสื่อสาร, การฝึก FP8, DeepGEMM, FlashAttention, และโอเปอเรเตอร์ Triton/TileLang
  2. ตรรกะการทำงานโปร่งใสเพียงพอ: โค้ดหลักเขียนด้วย Python ขนาดคลังทั้งหมดประมาณหนึ่งหมื่นบรรทัด ทั้งนักพัฒนามนุษย์และ AI agent สามารถเข้าใจกลไกการทำงานแบบ end-to-end
  3. กระบวนการทางวิศวกรรมครบวงจร: ครอบคลุมการสร้างข้อมูล, วงรอบการฝึก, โทโพโลยีแบบกระจาย, การใช้งานโมเดล, การแปลง checkpoint, การบันทึก日志, การทดสอบ และการวัดประสิทธิภาพ

กล่าวโดยย่อ PithTrain ไม่ได้พยายามซ่อนความซับซ้อน แต่จัดระเบียบความซับซ้อนให้เป็นชั้นลำดับชั้นที่อ่านง่าย

คำอธิบายสถาปัตยกรรมใน README แบ่งระบบทั้งหมดออกเป็นสามชั้น:

  • Upstream: วงรอบการฝึกสำหรับงานต่างๆ เช่น pre-training, SFT
  • Core: ประกอบด้วยโมเดล, บล็อกการสร้าง, ไปป์ไลน์ DualPipeV, การฝึกแบบกระจาย และโครงสร้างพื้นฐานการฝึก
  • Operators: ครอบคลุม PyTorch/NCCL, DeepGEMM, FlashAttention, และโอเปอเรเตอร์ Python DSL เช่น Triton, TileLang

2.2 ไดเรกทอรีหลักและหน้าที่

จากโครงสร้างคลัง โค้ดหลักของ PithTrain กระจุกตัวอยู่ในไดเรกทอรี pithtrain/:

  • pithtrain/tasks/pretrain_language_model.py: จุดเริ่มต้นสำหรับ pre-training โมเดลภาษา รับผิดชอบการจัดระเบียบบริบท การโหลด checkpoint และการดำเนินวงรอบการฝึก
  • pithtrain/modules/training.py: รับผิดชอบการเริ่มต้นคอนฟิกการฝึก, ชุดข้อมูล, โมเดล, FSDP, ตัวปรับให้เหมาะสม และตัวจัดตารางอัตราการเรียนรู้
  • pithtrain/modules/distributed.py: ใช้สร้าง DeviceMesh แบบ 4 มิติ PP/DP/CP/EP
  • pithtrain/modules/dataset.py: การอ่านข้อมูล token แบบแพ็ครวมโดยใช้ mmap และการสับเปลี่ยนแบบ global
  • pithtrain/dualpipe/dualpipev.py: ตัวจัดตารางไปป์ไลน์ DualPipeV ซึ่งเป็นศูนย์กลางหลักในการเพิ่มประสิทธิภาพปริมาณงานของระบบ
  • pithtrain/dualpipe/overlap.py: แยกชั้น Transformer ออกเป็นห้าขั้นตอน เพื่อให้เกิดการซ้อนทับแบบละเอียดของการคำนวณเดินหน้าและถอยหลัง
  • pithtrain/models/qwen3_30b_a3b.py, deepseek_v2_lite.py, gpt_oss.py: การใช้งานโครงสร้างโมเดลเฉพาะ
  • pithtrain/models/interface.py: อินเทอร์เฟซโปรโตคอลที่ชั้นโมเดลต้องปฏิบัติตามสำหรับ DualPipeV
  • pithtrain/layers/deepgemm_fp8_linear.py: ชั้นเชิงเส้น FP8 และ MoE GroupLinear ที่ใช้ DeepGEMM
  • pithtrain/operators/ep_dispatch.py: การใช้งาน Triton แบบฟิวชันสำหรับการกระจายแบบขนานผู้เชี่ยวชาญ
  • pithtrain/operators/deepgemm_fp8_quantize.py: เคอร์เนล Triton สำหรับการหาปริมาณ FP8
  • pithtrain/modules/checkpoint.py: เครื่องมือแปลง checkpoint แบบมาตรฐานที่ไม่ขึ้นกับ PP
  • examples/: ประกอบด้วยสคริปต์ที่รันได้สำหรับการเตรียมข้อมูล, pre-training, การแปลง checkpoint ฯลฯ

การจัดระเบียบนี้ชาญฉลาดมาก: ชั้นงานจะไม่เขียนโอเปอเรเตอร์ที่ซับซ้อนโดยตรง และชั้นโอเปอเรเตอร์ก็ไม่รับรู้ถึงวงรอบการฝึก ระบบฝึกอบรมที่ซับซ้อนทั้งหมดถูกตัดเป็นโมดูลที่สามารถเปลี่ยนแทนกันได้ แต่ละโมดูลมีขอบเขตที่ชัดเจน

unsetunsetสาม: เส้นทางหลักการฝึก: จาก torchrun สู่การอัปเดตพารามิเตอร์ครั้งเดียวunsetunset

3.1 สคริปต์เริ่มต้น: ปรับอัตโนมัติสำหรับเครื่องเดียวและสภาพแวดล้อม SLURM

จุดเริ่มต้นของ pre-training เริ่มจากสคริปต์ shell สคริปต์นี้จะสร้างพารามิเตอร์ที่จำเป็นสำหรับ torchrun โดยอัตโนมัติตามตัวแปรสภาพแวดล้อม SLURM หรือจำนวน GPU ของเครื่อง

ที่มา: examples/pretrain_language_model/launch.sh

SLURM_NNODES=${SLURM_NNODES:-1}
SLURM_NODEID=${SLURM_NODEID:-0}
SLURM_STEP_GPUS=${SLURM_STEP_GPUS:-${CUDA_VISIBLE_DEVICES:-$(nvidia-smi –query-gpu=index –format=csv,noheader | paste -sd,)}}

LAUNCH_ARGS+=(–nnodes=$SLURM_NNODES –node-rank=$SLURM_NODEID)
LAUNCH_ARGS+=(–nproc-per-node=$(echo “$SLURM_STEP_GPUS” | tr ‘,’ ‘n’ | wc -l))
LAUNCH_ARGS+=(–rdzv-backend=c10d)

SCRIPT=examples/pretrain_language_model/$1/script.py
torchrun ${LAUNCH_ARGS[@]} $SCRIPT

สคริปต์นี้ดูเรียบง่ายแต่สำคัญมาก: PithTrain ปล่อยให้ “ความซับซ้อนของการเริ่มต้นหลายเครื่อง” อยู่ในสคริปต์ shell ภายนอก ในขณะที่โค้ด Python ภายในสมมติว่ามันทำงานภายใต้สภาพแวดล้อม torchrun มาตรฐาน

3.2 ฟังก์ชัน launch: สามชั้นบริบทห่อหุ้มวงจรชีวิตการฝึก

จุดเริ่มต้นการฝึก Python ที่แท้จริงอยู่ที่ pithtrain/tasks/pretrain_language_model.py ไฟล์นี้ใช้กลไก ExitStack เพื่อค่อยๆ สร้างบริบทการบันทึก日志, สภาพแวดล้อมแบบกระจาย และบริบทการฝึก จากนั้นโหลด checkpoint และเข้าสู่วงรอบการฝึกหลัก

# ที่มา: pithtrain/tasks/pretrain_language_model.py  
@shutdown.record  
def launch(cfg: PretrainLanguageModelCfg) -> None:  
"""เริ่มต้นกระบวนการ pre-training โมเดลภาษา"""  
with ExitStack() as stack:  
ctx = PretrainLanguageModelCtx()  
stack.enter_context(logging_context(cfg, ctx))  
stack.enter_context(distributed_context(cfg, ctx))  
stack.enter_context(training_context(cfg, ctx))  
logger = ctx.logging.stdout  
logger.info("launch(cfg=%s)" % cfg)  
load_checkpoint(cfg, ctx)  
raise_if_dataset_insufficient(cfg, ctx)  
while ctx.training.step < cfg.training.max_steps:  
train_step(cfg, ctx)  

การออกแบบนี้สะท้อนปรัชญาการเขียนโปรแกรมแบบ Pythonic อย่างเต็มที่: การบันทึก日志, การเริ่มต้นแบบกระจาย, การสร้างโมเดล, การสร้างชุดข้อมูล และการสร้างตัวปรับให้เหมาะสม ล้วนถูกห่อหุ้มอยู่ในตัวจัดการบริบท ข้อดีคือวงจรชีวิตของอ็อบเจกต์ชัดเจนมาก และการจัดการเส้นทางข้อยกเว้นก็เป็นหนึ่งเดียวและกระชับยิ่งขึ้น

3.3 ขั้นตอนการทำงานของ train_step ครั้งเดียว

ฟังก์ชัน train_step เป็นภาพจำลองพฤติกรรมของระบบ มันเชื่อมโยงขั้นตอนการฝึกหนึ่งครั้งอย่างสมบูรณ์: รับ batch, ดำเนินการ DualPipeV, ปรับขนาดและตัดเกรเดียนต์, อัปเดต optimizer, ปรับอัตราการเรียนรู้, บันทึก日志 และบันทึก checkpoint

# ที่มา: pithtrain/tasks/pretrain_language_model.py  
def train_step(cfg: PretrainLanguageModelCfg, ctx: PretrainLanguageModelCtx) -> None:  
model = ctx.training.model  
optimizer = ctx.training.optimizer  
scheduler = ctx.training.scheduler  
model.train()  

accumulate_steps = global_batch_size // (micro_batch_size * dp_size * ep_size)  
global_tokens, global_labels = get_global_batch(cfg, ctx, device)  

loss, _ = model.step(  
global_tokens,  
num_chunks=accumulate_steps,  
criterion=criterion,  
labels=(global_labels,),  
return_outputs=False,  
)  

if accumulate_steps > 1:  
scale = 1.0 / accumulate_steps  
for p in model.parameters():  
if p.grad is not None:  
p.grad.mul_(scale)  

gradient_norm = clip_grad_norm_(model, max_norm=1.0, norm_type=2)  
optimizer.step()  
scheduler.step()  
optimizer.zero_grad(set_to_none=True)  

สิ่งที่ต้องสังเกตเป็นพิเศษคือ model ที่นี่ไม่ใช่ Transformer ทั่วไป แต่เป็นโมเดลไปป์ไลน์ที่ถูกห่อหุ้มด้วย DualPipeV ซึ่งหมายความว่าวงรอบการฝึกจะไม่เรียก model(input) โดยตรงอีกต่อไป แต่จะ แบ่ง batch ทั้งหมดออกเป็น micro-batch chunks หลายชิ้น และมอบหมายให้ตัวจัดตารางจัดการการสลับการทำงานของการส่งผ่านเดินหน้า, การส่งผ่านถอยหลัง, การสื่อสาร และการคำนวณเกรเดียนต์น้ำหนัก

การขนานสี่มิติ: PP, DP, CP, EP ถูกรวมเป็น DeviceMesh ได้อย่างไร

4.1 การแบ่งหน้าที่ของการขนานสี่แบบ

มิติการขนานที่ PithTrain รองรับ ได้แก่:

  • PP, Pipeline Parallelism (การขนานไปป์ไลน์): แบ่งชั้นต่างๆ ของโมเดลไปยังเฟส GPU ที่แตกต่างกัน
  • DP, Data Parallelism (การขนานข้อมูล): คัดลอกหรือแบ่งส่วนโมเดลเพื่อประมวลผลตัวอย่างข้อมูลที่แตกต่างกัน
  • CP, Context Parallelism (การขนานบริบท): แบ่งตามมิติความยาวลำดับ และแลกเปลี่ยน KV cache ผ่านกลไก ring attention
  • EP, Expert Parallelism (การขนานผู้เชี่ยวชาญ): กระจายผู้เชี่ยวชาญต่างๆ ในชั้น MoE ไปยัง GPU ที่แตกต่างกัน

กลยุทธ์การขนานเหล่านี้ไม่ได้เป็นแค่การนำมาประกอบกันอย่างง่าย แต่แสดงออกอย่างเป็นหนึ่งเดียวผ่าน PyTorch DeviceMesh

# ที่มา: pithtrain/modules/distributed.py  
def setup_device_mesh(cfg: DistributedCfg, ctx: DistributedCtx) -> None:  
ctx.ep_size = cfg.expert_parallel_size  
ctx.pp_size = cfg.pipeline_parallel_size  
ctx.cp_size = cfg.context_parallel_size  
ctx.dp_size = ctx.world_size // (ctx.ep_size * ctx.pp_size * ctx.cp_size)  

kwargs = dict()  
kwargs["device_type"] = "cuda"  
kwargs["mesh_shape"] = (ctx.pp_size, ctx.dp_size, ctx.cp_size, ctx.ep_size)  
kwargs["mesh_dim_names"] = ("pp", "dp", "cp", "ep")  
ctx.device_mesh = torch.distributed.init_device_mesh(**kwargs)  

ctx.dp_rank = ctx.device_mesh.get_local_rank("dp")  
ctx.pp_rank = ctx.device_mesh.get_local_rank("pp")  
ctx.cp_rank = ctx.device_mesh.get_local_rank("cp")  
ctx.ep_rank = ctx.device_mesh.get_local_rank("ep")  

มีรายละเอียดทางวิศวกรรมที่สำคัญที่นี่: รูปร่างของ Mesh ถูกกำหนดเป็น (PP, DP, CP, EP) 注释明确指出,将 EP 和 CP 放置在内层,目的是让高频通信尽可能在 NVLink 域内完成 ซึ่งหมายความว่าการจัดวางโทโพโลยีไม่ใช่แนวคิดทางคณิตศาสตร์นามธรรมล้วนๆ แต่ให้บริการโดยตรงเพื่อเพิ่มประสิทธิภาพการสื่อสาร

4.2 การแบ่งส่วนพารามิเตอร์ FSDP และ MoE

ใน training.py PithTrain ใช้ FSDP Mesh ที่แตกต่างกันสำหรับพารามิเตอร์ผู้เชี่ยวชาญ MoE และพารามิเตอร์ทั่วไป:

# ที่มา: pithtrain/modules/training.py  
def apply_fsdp(model, mesh: torch.distributed.DeviceMesh):  
moe_fsdp_mesh = mesh["dp", "cp"]._flatten()  
other_fsdp_mesh = mesh["dp", "cp", "ep"]._flatten()  

for i in range(2):  
for layer in model[i].layers.values():  
if hasattr(layer.mlp, "experts"):  
fully_shard(  
layer.mlp.experts,  
mesh=moe_fsdp_mesh,  
reshard_after_forward=False,  
mp_policy=mp,  
)  
fully_shard(layer, mesh=other_fsdp_mesh, reshard_after_forward=False, mp_policy=mp)  

โค้ดนี้เผยให้เห็นความเข้าใจอย่างลึกซึ้งของ PithTrain เกี่ยวกับ MoE: พารามิเตอร์ผู้เชี่ยวชาญถูกกระจายผ่านมิติ EP อยู่แล้ว ดังนั้นผู้เชี่ยวชาญจึงต้องถูกแบ่งส่วนตามมิติ DP/CP เท่านั้น ในขณะที่พารามิเตอร์อื่นๆ ที่ไม่ใช่ผู้เชี่ยวชาญสามารถถูกแบ่งส่วนร่วมกันในสามมิติ DP/CP/EP การออกแบบนี้หลีกเลี่ยงการแบ่งส่วนซ้ำซ้อนในมิติเดียวกัน และยังรับประกันว่าหน้าที่ของการขนานผู้เชี่ยวชาญและ FSDP ชัดเจนและไม่ขัดแย้งกัน

ห้า: DualPipeV: “ศูนย์ควบคุมการจราจร” เพื่อเพิ่มปริมาณงานของ PithTrain

5.1 สายไปป์ไลน์สองทิศทางรูปตัว V

pithtrain/dualpipe/dualpipev.py ระบุอย่างชัดเจนว่ามันอิงตามแนวคิดการออกแบบ DeepSeek DualPipe และขยายคุณสมบัติต่างๆ เช่น การซ้อนทับห้าขั้นตอน, การรวม FSDP2, แคชน้ำหนัก FP8 และการจัดสรรเทนเซอร์กลางล่วงหน้า

คุณค่าหลักของ DualPipeV สามารถเข้าใจได้ด้วยการเปรียบเทียบที่ชัดเจน:

  • ไปป์ไลน์แบบดั้งเดิมเปรียบเสมือนถนนเดินรถทางเดียว รถยนต์ที่คำนวณเดินหน้าต้องผ่านไปทั้งหมดก่อน รถยนต์ที่คำนวณถอยหลังจึงจะเริ่มวิ่งได้ ทำให้เกิด “ฟองอากาศ” ที่ว่างเปล่าจำนวนมากในระหว่างนั้น
  • ในขณะที่ DualPipeV เปรียบเสมือนสะพานลอยสองชั้น มันจัดเรียงการคำนวณเดินหน้า, การคำนวณถอยหลัง, การสื่อสาร และการคำนวณเกรเดียนต์น้ำหนักอย่างชาญฉลาดในช่วงเวลาที่แตกต่างกัน เพื่อให้แน่ใจว่า GPU ทำงานอยู่ตลอดเวลา หลีกเลี่ยงการรอคอยที่ว่างเปล่า

5.2 การจัดตารางแปดขั้นตอนในฟังก์ชัน step

DualPipeV.step() เป็นหัวใจของตัวจัดตารางทั้งหมด มันต้องการ num_chunks >= pp_size * 2 จากนั้นดำเนินการจัดตารางที่มีแปดขั้นตอน: การเดินหน้า预热, การเดินหน้าสองเฟส, การสลับ B/W/F, วงรอบหลัก, การถอยหลังปิดท้าย และการคำนวณเกรเดียนต์น้ำหนักแบบไร้ฟองอากาศ ฯลฯ

# ที่มา: pithtrain/dualpipe/dualpipev.py  
def step(  
self,  
    *inputs: Optional[torch.Tensor],  
num_chunks: int = 0,  
criterion: Optional[Callable] = None,  
labels: List[Optional[torch.Tensor]] = [],  
return_outputs: bool = False,  
):  
assert num_chunks > 0 and num_chunks >= pp_size * 2  

self._reset_states()  
if FP8WeightCacheControl.enabled:  
FP8WeightCacheControl.step()  
self._ensure_intermediate_tensors_allocated(num_chunks)  

if self.is_first_pp_rank:  
self.input_chunks = (scatter(inputs, num_chunks, self.batch_dim), [])  
self.labels = scatter(labels, num_chunks, self.batch_dim)  
self.criterion = criterion  

# Step 1: nF0  
# Step 2: nF0F1  
# Step 3: nB1W1F1  
# Step 4: nF0B1F1B0  
# Step 5: nB1F1B0  
# Step 6: nB1B0  
# Step 7: nWB0  
# Step 8: nW

### 5.3 การควบคุม hook ของ FSDP ด้วยตนเอง

ในซอร์สโค้ดที่สมบูรณ์ 注释เหล่านี้สอดคล้องกับการเรียก `_forward_chunk`, `_backward_chunk`, `_forward_backward_chunk` และ `_weight_chunk` จำนวนมาก นวัตกรรมที่แท้จริงไม่ใช่ "การเขียนลูป" แต่เป็นการจัดเฟสการดำเนินการที่แตกต่างกันของแต่ละ micro-batch ให้เป็นตารางเวลา เพื่อให้เกิดการซ้อนทับการคำนวณสูงสุด

DualPipeV ยังแก้ปัญหาทางวิศวกรรมระดับต่ำมาก: การส่งผ่านถอยหลังแบบกำหนดเองถูกกระตุ้นซ้ำๆ ในไปป์ไลน์ หากปล่อยให้ hook เริ่มต้นของ FSDP เข้ามาแทรกแซงทุกครั้ง จะเกิดค่าใช้จ่าย CPU ที่ไม่จำเป็น ดังนั้น โค้ดจึงระงับการเรียกกลับหลังของโหนดราก FSDP ก่อน และเรียกด้วยตนเองหลังจากลูปทั้งหมดสิ้นสุดลง

```python
# ที่มา: pithtrain/dualpipe/dualpipev.py  
for module in self.module:  
if isinstance(module, FSDPModule):  
module.set_is_last_backward(False)  
module.set_reshard_after_backward(False)  
module.set_requires_gradient_sync(False)  
if not self.forward_only:  
fully_shard.state(module)._state_ctx.post_backward_final_callback_queued = True  

โค้ดนี้แสดงให้เห็นว่า PithTrain ไม่ใช่แค่เฟรมเวิร์กระดับบนที่ “เรียก PyTorch API” แต่เข้าใจความขัดแย้งที่อาจเกิดขึ้นระหว่างจังหวะเวลาการดำเนินการของ FSDP, กลไกการซิงค์เกรเดียนต์ และการจัดตารางไปป์ไลน์อย่างลึกซึ้ง

หก: การซ้อนทับห้าขั้นตอน: แยก Transformer หนึ่งชั้นเป็นชิ้นส่วนไปป์ไลน์ที่จัดตารางได้

6.1 กำหนดโปรโตคอลโมเดลก่อน

เพื่อให้โมเดลประเภทต่างๆ สามารถเชื่อมต่อกับ DualPipeV ได้ PithTrain จึงกำหนดชุดอินเทอร์เฟซโปรโตคอลไว้ล่วงหน้า แต่ละชั้น decoder ต้องมีสามเมธอดหลัก: forward_attn, forward_mlp และ forward_aggregate

# ที่มา: pithtrain/models/interface.py  
class DecoderLayerProtocol(Protocol):  
def forward_attn(self, hidden_states: torch.Tensor) -> ForwardAttnOutput:  
"""LN + Attn + LN + Expert selection."""  

def forward_mlp(  
self,  
gathered_tokens: torch.Tensor,  
expert_idxs: Optional[torch.Tensor] = None,  
expand_idx: Optional[torch.Tensor] = None,  
) -> torch.Tensor:  
"""MLP forward."""  

def forward_aggregate(  
self,  
moe_outs: torch.Tensor,  
moe_local_idxs: Optional[torch.Tensor],  
topk_weight: Optional[torch.Tensor],  
residual: torch.Tensor,  
) -> torch.Tensor:  
"""Weighted expert output + residual connection."""  

ด้วยวิธีนี้ แม้ว่าโมเดลต่างๆ เช่น Qwen, DeepSeek, GPT-OSS จะมีโครงสร้างที่แตกต่างกัน แต่ตราบใดที่ปฏิบัติตามโปรโตคอลนี้ ก็สามารถใช้ตัวจัดตารางไปป์ไลน์ชุดเดียวกันได้

6.2 ความหมายที่แท้จริงของห้าขั้นตอน

ในตอนต้นของ dualpipev.py มีการระบุความสัมพันธ์ของการแมปห้าขั้นตอนโดยตรง:

  1. Attention: LN + Attention + LN + Expert selection
  2. Dispatch: การกระจาย all-to-all แบบขนานผู้เชี่ยวชาญ
  3. MLP: การคำนวณ MLP ของผู้เชี่ยวชาญหรือทั่วไป
  4. Combine: การรวม all-to-all แบบขนานผู้เชี่ยวชาญ
  5. Aggregate: ผลลัพธ์ผู้เชี่ยวชาญแบบถ่วงน้ำหนักและการเชื่อมต่อ residual

ในขณะที่ overlap.py รับผิดชอบการสลับการทำงานของโมดูลเดินหน้า 0 และโมดูลถอยหลัง 1

# ที่มา: pithtrain/dualpipe/overlap.py  
def overlapped_forward_backward(  
module0: ModelProtocol,  
inputs0: List[torch.Tensor],  
criterion0: Optional[Callable],  
labels0: Optional[List[torch.Tensor]],  
intermediate_tensors0: IntermediateTensors,  
module1: ModelProtocol,  
loss1: Optional[torch.Tensor],  
outputs1: Optional[List[torch.Tensor]],  
output_grads1: Optional[List[torch.Tensor]],  
intermediate_tensors1: IntermediateTensors,  
comm_stream: Optional[torch.cuda.Stream],  
ep_group: Optional[torch.distributed.ProcessGroup] = None,  
):  
# Interleaves forward stage1/2/3/4/5 with backward stage5/4/3/2/1  

ในลูปหลัก โค้ดจะสลับการทำงานของฟังก์ชันต่างๆ เช่น stage5_b, stage4_b, stage1_f, stage2_f, stage3_b, stage3_w, stage3_f สามารถเข้าใจการออกแบบนี้ได้ดังนี้:

  • เมื่อ token หนึ่งชุดกำลังรอการสื่อสารผู้เชี่ยวชาญ การคำนวณ attention หรือการส่งผ่านถอยหลังของ token อีกชุดสามารถเริ่มต้นได้ก่อน
  • เมื่อบางเฟสต้องรอสตรีมการสื่อสาร สตรีมการคำนวณสามารถดำเนินการงานของเฟสอื่นต่อไปได้

6.3 ทำไมต้องแยกเป็นห้าส่วน


⚠️ หมายเหตุ: เนื้อหาได้รับการแปลโดย AI และตรวจสอบโดยมนุษย์ หากมีข้อผิดพลาดโปรดแจ้ง

☕ สนับสนุนค่ากาแฟทีมงาน

หากคุณชอบบทความนี้ สามารถสนับสนุนเราได้ผ่าน PromptPay

PromptPay QR
SCAN TO PAY WITH ANY BANK

本文来自网络搜集,不代表คลื่นสร้างอนาคต立场,如有侵权,联系删除。转载请注明出处:http://www.itsolotime.com/th/archives/32986

Like (0)
Previous 1 hour ago
Next 59 mins ago

相关推荐