Expert Parallelism: common C API + NCCL EP backend#3034
Conversation
Greptile SummaryThis PR lands the foundational Expert Parallelism (EP) layer for TransformerEngine: a common C API (
Confidence Score: 4/5Safe to merge with one build issue addressed: the public The
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant C_API as nvte_ep_* (ep_api.cpp)
participant Backend as EPBackend singleton
participant NCCL_EP as ncclEp* (libnccl_ep.so)
Caller->>C_API: nvte_ep_initialize(ep_comm, group_config)
C_API->>Backend: EPBackend::initialize()
Backend->>NCCL_EP: ncclEpCreateGroup()
Caller->>C_API: "nvte_ep_register_layer(layer_config, &mem_size)"
C_API->>Backend: register_layer()
Backend->>NCCL_EP: ncclEpHandleMemSize()
Backend-->>Caller: handle_id + required mem_size
Note over Caller: Allocates handle_mem buffer
loop Per training step
Caller->>C_API: nvte_ep_prepare(handle, topk_idx, token_counts, stream)
C_API->>Backend: prepare() → ncclEpUpdateHandle()
Backend->>NCCL_EP: ncclEpUpdateHandle (AllGather routing map)
Caller->>C_API: nvte_ep_dispatch(handle, tokens, [win], weights, [win], stream)
C_API->>Backend: dispatch() → ncclEpDispatch()
Backend->>NCCL_EP: ncclEpDispatch (scatter tokens to expert ranks)
Note over Caller: Expert computation on recv_tokens
Caller->>C_API: nvte_ep_combine(handle, expert_out, [win], result, stream)
C_API->>Backend: combine() → ncclEpCombine()
Backend->>NCCL_EP: ncclEpCombine (scatter-sum back to source ranks)
end
Caller->>C_API: nvte_ep_shutdown()
C_API->>Backend: EPBackend::shutdown()
Backend->>NCCL_EP: ncclEpHandleDestroy + ncclEpGroupDestroy
Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| cfg.algorithm = NCCL_EP_ALGO_HIGH_THROUGHPUT; | ||
| cfg.num_experts = static_cast<unsigned int>(group_config.num_experts); | ||
| cfg.max_dispatch_tokens_per_rank = static_cast<unsigned int>(group_config.max_tokens_per_rank); | ||
| cfg.max_token_bytes = static_cast<unsigned int>(group_config.hidden_dim * sizeof(nv_bfloat16)); |
There was a problem hiding this comment.
max_token_bytes hardcoded to sizeof(nv_bfloat16) breaks float32 dispatch
cfg.max_token_bytes is computed as hidden_dim * sizeof(nv_bfloat16) (2 bytes), but nvte_dtype_to_nccl supports float32, float16, int32, int64, float8, etc. When a caller creates the EP group with this config and later dispatches float32 tokens (via nvte_ep_dispatch), the pre-allocated max_token_bytes is half the required size. NCCL EP uses this value to size internal staging buffers at group creation; dispatching a wider dtype silently overruns those buffers or triggers an internal NCCL error. NVTEEpGroupConfig needs a dtype (or max_token_element_bytes) field so callers can declare the maximum element width they will use.
There was a problem hiding this comment.
Note for myself: Need to expose this option for users to set in ep_bootstrap.
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
099857f to
17e5126
Compare
…em_reloc gating Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
| endif() | ||
|
|
||
| find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) | ||
| find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED NO_CMAKE_SYSTEM_PATH) |
| # No MPI dependency — processes are spawned by run_test_ep.sh with | ||
| # --rank / --nranks flags. ncclUniqueId exchange uses a | ||
| # shared temp file (see test_ep_common.h for details). |
There was a problem hiding this comment.
I believe that the other distributed tests do rely on MPI, so why don't we also do that here?
| # nvrtc symbols are referenced from libtransformer_engine.so but not in its | ||
| # DT_NEEDED list (loaded via dlopen in Python). For cpp tests we link nvrtc | ||
| # explicitly with --no-as-needed so the linker keeps the dependency. | ||
| set(EP_TEST_LINK_OPTS "LINKER:--no-as-needed") |
There was a problem hiding this comment.
This sounds like a bug actually, but the other tests do not need to do this, they instead specify the nvrtc after the TE_LIB in the LINKER_LIBS variable.
| // ── Error-checking macros ───────────────────────────────────────────────────── | ||
|
|
||
| #define CHECK_NCCL(expr) \ | ||
| do { \ | ||
| ncclResult_t _err = (expr); \ | ||
| if (_err != ncclSuccess) \ | ||
| FAIL() << "NCCL error " << _err << ": " << ncclGetErrorString(_err); \ | ||
| } while (false) | ||
|
|
||
| #define CHECK_CUDA(expr) \ | ||
| do { \ | ||
| cudaError_t _err = (expr); \ | ||
| if (_err != cudaSuccess) \ | ||
| FAIL() << "CUDA error " << _err << ": " << cudaGetErrorString(_err); \ | ||
| } while (false) | ||
|
|
||
| #define ASSERT_CUDA_OK(expr) \ | ||
| do { \ | ||
| cudaError_t _err = (expr); \ | ||
| if (_err != cudaSuccess) { \ | ||
| fprintf(stderr, "CUDA error %d: %s\n", _err, cudaGetErrorString(_err)); \ | ||
| exit(EXIT_FAILURE); \ | ||
| } \ | ||
| } while (false) | ||
|
|
||
| #define ASSERT_NCCL_OK(expr) \ | ||
| do { \ | ||
| ncclResult_t _err = (expr); \ | ||
| if (_err != ncclSuccess) { \ | ||
| fprintf(stderr, "NCCL error %d: %s\n", _err, ncclGetErrorString(_err)); \ | ||
| exit(EXIT_FAILURE); \ | ||
| } \ | ||
| } while (false) |
| struct TensorHandle { | ||
| NVTETensor tensor = nullptr; | ||
| void* dev_ptr = nullptr; | ||
|
|
||
| ~TensorHandle() { | ||
| if (tensor) nvte_destroy_tensor(tensor); | ||
| } | ||
|
|
||
| TensorHandle() = default; | ||
| TensorHandle(const TensorHandle&) = delete; | ||
| TensorHandle& operator=(const TensorHandle&) = delete; | ||
|
|
||
| TensorHandle(TensorHandle&& o) noexcept : tensor(o.tensor), dev_ptr(o.dev_ptr) { | ||
| o.tensor = nullptr; o.dev_ptr = nullptr; | ||
| } | ||
| TensorHandle& operator=(TensorHandle&& o) noexcept { | ||
| if (this != &o) { | ||
| if (tensor) nvte_destroy_tensor(tensor); | ||
| tensor = o.tensor; dev_ptr = o.dev_ptr; | ||
| o.tensor = nullptr; o.dev_ptr = nullptr; | ||
| } | ||
| return *this; | ||
| } | ||
| }; |
|
|
||
| // RAII owner for a cudaMalloc'd device buffer; frees on destruction. | ||
| template <typename T> | ||
| struct DevBuf { |
There was a problem hiding this comment.
We have a very similar thing already in the test_common.h
| @@ -0,0 +1,64 @@ | |||
| /************************************************************************* | |||
There was a problem hiding this comment.
I feel that having those tests as separate entities does not really make sense and will introduce overhead to the CI - the actual functionality tests would already be able to cover those initialization issues, no?
| }; | ||
|
|
||
| // Bundled NVTETensor views over an EPBuffers — one place to update the shape | ||
| // conventions when the C-API evolves. |
There was a problem hiding this comment.
What do you mean by "when the C-API evolves"? We should aim for stability of the C API.
| CHECK_CUDA(cudaMemcpy(h_result.data(), buf.result.get(), | ||
| h_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); | ||
| auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); | ||
| // Spot-check 3 hidden-dim positions per token to catch partial-row writes. |
There was a problem hiding this comment.
What? Why don't we check the full data?
| // Spot-check 3 hidden-dim positions per token to catch partial-row writes. | ||
| const int probes[3] = {0, hidden_dim_ / 2, hidden_dim_ - 1}; | ||
| for (int tok = 0; tok < num_tokens_; ++tok) { | ||
| float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast<float>(top_k_); |
There was a problem hiding this comment.
Why do we hardcode BF16 everywhere? I assume that NCCL EP works with the other datatypes, right?
| // BF16 has 7 mantissa bits; relative ULP ≈ 2^-7. Use 4× headroom for | ||
| // accumulation noise inside dispatch/combine. | ||
| static float bf16_tol(float magnitude) { | ||
| return 4.f * std::ldexp(std::fabs(magnitude) + 1e-3f, -7); | ||
| } |
There was a problem hiding this comment.
So why can't we just use rtol 2^-5 rather than this formula? In general the error checking here is very custom, could we integrate it better with the rest of the tests?
| @@ -0,0 +1,562 @@ | |||
| /************************************************************************* | |||
There was a problem hiding this comment.
What are the cases that this test would catch that the ep_pipeline one would not?
| namespace transformer_engine { | ||
| namespace ep { | ||
|
|
||
| /*! \brief EP backend singleton — owns the NCCL EP group; borrows the comm. */ |
There was a problem hiding this comment.
If it borrows the communicator then on the framework side we need to make sure that it stays alive.
Also, if it is a singleton, how does it work with multiple GPUs per process?
|
|
||
| // Host-only: reserve a fresh handle_id, cache the layer config, and report | ||
| // the handle_mem buffer size the caller must allocate. | ||
| uint64_t register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size); |
There was a problem hiding this comment.
Is it ever-growing? I don't see any free_layer API.
| typedef struct { | ||
| int ep_size; /*!< EP world size. */ | ||
| int num_experts; /*!< Total experts across all ranks. */ | ||
| int max_tokens_per_rank; /*!< Upper bound on tokens this rank sends per dispatch. */ | ||
| /*! Upper bound on tokens received per dispatch (worst-case top_k fan-out; must be > 0). */ | ||
| int max_recv_tokens_per_rank; | ||
| int hidden_dim; /*!< Token hidden dimension. */ | ||
| int max_num_sms; /*!< Max SMs for EP kernels. 0 = auto. */ | ||
| /*! 0 (default): throw on relocated handle_mem for a cached handle_id. 1: silently rebuild. */ | ||
| int allow_handle_mem_reloc; | ||
| } NVTEEpGroupConfig; | ||
|
|
||
| /*! \brief Per-layer EP configuration. */ | ||
| typedef struct { | ||
| int num_local_experts; /*!< Reserved for ABI stability (derived from group config). */ | ||
| int top_k; /*!< Per-token expert fan-out. Required. */ | ||
| size_t dispatch_output_per_expert_alignment; | ||
| /*!< Per-expert zone alignment in tokens (pow2; 0/1 = no padding). Must match | ||
| * between nvte_ep_register_layer and nvte_ep_prepare. */ | ||
| } NVTEEpLayerConfig; |
There was a problem hiding this comment.
If we make this a public API then we should probably version those?
Summary
First PR in the TE Expert Parallelism (EP) series. Lands the common C API and NCCL EP backend that later framework PRs (PyTorch, JAX) build on. No Python bindings yet — common-lib foundation plus build wiring only. Build/load works on any arch; SM and NCCL version gates fire at runtime.
Every network-bound payload tensor takes an optional
NVTECommWindow. When the window is provided, the backend uses NCCL EP's symmetric-memory zero-copy path, which skips the D2D Memcpy from the user buffers to the Symmetric Staging Buffers.Implementation
Public C API (
transformer_engine/common/include/transformer_engine/{ep.h,comm_window.h})Types:
NVTEEpGroupConfig,NVTEEpLayerConfig,NVTEEpHandle,NVTECommWindow(side-band{ncclWindow_t window, size_t offset}; NCCL peer handles are not carried onNVTETensor).Lifecycle (host-only, eager):
nvte_ep_initialize— borrow an externalncclComm_tfor the EP sub-group and init the singleton backend.nvte_ep_shutdown— tear down the backend; idempotent; does not destroyep_comm.nvte_ep_register_layer— reserve ahandle_idfor a layer config and report thehandle_membuffer size the caller must allocate. The pair{id, mem}becomes the per-stepNVTEEpHandle.Per-step (allocation-free, CUDA-graph capturable)
nvte_ep_prepare— all-gather the routing map and write routing maps tohandle.mem.nvte_ep_dispatch— scatter tokens and routing weights from source ranks to expert ranks.tokens,topk_weights,recv_tokens,recv_topk_weightseach accept an optional symm-mem window for zero-copy.nvte_ep_combine— scatter-sum expert outputs back to source ranks (unweighted; caller pre-multiplies byrecv_topk_weights).expert_outaccepts a window.nvte_ep_dispatch_bwd— backward of dispatch; routes token and weight grads back to source.gradandg_recv_topk_weightsaccept windows; the gathered outputs (grad_tokens,grad_topk_weights).nvte_ep_combine_bwd— backward of combine;gradandgrad_expert_outaccept windows. Padded slots ingrad_expert_outare zeroed.Backend + build
transformer_engine/common/ep/):EPBackendsingleton, HT-mode dispatch/combine over NCCL EP (libnccl_ep.so), group/layer registration. Internal helpermake_payload_tensor()builds the per-callncclEpTensor_t: when the caller'sNVTECommWindow.window != nullptrit setswin_hdl+win_offset(zero-copy); otherwise it setsdatafromnvte_tensor_data(t)(HBM fallback).EPBackend::initialize): SM>=90 (viacudaDeviceGetAttribute), NCCL>=2.30.4 (viancclGetVersion), CUDA multicast/NVLS support.NVTE_WITH_NCCL_EP=OFF,ep/ep_api_stub.cppprovides throwingnvte_ep_*stubs so framework bindings link unconditionally; failure surfaces at firstnvte_ep_initialize.setup.pybuildslibnccl_ep.sofrom3rdparty/ncclby default; auto-disables NCCL EP when no requested CUDA arch >= 90. ExplicitNVTE_BUILD_WITH_NCCL_EP=1with all archs < 90 is treated as user errorNVTE_BUILD_WITH_NCCL_EP=0to opt out.NCCL_HOMEresolved dynamically: explicit env →/opt/nvidia/nccl,/usr/local/nccl,/usr→ldconfig -pfallback.Testing
tests/cpp_distributed/.Type of change
Checklist: