[PyTorch] Expert Parallelism: PyTorch wrapper + autograd ops with symm-mem zero-copy#3035
[PyTorch] Expert Parallelism: PyTorch wrapper + autograd ops with symm-mem zero-copy#3035phu0ngng wants to merge 11 commits into
Conversation
Greptile SummaryThis PR adds the PyTorch-level binding for Expert Parallelism (EP): a public Python API (
Confidence Score: 3/5Two correctness bugs in the Python autograd layer and C++ bootstrap path need fixes before this lands in a production training run. The _EpDispatch.backward fallback allocates a zero-gradient tensor shaped (max_tokens_per_rank, H) when the correct shape is (recv_capacity_per_rank, H); any training path where the upstream gradient of recv_tokens is None will hit either a runtime NVTE_CHECK or silent wrong-sized communication. Separately, ep_initialize in the C++ extension creates an NCCL communicator and then calls nvte_ep_initialize; if the latter throws, the communicator is never stored and can never be destroyed. transformer_engine/pytorch/ep.py (_EpDispatch.backward zero-grad shape) and transformer_engine/pytorch/csrc/extensions/ep.cpp (ep_initialize NCCL comm lifetime) require the most attention before merge. Important Files Changed
Reviews (1): Last reviewed commit: "ep: PyTorch wrapper, autograd ops, symm-..." | Re-trigger Greptile |
| @contextlib.contextmanager | ||
| def _zero_copy_scope(enabled: bool): | ||
| """Toggles whether per-step ops apply the symm-mem NCCL window annotation.""" | ||
| if enabled: | ||
| yield | ||
| return | ||
| tex.ep_set_zero_copy(False) | ||
| try: | ||
| yield | ||
| finally: | ||
| tex.ep_set_zero_copy(True) |
There was a problem hiding this comment.
_zero_copy_scope does not save/restore the previous flag value
When enabled=False, the manager unconditionally sets g_zero_copy_enabled=False on entry and g_zero_copy_enabled=True on exit. If two callers both use zero_copy=False concurrently (e.g., pipeline-parallel microbatches dispatched from separate Python threads) or if the context is nested, the inner scope's finally block prematurely re-enables zero-copy while the outer scope is still active. The outer scope's finally then sets True again, but between the inner finally and the outer finally the C++ layer sees True unexpectedly.
The fix is to capture the previous value before writing and restore it unconditionally: save old = tex.ep_get_zero_copy() (adding a corresponding getter), then tex.ep_set_zero_copy(old) in the finally block. At minimum, document the single-caller-at-a-time assumption prominently so pipeline-parallel users know to serialize.
b3966bf to
b2931a6
Compare
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
f19ceff to
29ce8af
Compare
…em_reloc gating Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ps, symm-mem zero-copy, tests, examples Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
540ef54 to
bacae5f
Compare
for more information, see https://pre-commit.ci
…ang in --cuda-graph mode Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
…namic combine_bwd_post, 1F1B test + bench symm-mem inputs Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
| device = expert_out.device | ||
| # Weight in payload dtype: single fused broadcast multiply into combine_in. | ||
| w = recv_topk_weights.unsqueeze(-1).to(expert_out.dtype) | ||
| torch.mul(expert_out, w, out=combine_in) |
There was a problem hiding this comment.
why we need this?🤔
At the training scenario, the weight gets multiplied onto the activation between fc1 and fc2 (we also dispatch the weight at the same time as dispatching the tokens), or am I misunderstanding something here?
My understanding is that this multiplication is unnecessary. Furthermore, if it is removed, another problem becomes more prominent: how do we add symm buffer support for the combine input? This would require changes on the grouped GEMM side.
There was a problem hiding this comment.
Second this. I saw unexpected kernel here and found this same problem. A potential solution is to provide a separate path when the weight is not provided. This means the weight multiplication is handled elsewhere, and in this case skip the multiplication here.
There was a problem hiding this comment.
Good to learn that we can fuse the weight x to the activation. I will make this optional.
We will need to change the GG to return the symmetric memory buf.
There was a problem hiding this comment.
Yes. we need change the grouped gemm I think
| ep_group: dist.ProcessGroup, | ||
| num_experts: int, | ||
| max_tokens_per_rank: int, | ||
| recv_capacity_per_rank: int, |
There was a problem hiding this comment.
When allocating the buffer, we need to allocate according to the worst case. There are two scenarios here:
- The first is rank-major, where the memory footprint is max_tokens_per_rank × num_of_ranks. This generally stays below 10 GB, which is the primary memory overhead of typical EP setups and is acceptable.
- The second is expert-major, where the memory footprint is max_tokens_per_rank × num_of_ranks × min(topk, num_of_experts). This could reach 40–50 GB, which is unacceptable.
If I understand this correctly, we must find a way to optimize the memory usage in the expert-major layout — or alternatively, we need to fall back to the rank-major layout + explicit permutation approach.
There was a problem hiding this comment.
With the rank-major, you still need to overallocate the output buffer of local permute as in expert-major. Right?
There was a problem hiding this comment.
There are two types of buffers:
The first is the EP buffer, which serves as the destination for communication (NCCL EP is a push-based design), so it requires a relatively costly registration process. These are reused globally as static buffers as much as possible, so they are allocated based on the worst-case size. In HEP, the rank-major output buffer is an EP buffer, so we only need a rank-major worst-case-size buffer. I haven't studied NCCL EP in detail, but my understanding is that if our output is a symmetric buffer, we don't need a built-in static comm buffer inside NCCL EP — meaning recv_capacity_per_rank is not needed when the output buffer is a symm buffer. I think this is worth discussing and clarifying.
The second type is regular GPU memory, which can be managed by the caching allocator. In HEP, the output of the permute operation falls into this category — it can be dynamically allocated each iteration based on the scan result, with just one additional sync required. Additionally, in sync-free mode, the size of this buffer is specified by the user.
To summarize, we may need to confirm whether recv_capacity_per_rank requires building an expert-major worst-case-size buffer inside NCCL EP. If the output is a symm buffer, we theoretically don't need such a buffer. However, if it is necessary, then we cannot accept an expert-major worst-case-size buffer. I also observed in my draft PR that NCCL EP uses more memory.
There was a problem hiding this comment.
Hi,
It's correct that if the output buffer is a symmem, then we should not need to register the gigantic IPC/MC buffer in ep_group with the size based on recv_capacity_per_rank. Let's request NCCL EP to add an option to skip this buffer allocation.
However, I think we should still ask users to specify this recv_capacity_per_rank so that we can handle overflow policy in the metadata_preprocessing rather than delaying it to dispatch phase.
There was a problem hiding this comment.
We need an option to skip this internal buffer.
Also, are you thinking of using recv_capacity_per_rank to support the sync-free mechanism? That is, tokens exceeding the threshold get dropped, and then trigger the flipping of the overflow flag? I think this is incorrect — we should not set it at buffer initialization, but instead pass it as a parameter before the preprocess step of each dispatch, because the threshold changes every iteration.
cc @nanz-nv plz correct me if I made mistakes
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Summary
Second PR in the TE Expert Parallelism (EP) series. Adds the PyTorch binding on top of the common C API (#3034): exposes EP dispatch/combine as
torch.librarycustom ops with autograd, and plumbs NCCL symmetric-memory windows through for the zero-copy path.Payload tensors allocated via
te.pytorch.symm_mem_alloctake the one-sided zero-copy path; anything else silently falls back to staged-copy, so the API is drop-in compatible with any allocator.Implementation
Public Python API (
transformer_engine/pytorch/ep.py)ep_bootstrap/ep_finalize— one-time per-process init + teardown (also auto-registered viaatexit). Rank 0 mints anncclUniqueId, broadcasts it onep_group, backend opens its ownncclComm_t. Requires
ep_group.size() >= 4.symm_mem_alloc(shape, dtype, ep_group)— allocate a per-rank tensor backed by NCCL symmetric memory, already rendezvoused onep_group.EpHandle— per-layer routing state; reuse across steps.ep_prepare/ep_dispatch/ep_combine— per-step ops; both dispatch and combine are autograd-aware and registered astorch.library.custom_op, so they compose withtorch.compilefullgraph capture andCUDA graphs.
C++ bindings (
transformer_engine/pytorch/csrc/extensions/ep.cpp)py::bytesforncclUniqueId, primitives for config) — no c10d ABI on the boundary.maybe_make_window()looks up each payload tensor'sNCCLSymmetricMemorywindow and returns anNVTECommWindowto the backend; non-symm-mem tensors get{nullptr, 0}and the backend picks staged-copy automatically.
tokens,recv_tokens,expert_out,grad) aren't symm-mem-backed. Routing-weight tensors stay silent (nice-to-have, not required). Suppress withNVTE_EP_SILENCE _NONSYMM_WARN=1.Build
build_tools/pytorch.pypropagates-DNVTE_WITH_NCCL_EPto the PyTorch extension. When NCCL EP is off, the extension still loads —nvte_ep_*come from the common stub and throw on first call.Testing
tests/pytorch/distributed/run_ep.py— 17-testunittestsuite:preparecorrectness, dispatch/combine identity (uniform + non-uniform), 3D input, VJPs,top_k=1all-to-one, alignment edge cases, CUDA graph capture (eager + zero-copy),
torch.compilefullgraph, bf16 autocast (eager + autograd), zero-copy autograd combine, symm-mem fallback, gradient checkpointing.tests/pytorch/distributed/run_test_ep.sh. Verified on 8×H200:Ran 17 tests in 19.8s … OKon every rank.examples/pytorch/ep/ep_moe.py— minimal end-to-end MoE forward+backward driver.Type of change
Checklist: