ระบบฝึกอบรมโมเดลขนาดใหญ่มักเปรียบเสมือนโรงงานปิด: สายการผลิต โทโพโลยีการสื่อสาร การกำหนดเส้นทางผู้เชี่ยวชาญ การใช้หน่วยความจำซ้ำ การฝึกแบบผสมความแม่นยำ และการกู้คืนจุดตรวจ ล้วนทำงานด้วยความเร็วสูง แต่นักพัฒนามักมองไม่เห็นว่าเฟืองจักรเหล่านี้ทำงานประสานกันอย่างไร
เฟรมเวิร์กการผลิตมีประสิทธิภาพสูง แต่มักถูกห่อหุ้มด้วยโค้ด C++/CUDA หลายหมื่นบรรทัดและรันไทม์ที่ซับซ้อน ในขณะที่โค้ดน้ำหนักเบาอ่านง่าย แต่มักรับภาระปริมาณงานของการฝึก MoE จริงไม่ได้
Pith-Train พยายาม打破ทางเลือกสองทางนี้: มันใช้ Python ประมาณหนึ่งหมื่นบรรทัด จัดระเบียบ Pipeline, Expert, FSDP, Context แบบขนานสี่มิติ, การจัดตารางเดินหน้า-ถอยหลังซ้อนทับของ DualPipeV, การฝึก FP8 ของ DeepGEMM, และโอเปอเรเตอร์ Triton แบบฟิวชัน ให้เป็นระบบฝึกอบรมที่อ่านได้ตั้งแต่ต้นจนจบ
ความสำคัญของมันไม่ได้เป็นเพียง “เฟรมเวิร์กการฝึกอบรมอีกตัวหนึ่ง” แต่เป็นการแสดงให้เห็นว่าวิศวกรรมการฝึกอบรมประสิทธิภาพสูงสามารถถูกออกแบบใหม่ได้อย่างไรในยุคของตัวช่วยโค้ด AI
สารบัญ
- หนึ่ง: เริ่มต้นอย่างรวดเร็ว: จากการติดตั้งจนถึงการรัน MoE Pre-training ครั้งแรก
- สอง: ตำแหน่งของโปรเจกต์: ไม่ใช่เฟรมเวิร์กของเล่น แต่เป็น “ชุดฝึกอบรมระดับการผลิตที่อ่านได้”
- 2.1 ปัญหาหลักที่ PithTrain ต้องการแก้ไข
- 2.2 หน้าที่ของไดเรกทอรีหลัก
- สาม: เส้นทางหลักการฝึก: จาก torchrun สู่การอัปเดตพารามิเตอร์ครั้งเดียว
- 3.1 สคริปต์เริ่มต้น: ปรับอัตโนมัติสำหรับเครื่องเดียวและ SLURM
- 3.2 ฟังก์ชัน launch: สามชั้นบริบทห่อหุ้มวงจรชีวิตการฝึก
- 3.3 สิ่งที่เกิดขึ้นใน train_step ครั้งเดียว
- สี่: การขนานสี่มิติ: PP, DP, CP, EP ถูกรวมเป็น DeviceMesh ได้อย่างไร
- 4.1 การขนานสี่แบบแต่ละแบบมีหน้าที่ของตน
- 4.2 การแบ่งส่วนพารามิเตอร์ FSDP และ MoE
- ห้า: DualPipeV: “ศูนย์ควบคุมการจราจร” เพื่อเพิ่มปริมาณงานของ PithTrain
- 5.1 สายไปป์ไลน์สองทิศทางรูปตัว V
- 5.2 การจัดตารางแปดขั้นตอนในฟังก์ชัน step
- 5.3 การควบคุม hook ของ FSDP ด้วยตนเอง
- หก: การซ้อนทับห้าขั้นตอน: แยก Transformer หนึ่งชั้นเป็นชิ้นส่วนไปป์ไลน์ที่จัดตารางได้
- 6.1 โปรโตคอลโมเดลมาก่อน
- 6.2 ความหมายที่แท้จริงของห้าขั้นตอน
- 6.3 ทำไมต้องแยกเป็นห้าส่วน
- เจ็ด: การฝึก FP8: การหาปริมาณแบบฟิวชันของ DeepGEMM และ Triton
- 7.1 FP8Linear ทำงานอย่างไร
- 7.2 เคอร์เนลการหาปริมาณที่รับรู้สถาปัตยกรรม
- แปด: การกระจายแบบขนานผู้เชี่ยวชาญ: ใช้ Triton ฟิวชันโอเปอเรเตอร์เล็กๆ กว่ายี่สิบตัวเป็นสามเคอร์เนล
- 8.1 ทำไมการสื่อสาร MoE จึงยาก
- 8.2 การเรียงลำดับแบบนับ O(n) แทนที่ argsort O(n log n)
- เก้า: ข้อมูลและ checkpoint: “วิศวกรรมฐานราก” ของระบบฝึกอบรม
- 9.1 การอ่านข้อมูล mmap และการแบ่งส่วนแบบขนานบริบท
- 9.2 การสร้างข้อมูล: การแยกคำแบบหลายกระบวนการที่กู้คืนได้
- 9.3 รูปแบบมาตรฐานของ checkpoint
- สิบ: ความสำคัญทางเทคนิคและขอบเขตของ PithTrain
- 10.1 คุณค่าสูงสุด: เปลี่ยนระบบฝึกอบรมประสิทธิภาพสูงให้เป็นวัตถุที่เรียนรู้ได้
- 10.2 ความแตกต่างจากชุดฝึกอบรมแบบดั้งเดิม
- 10.3 ขอบเขตปัจจุบัน
- สิบเอ็ด: แผนผังข้อความหนึ่งภาพ: PithTrain ทำการฝึกอบรมครั้งเดียวได้อย่างไร
- 11.1 เส้นทางการดำเนินการแบบ end-to-end
- บทสรุป: แรงบันดาลใจที่ PithTrain มอบให้กับการออกแบบระบบฝึกอบรม
หนึ่ง: เริ่มต้นอย่างรวดเร็ว: จากการติดตั้งจนถึงการรัน MoE Pre-training ครั้งแรก
README ของ PithTrain ระบุข้อกำหนดด้านสภาพแวดล้อมอย่างตรงไปตรงมา: ต้องใช้ GPU NVIDIA Hopper (SM90) หรือ Blackwell (SM100), CUDA 13.0, Python >= 3.12 และใช้ uv จัดการ dependencies เส้นทางการติดตั้งขั้นต่ำมีดังนี้
git clone https://github.com/mlc-ai/Pith-Train.git && cd Pith-Train
uv venv
uv pip install .
# หากเป็นนักพัฒนา ต้องติดตั้ง dependencies สำหรับการพัฒนาและสภาพแวดล้อมซอร์สโค้ด:
uv sync
การ pre-training Qwen3-30B-A3B ทั่วไปตั้งแต่เริ่มต้นแบ่งเป็นสามขั้นตอน
- ขั้นตอนแรก: ดาวน์โหลดและแยกคำคลังข้อมูล pre-training DCLM:
bash examples/build_tokenized_corpus/launch.sh dclm-qwen3 - ขั้นตอนที่สอง: แก้ไขสคริปต์การฝึก ปรับขนาดการขนาน batch size อัตราการเรียนรู้ ฯลฯ ไฟล์คอนฟิกอยู่ที่:
examples/pretrain_language_model/qwen3-30b-a3b/script.py - ขั้นตอนที่สาม: เริ่มการฝึก:
bash examples/pretrain_language_model/launch.sh qwen3-30b-a3b
สคริปต์การฝึกจะตรวจจับจำนวน GPU โดยอัตโนมัติ และรองรับทั้งการรันบนเครื่องเดียวและสภาพแวดล้อมคลัสเตอร์หลายเครื่องแบบ SLURM จุดตรวจสอบโมเดลจะถูกบันทึกไว้ในไดเรกทอรี workspace โดยค่าเริ่มต้น และมีความสามารถในการกู้คืนการฝึกจาก checkpoint ล่าสุดโดยอัตโนมัติ เมื่อการฝึกสิ้นสุดลง คุณสามารถแปลงน้ำหนักโมเดลในรูปแบบ PyTorch Distributed Checkpoint เป็นรูปแบบที่เข้ากันได้กับ Hugging Face:
bash examples/convert_checkpoint/launch.sh qwen3-30b-a3b
สำหรับรายละเอียดเพิ่มเติมเกี่ยวกับโมเดลและการแปลง โปรดดู examples/build_tokenized_corpus/README.md, examples/pretrain_language_model/ และ examples/convert_checkpoint/README.md
unsetunsetสอง: ตำแหน่งของโปรเจกต์: ไม่ใช่เฟรมเวิร์กของเล่น แต่เป็น “ชุดฝึกอบรมระดับการผลิตที่อ่านได้”unsetunset
2.1 ปัญหาหลักที่ PithTrain ต้องการแก้ไข
README ของ PithTrain ระบุตำแหน่งของตัวเองอย่างเฉียบคม: Efficient, Python-native MoE training in ~10K lines of code. มันไม่ได้มุ่งเน้นที่ “วิธีการเขียนตัวอย่างการฝึก Transformer แบบง่าย” แต่ชี้ตรงไปยังเป้าหมายสามประการที่ยากจะบรรลุพร้อมกันในการฝึกโมเดลขนาดใหญ่แบบ MoE:
- ประสิทธิภาพใกล้เคียงระบบระดับการผลิต: รองรับการขนาน 4 มิติ, การซ้อนทับการคำนวณและการสื่อสาร, การฝึก FP8, DeepGEMM, FlashAttention, และโอเปอเรเตอร์ Triton/TileLang
- ตรรกะการทำงานโปร่งใสเพียงพอ: โค้ดหลักเขียนด้วย Python ขนาดคลังทั้งหมดประมาณหนึ่งหมื่นบรรทัด ทั้งนักพัฒนามนุษย์และ AI agent สามารถเข้าใจกลไกการทำงานแบบ end-to-end
- กระบวนการทางวิศวกรรมครบวงจร: ครอบคลุมการสร้างข้อมูล, วงรอบการฝึก, โทโพโลยีแบบกระจาย, การใช้งานโมเดล, การแปลง checkpoint, การบันทึก日志, การทดสอบ และการวัดประสิทธิภาพ
กล่าวโดยย่อ PithTrain ไม่ได้พยายามซ่อนความซับซ้อน แต่จัดระเบียบความซับซ้อนให้เป็นชั้นลำดับชั้นที่อ่านง่าย
คำอธิบายสถาปัตยกรรมใน README แบ่งระบบทั้งหมดออกเป็นสามชั้น:
- Upstream: วงรอบการฝึกสำหรับงานต่างๆ เช่น pre-training, SFT
- Core: ประกอบด้วยโมเดล, บล็อกการสร้าง, ไปป์ไลน์ DualPipeV, การฝึกแบบกระจาย และโครงสร้างพื้นฐานการฝึก
- Operators: ครอบคลุม PyTorch/NCCL, DeepGEMM, FlashAttention, และโอเปอเรเตอร์ Python DSL เช่น Triton, TileLang
2.2 ไดเรกทอรีหลักและหน้าที่
จากโครงสร้างคลัง โค้ดหลักของ PithTrain กระจุกตัวอยู่ในไดเรกทอรี
pithtrain/:
pithtrain/tasks/pretrain_language_model.py: จุดเริ่มต้นสำหรับ pre-training โมเดลภาษา รับผิดชอบการจัดระเบียบบริบท การโหลด checkpoint และการดำเนินวงรอบการฝึกpithtrain/modules/training.py: รับผิดชอบการเริ่มต้นคอนฟิกการฝึก, ชุดข้อมูล, โมเดล, FSDP, ตัวปรับให้เหมาะสม และตัวจัดตารางอัตราการเรียนรู้pithtrain/modules/distributed.py: ใช้สร้าง DeviceMesh แบบ 4 มิติ PP/DP/CP/EPpithtrain/modules/dataset.py: การอ่านข้อมูล token แบบแพ็ครวมโดยใช้ mmap และการสับเปลี่ยนแบบ globalpithtrain/dualpipe/dualpipev.py: ตัวจัดตารางไปป์ไลน์ DualPipeV ซึ่งเป็นศูนย์กลางหลักในการเพิ่มประสิทธิภาพปริมาณงานของระบบpithtrain/dualpipe/overlap.py: แยกชั้น Transformer ออกเป็นห้าขั้นตอน เพื่อให้เกิดการซ้อนทับแบบละเอียดของการคำนวณเดินหน้าและถอยหลังpithtrain/models/qwen3_30b_a3b.py,deepseek_v2_lite.py,gpt_oss.py: การใช้งานโครงสร้างโมเดลเฉพาะpithtrain/models/interface.py: อินเทอร์เฟซโปรโตคอลที่ชั้นโมเดลต้องปฏิบัติตามสำหรับ DualPipeVpithtrain/layers/deepgemm_fp8_linear.py: ชั้นเชิงเส้น FP8 และ MoE GroupLinear ที่ใช้ DeepGEMMpithtrain/operators/ep_dispatch.py: การใช้งาน Triton แบบฟิวชันสำหรับการกระจายแบบขนานผู้เชี่ยวชาญpithtrain/operators/deepgemm_fp8_quantize.py: เคอร์เนล Triton สำหรับการหาปริมาณ FP8pithtrain/modules/checkpoint.py: เครื่องมือแปลง checkpoint แบบมาตรฐานที่ไม่ขึ้นกับ PPexamples/: ประกอบด้วยสคริปต์ที่รันได้สำหรับการเตรียมข้อมูล, pre-training, การแปลง checkpoint ฯลฯ
การจัดระเบียบนี้ชาญฉลาดมาก: ชั้นงานจะไม่เขียนโอเปอเรเตอร์ที่ซับซ้อนโดยตรง และชั้นโอเปอเรเตอร์ก็ไม่รับรู้ถึงวงรอบการฝึก ระบบฝึกอบรมที่ซับซ้อนทั้งหมดถูกตัดเป็นโมดูลที่สามารถเปลี่ยนแทนกันได้ แต่ละโมดูลมีขอบเขตที่ชัดเจน
unsetunsetสาม: เส้นทางหลักการฝึก: จาก torchrun สู่การอัปเดตพารามิเตอร์ครั้งเดียวunsetunset
3.1 สคริปต์เริ่มต้น: ปรับอัตโนมัติสำหรับเครื่องเดียวและสภาพแวดล้อม SLURM
จุดเริ่มต้นของ pre-training เริ่มจากสคริปต์ shell สคริปต์นี้จะสร้างพารามิเตอร์ที่จำเป็นสำหรับ torchrun โดยอัตโนมัติตามตัวแปรสภาพแวดล้อม SLURM หรือจำนวน GPU ของเครื่อง
ที่มา: examples/pretrain_language_model/launch.sh
SLURM_NNODES=${SLURM_NNODES:-1}
SLURM_NODEID=${SLURM_NODEID:-0}
SLURM_STEP_GPUS=${SLURM_STEP_GPUS:-${CUDA_VISIBLE_DEVICES:-$(nvidia-smi –query-gpu=index –format=csv,noheader | paste -sd,)}}
LAUNCH_ARGS+=(–nnodes=$SLURM_NNODES –node-rank=$SLURM_NODEID)
LAUNCH_ARGS+=(–nproc-per-node=$(echo “$SLURM_STEP_GPUS” | tr ‘,’ ‘n’ | wc -l))
LAUNCH_ARGS+=(–rdzv-backend=c10d)
SCRIPT=examples/pretrain_language_model/$1/script.py
torchrun ${LAUNCH_ARGS[@]} $SCRIPT
สคริปต์นี้ดูเรียบง่ายแต่สำคัญมาก: PithTrain ปล่อยให้ “ความซับซ้อนของการเริ่มต้นหลายเครื่อง” อยู่ในสคริปต์ shell ภายนอก ในขณะที่โค้ด Python ภายในสมมติว่ามันทำงานภายใต้สภาพแวดล้อม torchrun มาตรฐาน
3.2 ฟังก์ชัน launch: สามชั้นบริบทห่อหุ้มวงจรชีวิตการฝึก
จุดเริ่มต้นการฝึก Python ที่แท้จริงอยู่ที่ pithtrain/tasks/pretrain_language_model.py ไฟล์นี้ใช้กลไก ExitStack เพื่อค่อยๆ สร้างบริบทการบันทึก日志, สภาพแวดล้อมแบบกระจาย และบริบทการฝึก จากนั้นโหลด checkpoint และเข้าสู่วงรอบการฝึกหลัก
# ที่มา: pithtrain/tasks/pretrain_language_model.py
@shutdown.record
def launch(cfg: PretrainLanguageModelCfg) -> None:
"""เริ่มต้นกระบวนการ pre-training โมเดลภาษา"""
with ExitStack() as stack:
ctx = PretrainLanguageModelCtx()
stack.enter_context(logging_context(cfg, ctx))
stack.enter_context(distributed_context(cfg, ctx))
stack.enter_context(training_context(cfg, ctx))
logger = ctx.logging.stdout
logger.info("launch(cfg=%s)" % cfg)
load_checkpoint(cfg, ctx)
raise_if_dataset_insufficient(cfg, ctx)
while ctx.training.step < cfg.training.max_steps:
train_step(cfg, ctx)
การออกแบบนี้สะท้อนปรัชญาการเขียนโปรแกรมแบบ Pythonic อย่างเต็มที่: การบันทึก日志, การเริ่มต้นแบบกระจาย, การสร้างโมเดล, การสร้างชุดข้อมูล และการสร้างตัวปรับให้เหมาะสม ล้วนถูกห่อหุ้มอยู่ในตัวจัดการบริบท ข้อดีคือวงจรชีวิตของอ็อบเจกต์ชัดเจนมาก และการจัดการเส้นทางข้อยกเว้นก็เป็นหนึ่งเดียวและกระชับยิ่งขึ้น
3.3 ขั้นตอนการทำงานของ train_step ครั้งเดียว
ฟังก์ชัน train_step เป็นภาพจำลองพฤติกรรมของระบบ มันเชื่อมโยงขั้นตอนการฝึกหนึ่งครั้งอย่างสมบูรณ์: รับ batch, ดำเนินการ DualPipeV, ปรับขนาดและตัดเกรเดียนต์, อัปเดต optimizer, ปรับอัตราการเรียนรู้, บันทึก日志 และบันทึก checkpoint
# ที่มา: pithtrain/tasks/pretrain_language_model.py
def train_step(cfg: PretrainLanguageModelCfg, ctx: PretrainLanguageModelCtx) -> None:
model = ctx.training.model
optimizer = ctx.training.optimizer
scheduler = ctx.training.scheduler
model.train()
accumulate_steps = global_batch_size // (micro_batch_size * dp_size * ep_size)
global_tokens, global_labels = get_global_batch(cfg, ctx, device)
loss, _ = model.step(
global_tokens,
num_chunks=accumulate_steps,
criterion=criterion,
labels=(global_labels,),
return_outputs=False,
)
if accumulate_steps > 1:
scale = 1.0 / accumulate_steps
for p in model.parameters():
if p.grad is not None:
p.grad.mul_(scale)
gradient_norm = clip_grad_norm_(model, max_norm=1.0, norm_type=2)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
สิ่งที่ต้องสังเกตเป็นพิเศษคือ model ที่นี่ไม่ใช่ Transformer ทั่วไป แต่เป็นโมเดลไปป์ไลน์ที่ถูกห่อหุ้มด้วย DualPipeV ซึ่งหมายความว่าวงรอบการฝึกจะไม่เรียก model(input) โดยตรงอีกต่อไป แต่จะ แบ่ง batch ทั้งหมดออกเป็น micro-batch chunks หลายชิ้น และมอบหมายให้ตัวจัดตารางจัดการการสลับการทำงานของการส่งผ่านเดินหน้า, การส่งผ่านถอยหลัง, การสื่อสาร และการคำนวณเกรเดียนต์น้ำหนัก
การขนานสี่มิติ: PP, DP, CP, EP ถูกรวมเป็น DeviceMesh ได้อย่างไร
4.1 การแบ่งหน้าที่ของการขนานสี่แบบ
มิติการขนานที่ PithTrain รองรับ ได้แก่:
- PP, Pipeline Parallelism (การขนานไปป์ไลน์): แบ่งชั้นต่างๆ ของโมเดลไปยังเฟส GPU ที่แตกต่างกัน
- DP, Data Parallelism (การขนานข้อมูล): คัดลอกหรือแบ่งส่วนโมเดลเพื่อประมวลผลตัวอย่างข้อมูลที่แตกต่างกัน
- CP, Context Parallelism (การขนานบริบท): แบ่งตามมิติความยาวลำดับ และแลกเปลี่ยน KV cache ผ่านกลไก ring attention
- EP, Expert Parallelism (การขนานผู้เชี่ยวชาญ): กระจายผู้เชี่ยวชาญต่างๆ ในชั้น MoE ไปยัง GPU ที่แตกต่างกัน
กลยุทธ์การขนานเหล่านี้ไม่ได้เป็นแค่การนำมาประกอบกันอย่างง่าย แต่แสดงออกอย่างเป็นหนึ่งเดียวผ่าน PyTorch DeviceMesh
# ที่มา: pithtrain/modules/distributed.py
def setup_device_mesh(cfg: DistributedCfg, ctx: DistributedCtx) -> None:
ctx.ep_size = cfg.expert_parallel_size
ctx.pp_size = cfg.pipeline_parallel_size
ctx.cp_size = cfg.context_parallel_size
ctx.dp_size = ctx.world_size // (ctx.ep_size * ctx.pp_size * ctx.cp_size)
kwargs = dict()
kwargs["device_type"] = "cuda"
kwargs["mesh_shape"] = (ctx.pp_size, ctx.dp_size, ctx.cp_size, ctx.ep_size)
kwargs["mesh_dim_names"] = ("pp", "dp", "cp", "ep")
ctx.device_mesh = torch.distributed.init_device_mesh(**kwargs)
ctx.dp_rank = ctx.device_mesh.get_local_rank("dp")
ctx.pp_rank = ctx.device_mesh.get_local_rank("pp")
ctx.cp_rank = ctx.device_mesh.get_local_rank("cp")
ctx.ep_rank = ctx.device_mesh.get_local_rank("ep")
มีรายละเอียดทางวิศวกรรมที่สำคัญที่นี่: รูปร่างของ Mesh ถูกกำหนดเป็น (PP, DP, CP, EP) 注释明确指出,将 EP 和 CP 放置在内层,目的是让高频通信尽可能在 NVLink 域内完成 ซึ่งหมายความว่าการจัดวางโทโพโลยีไม่ใช่แนวคิดทางคณิตศาสตร์นามธรรมล้วนๆ แต่ให้บริการโดยตรงเพื่อเพิ่มประสิทธิภาพการสื่อสาร
4.2 การแบ่งส่วนพารามิเตอร์ FSDP และ MoE
ใน training.py PithTrain ใช้ FSDP Mesh ที่แตกต่างกันสำหรับพารามิเตอร์ผู้เชี่ยวชาญ MoE และพารามิเตอร์ทั่วไป:
# ที่มา: pithtrain/modules/training.py
def apply_fsdp(model, mesh: torch.distributed.DeviceMesh):
moe_fsdp_mesh = mesh["dp", "cp"]._flatten()
other_fsdp_mesh = mesh["dp", "cp", "ep"]._flatten()
for i in range(2):
for layer in model[i].layers.values():
if hasattr(layer.mlp, "experts"):
fully_shard(
layer.mlp.experts,
mesh=moe_fsdp_mesh,
reshard_after_forward=False,
mp_policy=mp,
)
fully_shard(layer, mesh=other_fsdp_mesh, reshard_after_forward=False, mp_policy=mp)
โค้ดนี้เผยให้เห็นความเข้าใจอย่างลึกซึ้งของ PithTrain เกี่ยวกับ MoE: พารามิเตอร์ผู้เชี่ยวชาญถูกกระจายผ่านมิติ EP อยู่แล้ว ดังนั้นผู้เชี่ยวชาญจึงต้องถูกแบ่งส่วนตามมิติ DP/CP เท่านั้น ในขณะที่พารามิเตอร์อื่นๆ ที่ไม่ใช่ผู้เชี่ยวชาญสามารถถูกแบ่งส่วนร่วมกันในสามมิติ DP/CP/EP การออกแบบนี้หลีกเลี่ยงการแบ่งส่วนซ้ำซ้อนในมิติเดียวกัน และยังรับประกันว่าหน้าที่ของการขนานผู้เชี่ยวชาญและ FSDP ชัดเจนและไม่ขัดแย้งกัน
ห้า: DualPipeV: “ศูนย์ควบคุมการจราจร” เพื่อเพิ่มปริมาณงานของ PithTrain
5.1 สายไปป์ไลน์สองทิศทางรูปตัว V
pithtrain/dualpipe/dualpipev.pyระบุอย่างชัดเจนว่ามันอิงตามแนวคิดการออกแบบ DeepSeek DualPipe และขยายคุณสมบัติต่างๆ เช่น การซ้อนทับห้าขั้นตอน, การรวม FSDP2, แคชน้ำหนัก FP8 และการจัดสรรเทนเซอร์กลางล่วงหน้า
คุณค่าหลักของ DualPipeV สามารถเข้าใจได้ด้วยการเปรียบเทียบที่ชัดเจน:
- ไปป์ไลน์แบบดั้งเดิมเปรียบเสมือนถนนเดินรถทางเดียว รถยนต์ที่คำนวณเดินหน้าต้องผ่านไปทั้งหมดก่อน รถยนต์ที่คำนวณถอยหลังจึงจะเริ่มวิ่งได้ ทำให้เกิด “ฟองอากาศ” ที่ว่างเปล่าจำนวนมากในระหว่างนั้น
- ในขณะที่ DualPipeV เปรียบเสมือนสะพานลอยสองชั้น มันจัดเรียงการคำนวณเดินหน้า, การคำนวณถอยหลัง, การสื่อสาร และการคำนวณเกรเดียนต์น้ำหนักอย่างชาญฉลาดในช่วงเวลาที่แตกต่างกัน เพื่อให้แน่ใจว่า GPU ทำงานอยู่ตลอดเวลา หลีกเลี่ยงการรอคอยที่ว่างเปล่า
5.2 การจัดตารางแปดขั้นตอนในฟังก์ชัน step
DualPipeV.step() เป็นหัวใจของตัวจัดตารางทั้งหมด มันต้องการ num_chunks >= pp_size * 2 จากนั้นดำเนินการจัดตารางที่มีแปดขั้นตอน: การเดินหน้า预热, การเดินหน้าสองเฟส, การสลับ B/W/F, วงรอบหลัก, การถอยหลังปิดท้าย และการคำนวณเกรเดียนต์น้ำหนักแบบไร้ฟองอากาศ ฯลฯ
# ที่มา: pithtrain/dualpipe/dualpipev.py
def step(
self,
*inputs: Optional[torch.Tensor],
num_chunks: int = 0,
criterion: Optional[Callable] = None,
labels: List[Optional[torch.Tensor]] = [],
return_outputs: bool = False,
):
assert num_chunks > 0 and num_chunks >= pp_size * 2
self._reset_states()
if FP8WeightCacheControl.enabled:
FP8WeightCacheControl.step()
self._ensure_intermediate_tensors_allocated(num_chunks)
if self.is_first_pp_rank:
self.input_chunks = (scatter(inputs, num_chunks, self.batch_dim), [])
self.labels = scatter(labels, num_chunks, self.batch_dim)
self.criterion = criterion
# Step 1: nF0
# Step 2: nF0F1
# Step 3: nB1W1F1
# Step 4: nF0B1F1B0
# Step 5: nB1F1B0
# Step 6: nB1B0
# Step 7: nWB0
# Step 8: nW
### 5.3 การควบคุม hook ของ FSDP ด้วยตนเอง
ในซอร์สโค้ดที่สมบูรณ์ 注释เหล่านี้สอดคล้องกับการเรียก `_forward_chunk`, `_backward_chunk`, `_forward_backward_chunk` และ `_weight_chunk` จำนวนมาก นวัตกรรมที่แท้จริงไม่ใช่ "การเขียนลูป" แต่เป็นการจัดเฟสการดำเนินการที่แตกต่างกันของแต่ละ micro-batch ให้เป็นตารางเวลา เพื่อให้เกิดการซ้อนทับการคำนวณสูงสุด
DualPipeV ยังแก้ปัญหาทางวิศวกรรมระดับต่ำมาก: การส่งผ่านถอยหลังแบบกำหนดเองถูกกระตุ้นซ้ำๆ ในไปป์ไลน์ หากปล่อยให้ hook เริ่มต้นของ FSDP เข้ามาแทรกแซงทุกครั้ง จะเกิดค่าใช้จ่าย CPU ที่ไม่จำเป็น ดังนั้น โค้ดจึงระงับการเรียกกลับหลังของโหนดราก FSDP ก่อน และเรียกด้วยตนเองหลังจากลูปทั้งหมดสิ้นสุดลง
```python
# ที่มา: pithtrain/dualpipe/dualpipev.py
for module in self.module:
if isinstance(module, FSDPModule):
module.set_is_last_backward(False)
module.set_reshard_after_backward(False)
module.set_requires_gradient_sync(False)
if not self.forward_only:
fully_shard.state(module)._state_ctx.post_backward_final_callback_queued = True
โค้ดนี้แสดงให้เห็นว่า PithTrain ไม่ใช่แค่เฟรมเวิร์กระดับบนที่ “เรียก PyTorch API” แต่เข้าใจความขัดแย้งที่อาจเกิดขึ้นระหว่างจังหวะเวลาการดำเนินการของ FSDP, กลไกการซิงค์เกรเดียนต์ และการจัดตารางไปป์ไลน์อย่างลึกซึ้ง
หก: การซ้อนทับห้าขั้นตอน: แยก Transformer หนึ่งชั้นเป็นชิ้นส่วนไปป์ไลน์ที่จัดตารางได้
6.1 กำหนดโปรโตคอลโมเดลก่อน
เพื่อให้โมเดลประเภทต่างๆ สามารถเชื่อมต่อกับ DualPipeV ได้ PithTrain จึงกำหนดชุดอินเทอร์เฟซโปรโตคอลไว้ล่วงหน้า แต่ละชั้น decoder ต้องมีสามเมธอดหลัก: forward_attn, forward_mlp และ forward_aggregate
# ที่มา: pithtrain/models/interface.py
class DecoderLayerProtocol(Protocol):
def forward_attn(self, hidden_states: torch.Tensor) -> ForwardAttnOutput:
"""LN + Attn + LN + Expert selection."""
def forward_mlp(
self,
gathered_tokens: torch.Tensor,
expert_idxs: Optional[torch.Tensor] = None,
expand_idx: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""MLP forward."""
def forward_aggregate(
self,
moe_outs: torch.Tensor,
moe_local_idxs: Optional[torch.Tensor],
topk_weight: Optional[torch.Tensor],
residual: torch.Tensor,
) -> torch.Tensor:
"""Weighted expert output + residual connection."""
ด้วยวิธีนี้ แม้ว่าโมเดลต่างๆ เช่น Qwen, DeepSeek, GPT-OSS จะมีโครงสร้างที่แตกต่างกัน แต่ตราบใดที่ปฏิบัติตามโปรโตคอลนี้ ก็สามารถใช้ตัวจัดตารางไปป์ไลน์ชุดเดียวกันได้
6.2 ความหมายที่แท้จริงของห้าขั้นตอน
ในตอนต้นของ dualpipev.py มีการระบุความสัมพันธ์ของการแมปห้าขั้นตอนโดยตรง:
- Attention: LN + Attention + LN + Expert selection
- Dispatch: การกระจาย all-to-all แบบขนานผู้เชี่ยวชาญ
- MLP: การคำนวณ MLP ของผู้เชี่ยวชาญหรือทั่วไป
- Combine: การรวม all-to-all แบบขนานผู้เชี่ยวชาญ
- Aggregate: ผลลัพธ์ผู้เชี่ยวชาญแบบถ่วงน้ำหนักและการเชื่อมต่อ residual
ในขณะที่ overlap.py รับผิดชอบการสลับการทำงานของโมดูลเดินหน้า 0 และโมดูลถอยหลัง 1
# ที่มา: pithtrain/dualpipe/overlap.py
def overlapped_forward_backward(
module0: ModelProtocol,
inputs0: List[torch.Tensor],
criterion0: Optional[Callable],
labels0: Optional[List[torch.Tensor]],
intermediate_tensors0: IntermediateTensors,
module1: ModelProtocol,
loss1: Optional[torch.Tensor],
outputs1: Optional[List[torch.Tensor]],
output_grads1: Optional[List[torch.Tensor]],
intermediate_tensors1: IntermediateTensors,
comm_stream: Optional[torch.cuda.Stream],
ep_group: Optional[torch.distributed.ProcessGroup] = None,
):
# Interleaves forward stage1/2/3/4/5 with backward stage5/4/3/2/1
ในลูปหลัก โค้ดจะสลับการทำงานของฟังก์ชันต่างๆ เช่น stage5_b, stage4_b, stage1_f, stage2_f, stage3_b, stage3_w, stage3_f สามารถเข้าใจการออกแบบนี้ได้ดังนี้:
- เมื่อ token หนึ่งชุดกำลังรอการสื่อสารผู้เชี่ยวชาญ การคำนวณ attention หรือการส่งผ่านถอยหลังของ token อีกชุดสามารถเริ่มต้นได้ก่อน
- เมื่อบางเฟสต้องรอสตรีมการสื่อสาร สตรีมการคำนวณสามารถดำเนินการงานของเฟสอื่นต่อไปได้
6.3 ทำไมต้องแยกเป็นห้าส่วน
⚠️ หมายเหตุ: เนื้อหาได้รับการแปลโดย AI และตรวจสอบโดยมนุษย์ หากมีข้อผิดพลาดโปรดแจ้ง
☕ สนับสนุนค่ากาแฟทีมงาน
หากคุณชอบบทความนี้ สามารถสนับสนุนเราได้ผ่าน PromptPay
本文来自网络搜集,不代表คลื่นสร้างอนาคต立场,如有侵权,联系删除。转载请注明出处:http://www.itsolotime.com/th/archives/32986
