• <xmp id="om0om">
  • <table id="om0om"><noscript id="om0om"></noscript></table>
  • 開發與優化

    使用 FlashInfer 運行 NVIDIA 的高性能 LLM 推理內核

    出色的 LLM 推理需要兩個關鍵要素:速度和開發者速度。速度是指通過使用高度優化的計算內核算法,最大限度地提高底層硬件的效率。開發者速度是指快速采用這些新內核并加速新模型、算法和硬件的能力。最終,這種速度的基礎是快速部署在底層 GPU 上運行的新計算內核,以及將這些內核輕松引入框架。

    FlashInfer architecture diagram shows an inference-engine-agnostic library with a unified API and flexible sparse KV-cache primitives.
    圖 1。FlashInfer 技術架構

    FlashInfer 是一個可定制的高效庫,用于構建高效的 LLM 服務引擎。它使用塊稀疏和可組合格式優化 KV 緩存存儲,以改善內存訪問并減少冗余,并具有可定制的注意力模板,可通過即時 (JIT) 編譯來適應各種設置。其負載平衡調度算法可根據動態用戶請求進行調整,同時保持與 NVIDIA CUDA Graph 靜態配置兼容。FlashInfer 已集成到領先的 LLM 服務框架 (例如 MLC Engine、SGLang 和 vLLM) 以及多個自定義引擎中。

    FlashInfer 最初是卡內基梅隆大學 Paul G. Allen 計算機科學與工程學院和 OctoAI (已被 NVIDIA 收購) 的協作研究項目。該團隊旨在創建一個靈活的 LLM 推理內核庫,該庫與引擎無關、高度優化且易于擴展,適用于 KV 緩存重用算法等新技術。它現在是一個蓬勃發展的開源項目,生產部署以及整個 AI 系統社區的研發團隊所做的貢獻。

    技術論文《FlashInfer:Efficient and Customizable Attention Engine for LLM Inference Serving》在MLSys 2025上榮獲最佳論文獎。

    NVIDIA 目前正在 FlashInfer 中積極發布性能超強的 LLM 推理內核,包括來自 NVIDIA TensorRT-LLM 的推理內核,以便輕松集成到 vLLM、SGLang 和自定義推理引擎中。

    FlashInfer 架構概述

    作為用于 LLM 服務的專用 NVIDIA GPU Operator Stack,FlashInfer 旨在提高最新內核的速度和開發者速度。推理平臺可以采用新創意,而無需等待新庫或在 CUDA C++ 中重寫內核。這些內核將通過 DLPack API 提供給所有框架,并注冊為 PyTorch 運算符,以便輕松集成到許多推理引擎中。JIT 功能使用戶能夠實現目標模型使用的核函數,這意味著 FlashInfer 的依賴項占用空間最小。

    FlashInfer 將 LLM 工作負載分為四個運算符系列 ( Attention、GEMM、通信和采樣) ,并通過輕量級、高性能的集合公開每個系列的運算符,這些集合只需更改最少的代碼即可放入任何服務引擎。

    Attention

    現代推理請求的序列長度、KV 緩存塊大小、掩碼規則和位置編碼方案各不相同。FlashInfer 通過以下方式吸收這種活力:

    • 統一存儲:將每個緩存布局表示為block/vector稀疏矩陣。
    • 模板和 JIT 內核:CUDA/CUTLASS 代碼庫,其專用旋鈕、logits/key/query、分組、MLA 和未來變體。
    • Inspector-executor 接口:一個 PyTorch 友好型 API,可首先檢查請求形狀和前綴共享模式,然后通過輕量級調度程序啟動經過調優的內核,以保持 GPU 飽和。
    A workflow showing modern inference requests with different sequence lengths, KV cache block sizes, masking rules, and positional-encoding schemes handled by FlashInfer.
    圖 2。FlashInfer 架構

    GEMM 和通信

    LLM 塊仍然嚴重依賴矩陣乘法。除了傳統的 GEMV/GEMM 計算和全歸約通信之外,近期的進展 (例如混合專家層和 LoRA 層) 引入了新的要求,例如分組 GEMM (單次調用中的許多小矩陣乘法) 和多對多通信。FlashInfer 選擇最快的開源或 NVIDIA 內核 (包括 fp4/fp8 tensor-core 路徑) ,并將它們提供給一個一致的 API,因此服務堆棧可以在不接觸應用邏輯的情況下交換 GPU 或內核。

    Token 采樣

    生成下一個 token 通常會成為 Top-K/ Top-P 過濾的瓶頸。傳統的實現會對整個詞匯進行分類,而當只有少數 logit 很重要時,這會造成浪費。FlashInfer 使用基于拒絕、無排序的采樣器取代了全局排序,該采樣器可實時修剪不可能出現的 token,從而降低大型詞匯表的延遲,并保持數字準確性。

    面向未來的推理

    有了這些層,服務框架就可以改變 KV 緩存布局,引入新的注意力設計,批量任意長度,或追求更嚴格的延遲目標,而無需重寫內核或回退到 CPU。從第一個查詢到最終 token,FlashInfer 在 GPU 上保留關鍵推理路徑,靈活、面向未來且快速。

    使用 FlashInfer

    PyPI 上提供了 Flashinfer 軟件包。您可以通過以下方式進行嘗試:

    pip install flashinfer-python

    FlashInfer 具有 Torch 原生 API,其設計為 plan/run,可解內核編譯/ 選擇/ 調整和內核運行。為便于注意,API 如下所示:

    from flashinfer.attention import BatchAttention
    attention = BatchAttention(backend="cutlass") # we provide multiple backend implementations
    attention.plan(
      qo_offsets, # offsets of each request in variable length query/output
      kv_lens, # kv length of each request in page table
      kv_block_table, # block table denoting the block indices in page table, could be packed/padded
      num_qo_heads, # number of query/output heads
      num_kv_heads, # number of key/value heads
      head_dim_qk, # head dimension of query/key
      head_dim_vo, # head dimension of value/output
      dtype_q=torch.bfloat16, # query data type
      dtype_kv=torch.bfloat16, # kv data type
      dtype_o=torch.bfloat16, # output data type
      **variant_kwargs, # other arguments specifying attention variants
    )
    O, lse = attention.run(q, (k, v)) # return output/lse

    plan 階段執行Kernel選擇和調整,該階段收集Kernel所需的元數據。相同的計劃信息可重復用于共享相同元數據 (LLM 生成步驟中的所有層) 的后續運行。

    用戶可以從多個注意力后端中進行選擇,為其用例實現最佳性能。所有內核均支持 CUDAGraph,可實現低延遲 LLM 推理服務。

    對于 logits 處理,模塊化接口由不同的 logits 處理器組成,而 flashinfer 發出基于融合拒絕采樣的高效實現。我們最近的博文解釋了 flashinfer rejection sampling 算法的工作原理。

    import flashinfer
    from flashinfer.logits_processor import LogitsPipe, Temperature, Softmax, TopP, Sample
     
    # Create a pipeline
    pipe = LogitsPipe([
        Temperature(),      # Scale logits by temperature
        Softmax(),          # Convert logits to probabilities
        TopP(),             # Apply top-p filtering
        Sample()            # Sample from the distribution
    ])
     
    # Apply the pipeline
    logits = torch.randn(batch_size, vocab_size, device="cuda")
    output_ids = pipe(logits, temperature=0.7, top_p=0.9)

    要開始使用 FlashInfer,請參閱 GitHub 資源庫文檔

    ?

    0

    標簽

    人人超碰97caoporen国产