知識蒸餾是一種將更大的教師模型的知識轉移到更小的學生模型的方法,理想情況下可生成緊湊、易于部署的學生,且準確度與教師相當。知識蒸餾在預訓練設置中越來越受歡迎,但用于在監督式微調(Supervised Fine-Tuning,SFT)期間執行知識蒸餾的資源越來越少。
NVIDIA NeMo-Aligner 開源了一個在 SFT 期間使用知識蒸餾的實現,相較于標準 SFT,該實現的數據效率更高,準確性也更高 (Table 1)。
訓練目標 | 訓練步驟 | MMLU (5 次采樣) | MMLU (0 次采樣) | HumanEval (0 分) | MBPP (零射) | GSM8K (零射) | 數學 (0 分) |
表面張力損失 | 600000 | 65.3% | 56.9% | 64.6 | 71.7% | 84.2 | 30.12 |
KD = SFT 損失 | 420000 | 65.3% | 57.3% | 70.1 | 73.3% | 85.2 | 35.84 |
KD = SFT 損失 | 600000 | 65.3% | 57.6% | 72 | 73.8 | 84.8 | 36.6 |
在表 1 中,SFT 是使用數學/代碼數據集執行的。使用知識蒸餾微調的模型版本在所有數學和代碼相關基準測試中均優于基準,即使僅執行 70%的訓練步驟也是如此。
NeMo-Aligner 中的知識蒸餾?
在 SFT 期間,有許多方法可以從大型模型傳輸知識。最常見的方法是使用教師模型生成合成數據,我們稱之為 KD-SDG。然后,使用合成生成的數據微調學生模型。
還有一種開創性的方法,即訓練學生以匹配教師的輸出 logits。此方法在 Distilling the Knowledge in a Neural Network 中引入。我們將其稱為 KD-logit。
此方法利用跨類(稱為 暗知識 )的知識,生成信息更豐富的梯度信號。有關更多信息,請參閱神經網絡中的 Dark Knowledge。
在本文和 NeMo-Aligner 中,我們將重點介紹在 SFT 期間應用 KD-logit。
NeMo-Aligner 的離線 KD-logit 工作流包含以下關鍵步驟:
- 教師模型對訓練數據進行預測的預處理步驟。教師模型的 logits 添加到訓練數據中。
- 這是一個訓練步驟,其中對學生進行了訓練,使其 logits 與教師的 logits 相匹配。
只需緩存一次教師的 logits。與在訓練時動態計算教師邏輯相比,此方法具有以下優勢:
- 節省 內存: 您不必同時在 GPU 上加載教師和學生模型。
- 加快訓練速度: 您不必等待老師在訓練期間做出預測。
但是,將所有教師的 logits 保存到磁盤可能需要大量內存。為節省內存,我們僅將教師的最高 K logits 保存到磁盤,其中 K 是從業者選擇的超參數。
K 的值越大,學生可以從教師那里學習的細粒度信息越多,但內存壓力就越大。在實踐中,通常選擇 K 值在 100 左右,這比典型的詞匯量小。
將教師 logits 添加到數據集后,學生被訓練以匹配教師的 top- K logits。具體來說,知識蒸餾損失函數等于 K 學生和教師 logits 之間的前向 KL 差異:
此損失函數與 Vanilla SFT 交叉熵損失函數結合使用,以生成最終訓練目標,其中 ?控制 SFT 損失項相對于 KD 損失項的強度:
結果?
表 1 顯示,與 Vanilla SFT 相比,使用知識蒸餾目標微調模型可獲得更高的準確性和所需的訓練令牌。我們使用 基礎 Nemotron-4 15B 學生模型 和 微調的 Nemotron-4 340B 教師模型 進行實驗。
用于 SFT 的數據集是使用以下論文中描述的技術生成的組合:
- 數學數據集:OpenMathInstruct-2: 利用海量開源指令數據加速數學 AI (使用 Nemotron-4 340B,而非 Llama-3.1-405B-Instruct)
- 代碼數據集:Genetic Instruction:Scaling up Synthetic Generation of Coding Instructions for Large Language Models
數據集的數學和代碼部分均使用合成數據生成。這些實驗設置了 和
。
在相同數量的訓練步驟中,使用聯合知識蒸餾和 SFT 目標微調的模型在七個評估指標中的六個方面的表現優于 SFT 基準。特別是,我們看到 HumanEval、MBPP 和 MATH 基準測試有了顯著改進,這些基準用于衡量編碼和數學推理技能。在評估各種語言理解任務的 MMLU 上,KD 微調模型的表現至少與零樣本設置中的基準相當,并且在 5 鏡頭設置中優于基準。
KD-finetuned Nemotron-4 僅使用 70% 的訓練令牌,但在相同的六個評估指標上,其性能仍然優于 Vanilla SFT 模型。
結束語?
這些結果具有兩個重要含義。首先,我們已證明知識蒸餾(Knowledge Distillation)可用于提高微調模型的準確性。這在數據稀缺的設置中特別有用,因為需要更少的訓練令牌才能實現良好的準確性。
其次,我們已經證明 KD-logit 可以與您的 SDG 數據結合使用,以實現復合優勢。
有關如何在 NeMo-Aligner 中將知識蒸餾添加到 SFT 訓練的更多信息,請參閱使用知識蒸餾進行監督微調 (SFT)。