TL;DR
- 日本語モデルでDPOが動かない問題に遭遇
- 英語版は1段階で成功、日本語版は全滅
- 2段階アプローチ(SFT → DPO)で完全解決
- 成功率100%、マージン21.1%改善を達成
問題の始まり
MacBookでLLM強化学習を試そうと思い、DPO (Direct Preference Optimization) を実装した。英語版(distilgpt2)は完璧に動作:
- Chosen logp: +4.07 ✅
- Rejected logp: -1.89 ✅
失敗の連続
v1: 標準設定
- データ: 12ペア
- 学習率: 5e-5
- エポック: 5
- LoRAランク: 16
結果: Chosen変化: -0.0041 ❌ Rejected変化: -0.0893 ✅ → 学習不足
v2: 強い設定
- 学習率: 1e-4 (2倍)
- エポック: 10 (2倍)
- LoRAランク: 32 (2倍)
結果: Chosen変化: -1.2946 ❌ Rejected変化: -5.2282 ✅ → 過学習
v3: データ拡張
- データ: 50ペア (4倍)
結果: Chosen変化: -0.4492 ❌ Rejected変化: -2.3007 ✅ → 部分的改善
v4: モデル変更
パターン:Rejectedは減るが、Chosenが増えない- モデル: cyberagent/open-calm-small
結果: Chosen変化: -12.6191 ❌ Rejected変化: -39.9840 ✅ → 激しい過学習
原因の仮説
1段階DPOの問題点:
- モデルがまだ「良い応答」を生成できない
- DPOは「既存の能力」の上で選好を学習
- 能力がないのに選好だけ学習 → 失敗
解決策:2段階アプローチ
Stage 1: SFT (Supervised Fine-Tuning)
目的: Chosenレスポンスを生成する能力を獲得Chosenデータのみで教師あり学習
trainer = SFTLoRATrainer( model_name="rinna/japanese-gpt2-medium", lora_r=16, lora_alpha=32 )
trainer.train( preference_data=data, # Chosenのみ使用 epochs=3, batch_size=2, learning_rate=5e-5 )
結果:
- 損失: 2.9112 → 2.8333
- モデルが良い応答を生成できるようになる
Stage 2: DPO
目的: 獲得した能力の上で選好学習SFT済みモデルをベースにDPO
trainer = SFTDPOLoRATrainer( model_name="rinna/japanese-gpt2-medium", sft_adapter_path="./sft_lora_adapter_ja", # Stage 1の成果 lora_r=16, lora_alpha=32 )
trainer.train( preference_data=data, # Chosen + Rejected epochs=3, batch_size=2, learning_rate=5e-5, beta=0.1 )
結果:
- 損失: 0.6994 → 0.6693
- 選好学習が成功
検証結果
基本指標
| 指標 | 変化量 | 評価 | |-----|--------|------| | Chosen logp | +0.1355 | ✅ 増加 | | Rejected logp | -0.1514 | ✅ 減少 |
DPOの核心:マージン検証
重要: DPOの本質は「ChosenとRejectedの差を広げること」マージン = Chosen logp - Rejected logp
結果:
- ベースモデル平均マージン: 1.4958
- SFT+DPOモデル平均マージン: 1.8114
- 改善: +0.3156 (+21.1%)
ペアごとの成功率
全10ペアで検証:1. 天気: +0.2894 ✅ 2. Python: +0.0299 ✅ 3. レストラン: +0.5170 ✅ 4. ありがとう: +0.2566 ✅ 5. 機械学習: +0.3415 ✅ 6. 自己紹介: +0.4129 ✅ 7. プログラミング: +0.1639 ✅ 8. 好きな本: +0.3800 ✅ 9. 量子力学: +0.4197 ✅ 10. 愛: +0.3455 ✅
成功率: 10/10 (100%)全アプローチ比較
| アプローチ | Chosen変化 | Rejected変化 | マージン改善 | 状態 | |-----------|-----------|-------------|------------|------| | v1 (標準) | -0.0041 ❌ | -0.0893 ✅ | - | 学習不足 | | v2 (強化) | -1.2946 ❌ | -5.2282 ✅ | - | 過学習 | | v3 (データ拡張) | -0.4492 ❌ | -2.3007 ✅ | - | 部分的 | | v4 (open-calm) | -12.6191 ❌ | -39.9840 ✅ | - | 激しい過学習 | | 2段階 | +0.1355 ✅ | -0.1514 ✅ | +21.1% | 完璧 | | 英語版 | +4.07 ✅ | -1.89 ✅ | - | 完璧 |
なぜ2段階が必要だったのか
英語版が1段階で成功した理由
- distilgpt2は大規模な英語データで事前訓練済み
- 既に「良い応答」の生成能力を持っている
- DPOで選好学習するだけで十分
日本語版で失敗した理由
- rinnaの事前訓練データは異なる
- 選好データの「良い応答」を生成する能力が不足
- 能力がないのに選好だけ学習 → 失敗
2段階アプローチの利点
1. Stage 1 (SFT): Chosenデータで能力を獲得 2. Stage 2 (DPO): 獲得した能力の上で選好学習 3. 能力構築 → 選好学習の順序が正しい実装のポイント
1. SFTトレーナー
class SFTLoRATrainer: def train(self, preference_data, ...): for item in batch: # Chosenのみ使用 full_text = item["prompt"] + item["chosen"]
# 次トークン予測の損失 loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) )
2. DPOトレーナー(SFTベース)
class SFTDPOLoRATrainer:
def __init__(self, sft_adapter_path, ...):
# SFTアダプターをマージしてからDPO用LoRA追加
sft_model = PeftModel.from_pretrained(base, sft_adapter_path)
merged_sft = sft_model.merge_and_unload()
self.policy_model = get_peft_model(merged_sft, dpo_lora_config)
3. 検証スクリプト
SFT+DPOモデルの正しい読み込み
sft_temp = PeftModel.from_pretrained(base, "./sft_lora_adapter_ja")
merged_sft = sft_temp.merge_and_unload() # 重要!
sft_dpo_model = PeftModel.from_pretrained(
merged_sft,
"./sft_dpo_lora_adapter_ja"
)
学んだこと
1. 理論的理解の重要性
- DPOは「既存能力」の上で選好を学習
- 能力がなければ選好学習は無意味
- 基礎能力の構築が先
2. 言語・モデルによる違い
- 英語版で成功 ≠ 日本語版で成功
- 事前訓練の違いを考慮する必要
- ドメイン適応が重要
3. 検証方法の重要性
- Chosen/Rejected個別の変化だけでは不十分
- マージン(差)の検証が本質
- 100%成功率で初めて確信
4. 2段階アプローチは標準
- 世間的にも広く使われている手法
- InstructGPT、Claude、Llama2なども同様
- SFT → RLHF/DPO が王道
結論
「時間がかかっても良いので、本質的なやり方で根本的解決を目指したい」 この目標は完全に達成された:✅ 理論的に正しいアプローチ(SFT → DPO) ✅ 数学的に検証可能(マージン+21.1%改善) ✅ 実用的な成功率(100%) ✅ 再現可能な実装
日本語LLMでDPOを成功させるには、2段階アプローチが必須だった。技術スタック
- モデル: rinna/japanese-gpt2-medium (336M)
- フレームワーク: PyTorch + Transformers + PEFT
- LoRA: rank=16, alpha=32
- データ: 52ペアの日本語選好データ
- デバイス: MacBook (Apple Silicon MPS)
コード
全コードはGitHubで公開予定:
sft_lora_trainer_ja.py: SFTトレーナーsft_dpo_lora_trainer_ja.py: DPOトレーナーverify_sft_dpo_ja.py: 検証スクリプトverify_dpo_detailed.py: マージン詳細検証
