Skip to content

Bitmap topk#3009

Open
tdophung wants to merge 18 commits into
NVIDIA:mainfrom
tdophung:bitmap_topk
Open

Bitmap topk#3009
tdophung wants to merge 18 commits into
NVIDIA:mainfrom
tdophung:bitmap_topk

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented May 18, 2026

Description

Add a new path to our topk kernel to output the routing map in bitmap format instead of bytemap alone. The default still stay at bytemap so no regression for existing consumers downstream of this op. However, since the op now requires an additional arg to specify the routing map type (bytemap or bitmap), we introduce a V2 of the API to accomplish this, while keeping the original API the same not to break customers.

This helps NCCL EP not have to do the token_indices (sparse format) conversion to bitmap format for comms later.

Fixes #2999

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Change in the kernel for an if path that does atomicOr for all expert indices shifted by bit position to create the expert indices -> bitmap conversion.
  • Plumb the arg for routing map mode (byte map or bitmap) through pytorch and jax primitives

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

tdophung and others added 5 commits May 18, 2026 11:29
Signed-off-by: tdophung <tdophung@nvidia.com>
Without this XLA_FFI_REGISTER_ENUM_ATTR_DECODING the FFI handler
templates cannot instantiate AttrDecoding<JAXX_Routing_Map_Format>,
breaking the JAX build in router.cpp.

Signed-off-by: tdophung <tdophung@nvidia.com>
…g the routing map type enum

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung marked this pull request as ready for review May 20, 2026 20:42
@tdophung tdophung requested a review from phu0ngng May 20, 2026 20:43
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 20, 2026

Greptile Summary

This PR adds a BITMAP_U8 routing-map format to the fused MoE topk kernel, complementing the existing BYTEMAP (bool) format. Expert bits are LSB-packed into uint8 bytes along the expert axis, producing a [num_tokens, ceil(num_experts/8)] output instead of [num_tokens, num_experts]. The original C/Python APIs are preserved as V1 shims that delegate to new V2 entry points.

  • New NVTERoutingMapFormat C enum + Python RoutingMapFormat enum; _validate_routing_map_format helpers in both PyTorch and JAX accept str/enum/int inputs.
  • CUDA forward kernels accumulate the bitmap in per-warp shared memory via atomicOr, then flush to global memory; backward kernels bit-unpack on load into local bool shmem.
  • PyTorch and JAX primitives are fully plumbed with shape checks, sharding rules, and backward parity tests.

Confidence Score: 5/5

Safe to merge; the new BITMAP_U8 path is gated behind an explicit format argument and the existing BYTEMAP default is unchanged.

The CUDA kernels are templated on the routing-map format, so BYTEMAP and BITMAP_U8 code paths are completely separated at compile time. The V1 API shims preserve backward compatibility. Shape checks are added at the C layer for both formats, and the new tests verify bit-identical results against a numpy reference packer for both straight-forward and aux-loss paths, including non-multiple-of-8 expert counts. Only minor inconsistencies in contiguity checks were found.

transformer_engine/pytorch/csrc/extensions/router.cpp — the backward functions check grad_probs and grad_logits for contiguity but skip the same check on routing_map and intermediate_output.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Adds BITMAP_U8 template path: per-warp shmem atomicOr accumulator in forward, bit-unpack in backward. V1 API preserved as BYTEMAP shim over new V2 entry points.
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Same BITMAP_U8 pattern as fused_topk. V1 shim added. Consistent with the fused_topk changes.
transformer_engine/common/include/transformer_engine/fused_router.h Introduces NVTERoutingMapFormat enum and V2 C API declarations; V1 entries deprecated with clear documentation.
transformer_engine/pytorch/csrc/extensions/router.cpp Refactors all four router C++ functions: multi-dim logits support, routing_map allocation helper, V2 kernel calls. Missing contiguity checks on routing_map and intermediate_output in backward functions.
transformer_engine/pytorch/router.py Adds RoutingMapFormat enum alias, _validate_routing_map_format helper, and routing_map_format arg to both public functions. Backward return counts correctly updated.
transformer_engine/jax/cpp_extensions/router.py Adds RoutingMapFormat IntEnum, threads routing_map_format through all primitive methods. Shardy rule correctly distinguishes packed_experts from num_experts.
transformer_engine/jax/csrc/extensions/router.cpp FFI handlers correctly accept JAXX_Routing_Map_Format attr, compute routing_map_shape_2d per format, and call V2 NVTE entry points.
tests/pytorch/test_fused_router.py New tests cover BITMAP_U8 vs BYTEMAP parity for both topk-fwd+bwd and aux-loss paths, including non-multiple-of-8 expert counts (130).
tests/jax/test_fused_router.py Mirrors PyTorch test structure; verifies forward parity, backward grad bit-identity, and correct BITMAP_U8 shape.

Sequence Diagram

sequenceDiagram
    participant Caller as Python caller
    participant Validate as _validate_routing_map_format
    participant FwdFn as fused_topk_with_score_function
    participant CppFwd as nvte_*_forward_v2 (C++)
    participant Kernel as CUDA kernel
    participant BwdFn as backward
    participant CppBwd as nvte_*_backward_v2 (C++)

    Caller->>Validate: "routing_map_format (str | enum | int)"
    Validate-->>FwdFn: validated int / RoutingMapFormat
    FwdFn->>CppFwd: logits, routing_map_format
    CppFwd->>CppFwd: allocate routing_map shape per format
    CppFwd->>Kernel: launch with RoutingMapFormat template param
    alt BYTEMAP
        Kernel-->>CppFwd: "routing_map[pos + e] = 1"
    else BITMAP_U8
        Kernel-->>CppFwd: "atomicOr(shmem[e/32], 1 << e%32) then flush row"
    end
    CppFwd-->>FwdFn: probs, routing_map, intermediate
    FwdFn-->>Caller: probs, routing_map

    Caller->>BwdFn: grad_probs
    BwdFn->>CppBwd: routing_map, intermediate, grad_probs, routing_map_format
    CppBwd->>Kernel: launch backward with RoutingMapFormat template
    alt BYTEMAP
        Kernel-->>CppBwd: "local_routing_map[i] = routing_map[pos+i] != 0"
    else BITMAP_U8
        Kernel-->>CppBwd: "local_routing_map[i] = (row[i/8] >> i%8) & 1"
    end
    CppBwd-->>BwdFn: grad_logits
    BwdFn-->>Caller: grad_logits
Loading

Reviews (7): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/jax/router.py
Comment thread transformer_engine/pytorch/router.py Outdated
Comment thread transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Outdated
Comment thread transformer_engine/common/include/transformer_engine/fused_router.h Outdated
Comment thread transformer_engine/common/include/transformer_engine/fused_router.h Outdated
Comment thread transformer_engine/common/fused_router/fused_topk_with_score_function.cu Outdated
Comment thread transformer_engine/jax/cpp_extensions/router.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/misc.h
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! I reviewed core and JAX changes but not PyTorch

@tdophung
Copy link
Copy Markdown
Collaborator Author

/te_ci

Comment thread transformer_engine/pytorch/router.py Outdated
Comment thread transformer_engine/pytorch/router.py Outdated
Comment thread transformer_engine/pytorch/router.py Outdated
Comment thread transformer_engine/pytorch/router.py Outdated
@denera denera self-requested a review May 22, 2026 18:18
…uting_map_format

Apply the four CPU-overhead fixes the reviewer asked for and the
CLAUDE.md "CPU overhead in PyTorch wrappers" section codifies:

1. _validate_routing_map_format returns plain int (not enum); the
   autograd Function + tex.* bindings only see ints. Validates via
   precomputed frozenset and a single dict.get with canonical
   lowercase keys (no .lower()/.upper()).

2. Type annotations on Function.forward use int (not the string
   forward-ref 'RoutingMapFormat').

3. Removed every .view() from FusedTopkScoreFunction.{forward,backward}
   and FusedComputeScoresForMoEAuxLoss.{forward,backward}. C++
   extension now accepts N-D logits/grad_probs, computes num_tokens
   from the product of leading dims, num_experts from the last dim,
   allocates outputs at the user-facing N-D shape, and wraps tensors
   with an explicit 2D shape via makeTransformerEngineTensor only for
   the kernel call. Asserts is_contiguous() on inputs.

4. Bwd allocates grad_logits with torch.empty_like(grad_probs) (N-D)
   instead of allocate-2D-then-view.

PyTorch-extension boundary takes 'int routing_map_format' and casts
to NVTERoutingMapFormat inside; the common-layer C API (nvte_*_v2)
keeps the enum.

Signed-off-by: tdophung <tdophung@nvidia.com>
pre-commit-ci Bot and others added 2 commits May 27, 2026 00:46
…nt subclass

pybind11 enum_<NVTERoutingMapFormat> binds as a standalone type, not a
subclass of int. The validator must check isinstance(x, RoutingMapFormat)
before the int branch and explicitly normalize via int(x).

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Copy Markdown
Collaborator Author

/te_ci

Comment thread tests/pytorch/test_fused_router.py Outdated
} else {
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
int e = topk_indices[i];
atomicOr(&local_bitmap_words[e / 32], 1u << (e % 32));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the performance you are getting with this mode?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not benchmark this specific kernel, but I got numbers for the fused_topk_with_score_functions as pasted below i n reply to your other comment. I would assume it is similar, because the way we use the atomicOr was the same.

// copy the bytemap-equivalent bytes out to the global uint8 bitmap row.
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
int e = topk_indices[i];
atomicOr(&local_bitmap_words[e / 32], 1u << (e % 32));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the performance of this? I'm pretty scared by this atomic or.

Copy link
Copy Markdown
Collaborator Author

@tdophung tdophung May 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there was not any meaningful difference to bytemap. I did a 3 ways comparison (between original kernel time, after the change but bytemap, and after the change choosing bitmap) and the difference were < 6% between any pair on all the sizes tested. The larger the number of tokens, the less the overhead of atomicOr is visble. Here is a chart that sumarized the numbers for kernel-only duration:
image

And here is a chart that summarized the measurement including the python side (I did not do bytemap post-PR for this one because it was similar in kernel times based on the above ^ chart). "fb" is forward + backward:
image

As you can see, the time actually decreased because we removed the .view() reshapes that was there before this PR started (in ORIG).

Comment thread transformer_engine/pytorch/csrc/extensions/pybind.cpp Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/router.cpp Outdated
Comment on lines +76 to +79
// Wrap with explicit 2D shape for the kernel — the common-layer NVTE_CHECKs
// expect {num_tokens, num_experts} (or {num_tokens, ceil(num_experts/8)} for
// the bitmap routing_map). This is ~100ns of std::vector + TensorWrapper
// construction vs ~1.5us per .view() in Python.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I appreciate the benchmark, I don't think it is needed in this comment.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im removing the benchmark but still keeping the explanation on why we want to wrap with explicit 2D shape

Comment thread transformer_engine/pytorch/csrc/extensions.h Outdated
Comment thread transformer_engine/pytorch/router.py Outdated
tdophung and others added 4 commits May 27, 2026 19:34
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
# Conflicts:
#	transformer_engine/common/fused_router/fused_topk_with_score_function.cu
@tdophung
Copy link
Copy Markdown
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Output routing map from fused_topk_with_scores in bitmap format

3 participants