Skip to content

[PyTorch] Allocate grouped linear wgrads as tensor views#3049

Open
timmoon10 wants to merge 2 commits into
NVIDIA:mainfrom
timmoon10:tmoon/debug-bulk-allocate
Open

[PyTorch] Allocate grouped linear wgrads as tensor views#3049
timmoon10 wants to merge 2 commits into
NVIDIA:mainfrom
timmoon10:tmoon/debug-bulk-allocate

Conversation

@timmoon10
Copy link
Copy Markdown
Member

Description

Under certain configurations, #2900 has introduced race conditions with the grouped linear weight grads. This is because tex.bulk_allocate constructs tensor views out of raw pointers, bypassing the stream synchronization logic in standard tensor views. This PR fixes this issue by constructing weight grad tensors as plain tensor views.

A simple benchmark on a GB200 node finds that allocating 64 tensor views takes 50 us while tex.bulk_allocate takes 58 us. Weight grads are uniform, so they don't need all of the reshaping and dtype conversion logic in tex.bulk_allocate.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Allocate grouped linear wgrads as tensor views

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 requested a review from vthumbe1503 May 28, 2026 00:21
@timmoon10 timmoon10 requested a review from ksivaman as a code owner May 28, 2026 00:21
@timmoon10 timmoon10 added bug Something isn't working cpu_overhead labels May 28, 2026
@timmoon10 timmoon10 requested a review from zhongbozhu May 28, 2026 00:21
@timmoon10
Copy link
Copy Markdown
Member Author

/te-ci pytorch

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 28, 2026

Greptile Summary

This PR fixes a race condition introduced by PR #2900 in grouped linear backward passes. tex.bulk_allocate constructed tensors via raw pointers (at::from_blob), bypassing PyTorch's stream-usage recording, which could cause multi-stream data hazards on the weight gradient buffers.

  • Root cause addressed: replaces tex.bulk_allocate with torch.empty(num_groups, *weight_shape, ...) followed by index-based slice views in all three backward paths (module/grouped_linear.py, ops/basic/grouped_linear.py, ops/fused/backward_grouped_mlp.py), so PyTorch's allocator properly tracks stream usage.
  • Clarifying comment added: allocate.cpp now documents that bulk_allocate bypasses stream synchronization and should be used with caution, which correctly reflects the remaining callers' responsibility.
  • Trade-off: the explicit 256-byte per-tensor alignment that tex.bulk_allocate enforced is no longer guaranteed for sub-tensor views; alignment now depends on whether weight_rows * weight_cols * element_size is a multiple of 256, which holds for typical large matrices but is not asserted anywhere.

Confidence Score: 4/5

Safe to merge; the core fix correctly replaces stream-unsafe raw-pointer tensor construction with standard PyTorch views, directly eliminating the documented race condition without introducing new correctness issues.

The three backward paths are changed consistently and the new pattern already appears in the non-race-condition paths of the same files, so the approach is well-validated within the codebase. The only open questions — implicit 256-byte per-sub-tensor alignment and the unguarded uniform-weight-shape assumption — do not affect correctness for current usage.

The three Python backward files deserve a second look to confirm the alignment and uniform-shape assumptions hold for all models in the project's test suite.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/allocate.cpp Adds a descriptive comment to bulk_allocate warning about the raw-pointer stream-sync bypass; no functional changes to the C++ code.
transformer_engine/pytorch/module/grouped_linear.py Replaces tex.bulk_allocate with a standard torch.empty packed buffer and index views; fixes the race condition but drops the explicit 256-byte per-tensor alignment.
transformer_engine/pytorch/ops/basic/grouped_linear.py Same tex.bulk_allocate → torch.empty + slice-view pattern; uses pre-computed grouped_shape consistently.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Applies the same fix in the fused MLP backward path; the wgrad_packed local correctly keeps storage alive via the view references in w_list.

Sequence Diagram

sequenceDiagram
    participant BW as Backward Pass
    participant OLD as tex.bulk_allocate (old)
    participant NEW as torch.empty (new)
    participant CUDA as CUDA Allocator
    participant GEMM as GEMM Kernel

    Note over BW,GEMM: OLD path (race condition possible)
    BW->>OLD: "bulk_allocate([shape]*N, dtype, device, [256]*N)"
    OLD->>CUDA: at::empty(raw_bytes) [stream NOT recorded]
    OLD->>OLD: at::from_blob(ptr+offset, shape) [bypasses stream tracking]
    OLD-->>BW: list of N tensors
    BW->>GEMM: wgrad GEMM on stream S1
    Note over GEMM,CUDA: Another stream S2 may read stale data — race!

    Note over BW,GEMM: NEW path (race condition fixed)
    BW->>NEW: "torch.empty(N, *weight_shape, dtype, device)"
    NEW->>CUDA: standard allocation [stream S1 recorded]
    NEW-->>BW: "packed tensor (N, *weight_shape)"
    BW->>BW: "wgrad_list = [packed[i] for i in range(N)]"
    BW->>GEMM: wgrad GEMM on stream S1
    Note over GEMM,CUDA: Stream recorded → sync enforced automatically
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/grouped_linear.py
Comment on lines +1396 to +1401
grad_weights_packed = torch.empty(
grouped_shape,
dtype=ctx.dtype,
device=device,
)
grad_weights = [grad_weights_packed[i] for i in range(num_groups)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Per-tensor 256-byte alignment no longer guaranteed

tex.bulk_allocate(..., [256] * num_groups) guaranteed each wgrad tensor started on a 256-byte boundary. With torch.empty(grouped_shape, ...), PyTorch aligns the base allocation but sub-tensors grad_weights_packed[i] land at byte offset i * out_features * in_features * element_size. That offset is a multiple of 256 bytes only when out_features * in_features * element_size % 256 == 0. For typical large weight matrices this holds, but it is not enforced. If a future model uses an unusual hidden dimension the misalignment could cause a silent performance regression in the GEMM kernel. The same pattern appears in backward_grouped_mlp.py and module/grouped_linear.py.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true, there should be some check added for the sizes to be right.

Copy link
Copy Markdown
Member

@cspades cspades left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.16.0 bug Something isn't working cpu_overhead

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants