Bitmap topk#3009
Conversation
Signed-off-by: tdophung <tdophung@nvidia.com>
Without this XLA_FFI_REGISTER_ENUM_ATTR_DECODING the FFI handler templates cannot instantiate AttrDecoding<JAXX_Routing_Map_Format>, breaking the JAX build in router.cpp. Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
…g the routing map type enum Signed-off-by: tdophung <tdophung@nvidia.com>
Greptile SummaryThis PR adds a
Confidence Score: 5/5Safe to merge; the new BITMAP_U8 path is gated behind an explicit format argument and the existing BYTEMAP default is unchanged. The CUDA kernels are templated on the routing-map format, so BYTEMAP and BITMAP_U8 code paths are completely separated at compile time. The V1 API shims preserve backward compatibility. Shape checks are added at the C layer for both formats, and the new tests verify bit-identical results against a numpy reference packer for both straight-forward and aux-loss paths, including non-multiple-of-8 expert counts. Only minor inconsistencies in contiguity checks were found. transformer_engine/pytorch/csrc/extensions/router.cpp — the backward functions check grad_probs and grad_logits for contiguity but skip the same check on routing_map and intermediate_output. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller as Python caller
participant Validate as _validate_routing_map_format
participant FwdFn as fused_topk_with_score_function
participant CppFwd as nvte_*_forward_v2 (C++)
participant Kernel as CUDA kernel
participant BwdFn as backward
participant CppBwd as nvte_*_backward_v2 (C++)
Caller->>Validate: "routing_map_format (str | enum | int)"
Validate-->>FwdFn: validated int / RoutingMapFormat
FwdFn->>CppFwd: logits, routing_map_format
CppFwd->>CppFwd: allocate routing_map shape per format
CppFwd->>Kernel: launch with RoutingMapFormat template param
alt BYTEMAP
Kernel-->>CppFwd: "routing_map[pos + e] = 1"
else BITMAP_U8
Kernel-->>CppFwd: "atomicOr(shmem[e/32], 1 << e%32) then flush row"
end
CppFwd-->>FwdFn: probs, routing_map, intermediate
FwdFn-->>Caller: probs, routing_map
Caller->>BwdFn: grad_probs
BwdFn->>CppBwd: routing_map, intermediate, grad_probs, routing_map_format
CppBwd->>Kernel: launch backward with RoutingMapFormat template
alt BYTEMAP
Kernel-->>CppBwd: "local_routing_map[i] = routing_map[pos+i] != 0"
else BITMAP_U8
Kernel-->>CppBwd: "local_routing_map[i] = (row[i/8] >> i%8) & 1"
end
CppBwd-->>BwdFn: grad_logits
BwdFn-->>Caller: grad_logits
Reviews (7): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM, thanks! I reviewed core and JAX changes but not PyTorch
|
/te_ci |
…uting_map_format
Apply the four CPU-overhead fixes the reviewer asked for and the
CLAUDE.md "CPU overhead in PyTorch wrappers" section codifies:
1. _validate_routing_map_format returns plain int (not enum); the
autograd Function + tex.* bindings only see ints. Validates via
precomputed frozenset and a single dict.get with canonical
lowercase keys (no .lower()/.upper()).
2. Type annotations on Function.forward use int (not the string
forward-ref 'RoutingMapFormat').
3. Removed every .view() from FusedTopkScoreFunction.{forward,backward}
and FusedComputeScoresForMoEAuxLoss.{forward,backward}. C++
extension now accepts N-D logits/grad_probs, computes num_tokens
from the product of leading dims, num_experts from the last dim,
allocates outputs at the user-facing N-D shape, and wraps tensors
with an explicit 2D shape via makeTransformerEngineTensor only for
the kernel call. Asserts is_contiguous() on inputs.
4. Bwd allocates grad_logits with torch.empty_like(grad_probs) (N-D)
instead of allocate-2D-then-view.
PyTorch-extension boundary takes 'int routing_map_format' and casts
to NVTERoutingMapFormat inside; the common-layer C API (nvte_*_v2)
keeps the enum.
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
…nt subclass pybind11 enum_<NVTERoutingMapFormat> binds as a standalone type, not a subclass of int. The validator must check isinstance(x, RoutingMapFormat) before the int branch and explicitly normalize via int(x). Signed-off-by: tdophung <tdophung@nvidia.com>
|
/te_ci |
| } else { | ||
| for (int i = lane_id; i < topk; i += kThreadsPerWarp) { | ||
| int e = topk_indices[i]; | ||
| atomicOr(&local_bitmap_words[e / 32], 1u << (e % 32)); |
There was a problem hiding this comment.
What is the performance you are getting with this mode?
There was a problem hiding this comment.
I did not benchmark this specific kernel, but I got numbers for the fused_topk_with_score_functions as pasted below i n reply to your other comment. I would assume it is similar, because the way we use the atomicOr was the same.
| // copy the bytemap-equivalent bytes out to the global uint8 bitmap row. | ||
| for (int i = lane_id; i < topk; i += kThreadsPerWarp) { | ||
| int e = topk_indices[i]; | ||
| atomicOr(&local_bitmap_words[e / 32], 1u << (e % 32)); |
There was a problem hiding this comment.
What is the performance of this? I'm pretty scared by this atomic or.
There was a problem hiding this comment.
there was not any meaningful difference to bytemap. I did a 3 ways comparison (between original kernel time, after the change but bytemap, and after the change choosing bitmap) and the difference were < 6% between any pair on all the sizes tested. The larger the number of tokens, the less the overhead of atomicOr is visble. Here is a chart that sumarized the numbers for kernel-only duration:

And here is a chart that summarized the measurement including the python side (I did not do bytemap post-PR for this one because it was similar in kernel times based on the above ^ chart). "fb" is forward + backward:

As you can see, the time actually decreased because we removed the .view() reshapes that was there before this PR started (in ORIG).
| // Wrap with explicit 2D shape for the kernel — the common-layer NVTE_CHECKs | ||
| // expect {num_tokens, num_experts} (or {num_tokens, ceil(num_experts/8)} for | ||
| // the bitmap routing_map). This is ~100ns of std::vector + TensorWrapper | ||
| // construction vs ~1.5us per .view() in Python. |
There was a problem hiding this comment.
While I appreciate the benchmark, I don't think it is needed in this comment.
There was a problem hiding this comment.
Im removing the benchmark but still keeping the explanation on why we want to wrap with explicit 2D shape
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
# Conflicts: # transformer_engine/common/fused_router/fused_topk_with_score_function.cu
for more information, see https://pre-commit.ci
|
/te-ci |
Description
Add a new path to our topk kernel to output the routing map in bitmap format instead of bytemap alone. The default still stay at bytemap so no regression for existing consumers downstream of this op. However, since the op now requires an additional arg to specify the routing map type (bytemap or bitmap), we introduce a V2 of the API to accomplish this, while keeping the original API the same not to break customers.
This helps NCCL EP not have to do the token_indices (sparse format) conversion to bitmap format for comms later.
Fixes #2999
Type of change
Changes
Checklist: