ทีม Tsinghua คลายปริศนา FlashAttention การฝึกฝนความแม่นยำต่ำ: ความเอนเอียงเชิงตัวเลขภายใต้ BF16 กระตุ้นการฝึกโมเดลขนาดใหญ่

สรุปสั้นๆ: ปรากฏการณ์ “ลึกลับ” ที่สร้างความสับสนให้กับชุมชนมาหลายปีในที่สุดก็ได้รับการคลี่คลาย: ในการฝึกฝนด้วยความแม่นยำต่ำเช่น BF16 FlashAttention ไม่ได้เกิดข้อผิดพลาดแบบสุ่ม แต่จะกระตุ้นความเอนเอียงเชิงตัวเลขที่มีทิศทางภายใต้เงื่อนไขเฉพาะ ความเอนเอียงนี้ถูกขยายอย่างต่อเนื่องด้วยทิศทางการอัปเดตอันดับต่ำที่เกิดขึ้นในกลไกความสนใจ ส่งผลให้สเปกตรัมนอร์มของน้ำหนักและค่าการกระตุ้นหลุดจากการควบคุม และทำให้ฟังก์ชันการสูญเสียระเบิดขึ้นอย่างกะทันหันในที่สุด งานวิจัยนี้ยังเสนอการแก้ไขเพียงเล็กน้อยที่แทบไม่ต้องปรับเปลี่ยนโมเดลเลย โดยทำการเปลี่ยนแปลงเฉพาะใน safe softmax ซึ่งการทดลองพิสูจน์แล้วว่าสามารถทำให้การฝึกฝนมีความเสถียรอย่างมีนัยสำคัญ

ภาพรวมข้อมูลงานวิจัย

ทีม Tsinghua คลายปริศนา FlashAttention การฝึกฝนความแม่นยำต่ำ: ความเอนเอียงเชิงตัวเลขภายใต้ BF16 กระตุ้นการฝึกโมเดลขนาดใหญ่

  • ชื่อเรื่อง: Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention
  • ผู้เขียน: ไห่เฉวียน ชิว, เฉวียนหมิง เหยา
  • สถาบัน: ภาควิชาวิศวกรรมอิเล็กทรอนิกส์ มหาวิทยาลัยชิงหวา
  • ส่งตีพิมพ์: ICLR 2026 Oral
  • คำสำคัญ: การฝึกฝนความแม่นยำต่ำ, BF16, FlashAttention, ความเสถียรเชิงตัวเลข, ข้อผิดพลาดการปัดเศษ, การแสดงอันดับต่ำ
  • ลิงก์งานวิจัย: https://arxiv.org/abs/2510.04212
  • ลิงก์โค้ด: https://github.com/ucker/why-low-precision-training-fails

ภูมิหลังการวิจัย: ความต้องการที่จำเป็นของการฝึกฝนความแม่นยำต่ำและความไวของกลไกความสนใจ

ความเป็นจริงของการฝึกฝนโมเดลขนาดใหญ่นั้น ความจำและปริมาณการประมวลผลเป็นตัวกำหนดทุกสิ่ง อุตสาหกรรมทั่วไปใช้ BF16/FP16 ในการฝึกฝนแบบความแม่นยำผสม และแม้กระทั่งผลักดันความแม่นยำในการคำนวณของเครือข่ายฟีดฟอร์เวิร์ด (FFN) ไปสู่ FP8 เพื่อแลกกับประสิทธิภาพการฝึกฝนที่สูงขึ้น อย่างไรก็ตาม การปฏิบัติทางวิศวกรรมก็โหดร้ายไม่แพ้กัน: ยิ่งใกล้เคียงกับ “ความแม่นยำสูงสุด” กระบวนการฝึกฝนก็ยิ่งมีแนวโน้มที่จะเกิดความไม่เสถียรที่อธิบายได้ยาก

FlashAttention ในฐานะองค์ประกอบสำคัญในการเร่งการฝึกฝนบริบทยาว ได้กลายเป็นมาตรฐานของอุตสาหกรรม ปัญหาอยู่ที่ว่าในชุมชนมีกรณีความล้มเหลวที่สามารถทำซ้ำได้แต่ยากที่จะอธิบายมานานแล้ว:
* เมื่อใช้ FlashAttention + BF16 ฝึกฝน GPT-2 โมเดลจะลู่เข้าสู่ค่าที่เหมาะสมได้ปกติในช่วงแรก แต่หลังจากฝึกฝนไปหลายพันขั้น ฟังก์ชันการสูญเสียจะระเบิดขึ้นอย่างกะทันหัน
* แม้ว่าจะสามารถ “ดับไฟ” ได้โดยการย้อนกลับไปใช้ความสนใจมาตรฐาน หรือเพิ่มความแม่นยำของการคำนวณที่สำคัญเป็น FP32 แต่ก็หมายถึงการเสียเปรียบด้านปริมาณการประมวลผลและความจำ

ปัญหาประเภทนี้ถูกรายงานมาหลายปี (issue ที่เกี่ยวข้องปรากฏซ้ำๆ ในหลายโครงการโอเพนซอร์ส) แต่ขาดกลไกที่สมบูรณ์ซึ่งสามารถ “อธิบายจากข้อผิดพลาดเชิงตัวเลขไปจนถึงการระเบิดของฟังก์ชันการสูญเสีย” ได้

ทีม Tsinghua คลายปริศนา FlashAttention การฝึกฝนความแม่นยำต่ำ: ความเอนเอียงเชิงตัวเลขภายใต้ BF16 กระตุ้นการฝึกโมเดลขนาดใหญ่

การค้นพบหลัก: ระบุสาเหตุของปัญหาที่อยู่ที่พจน์เฉพาะในการแพร่กลับของ FlashAttention

ผู้เขียนใช้วิธีการทางวิศวกรรมที่เข้มงวดและสามารถทำซ้ำได้ เพื่อจำกัดขอบเขตของปัญหาอย่างค่อยเป็นค่อยไป:

  1. ทำซ้ำความล้มเหลวอย่างเข้มงวด: ทำการฝึกฝนล่วงหน้า GPT-2 (12 ชั้น, 12 เฮด, มิติซ่อน 768, ความยาวบริบท 1024) โดยใช้ชุดข้อมูล OpenWebText โดยการบันทึกและเล่นลำดับชุดข้อมูลเดียวกันซ้ำ เพื่อขจัดความสุ่มที่เกิดจากลำดับข้อมูล
  2. ระบุชั้นและเฮดที่ผิดปกติ: ใช้ตัวชี้วัดเช่นสเปกตรัมนอร์มเพื่อจำกัดขอบเขตอย่างรวดเร็ว พบว่าความผิดปกติส่วนใหญ่มาจากโมดูลความสนใจเฉพาะของชั้นหนึ่ง และยิ่งไปกว่านั้น กระจุกตัวอยู่ที่เฮดความสนใจเพียงไม่กี่เฮด
  3. ระบุปริมาณกลางที่สำคัญ: การวิจัยพบว่า dP (เกรเดียนต์ของเมทริกซ์ความสนใจ P) ที่คำนวณเพื่อประสิทธิภาพในการแพร่กลับของ FlashAttention เป็นกุญแจสำคัญของปัญหา

งานวิจัยพบว่า: ตราบใดที่เมทริกซ์ P ที่ใช้ในการคำนวณ dP ได้มาจากวิธีที่ “มีความแม่นยำสูงกว่า หรือเทียบเท่าเชิงตัวเลขแต่เส้นทางต่างกัน” การฝึกฝนก็จะกลับมาเสถียรได้ กล่าวอีกนัยหนึ่ง สาเหตุของการล่มสลายในการฝึกฝนไม่ใช่กระบวนการฝึกฝนความแม่นยำต่ำทั้งหมด แต่เป็นสิ่งที่เฉพาะเจาะจงมาก: ข้อผิดพลาดเชิงตัวเลขของเมทริกซ์ P ภายใต้ความแม่นยำต่ำ ถูกนำเข้ามาเมื่อคำนวณ dP และปนเปื้อนเกรเดียนต์ในขั้นตอนต่อๆ ไป

คำอธิบายกลไกที่หนึ่ง: โครงสร้างอันดับต่ำที่คล้ายคลึงกันทำให้ข้อผิดพลาดกลายเป็น “แรงผลักดันที่ต่อเนื่อง”

หลังจากระบุที่ dP แล้ว ปัญหาสำคัญกลายเป็น: ทำไมข้อผิดพลาดเชิงตัวเลขที่ดูเหมือนเล็กน้อย จึงสามารถถูกขยายจนถึงระดับหายนะได้ในการฝึกฝน?

งานวิจัยเขียนความแตกต่างของเกรเดียนต์ภายใต้ความแม่นยำสูงและต่ำในรูปแบบที่เข้าใจง่าย: ข้อผิดพลาดของเกรเดียนต์เป็นสัดส่วนกับ (P_lp - P_hp) และถูกปรับด้วยพจน์บางอย่างในกลไกความสนใจ หลังจากแยกย่อยเพิ่มเติม การอัปเดตข้อผิดพลาดสามารถประมาณได้ว่าเป็นการซ้อนทับของพจน์อันดับ-1 หลายพจน์ สิ่งที่สำคัญกว่านั้นคือ ผู้เขียนสังเกตเห็นในการทดลองจริงว่า:
* ภายใต้โทเค็นและขั้นการฝึกฝนที่ต่างกัน โครงสร้างเมทริกซ์ที่เกี่ยวข้องแสดงความคล้ายคลึงกันสูง ซึ่งสามารถสรุปเป็นทิศทางอันดับต่ำร่วม R
* หากสัมประสิทธิ์ของ (P_lp - P_hp) มีความเอนเอียงในเชิงสถิติ (แทนที่จะผันผวนรอบศูนย์แบบสมมาตร) ข้อผิดพลาดก็จะไม่หักล้างกัน แต่จะสะสมอย่างต่อเนื่องไปตามทิศทาง R

ผลลัพธ์สุดท้ายคือ: การอัปเดตน้ำหนักถูก “ทำให้เบี่ยงเบน” สเปกตรัมนอร์มและค่าการกระตุ้นเติบโตผิดปกติ และในที่สุดก็ผลักดันการฝึกฝนไปสู่การระเบิดของฟังก์ชันการสูญเสีย

ทีม Tsinghua คลายปริศนา FlashAttention การฝึกฝนความแม่นยำต่ำ: ความเอนเอียงเชิงตัวเลขภายใต้ BF16 กระตุ้นการฝึกโมเดลขนาดใหญ่
ทีม Tsinghua คลายปริศนา FlashAttention การฝึกฝนความแม่นยำต่ำ: ความเอนเอียงเชิงตัวเลขภายใต้ BF16 กระตุ้นการฝึกโมเดลขนาดใหญ่
(รูปที่ 4/5 ของงานวิจัย: แผนภาพความคล้ายคลึงของโครงสร้างอันดับต่ำและการสะสมความเอนเอียง)

คำอธิบายกลไกที่สอง: ที่มาของความเอนเอียง – “ตัวกระตุ้นแบบไม่ต่อเนื่อง” ใน safe softmax และข้อผิดพลาดการปัดเศษ BF16

ห่วงโซ่เหตุผลที่สองขัดต่อสัญชาตญาณมากกว่า แต่ก็สำคัญกว่า: ทำไม (P_lp - P_hp) ถึงเอนเอียงไปในทิศทางเดียวกัน?

ผู้เขียนติดตามปัญหาย้อนกลับไปยังผลลัพธ์ที่ไม่ได้ปรับมาตรฐานในการแพร่ไปข้างหน้าของ FlashAttention:
* P_bar = exp(S - m) (รูปแบบทั่วไปของ safe softmax)
* P = P_bar / rowsum(P_bar)
* O = P @ V

ข้อสังเกตสำคัญของงานวิจัยคือ: P_bar จะเกิดความเอนเอียงเชิงระบบภายใต้ความแม่นยำ BF16 และเงื่อนไขการกระตุ้นความเอนเอียงนี้มีความเฉพาะเจาะจงมาก:

เงื่อนไขการกระตุ้น: เมื่อมีค่าสูงสุดหลายค่าที่เหมือนกันในแถวใดแถวหนึ่งของเมทริกซ์คะแนนความสนใจ S ตำแหน่งที่สอดคล้องกันใน P_bar จะปรากฏค่า 1 ที่แม่นยำ (เท่ากับ 1 อย่างแม่นยำในการแสดงเลขทศนิยม ไม่ใช่ประมาณ 1)

นี่ดูเหมือนเป็นรายละเอียด แต่จะผลักดันการคำนวณ O = P @ V ในขั้นตอนต่อไปให้เข้าสู่ช่วงอันตราย

แหล่งที่มาของความเอนเอียง: เมื่อ P_bar[t, j] = 1 และเมทริกซ์ค่า V ในบางมิติมีค่าลบเป็นหลัก การบวก BF16 จะมีแนวโน้ม “ยิ่งบวกยิ่งติดลบ” อย่างเป็นระบบ

ในบางมิติคุณลักษณะ การกระจายตัวของ V[:, i] อาจมีค่าลบเป็นหลัก ในกรณีนี้ หาก P_bar[t, j] = 1 พจน์ผลคูณก็คือ V[t, i] เอง (เลข BF16 ที่เป็นลบ) จำนวนลบหลายจำนวนในการบวกและการปัดเศษของ BF16 มีแนวโน้มที่จะกระตุ้นการล้นของแมนทิสซา การเลื่อนขวา และพฤติกรรมการปัดเศษที่เกี่ยวข้องกับ sticky bit ส่งผลให้การมีส่วนร่วมของข้อผิดพลาดไม่สมมาตร ซึ่งแสดงให้เห็นชัดเจนว่า:
* O_lp มีแนวโน้มที่จะ “เอนเอียงไปทางลบ” เมื่อเทียบกับ O_hp
* หากเกรเดียนต์ต้นทาง dO ในมิติที่สอดคล้องกันก็มีแนวโน้มที่จะเป็นลบเช่นกัน เมื่อคำนวณ dP ก็จะเกิดพจน์ข้อผิดพลาดที่เอนเอียงไปทางบวก
* dP ที่เอนเอียงไปทางบวกนี้ ไปขับเคลื่อนทิศทางอันดับต่ำที่คล้ายคลึงกัน R ที่กล่าวถึงก่อนหน้า จึงก่อให้เกิดวงจรอุบาทว์ที่ “ยิ่งฝึกยิ่งเบี่ยงเบน”

ทีม Tsinghua คลายปริศนา FlashAttention การฝึกฝนความแม่นยำต่ำ: ความเอนเอียงเชิงตัวเลขภายใต้ BF16 กระตุ้นการฝึกโมเดลขนาดใหญ่
(รูปที่ 6 ของงานวิจัย: เมื่อ P_bar = 1 ปรากฏ ข้อผิดพลาดของ O เกิด “การกระโดดเชิงลบ” ที่ชัดเจน)

แผนการแก้ไขที่เรียบง่ายที่สุด: รับประกันว่า P_bar มีค่าน้อยกว่า 1 อย่างเคร่งครัดเสมอ

เนื่องจากตัวกระตุ้นแบบไม่ต่อเนื่องของปัญหาคือการปรากฏของค่า 1 ที่แม่นยำใน P_bar แนวคิดการแก้ไขที่ผู้เขียนเสนอจึงตรงไปตรงมา:
* ตรวจสอบว่าค่าสูงสุดในแถว S ปรากฏมากกว่าหนึ่งครั้งหรือไม่
* ทันทีที่พบ “ค่าสูงสุดซ้ำ” ให้ปรับค่าคงที่การเลื่อนแถว m ของ safe softmax แบบไดนามิก เพื่อให้ผลการคำนวณเลขชี้กำลังที่ตำแหน่งสูงสุดสอดคล้องกันน้อยกว่า 1 อย่างเคร่งครัด

งานวิจัยให้การนำเสนอเชิงแนวคิดดังนี้:
python
rm = rowmax(S)
rs = rowsum(S == rm) # จำนวนครั้งที่ค่าสูงสุดปรากฏ
if rs > 1 and rm > 0:
m = β * rm # β > 1
elif rs > 1 and rm < 0:
m = 0
else:
m = rm
Pbar = exp(S - m) # ดังนั้น max(Pbar) < 1

(โปรดติดตามตอนต่อไป ซึ่งจะวิเคราะห์การทดสอบยืนยัน ผลของแผนการแก้ไข และผลกระทบที่กว้างขวางยิ่งขึ้น)

ขั้นตอนนี้จะไม่เปลี่ยนผลลัพธ์ความสนใจภายใต้เลขคณิตที่แม่นยำ (เพราะ softmax ไม่ไวต่อ “การลบค่าคงที่ออกจากทั้งแถว”) แต่ในการคำนวณความแม่นยำจำกัด มันสามารถหลีกเลี่ยง ทีม Tsinghua คลายปริศนา FlashAttention การฝึกฝนความแม่นยำต่ำ: ความเอนเอียงเชิงตัวเลขภายใต้ BF16 กระตุ้นการฝึกโมเดลขนาดใหญ่ ที่จะกระตุ้นการปัดเศษที่มีอคติในการบวกสะสม BF16 ในขั้นตอนต่อไป จึงตัดสายโซ่การแพร่กระจายข้อผิดพลาดจากต้นตอ

ผลการทดลอง: การฝึกฝนที่เสถียรไม่ “ล่มสลายกะทันหัน” อีกต่อไป

งานวิจัยตรวจสอบการวิเคราะห์และแผนการแก้ไขข้างต้นภายใต้การตั้งค่าความแม่นยำ BF16:
* GPT-2S: ใช้ FlashAttention ที่แก้ไขแล้ว สามารถฝึกฝนอย่างเสถียรจนถึง 600K ขั้น ภายใต้ตัวเพิ่มประสิทธิภาพทั้งสองชนิดคือ AdamW และ Muon
* GPT-2M: สามารถฝึกฝนอย่างเสถียรภายใต้ตัวเพิ่มประสิทธิภาพ AdamW เช่นกัน (งานวิจัยแสดงจนถึง 100K ขั้น)
* ความสม่ำเสมอของฮาร์ดแวร์: ปรากฏการณ์และข้อสรุปนี้มีความสอดคล้องกันบนแพลตฟอร์มฮาร์ดแวร์หลายชนิด (รวมถึง NVIDIA A100, RTX 4090 และ Huawei Ascend 910B)

ทีม Tsinghua คลายปริศนา FlashAttention การฝึกฝนความแม่นยำต่ำ: ความเอนเอียงเชิงตัวเลขภายใต้ BF16 กระตุ้นการฝึกโมเดลขนาดใหญ่
เปรียบเทียบเส้นโค้งการสูญเสียจากชุดตรวจสอบ (รูปที่ 7 ของงานวิจัย)

ข้อคิดหลัก: ข้อผิดพลาดความแม่นยำต่ำไม่ใช่ “สัญญาณรบกวนที่มีค่าเฉลี่ยเป็นศูนย์”

คุณค่าของงานวิจัยนี้ไม่เพียงอยู่ที่การแก้ไขปัญหาที่เฉพาะเจาะจง แต่ยังอยู่ที่การให้รูปแบบการวินิจฉัยเชิงตัวเลขที่สามารถถ่ายทอดได้:
* ความเอนเอียงเชิงระบบของข้อผิดพลาด: ข้อผิดพล


⚠️ หมายเหตุ: เนื้อหาได้รับการแปลโดย AI และตรวจสอบโดยมนุษย์ หากมีข้อผิดพลาดโปรดแจ้ง

本文来自网络搜集,不代表คลื่นสร้างอนาคต立场,如有侵权,联系删除。转载请注明出处:https://www.itsolotime.com/th/archives/23923

Like (0)
Previous 3 hours ago
Next 3 hours ago

相关推荐