GPU 上的 GEMM 優化是一個模塊化問題。高性能實現需要指定超參數,例如圖塊形狀、數學和復制指令以及線程束專用方案。這些超參數在很大程度上彼此獨立;此外,最佳選擇可能會因硬件、問題形狀或其他用戶需求而有顯著差異。
通過重新設計 3.x,CUTLASS 旨在通過可組合、正交構建塊的分層系統最大限度地覆蓋 GEMM 實現空間,同時提高代碼可讀性,并將支持擴展到后續的 NVIDIA 架構 (如 Hopper 和 Blackwell) 。由于這種設計理念與 GPU 的分層硬件設計相關聯,因此對于其他 GPU 應用程序也是一個不錯的選擇,例如,FlashAttention-3 在其設計中使用熟悉的 CUTLASS 抽象概念。
在 CUTLASS 博客系列的第二篇博文中,我們將探討 CUTLASS 3.x 中 GEMM 分層系統背后的設計原則,并解壓 CUTLASS 如何從第 1 部分中介紹的低級 CuTe 抽象中構建 GEMM 內核。
CUTLASS 3.x 中的新概念 GEMM 層次結構
CUTLASS 3.x 開發了一個獨立于特定硬件功能的概念 GEMM 層次結構。它分為五個層:

- 原子層:特定于架構的指令和相關的元信息
cute::Mma_Atom<>
和cute::Copy_Atom<>
- Tiled MMA/ Copy:空間微核,支持架構特定原子的任意交錯和平鋪
cute::TiledMma<>
和cute::TiledCopy<>
- 集合層:時間微核函數,使用架構特定的同步來編排執行一個或多個空間微核函數,以計算單個輸出圖塊
cutlass::gemm::collective::CollectiveMma<>
、cutlass::epilogue::collective::CollectiveEpilogue<>
- 內核層:用于在線程塊/ 集群網格上執行內核的設備代碼
cutlass::gemm::kernel::GemmUniversal<>
- 設備層:主機側設置和接口
cutlass::gemm::device::GemmUniversalAdapter<>
每個層都用作前一層抽象的合成點,可以使用模板參數進行高度定制。用戶可以堅持使用最高層,信任 CUTLASS 的編譯時邏輯來提供高性能 GEMM 實現,也可以選擇使用較低級別的層次結構所帶來的高級修改。Atom 和 Tiled MMA/ Copy 層提供的空間微核是 CuTe 的領域,我們將在第 1 部分討論這些微核。本文的其余部分將介紹高層中提供的 GEMM 的時間級和內核級組織。
以下是如何在 CUTLASS 3.x 中定義 GEMM 內核的基本示例:
// Step 1: Generate the required collective layer mainloop specialization using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, TilesShape, ClusterShape, cutlass::gemm::collective::StageCountAuto, cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; // Step 2: Specify the collective layer epilogue type using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< cutlass::gemm::TagToStrideC_t<LayoutC>, cutlass::gemm::TagToStrideC_t<LayoutC>, cutlass::epilogue:: thread ::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>>; // Step 3: Compose the mainloop and epilogue together at the kernel layer using GemmKernel = cutlass::gemm::kernel::GemmUniversal< cute::Shape< int , int , int , int >, // ProblemShape [M,N,K,L] CollectiveMainloop, CollectiveEpilogue >; // Step 4: Wrap up the kernel::GemmUniversal kernel class // with the device adapter to obtain a host-side handle to the kernel using GemmHandle = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; |
集合層:Mainloop
集合是一組線程,它們相互協作以執行工作,并且可以并行重復以形成整個內核。通常,這是線程塊或集群。TiledMMA 和 TiledCopy 對象用于描述并行工作進程對計算和復制工作的空間分配 (例如,線程束、線程組,甚至 Blackwell MMA 的線程塊) ,而集合層則負責以時間方式組織此工作,方法是設置工作流和線程束專用方案,以及使用硬件加速同步基元來管理工作流和集合主回路的定義如下:
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< DispatchPolicy, TileShape, ElementA, // dtype, e.g. float StrideA, // e.g. Stride<_1, int> for M-major ElementB, StrideB, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, TransformA, GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, TransformB >; |
集合主回路是低層抽象的合成點:TiledMma、每個操作數的 GMEM 到 SMEM 加載的 TiledCopy,以及用于 SMEM 到 RMEM 加載的可選復制原子,用于寄存器來源的 MMA。這些抽象在很大程度上是正交的,允許將不同的 MMA 操作與不同的復制操作結合起來,同時更大限度地重復使用代碼。
可以說,最重要的部分是調度策略,該策略定義了特定算法或 GPU 架構的主循環專用化。例如,調度策略 MainloopSm90TmaGmmaWarpSpecialized
將 CollectiveMma 專門用于 Hopper TMA 線程束專用實現。它本身就是一個模板,可以針對工作流階段、集群形狀和內核調度選擇 (例如針對 Hopper GEMM 內核的 pingpong 或協同調度) 進行參數化。
您可以在 GEMM 集合文件夾中找到專門的集合主循環實現示例。
集合構建器
CollectiveMma 具有各種調優旋鈕,允許用戶根據 TiledCopy 和 TiledMma 對象精確指定 GEMM 主回路,但伴隨這種靈活性,復雜性也隨之增加。通常,用戶希望從有關流水線、硬件功能和資源可用性的高階考慮因素中推斷出這些對象和相關的 SMEM 布局。CUTLASS 還可以使用 CollectiveBuilder 接口執行此推理。使用 CollectiveBuilder 的主循環聲明如下所示:
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, // e.g. cute::arch::Sm90 for Hopper OpClass, // e.g. cute::arch::OpClassTensorOp for Tensor Cores ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape, StageCount, // e.g. cutlass::gemm::collective::StageCountAuto KernelSchedule // e.g. cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; |
模板參數從用戶友好型標準中選擇,并使用它們將較低級別的參數推導至 CollectiveMma 模板:
- 架構專業領域:GPU 架構和 MMA 運算符類型 (例如 SIMT 或 Tensor Core) 。
- 操作數和累加器信息:操作數和累加器的數據類型,以及全局內存中操作數 (例如,行或列主) 的對齊和編譯時布局信息。
- 圖塊形狀:用于推理 TiledMma 和 TiledCopy 對象以及 SMEM 布局。
- 調度信息:集群形狀、工作流階段計數和內核調度均由調度算法使用。對于階段計數和內核調度參數,有默認的“Auto” (自動) 選項,這些選項指示 CUTLASS 嘗試為給定的架構和參數自動選擇最佳選項。
集合層:Epilogue
集合 epilogue 是集合 API 的另一端。它負責在每次主循環迭代后對工作圖的后處理和輸出存儲進行時間編排。與主循環一樣,這意味著集合結語是復制運算 (輸出存儲) 和一些數學運算 (通常是元素級運算,但可能也包括歸約) 的合成點。與主循環不同,這些數學運算本身通過結語訪客樹 (EVT) 形式高度可組合。這對于 AI 工作負載尤其有用,因為這些工作負載通常需要在 GEMM 之后立即計算激活函數。CUTLASS 的集合結語負責將此激活函數融合到內核中,從而消除不必要的數據移動。
CUTLASS 在 GitHub 上定義了幾個結語。模板參數在不同實現之間存在顯著差異,但通常包括以下信息:
- 矩陣 C 和 D 的數據類型和編譯時布局信息。
- 指定任何其他后處理的融合運算。
- GMEM 商店和任何 SMEM 暫存的平鋪復制操作。
- 與集合主循環一樣,調度策略包含有關集群大小、TMA 使用、線程束專門化等的信息。
適用于結語的 CollectiveBuilder 提供了一個更統一、更高級的界面:
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OpClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC, GmemLayoutTagC, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD, EpilogueScheduleType, FusionOpOrCallbacks >::CollectiveOp; |
其中許多參數在主循環構建器中很常見,但也有一些是新參數:
- 結語可以將 CTA 圖塊劃分為更小的圖塊,以實現更好的數學拷貝重疊。
- 累加器 (主回路的輸出) 現在是結語的輸入。結語計算可在不同的中間數據類型 (由
ElementCompute
給出) 中進行。 - CUTLASS 提供多種常見的融合運算,例如
D = activation(alpha * AB + beta * C)
。用戶還可以使用 Epilogue Visitor Trees 構建定制的融合操作。有關結語訪客樹的更多信息,請參閱此 Colfax 教程。 - 結語調度類型定義了 TMA 和線程束專用化的用法。默認的
EpilogueScheduleAuto
指示 CUTLASS 嘗試推斷出最佳選項。
要了解這兩個集合構建器的實際應用,我們參考了用于 Hopper 的 CUTLASS 示例 49 和用于 Blackwell 的示例 71。
內核層
集合層完全定義了核函數執行期間集合所完成的計算。內核層的作用是將集合擴展到涵蓋整個動態大小問題空間的線程塊或集群網格上。內核層通過將加載、存儲、MMA 等的基本程序拼接在一起,將集合主回路和集合結語組合到設備內核中。
內核層的入口點 API 是 cutlass::gemm::kernel::GemmUniversal 類,這是一種無狀態通用設備內核,可將 GEMM 實現為集合主回路和集合結語的合成。無狀態意味著調用者通過向內核傳入參數來管理內核的狀態。通用意味著 GemmUniversal
是 2.x 和 3.x GEMM 內核的入口點。對于 3.x API,GemmUniversal
的基本用法如下所示:
using GemmKernel = cutlass::gemm::kernel::GemmUniversal< ProblemShape, // e.g. Shape<int, int, int> for a fully generic GEMM CollectiveMainloop, CollectiveEpilogue >; |
與 TiledMma
和 tg_ 21 一樣,tg_ 22 和 tg_ 23 是通過 tg_ 24 合成的正交抽象。第一個模板參數,即問題形狀,主要用于在普通 GEMM (具有 rank-3 問題形狀) 和批量 GEMM (具有 rank-4 問題形狀) 之間進行選擇,但如果需要,也可以靜態地限制某些問題維度。
GemmUniversal
的實例化可以在 tg_ 26 形式的文件中找到,其中 tg_ 27 主要基于集合主循環的 tg_ 28 參數進行調度。所有實例化均提供一致的接口:
- 用于向內核傳遞參數的接口,包括問題形狀、硬件信息、張量的指針和布局,以及結語參數。
- 靜態初始化功能用于獲取網格和塊維度,檢查內核是否可在硬件上實現,并為結語或圖塊調度程序所需的任何歸約操作或全局屏障設置全局內存工作空間。
- 最重要的是,它們將核函數邏輯實現為
operator()
。這是一個設備函數,雖然內核層包含內核執行的所有邏輯,但尚未顯示從主機啟動的方法。
例如,此處定義了 Blackwell 的 TMA 線程束專用內核。
圖塊調度
內核層也是用于指定圖塊調度程序的合成點。正如內核調度程序定義集合內工作的時間安排一樣,圖塊調度程序定義集合內工作的順序和分布。對于最基本的圖塊調度程序,每個輸出圖塊分配一個 CTA。CUTLASS 3.x 為 Hopper 實現了兩個額外的圖塊調度程序:一個是持久性調度程序,可為每個 SM 啟動一個 CTA,并讓每個 CTA (可能) 在其生命周期內計算多個輸出圖塊;另一個是 Stream-K 調度程序,它也是持久性的,但會沿 K 模式額外劃分一些輸出圖塊工作,以實現更好的負載平衡。在 Blackwell 架構中,則使用具有集群啟動控制的調度程序。有關圖塊調度的更多深入信息,請參閱此 Colfax 教程。
我們可以使用以下命令擴展上述核函數以使用 Stream-K 圖塊調度程序:
using GemmKernel = cutlass::gemm::kernel::GemmUniversal< cute::Shape<int,int,int,int>, CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::StreamKScheduler >; |
CUTLASS 示例 74 是使用 Stream-K 調度程序的更詳細示例。
設備層
用于核函數啟動 (包括使用集群支持或在不同設備或 CUDA 流上啟動) 的主機端邏輯在設備層中實施。設備層的主要入口點是 cutlass::gemm::device::GemmUniversalAdapter
,它將 tg_ 32 核函數封裝在一個有狀態、可重復使用的句柄中。有狀態意味著句柄實例包含核函數需要運行的狀態 (即,它管理核函數參數本身) 。可重用意味著同一句柄實例可用于多次使用不同參數調用核函數。
GemmUniversalAdapter
是在 GitHub 上實現的。此示例展示了如何使用 GemmUniversalAdapter
啟動核函數:
using GemmHandle = cutlass::gemm::kernel::GemmUniversalAdapter<GemmKernel>; using Arguments = typename GemmHandle::Arguments; // surfaced from GemmKernel Arguments args { cutlass::Gemm::kBatched, // mode (here batched GEMM) cute::make_shape(M, N, K, L), // problem shape {A, stride_A, B, stride_B}, // mainloop args {{alpha, beta}, C, stride_C, D, stride_D}, // epilogue args make_kernel_hardware_info(device_id), // hardware info {} // scheduler args (here default) }; GemmHandle gemm; // Check that problem can run with given shape and hardware cutlass::Status status; status = GemmHandle::can_implement(args); if (status != cutlass::Status::kSuccess) { std::cerr << "Problem not supported\n" ; exit (EXIT_FAILURE); } // Set up global memory workspace size_t workspace_size = GemmHandle::get_workspace_size(args); cutlass::device_memory::allocation< uint8_t > workspace(workspace_size); // Initialize GEMM handle state from arguments status = gemm.initialize(args, workspace.get()); if (status != cutlass::Status::kSuccess) { std::cerr << "Failed to initialize GEMM kernel\n" ; exit (EXIT_FAILURE); } // Launch kernel status = gemm.run(); // can supply CUDA stream and CUDA host adaptor here if (status != cutlass::Status::kSuccess) { std::cerr << "Failed to launch GEMM kernel\n" ; exit (EXIT_FAILURE); } |
總結
在本文中,我們討論了如何將 CUTLASS 庫從概念上組織為層次結構,其中每層的對象由來自下層的正交對象組成。這種設計可實現高度可定制的 GEMM 實現,并實現高級別的代碼重用。在該系列的下一篇也是最后一篇文章中,我們將介紹 CUTLASS 4.0 中引入的更改,尤其是 CuTe Python DSL。
有關更多信息,您可以在 GitHub 上下載軟件,閱讀我們的文檔,或加入我們的開發者論壇進行更深入的討論。
致謝
感謝 Cris Cecka、Jack Kosaian、Mark Hoemmen、Haicheng Wu 和 Matt Nicely 為本文做出的貢獻。
?