當今的大型語言模型(LLM)基于 2017 年推出的 Transformer 模型架構。自那時以來,AI 計算性能的快速進步使創建更大的基于 Transformer 的 LLM 成為可能,這極大地提高了其功能。基于 Transformer 的高級 LLM 正在為許多令人興奮的應用提供支持,如智能聊天機器人、計算機代碼生成和甚至 芯片設計 等。
訓練先進的 LLM 需要高效且通用的軟件堆棧。為此,NVIDIA NeMo 提供了一個端到端平臺,用于構建、自定義和部署 LLM。Integrated 深入集成到 NeMo 框架中的是 Megatron-Core,一個基于 PyTorch 的庫,提供了大規模訓練 LLM 所需的基本組件和優化。隨著模型開發者探索新的模型架構,NVIDIA 平臺不斷擴展,以實現創新。
今天,NVIDIA 宣布 NeMo 和 Megatron-Core 現在分別支持狀態空間模型(SSM)的預訓練和微調。另外,NeMo 現在還支持基于 Google DeepMind 所述的 Griffin 架構的模型訓練。
為何要探索替代模型架構?
Transformer 模型擅長捕捉遠程依賴項,因為它們通過著名的注意力機制實現了全局上下文理解,從而非常適合處理需要這種理解的任務。
但是,注意力的計算復雜性會隨序列長度二次擴展,從而導致訓練時間和訓練成本的大幅增加。隨著序列長度的增加,這些成本也會增加。此外,在推理期間,注意力需要存儲鍵值對緩存(稱為 KV 緩存),這些緩存會隨著序列長度線性增長。這會導致內存占用隨著序列長度的增加而增加。
最近,SSM 模型架構克服了一些注意力限制,成為處理序列建模任務的極具吸引力的模型架構。
SSM 可以實現長序列長度訓練的更高效率
SSM 是一類模型,在深度學習社區中越來越受歡迎,作為基于注意力的 Transformer 模型的高效替代方案,適用于序列建模任務。
SSM 具有以下令人信服的特性:
- 線性復雜性:SSM 在計算和內存復雜性方面都是線性的,而注意力在這兩方面都是二次的。這意味著,相比注意力,SSM 可以更高效地對序列中的長程依賴項建模。
- 高質量和高精度:與注意力一樣,SSM 查看輸入序列的標記,使模型能夠專注于最相關的部分,從而實現與基于 Transformer 的模型相當的質量和準確性。
- 高效推理:SSM 只需存儲常量大小的向量,而不是 KV 緩存,這使得推理變得更加內存高效,特別是在序列長度較長的情況下。
為說明 SSM 為更長序列長度提供的優勢,下圖顯示了 Mamba-2 層 (稍后將在本文中介紹的狀態空間模型變體) 與訓練 Transformer 層相比的相對速度隨著序列長度的增加。隨著序列長度增加到 256K,Mamba-2 層的速度比 Transformer 層快 18 倍。

Transformer 模型維度為 4096,擁有 32 個頭。Mamba-2 模型維度為 4096,狀態維度為 128,分為 8 組
一些 SSM 變體在 AI 社區中已經流行起來,其中包括 Hyena、Mamba-1 和最近推出的 Mamba-2。
結構化狀態空間對偶性和 Mamba-2
Mamba-2 作為最新版本脫穎而出,在多個基準測試中實現了非常高的準確性。其核心是一個新的結構化狀態空間對偶(SSD)層。在實踐中,這個 SSD 層重新表述了 Mamba-1 模型中使用的 SSM 數學運算。這種重新表述將 SSM 計算重構為矩陣乘法,從而使它們能夠利用 NVIDIA Tensor Core 的重要矩陣乘法性能。
因此,與 Mamba-1 相比,Mamba-2 的訓練速度要快得多。另外,在語言建模任務中,Mamba-2 還提供了與 Transformer 相比具有競爭力的質量和準確性。當在混合模型中將幾個注意力層與 SSD 層相結合時,Mamba-2 可以產生更好的結果。
但是,純粹的 SSM 并非沒有限制。例如,它們在“干草堆”類型場景中很困難,這些場景需要在非常長的序列中精確地調用信息。
混合模型可以改善結果,同時提高性能
結合 SSM、SSD、RNN 和 Transformer 的混合模型可以充分發揮每個模型架構的優勢,同時減輕其各自的弱點。
在最近的一篇論文中,包括 NVIDIA 應用深度學習研究 (ADLR) 團隊成員在內的研究人員描述了混合 Mamba-Transformer 模型。在這些混合模型中,標準 Transformer 層和新型 SSM 層可以在任意配置中交織在一起。例如,本文中描述的 8B 混合模型有 56 層,其中包括 4 層自注意力層、24 層 Mamba-2 層和 28 層多層感知器 (MLP) 層。這些層的分配是這樣的:首先出現 Mamba-2 層,然后是注意力層,最后 MLP 層均勻分布在整個模型中。
根據該論文,該團隊評估的混合 8B Mamba-2-Hybrid 模型在所有 12 項標準任務中都超過了 8B Transformer。此外,該模型還“預計在推理時生成令牌的速度將提高 8 倍”。

除了在推理過程中提高執行任務的能力和顯著的性能優勢外,Mamba-2-Hybrid 模型還顯示了更高的計算效率。在序列長度增加時,訓練 8B Mamba-2-Hybrid 模型和訓練 8B Transformer 模型所需的計算量進行了比較,結果如下圖所示。

當序列長度為 2048 個令牌時,兩者的計算需求大致相同,其中混合模型略有優勢。然而,當序列長度擴展到多達 32768 個令牌時,8B Transformer 模型的計算需求將翻倍,而混合模型僅增長 13%。由于現代語言模型支持 1M 令牌及以上的序列長度,因此 SSM-Transformer 混合模型的這種優勢只會增長。
支持新型模型架構的第一步
模型架構創新對于實現更高水平的智能至關重要。除了為構建基于 Transformer 的模型提供出色支持外,NeMo 和 Megatron-Core 現在還為社區提供了訓練 Self-Supervised Models(SSMs)和 Single-Stage Detectors(SSDs)的能力,以及將其優勢與 Transformer 模型優勢相結合的混合模型的能力。
在此版本的 NeMo 中,我們提供了以下初始功能,以便社區能夠快速開始試驗:
- 支持 SSD 模型,包括Mamba-2。
- 支持 RG-LRU (Griffin 架構)
- 支持 Transformer/SSM 混合模型的組合。
- 微調支持 Recurrent Gemma(Griffin)模型、純 Mamba-2 模型和 8B Mamba-2-Hybrid 模型。
- 分片和模型并行支持。
在即將發布的版本中,我們計劃支持其他子二次模型架構,以及其他性能優化和 FP8 訓練的支持。
?