Skip to content

[Common] Optimize fused router forward/backward kernels#3012

Open
harryzhou2000 wants to merge 18 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_fix_p3R
Open

[Common] Optimize fused router forward/backward kernels#3012
harryzhou2000 wants to merge 18 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_fix_p3R

Conversation

@harryzhou2000
Copy link
Copy Markdown
Member

@harryzhou2000 harryzhou2000 commented May 19, 2026

Summary

Optimizes the fused router CUDA kernels introduced in #2821 (fused_topk_with_score_function and fused_score_for_moe_aux_loss). Achieves significant bandwidth improvements for large expert counts and topk values while preserving identical performance for smaller configurations (e.g., E=256, topk=4).

Key results (B300, float32, 8192 tokens):

  • Forward (E=2304, K=36, softmax): 673 → 964 GB/s (+43%)
  • Backward (E=2304, K=36, softmax): 543 → 2766 GB/s (+410%)
  • Forward (E=512, K=4): no regression (±0.3%)

Changes

Forward kernels

  • Persistent grid with async double-buffered prefetch: RawAsyncLoader<T> uses cp.async (sm_80+) for non-blocking global→shmem loads. Occupancy-aware grid sizing (compute_persistent_grid) keeps all SMs saturated across multiple rounds.
  • Packed 8-bit radix histogram: Reduces radix topk register usage from 32 to 4 registers by packing 16 bucket counts into 4×u32 with 8-bit fields. Eliminates local memory spill at large E.
  • Compile-time score function dispatch: ScoreFunc template parameter with if constexpr removes runtime branches from the hot loop.
  • Simple kernel path for small topk: When topk < NVTE_RADIX_TOPK_THRESHOLD (default 8), dispatches to a lightweight kernel matching the original structure — no async loader, no persistent grid — avoiding scheduling overhead that dominates at small K.

Backward kernels

  • Two-pass fused design: Pass 1 accumulates warp-level sums via register reduction + warp_allreduce_sum. Pass 2 computes per-element gradients using scalar helpers. Eliminates the comp_buf shared memory buffer (saves E × warps × 4 bytes per block).
  • Double-buffered async loading: All backward inputs (grad, activation, mask) loaded through RawAsyncLoader with always-on double buffering.

Infrastructure

  • async_loader.h: RawAsyncLoader<T>, compute_persistent_grid(), choose_num_buffers(), vectorized global store/fill helpers.
  • NVTE_RADIX_TOPK_THRESHOLD env var (default 8): configurable naive↔radix crossover.
  • Templated warp_reduce_on_shmem<T, ReduceFuncType> eliminates function-pointer overhead.

Hardening

  • Host-side: num_tokens * num_experts <= INT_MAX, topk ∈ [1, E], topk % group_topk == 0
  • Device-side: assert(data_size <= kMaxExpertsRadixTopk) in radix path
  • Correct cudaDevAttrMaxSharedMemoryPerMultiprocessor for buffer-count decision
  • Fix: single-buffer prefetch clobber when shmem is too tight for double buffering

Compatibility

  • No regression for small configs: The simple forward kernel path is an exact replica of the original kernel structure, ensuring E=256/topk=4 (common in standard MoE) performs identically.
  • All existing tests pass: 891/891 test_fused_router.py tests pass, 117 skipped (fp8/multi-node).
  • No API changes: Same Python/C++ interface, same output semantics.
  • Tunable: Set NVTE_RADIX_TOPK_THRESHOLD=0 to force radix everywhere, or =16 to use naive for topk<16.

Performance (B300 SXM6, sm_103, float32, 8192 tokens)

Effective bandwidth (GB/s) is computed as the minimum bytes that must be transferred to/from global memory for one kernel invocation, divided by the measured wall time. For example, the topk forward kernel reads logits (T×E×dtype) and writes probs (T×E×dtype), routing_map (T×E×1), and intermediate_output (T×E×4). This metric captures how well the kernel utilizes memory bandwidth — higher is better, with the device peak around 8 TB/s on B300. Config format is num_experts/topk.

Full benchmark table (softmax)
kernel pass config before after
topk fprop 512/4 1779 1784 (+0.3%)
topk fprop 512/8 798 904 (+13%)
topk fprop 512/22 514 924 (+80%)
topk fprop 512/36 499 908 (+82%)
topk fprop 2304/4 1803 1802 (0%)
topk fprop 2304/8 660 993 (+51%)
topk fprop 2304/22 602 972 (+61%)
topk fprop 2304/36 673 964 (+43%)
topk bprop 512/22 3391 5362 (+58%)
topk bprop 2304/36 543 2766 (+410%)
aux_loss fprop 512/22 519 896 (+73%)
aux_loss fprop 2304/36 645 891 (+38%)
aux_loss bprop 512/22 5289 6155 (+16%)
aux_loss bprop 2304/36 2272 4201 (+85%)
Full benchmark table (sigmoid)
kernel pass config before after
topk fprop 512/4 1728 1736 (+0.5%)
topk fprop 512/22 470 891 (+90%)
topk fprop 2304/36 639 798 (+25%)
topk bprop 512/22 3169 4398 (+39%)
topk bprop 2304/36 533 2274 (+327%)
aux_loss fprop 512/22 475 912 (+92%)
aux_loss fprop 2304/36 598 867 (+45%)
aux_loss bprop 2304/36 1965 2757 (+40%)

@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p3R branch 2 times, most recently from 14a302c to a805f38 Compare May 19, 2026 10:22
@harryzhou2000 harryzhou2000 marked this pull request as ready for review May 20, 2026 08:29
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 20, 2026

Greptile Summary

This PR optimizes the fused router CUDA kernels (fused_topk_with_score_function and fused_score_for_moe_aux_loss) with persistent-grid double-buffered async loading, a packed 8-bit radix histogram top-K, and a two-pass fused backward that eliminates the comp_buf shared memory buffer. A lightweight simple-kernel path is preserved for small topk to avoid scheduling overhead.

  • Forward: Persistent grid with cp.async double buffering via RawAsyncLoader<T>; compile-time score-function dispatch (ScoreFunc template); naive path dispatches to an exact copy of the original kernel when topk < NVTE_RADIX_TOPK_THRESHOLD.
  • Backward: Two-pass warp-level design (Pass 1 accumulates reduction sums, Pass 2 computes per-element gradients); all three inputs (grad, activations, mask) loaded through double-buffered async loaders.
  • Infrastructure: async_loader.h adds RawAsyncLoader<T>, compute_persistent_grid(), choose_num_buffers() (occupancy-aware, uses cudaDevAttrMaxSharedMemoryPerMultiprocessor), and vectorized global store/fill helpers.

Confidence Score: 5/5

Safe to merge; the optimization is well-structured, all existing tests pass, and no API surface changes.

The double-buffering pipeline, radix topk, two-pass backward, and occupancy-aware buffer selection are all correctly implemented. The shmem capacity checks are correctly scoped per code path, fixing the premature-check issues from the prior review. The only item flagged is null-pointer arithmetic on device code in the simple forward kernel when group_topk==0 — NVCC generates harmless PTX arithmetic in this case and the pointer is never dereferenced, but a one-line null-guard would align it with the already-correct optimized kernel.

transformer_engine/common/fused_router/fused_topk_with_score_function.cu — simple kernel null-pointer arithmetic for masked_scores/group_scores when group_topk==0.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/async_loader.h New header introducing RawAsyncLoader, compute_persistent_grid(), choose_num_buffers(), and vectorized store/fill helpers. Logic for double-buffering, alignment checks, and cp.async pipeline management is sound; the scalar-fallback cp_async_commit() bug from the previous review is correctly absent.
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Adds simple kernel path (exact upstream structure) and optimized persistent/double-buffered kernel with compile-time score-function dispatch and radix topk. Shmem check is now correctly gated per path. Minor UB (null + offset) exists in the simple kernel for masked_scores/group_scores when group_topk==0.
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Mirrors the topk file: adds simple kernel + optimized kernel with async loader and persistent grid. Shmem capacity checks correctly scoped per code path. Two-pass backward eliminates comp_buf. Math for all three score functions verified correct.
transformer_engine/common/fused_router/utils.h Adds warp_reduce_on_shmem<T,type> template (compile-time ReduceFuncType dispatch), scalar activation/backward helpers, packed 8-bit radix topk, and naive topk variants. All implementations appear numerically correct; radix constraint kMaxExpertsRadixTopk=8160 properly enforced with device-side assert.
transformer_engine/common/fused_router/fused_moe_aux_loss.cu Minimal change: updates warp_reduce_on_shmem call from runtime ReduceFuncType dispatch to templated compile-time form. Correct and consistent with the new API.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Launcher: fused_topk_forward] --> B{topk >= radix_threshold\nAND num_experts <= 8160?}
    B -- No --> C[Simple kernel path\nnaive topk, runtime score_fn\nno async loader]
    B -- Yes --> D[choose_num_buffers\noccupancy-aware]
    D --> E{double-buffer fits?}
    E -- Yes\nnum_buffers=2 --> F[compute_persistent_grid\ncudaOccupancyMax...]
    E -- No\nnum_buffers=1 --> F
    F --> G[Optimized kernel\npersistent grid + RawAsyncLoader]
    G --> H[Round loop]
    H --> I[wait current buffer]
    I --> J[start_load next buffer\ndouble-buf only]
    J --> K[Convert DataType to CompType\nApply score fn via if constexpr\nScoreFunc=0/1/2]
    K --> L[radix_topk_and_mask\n8-bit packed histogram]
    L --> M[Write probs + routing_map\nvec_store_global]
    M --> N[loader.flip]
    N --> H

    style C fill:#d4edda
    style G fill:#cce5ff
    style L fill:#fff3cd
Loading

Reviews (4): Last reviewed commit: "[Common] Restore cudnn-frontend pointer ..." | Re-trigger Greptile

Comment thread transformer_engine/common/fused_router/async_loader.h
@tdophung tdophung self-assigned this May 20, 2026
Replace multi-loop preprocess (separate clear/load/score/save/bias loops)
with single fused loops per score function in all 4 kernel paths (topk
forward, topk backward, aux_loss forward, aux_loss backward).

Replace multi-pass backward (array-based helpers + comp_buf shmem) with
a two-pass approach using scalar helpers:
  Pass 1: reduction — warp-level sums via warp_allreduce_sum()
  Pass 2: element-wise — scalar gradient computation → write to global

Add scalar helpers to utils.h: sigmoid_scalar, sqrtsoftplus_scalar,
sigmoid_bwd_scalar, sqrtsoftplus_bwd_scalar, normalize_bwd_scalar,
softmax_bwd_scalar.

Remove dead array helpers from utils.h: apply_sigmoid_on_float,
apply_sigmoid_bwd_on_float, apply_sqrtsoftplus_on_float,
apply_sqrtsoftplus_bwd_on_float, apply_softmax_bwd_on_float,
masked_warp_reduce_on_shmem.

Backward shmem reduced by E×W×sizeof(float) per kernel (comp_buf
eliminated).  Net -226 lines across 3 files.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Add async_loader.h with:
  - RawAsyncLoader<T>: cp.async on sm_80+, int4 fallback on sm_70,
    stores data in original type (no conversion during copy)
  - compute_persistent_grid(): occupancy-based grid sizing
  - choose_num_buffers(): shmem-aware 1-vs-2 buffer decision
  - vec_fill_global(), vec_store_global(): vectorized output helpers

Forward kernels (topk + aux_loss):
  - Logits loaded via RawAsyncLoader with double-buffered prefetch
  - Persistent grid replaces 1-shot grid launch
  - DataType→CompType conversion during compute, not during load
  - vec_fill_global for clearing probs/routing_map

Backward kernels (topk + aux_loss):
  - All inputs loaded via RawAsyncLoader (topk: 3 loaders for
    grad/act/mask; aux_loss: 2 loaders for grad/act)
  - Always double-buffered (kBwdNumBuffers=2, kAuxBwdNumBuffers=2)
  - Persistent grid with occupancy-based sizing

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Replace counts[16] + total_counts[16] (32 registers) with 4 packed u32
registers using 8-bit fields (4 counters per register).  Eliminates
massive register spill to local memory on large kernels (81% of L1
traffic on E=2304, K=36).

Add kMaxExpertsRadixTopk constant (8160 = 255 * 32) and runtime checks
in both forward launchers to guard against 8-bit overflow.  All current
MoE configurations (max E=2304) are well within this limit.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
…dispatch

Replace runtime score_function parameter in all 4 kernel __global__
functions with template int ScoreFunc (0=sigmoid, 1=softmax,
2=sqrtsoftplus).  All score_function branches now use if constexpr,
eliminating dead-code register pressure and branch overhead.

Forward launchers dispatch on TopkFunc × ScoreFunc = 6 instantiations
per DataType.  Backward launchers dispatch on ScoreFunc = 3
instantiations per DataType.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Fix broken topk < 0 threshold (radix was always selected, naive
unreachable).  Replace with configurable NVTE_RADIX_TOPK_THRESHOLD
env var (default 0, i.e. always use radix).  Set to 16 to restore
the old naive-for-small-K behavior.

Uses the standard TE pattern: static local + getenv (read once,
cached for process lifetime).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
When choose_num_buffers() returns 1 (shmem too tight for double
buffering, e.g. E=1024 with group_topk scratch), buf_[0] and buf_[1]
alias the same memory.  The prefetch via start_load(next_buf()) then
overwrites the current buffer while compute is still reading it.

Fix: guard the prefetch on num_buffers > 1.  When single-buffered,
load the current round's data at the top of each iteration instead.
The first round's load_current is still issued before the loop.

Backward kernels are unaffected (always kBwdNumBuffers=2).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Code review fixes:

- C1: choose_num_buffers() now queries cudaDevAttrMaxSharedMemoryPerMultiprocessor
  (per-SM budget) instead of cudaDevAttrMaxSharedMemoryPerBlockOptin (per-block
  max).  These coincide on Hopper/Blackwell but differ on Ampere.

- H3: Remove dead fallback branch in choose_num_buffers() — since
  total_double >= total_single always, blocks_single >= blocks_double,
  so the old ternary always returned 1 anyway.

- H4/M8: Add host-side NVTE_CHECK in all 4 launchers:
  - num_experts > 0
  - topk in [1, num_experts]
  - (int64_t)num_tokens * num_experts <= INT_MAX (kernel uses int offsets)

- M9: Assert topk % group_topk == 0 when group_topk > 0.

- H6: Add device-side assert(data_size <= kMaxExpertsRadixTopk) in
  radix_topk_and_mask() — zero cost in release (NDEBUG), catches
  8-bit histogram overflow in debug builds.

- L1: Fix stale comments claiming default threshold is 16 (it is 0).
- L4: Fix typo 'hanlded' -> 'handled'.
- L8: Remove unused topk parameter from aux loss backward kernel.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Move the duplicated static function from both .cu files into utils.h
as an inline function.  Each TU gets its own static local (read-once
per TU), which is safe since environment variables are immutable
during process lifetime.  Documented this in a NOTE comment.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Replace runtime function-pointer dispatch with compile-time if constexpr.
Eliminates indirect call overhead in the reduction loop and warp shuffle
butterfly, allowing the compiler to emit straight-line arithmetic.

Removes the now-unused max<T>() and sum<T>() helper functions.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
When topk < NVTE_RADIX_TOPK_THRESHOLD (default 8), use a lightweight
forward kernel that avoids the async loader and persistent grid overhead.
The simple kernel loads logits directly from global memory to shmem and
uses Naive iterative-argmax topk — matching the baseline structure that
was faster for small K due to lower launch/scheduling overhead.

The optimized path (async loader + persistent grid + radix topk) remains
the default for topk >= 8 where the compute savings dominate.

Both topk and aux_loss forward kernels get the simple variant.
Backward kernels are unchanged (always use the optimized path).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Use 0.0f instead of 0 to avoid ambiguity between __nv_bfloat16(float)
and __nv_bfloat16(double) constructors on older CUDA toolkits.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p3R branch from 9a7cb7e to 3bab7cb Compare May 21, 2026 03:03
Comment thread transformer_engine/common/fused_router/fused_topk_with_score_function.cu Outdated
Comment thread transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Outdated
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants