Skip to content

[PyTorch] Expert Parallelism: PyTorch wrapper + autograd ops with symm-mem zero-copy#3035

Draft
phu0ngng wants to merge 11 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-pytorch-on-commwindow
Draft

[PyTorch] Expert Parallelism: PyTorch wrapper + autograd ops with symm-mem zero-copy#3035
phu0ngng wants to merge 11 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-pytorch-on-commwindow

Conversation

@phu0ngng
Copy link
Copy Markdown
Collaborator

Summary

Second PR in the TE Expert Parallelism (EP) series. Adds the PyTorch binding on top of the common C API (#3034): exposes EP dispatch/combine as torch.library custom ops with autograd, and plumbs NCCL symmetric-memory windows through for the zero-copy path.

Payload tensors allocated via te.pytorch.symm_mem_alloc take the one-sided zero-copy path; anything else silently falls back to staged-copy, so the API is drop-in compatible with any allocator.

Implementation

Public Python API (transformer_engine/pytorch/ep.py)

from transformer_engine.pytorch.ep import (
    EpHandle, ep_bootstrap, ep_finalize,
    ep_prepare, ep_dispatch, ep_combine,
    symm_mem_alloc,
)
  • ep_bootstrap / ep_finalize — one-time per-process init + teardown (also auto-registered via atexit). Rank 0 mints an ncclUniqueId, broadcasts it on ep_group, backend opens its own ncclComm_t. Requi
    res ep_group.size() >= 4.
  • symm_mem_alloc(shape, dtype, ep_group) — allocate a per-rank tensor backed by NCCL symmetric memory, already rendezvoused on ep_group.
  • EpHandle — per-layer routing state; reuse across steps.
  • ep_prepare / ep_dispatch / ep_combine — per-step ops; both dispatch and combine are autograd-aware and registered as torch.library.custom_op, so they compose with torch.compile fullgraph capture and
    CUDA graphs.

C++ bindings (transformer_engine/pytorch/csrc/extensions/ep.cpp)

  • POD-only pybind boundary (py::bytes for ncclUniqueId, primitives for config) — no c10d ABI on the boundary.
  • maybe_make_window() looks up each payload tensor's NCCLSymmetricMemory window and returns an NVTECommWindow to the backend; non-symm-mem tensors get {nullptr, 0} and the backend picks staged-copy autom
    atically.
  • Warn-once hint when high-traffic payloads (tokens, recv_tokens, expert_out, grad) aren't symm-mem-backed. Routing-weight tensors stay silent (nice-to-have, not required). Suppress with NVTE_EP_SILENCE _NONSYMM_WARN=1.

Build

  • build_tools/pytorch.py propagates -DNVTE_WITH_NCCL_EP to the PyTorch extension. When NCCL EP is off, the extension still loads — nvte_ep_* come from the common stub and throw on first call.

Testing

  • tests/pytorch/distributed/run_ep.py — 17-test unittest suite: prepare correctness, dispatch/combine identity (uniform + non-uniform), 3D input, VJPs, top_k=1 all-to-one, alignment edge cases, CUDA grap
    h capture (eager + zero-copy), torch.compile fullgraph, bf16 autocast (eager + autograd), zero-copy autograd combine, symm-mem fallback, gradient checkpointing.
  • Launcher: tests/pytorch/distributed/run_test_ep.sh. Verified on 8×H200: Ran 17 tests in 19.8s … OK on every rank.
  • Example: examples/pytorch/ep/ep_moe.py — minimal end-to-end MoE forward+backward driver.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@phu0ngng phu0ngng requested review from ksivaman and ptrendx as code owners May 22, 2026 02:54
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR adds the PyTorch-level binding for Expert Parallelism (EP): a public Python API (ep.py), pybind11 C++ extensions (ep.cpp), a backend singleton (ep_backend.cpp), and a comprehensive distributed test suite. Payload tensors backed by NCCL symmetric memory take a zero-copy one-sided path; all others fall back to staged copy transparently.

  • transformer_engine/pytorch/ep.pyep_bootstrap/ep_finalize lifecycle, EpHandle per-layer state, ep_prepare/ep_dispatch/ep_combine as torch.library.custom_op with autograd, and symm_mem_alloc for symmetric-memory buffer allocation.
  • transformer_engine/pytorch/csrc/extensions/ep.cpp — pybind11 bindings that translate PyTorch tensors to NVTE descriptors, look up NCCL symmetric-memory windows via maybe_make_window, and forward to the C API.
  • transformer_engine/common/ep/ep_backend.cppEPBackend singleton wrapping ncclEpGroup_t, with per-layer handle caching and forward/backward dispatch/combine ops.

Confidence Score: 3/5

Two correctness bugs in the Python autograd layer and C++ bootstrap path need fixes before this lands in a production training run.

The _EpDispatch.backward fallback allocates a zero-gradient tensor shaped (max_tokens_per_rank, H) when the correct shape is (recv_capacity_per_rank, H); any training path where the upstream gradient of recv_tokens is None will hit either a runtime NVTE_CHECK or silent wrong-sized communication. Separately, ep_initialize in the C++ extension creates an NCCL communicator and then calls nvte_ep_initialize; if the latter throws, the communicator is never stored and can never be destroyed.

transformer_engine/pytorch/ep.py (_EpDispatch.backward zero-grad shape) and transformer_engine/pytorch/csrc/extensions/ep.cpp (ep_initialize NCCL comm lifetime) require the most attention before merge.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ep.py New public Python EP API: bootstrap, EpHandle, ep_prepare/dispatch/combine with autograd. Contains a shape bug in _EpDispatch.backward (wrong zero-gradient fallback for g_recv_tokens) and a non-reentrant _zero_copy_scope context manager.
transformer_engine/pytorch/csrc/extensions/ep.cpp New C++ PyTorch extension binding; implements ep_initialize, per-step ops, and symm-mem window lookup. Has a NCCL communicator resource leak on ep_initialize's error path when nvte_ep_initialize throws.
transformer_engine/common/ep/ep_backend.cpp EPBackend singleton: NCCL EP group creation, per-op dispatch/combine, and handle cache. Missing validate_config check for max_recv_tokens_per_rank > 0; otherwise well-structured with proper mutex usage and RAII handle guard.
tests/pytorch/distributed/run_ep.py Comprehensive 17-test multi-process suite covering correctness, autograd VJPs, CUDA graph capture, torch.compile, autocast, and symm-mem paths.

Reviews (1): Last reviewed commit: "ep: PyTorch wrapper, autograd ops, symm-..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/ep.py Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/ep.cpp Outdated
Comment thread transformer_engine/pytorch/ep.py Outdated
Comment on lines +558 to +568
@contextlib.contextmanager
def _zero_copy_scope(enabled: bool):
"""Toggles whether per-step ops apply the symm-mem NCCL window annotation."""
if enabled:
yield
return
tex.ep_set_zero_copy(False)
try:
yield
finally:
tex.ep_set_zero_copy(True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _zero_copy_scope does not save/restore the previous flag value

When enabled=False, the manager unconditionally sets g_zero_copy_enabled=False on entry and g_zero_copy_enabled=True on exit. If two callers both use zero_copy=False concurrently (e.g., pipeline-parallel microbatches dispatched from separate Python threads) or if the context is nested, the inner scope's finally block prematurely re-enables zero-copy while the outer scope is still active. The outer scope's finally then sets True again, but between the inner finally and the outer finally the C++ layer sees True unexpectedly.

The fix is to capture the previous value before writing and restore it unconditionally: save old = tex.ep_get_zero_copy() (adding a corresponding getter), then tex.ep_set_zero_copy(old) in the finally block. At minimum, document the single-caller-at-a-time assumption prominently so pipeline-parallel users know to serialize.

Comment thread transformer_engine/common/ep/ep_backend.cpp
@phu0ngng phu0ngng marked this pull request as draft May 22, 2026 03:03
@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch 2 times, most recently from b3966bf to b2931a6 Compare May 22, 2026 22:59
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch from f19ceff to 29ce8af Compare May 22, 2026 23:42
phu0ngng added 2 commits May 23, 2026 19:36
…em_reloc gating

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ps, symm-mem zero-copy, tests, examples

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch from 540ef54 to bacae5f Compare May 24, 2026 00:06
pre-commit-ci Bot and others added 4 commits May 24, 2026 00:06
…ang in --cuda-graph mode

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…namic combine_bwd_post, 1F1B test + bench symm-mem inputs

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Comment thread transformer_engine/pytorch/ep.py Outdated
device = expert_out.device
# Weight in payload dtype: single fused broadcast multiply into combine_in.
w = recv_topk_weights.unsqueeze(-1).to(expert_out.dtype)
torch.mul(expert_out, w, out=combine_in)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need this?🤔
At the training scenario, the weight gets multiplied onto the activation between fc1 and fc2 (we also dispatch the weight at the same time as dispatching the tokens), or am I misunderstanding something here?

My understanding is that this multiplication is unnecessary. Furthermore, if it is removed, another problem becomes more prominent: how do we add symm buffer support for the combine input? This would require changes on the grouped GEMM side.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second this. I saw unexpected kernel here and found this same problem. A potential solution is to provide a separate path when the weight is not provided. This means the weight multiplication is handled elsewhere, and in this case skip the multiplication here.

Copy link
Copy Markdown
Collaborator Author

@phu0ngng phu0ngng May 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to learn that we can fuse the weight x to the activation. I will make this optional.

We will need to change the GG to return the symmetric memory buf.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. we need change the grouped gemm I think

ep_group: dist.ProcessGroup,
num_experts: int,
max_tokens_per_rank: int,
recv_capacity_per_rank: int,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When allocating the buffer, we need to allocate according to the worst case. There are two scenarios here:

  • The first is rank-major, where the memory footprint is max_tokens_per_rank × num_of_ranks. This generally stays below 10 GB, which is the primary memory overhead of typical EP setups and is acceptable.
  • The second is expert-major, where the memory footprint is max_tokens_per_rank × num_of_ranks × min(topk, num_of_experts). This could reach 40–50 GB, which is unacceptable.

If I understand this correctly, we must find a way to optimize the memory usage in the expert-major layout — or alternatively, we need to fall back to the rank-major layout + explicit permutation approach.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the rank-major, you still need to overallocate the output buffer of local permute as in expert-major. Right?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two types of buffers:

The first is the EP buffer, which serves as the destination for communication (NCCL EP is a push-based design), so it requires a relatively costly registration process. These are reused globally as static buffers as much as possible, so they are allocated based on the worst-case size. In HEP, the rank-major output buffer is an EP buffer, so we only need a rank-major worst-case-size buffer. I haven't studied NCCL EP in detail, but my understanding is that if our output is a symmetric buffer, we don't need a built-in static comm buffer inside NCCL EP — meaning recv_capacity_per_rank is not needed when the output buffer is a symm buffer. I think this is worth discussing and clarifying.

The second type is regular GPU memory, which can be managed by the caching allocator. In HEP, the output of the permute operation falls into this category — it can be dynamically allocated each iteration based on the scan result, with just one additional sync required. Additionally, in sync-free mode, the size of this buffer is specified by the user.

To summarize, we may need to confirm whether recv_capacity_per_rank requires building an expert-major worst-case-size buffer inside NCCL EP. If the output is a symm buffer, we theoretically don't need such a buffer. However, if it is necessary, then we cannot accept an expert-major worst-case-size buffer. I also observed in my draft PR that NCCL EP uses more memory.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
It's correct that if the output buffer is a symmem, then we should not need to register the gigantic IPC/MC buffer in ep_group with the size based on recv_capacity_per_rank. Let's request NCCL EP to add an option to skip this buffer allocation.

However, I think we should still ask users to specify this recv_capacity_per_rank so that we can handle overflow policy in the metadata_preprocessing rather than delaying it to dispatch phase.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need an option to skip this internal buffer.
Also, are you thinking of using recv_capacity_per_rank to support the sync-free mechanism? That is, tokens exceeding the threshold get dropped, and then trigger the flipping of the overflow flag? I think this is incorrect — we should not set it at buffer initialization, but instead pass it as a parameter before the preprocess step of each dispatch, because the threshold changes every iteration.
cc @nanz-nv plz correct me if I made mistakes

phu0ngng added 4 commits May 27, 2026 20:55
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants