統一FP8:混合精度を超え、安定的な高速化を実現するMoE RL訓練

TL;DR:私たちはRLにおける完全FP8サンプリングと訓練フローを実現しました。実験によると、MoEモデルにおいて、BF16訓練とFP8ロールアウトを組み合わせた場合、モデル規模が大きくなるほど訓練・推論の不整合が深刻化します。統一FP8を訓練とロールアウトに使用することで、量子化誤差による訓練・推論の不整合を効果的に解消し、RL訓練の速度と安定性を向上させることができます。

SGLang RLチームとMilesコミュニティは、RL訓練の安定性と高速化において興味深い探求を行ってきました。これにはSGLangとFSDPバックエンドの整合による厳密なゼロKL乖離の実現、およびSpeculative DecodingとオンラインSFTを組み合わせたドラフトモデルの活用が含まれます。

この基盤の上に、私たちは安定性とパフォーマンスのバランスを取る新たな進展を共有します——エンドツーエンドFP8 RL訓練とサンプリングパイプラインです。milesフレームワークは完全にサポートしており、Qwen3-4BとQwen3-30B-A3BのFP8 RL訓練(詳細はこちら)が、すぐに使用可能です。

本研究はInfiXAI Team、Ant Group AQ Team、SGLang RL Team、Miles Teamの共同作業によるものです。特にVerda Cloudには計算リソースの提供を、NVIDIAにはTransformer Engine (TE)の技術サポートを頂きました。

FP8訓練のハードウェア基盤

Tensor Coresと低精度サポート

低精度計算は、ハードウェア・ソフトウェア協調設計の結晶です。そのハードウェア基盤はTensor Coresであり、深層学習の中核計算である大規模行列の乗算累積のために設計されたGPU専用ハードウェアアクセラレータユニットです。従来のCUDAコアと比較して、Tensor Coresは低精度フォーマット(FP16、BF16、FP8など)に対してより高いスループットを提供します。その発展は基本的なFMA命令とDP4Aベクトル化から始まり、Voltaアーキテクチャで初めて専用Tensor Coresが導入され、その後Ampere、Hopper、Blackwellが継続的に改善を進めています:

  • 規模の拡張:単一操作でより大きな行列を処理し、計算対メモリ比を向上
  • 精度の低下:FP/BF16、FP8などのより低い精度フォーマットを継続的にサポート
アーキテクチャFP64F16INT8INT4FP8MXFP
Volta✅ FP16
Turing✅ FP16
Ampere✅ FP16/BF16
Hopper✅ FP16/BF16
(累積はFP22のみサポート)
Blackwell✅ FP16/BF16✅ MXFP(8/6/4)
NVFP4
Blackwell Ultra✅ (reduced FLOPs)✅ FP16/BF16✅ (reduced FLOPS)✅ MXFP(8/6/4)
NVFP4

画像出典:zartbotSemiAnalysis

この傾向により、低精度ストレージと計算がより魅力的になりました。具体的な利点には以下が含まれます:

  1. メモリ使用量の大幅削減:FP8は理論的にモデルの重みとアクティベーションメモリを半減し、VRAMの圧力を緩和
  2. 理論上2倍の計算スループット:H100 SXMでは、FP8 Tensor Coresは1979 TFLOPSに達し、BF16(989 TFLOPS)の2倍
  3. メモリ帯域幅ボトルネックの緩和:データがよりコンパクトになり、HBMから計算コアへの転送を削減

FP8フォーマット

FP8は8ビット浮動小数点フォーマットで、FP32(32ビット)やFP16/BF16(16ビット)と比較して、ストレージと転送コストを1/4または1/2に削減し、VRAMと帯域幅のボトルネックを緩和し、訓練と推論のパフォーマンスを向上させます。現在、主に2つのフォーマットがあります:

  • E4M3:4ビット指数+3ビット仮数。動的範囲は小さいが精度が高い
  • E5M2:5ビット指数+2ビット仮数。動的範囲は大きいが精度が低い

FP8 E4M3 vs E5M2

画像出典:OCP whitepaper

この設計は、ハードウェアスループットを最大化しながら、十分な数値範囲と精度を維持します。

FP8スケール選択

次元FP32 Scale(全精度スケーリング係数)E8M0 Scale(指数のみのスケーリング)
フォーマット定義FP32 (IEEE 754単精度浮動小数点)E8M0 (8ビット指数、0ビット仮数)
数値特性任意精度実数表現2の累乗のみサポート(1、2、0.5など);1.5などは表現不可
核心思想高精度でスケーリング係数を管理し、訓練の数値安定性を確保スケーリング係数を低精度に組み込み、ビット演算で効率化
主な利点1. 高精度、安定した訓練:動的範囲を正確に捕捉し、量子化誤差を削減、発散を防止
2. 幅広いサポート:NVIDIA Transformer Engineのデフォルト、エコシステムが成熟
1. 極めてハードウェアフレンドリー:スケーリングが単純なビットシフトで、高速・低消費電力
2. 統一パイプライン:全8ビット実行、ハードウェア設計を簡素化
主な欠点1. ストレージオーバーヘッド:各量子化テンソルに追加のFP32スケールが必要、VRAMを消費
2. 計算オーバーヘッド:スケール計算と変換にFP32が必要
1. 精度損失リスク:2の累乗への強制的な丸めがノイズを導入、逆伝播で累積し発散の原因に
2. 限定的な動的範囲分解能:複雑なテンソル分布への精細な適応が困難
総括業界で最も一般的で安全な方案精度を犠牲にして極限のハードウェア効率を追求

総合的な評価の結果、私たちはFP32を訓練スケール精度として選択しました。理由:

  1. 精度整合と訓練安定性:FP32スケールはテンソルの動的範囲を精細に捕捉し、FP8訓練の損失曲線をBF16ベースラインに近づけます
  2. 推論エコシステムとの一致:主流の推論モデルもFP32量子化スケールを使用
  3. 実際のハードウェメリット
    - Hopper (H100/H800):FP8 Tensor Coresをサポートするが、E8M0専用ユニットなし
    - Blackwell (B100/B200):MXFP8を導入し、E8M0類似のブロック単位スケーリングをサポート(arXiv:2506.08027

したがって、現在のHシリーズクラスタでは、E8M0を強制しても明確な高速化はなく、ソフトウェアエミュレーションのオーバーヘッドと精度リスクを導入するだけです。

FP8量子化

一般的な量子化戦略にはper-tensorper-blockper-tokenがあります。粒度に関わらず、量子化は通常2つのステップで行われます:

FP8 quantization flow

画像出典:InfiR2: A Comprehensive FP8 Training Recipe for Reasoning-Enhanced Language Models

ステップ1:スケーリング係数Sの計算

テンソル(またはブロック)の最大絶対値 max|X| を取り、FP8の最大表現可能値 V_max で除算:

S = max|X| / V_max

ステップ2:量子化値Qの計算

Sを使用して元のテンソルXの各要素xをSで除算し、丸める:

Q(x) = round(x / S)

FP8はFP16/BF16より精度が低いため、実際には安定性と効率のバランスが必要で、順伝播/逆伝播では異なる戦略と粒度が採用されることが多い:

  • Activations:通常per-token量子化。活性化関数は顕著な外れ値を含むことが多く、細粒度により外れ値の影響を局所化し、全体的な精度を保持
  • Weights:通常per-block量子化。収束後の重み分布は滑らか(ほぼガウス分布)で外れ値が少ないが、量子化誤差に敏感。ブロック状(block_size × block_sizeなど)は精度、ハードウェア最適化、効率とメモリ節約のバランスを取る
  • Gradients:通常per-token量子化