[PyTorch] Allocate grouped linear wgrads as tensor views#3049
[PyTorch] Allocate grouped linear wgrads as tensor views#3049timmoon10 wants to merge 2 commits into
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
Greptile SummaryThis PR fixes a race condition introduced by PR #2900 in grouped linear backward passes.
Confidence Score: 4/5Safe to merge; the core fix correctly replaces stream-unsafe raw-pointer tensor construction with standard PyTorch views, directly eliminating the documented race condition without introducing new correctness issues. The three backward paths are changed consistently and the new pattern already appears in the non-race-condition paths of the same files, so the approach is well-validated within the codebase. The only open questions — implicit 256-byte per-sub-tensor alignment and the unguarded uniform-weight-shape assumption — do not affect correctness for current usage. The three Python backward files deserve a second look to confirm the alignment and uniform-shape assumptions hold for all models in the project's test suite. Important Files Changed
Sequence DiagramsequenceDiagram
participant BW as Backward Pass
participant OLD as tex.bulk_allocate (old)
participant NEW as torch.empty (new)
participant CUDA as CUDA Allocator
participant GEMM as GEMM Kernel
Note over BW,GEMM: OLD path (race condition possible)
BW->>OLD: "bulk_allocate([shape]*N, dtype, device, [256]*N)"
OLD->>CUDA: at::empty(raw_bytes) [stream NOT recorded]
OLD->>OLD: at::from_blob(ptr+offset, shape) [bypasses stream tracking]
OLD-->>BW: list of N tensors
BW->>GEMM: wgrad GEMM on stream S1
Note over GEMM,CUDA: Another stream S2 may read stale data — race!
Note over BW,GEMM: NEW path (race condition fixed)
BW->>NEW: "torch.empty(N, *weight_shape, dtype, device)"
NEW->>CUDA: standard allocation [stream S1 recorded]
NEW-->>BW: "packed tensor (N, *weight_shape)"
BW->>BW: "wgrad_list = [packed[i] for i in range(N)]"
BW->>GEMM: wgrad GEMM on stream S1
Note over GEMM,CUDA: Stream recorded → sync enforced automatically
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| grad_weights_packed = torch.empty( | ||
| grouped_shape, | ||
| dtype=ctx.dtype, | ||
| device=device, | ||
| ) | ||
| grad_weights = [grad_weights_packed[i] for i in range(num_groups)] |
There was a problem hiding this comment.
Per-tensor 256-byte alignment no longer guaranteed
tex.bulk_allocate(..., [256] * num_groups) guaranteed each wgrad tensor started on a 256-byte boundary. With torch.empty(grouped_shape, ...), PyTorch aligns the base allocation but sub-tensors grad_weights_packed[i] land at byte offset i * out_features * in_features * element_size. That offset is a multiple of 256 bytes only when out_features * in_features * element_size % 256 == 0. For typical large weight matrices this holds, but it is not enforced. If a future model uses an unusual hidden dimension the misalignment could cause a silent performance regression in the GEMM kernel. The same pattern appears in backward_grouped_mlp.py and module/grouped_linear.py.
There was a problem hiding this comment.
That is true, there should be some check added for the sizes to be right.
Description
Under certain configurations, #2900 has introduced race conditions with the grouped linear weight grads. This is because
tex.bulk_allocateconstructs tensor views out of raw pointers, bypassing the stream synchronization logic in standard tensor views. This PR fixes this issue by constructing weight grad tensors as plain tensor views.A simple benchmark on a GB200 node finds that allocating 64 tensor views takes 50 us while
tex.bulk_allocatetakes 58 us. Weight grads are uniform, so they don't need all of the reshaping and dtype conversion logic intex.bulk_allocate.Type of change
Changes
Checklist: