Skip to content

Feat/cp nvshmem enhanced#2737

Open
Knight-of-Thunder wants to merge 21 commits intoNVIDIA:mainfrom
ETOgaosion:feat/cp_nvshmem_enhanced
Open

Feat/cp nvshmem enhanced#2737
Knight-of-Thunder wants to merge 21 commits intoNVIDIA:mainfrom
ETOgaosion:feat/cp_nvshmem_enhanced

Conversation

@Knight-of-Thunder
Copy link

Description

To make the computation and communication become overlap, we create a new stream for communication.
To use NVSHMEM APIs easier, and make the code cleaner, we use NVSHMEM pybindings to replace original cpp code.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

ptrendx and others added 21 commits August 18, 2025 16:24
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
…e global_mesh_resource().fsdp_resource (NVIDIA#2088)

* Enforce global MeshResource is set

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* Use global_mesh_resource().fsdp_resource in gemm primitive

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* Update tests

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* Update gemm.py

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* Update test_layer.py

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

---------

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…A#2092)

Avoid garbage collection when capturing a CUDA Graph

Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Fix incorrect version checks for atomic GEMM

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix typo

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* added cp strategy arg to DPA api

Signed-off-by: Md Fahim Faysal Khan <mdfahimfaysa@nvidia.com>

* converted DPA cp_strategy to string

Signed-off-by: Md Fahim Faysal Khan <mdfahimfaysa@nvidia.com>

---------

Signed-off-by: Md Fahim Faysal Khan <mdfahimfaysa@nvidia.com>
* Return dummy wgrad tensors when requested by Mcore

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Apply suggestions from code review

Co-authored-by: Jan Bielak <janekb04@icloud.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Jan Bielak <janekb04@icloud.com>
* added shardy warning

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>


---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Revert "[Common] PDL for Quantization Kernels (NVIDIA#2001)"

This reverts commit bfab8c6.

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
* Bump cuDNN FE to 1.14.0

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Change submodule hash

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Pick up a cuDNN FE fix

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* New model configs in tests

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Exclude cuDNN backend for some configs

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

---------

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Revert "[Common] PDL for Blockwise Quantization (NVIDIA#2066)"

This reverts commit ebca615.

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…#2083)

* code drop

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

---------

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Pick up cuBLASMp during build

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* A fix

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* More test config

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Skip GemmAr test on unsupported HW

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Comment API

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update transformer_engine/common/include/transformer_engine/comm_gemm.h

Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

---------

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com>
…kv caching (NVIDIA#2121)

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
* disable determinism for sm100+ and cudnn<9.14

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix remaining CI failures

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert some changes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert more changes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm100 from determinism table

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
dev-base initialize: with version change and log verify
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 5, 2026

Greptile Summary

This PR introduces two ~4000-line Python modules that replace point-to-point distributed collectives with NVSHMEM-based one-sided get operations for KV exchange, adds nvshmem_get_on_current_stream C++ binding, and brings in a new comm_gemm module using cuBLASMp. Several ancillary changes update FlashAttention version caps and cuDNN thresholds.

Critical issues identified:

  1. context_parallel_nvshmem_enhanced.py: _store_fa_nvshmem is defined with 2 parameters but called with 3 at lines 932 and 2999 — this raises TypeError at runtime.

  2. context_parallel_nvshmem_enhanced.py: The backward pass uses bare variable cp_global_ranks at lines 1805 and 1835, which is never assigned locally; only ctx.cp_global_ranks is persisted. This raises NameError on backward.

  3. context_parallel_nvshmem_enhanced.py: NVSHMEM get operations (line 765) are issued on flash_attn_streams[i % 2], but data is consumed in the next iteration on flash_attn_streams[(i+1) % 2] — a different stream — with no cross-stream CUDA event synchronization. The send_recv_reqs list stays [[], []] on the NVSHMEM path, bypassing the intended wait. This risks silent data corruption.

  4. nvshmem_comm.cpp: The #else error message (line 100) incorrectly states the function "cannot be initialized when TE is compiled with NVTE_ENABLE_NVSHMEM=1" — the exact opposite of the true condition.

  5. context_parallel_nvshmem.py: torchrun_uid_init_bcast_object_no_reinit calls nvshmem.init() unconditionally on every forward pass (line 730); repeated initialization can cause undefined behavior.

  6. context_parallel_nvshmem.py: tensor_get_buffer is imported at line 14 but never used.

Confidence Score: 1/5

  • Critical — not safe to merge. Two new modules contain runtime crashes (TypeError, NameError) and a cross-stream data-race that produces silently incorrect results.
  • Three separate code paths will fail before end-to-end testing can pass: (1) wrong-arity call to _store_fa_nvshmem raises TypeError, (2) undefined cp_global_ranks in backward raises NameError, and (3) missing cross-stream synchronization between NVSHMEM get stream and FlashAttention compute stream creates a data-race that can silently corrupt gradients. Additionally, NVSHMEM is re-initialized on every forward pass with undefined behavior, and an error message is logically inverted.
  • context_parallel_nvshmem_enhanced.py (critical), context_parallel_nvshmem.py (critical), nvshmem_comm.cpp (logic fix required)

Last reviewed commit: f63d3c3

Comment on lines +699 to +720
def _store_fa_nvshmem(out_tensor, softmax_tensor):
# allocate symmetric tensors lazily and copy
if not causal:
return
try:
nvshmem_fa_out = [tex.create_nvshmem_tensor(list(out_tensor.shape), out_tensor.dtype) for _ in range(cp_size)]
nvshmem_fa_softmax_lse = [tex.create_nvshmem_tensor(
list(softmax_tensor.shape), softmax_tensor.dtype
) for _ in range(cp_size)]
except Exception:
nvshmem_fa_out = [None for _ in range(cp_size)]
nvshmem_fa_softmax_lse = [None for _ in range(cp_size)]
# if nvshmem_fa_out is not None:
# for idx in range(cp_size):
# if nvshmem_fa_out[idx] is not None:
# nvshmem_fa_out[idx].copy_(out_tensor)
# nvshmem_fa_softmax_lse[idx].copy_(softmax_tensor)

# create two streams to resolve wave quantization issue of Flash Attn in each step
flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
# synchronize fwd results correction across steps
fwd_results_correction_done = torch.cuda.Event()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function signature mismatch causes TypeError at runtime

The function _store_fa_nvshmem is defined with 2 parameters but called with 3 positional arguments at lines 932 and 2999:

# Definition (line 699): 2 params
def _store_fa_nvshmem(out_tensor, softmax_tensor):
    ...

# Call sites: 3 args
_store_fa_nvshmem(i, out_per_step[i], softmax_lse_per_step[i])  # lines 932, 2999

This will raise TypeError: _store_fa_nvshmem() takes 2 positional arguments but 3 were given at runtime.

The commented-out version at line 682 had the correct 3-parameter signature. Restore it with the idx parameter:

Suggested change
def _store_fa_nvshmem(out_tensor, softmax_tensor):
# allocate symmetric tensors lazily and copy
if not causal:
return
try:
nvshmem_fa_out = [tex.create_nvshmem_tensor(list(out_tensor.shape), out_tensor.dtype) for _ in range(cp_size)]
nvshmem_fa_softmax_lse = [tex.create_nvshmem_tensor(
list(softmax_tensor.shape), softmax_tensor.dtype
) for _ in range(cp_size)]
except Exception:
nvshmem_fa_out = [None for _ in range(cp_size)]
nvshmem_fa_softmax_lse = [None for _ in range(cp_size)]
# if nvshmem_fa_out is not None:
# for idx in range(cp_size):
# if nvshmem_fa_out[idx] is not None:
# nvshmem_fa_out[idx].copy_(out_tensor)
# nvshmem_fa_softmax_lse[idx].copy_(softmax_tensor)
# create two streams to resolve wave quantization issue of Flash Attn in each step
flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
# synchronize fwd results correction across steps
fwd_results_correction_done = torch.cuda.Event()
def _store_fa_nvshmem(idx, out_tensor, softmax_tensor):
# allocate symmetric tensors lazily and copy
if not causal:
return
try:
nvshmem_fa_out[idx] = tex.create_nvshmem_tensor(list(out_tensor.shape), out_tensor.dtype)
nvshmem_fa_softmax_lse[idx] = tex.create_nvshmem_tensor(
list(softmax_tensor.shape), softmax_tensor.dtype
)
except Exception:
nvshmem_fa_out[idx] = None
nvshmem_fa_softmax_lse[idx] = None
if nvshmem_fa_out[idx] is not None:
nvshmem_fa_out[idx].copy_(out_tensor)
nvshmem_fa_softmax_lse[idx].copy_(softmax_tensor)

Comment on lines +1590 to +1610
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
cp_size_a2a = ctx.cp_size_a2a
rank_a2a = ctx.rank_a2a

cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
device_compute_capability = get_device_compute_capability()
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (
device_compute_capability < (10, 0) and cp_size == 2
)

q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = (
restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
)
cu_seqlens_q_per_step = other_tensors[:cp_size]
cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2]
rng_states = other_tensors[cp_size * 2 : cp_size * 3]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Undefined variable cp_global_ranks in backward

The backward function accesses cp_global_ranks at lines 1805 and 1835 as a bare local variable, but this variable is never assigned locally. The code only saves and retrieves ctx.cp_global_ranks. At runtime, this will raise NameError: name 'cp_global_ranks' is not defined.

# Line 1598-1599: correct usage with ctx prefix
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]

# Line 1805, 1835: incorrect usage without ctx prefix
owner_global = cp_global_ranks[owner_idx * cp_size_a2a + rank_a2a]  # NameError

Add this assignment near the top of the backward method alongside the other context unpacking:

Suggested change
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
cp_size_a2a = ctx.cp_size_a2a
rank_a2a = ctx.rank_a2a
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
device_compute_capability = get_device_compute_capability()
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (
device_compute_capability < (10, 0) and cp_size == 2
)
q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = (
restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
)
cu_seqlens_q_per_step = other_tensors[:cp_size]
cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2]
rng_states = other_tensors[cp_size * 2 : cp_size * 3]
cp_size_a2a = ctx.cp_size_a2a
rank_a2a = ctx.rank_a2a
cp_global_ranks = ctx.cp_global_ranks

Comment on lines +98 to +102
#else
NVTE_ERROR(
"Internal TE error: nvshmem_get_on_current_stream cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inverted error message in #else branch

The #else branch executes when NVTE_ENABLE_NVSHMEM is not defined, but the error message incorrectly states the function "cannot be initialized… when TE is compiled with NVTE_ENABLE_NVSHMEM=1" — the exact opposite condition.

Suggested change
#else
NVTE_ERROR(
"Internal TE error: nvshmem_get_on_current_stream cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
NVTE_ERROR(
"Internal TE error: nvshmem_get_on_current_stream requires TE to be compiled with ",
"NVTE_ENABLE_NVSHMEM=1!");

Comment on lines +752 to +776
with torch.cuda.stream(flash_attn_streams[i % 2]):
# wait until KV is received
for req in send_recv_reqs[(i + 1) % 2]:
req.wait()

if i < (cp_size - 1):
p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
if nvshmem_kv is not None:
# Use NVSHMEM get: compute owner of the (i+1)-th step KV block
owner_idx = (rank - (i + 1)) % cp_size
# Map owner idx to global rank (accounting for a2a groups)
owner_global = cp_global_ranks[owner_idx * cp_size_a2a + rank_a2a]
# nvshmem_get: dst (local buffer), src (symmetric address), peer=owner_global
tex.nvshmem_get_on_current_stream(p2p_comm_buffers[i + 1], nvshmem_kv, int(owner_global))
else:
# fallback to P2P if NVSHMEM not available
send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
rank,
p2p_comm_buffers[i],
send_dst,
p2p_comm_buffers[i + 1],
recv_src,
cp_group,
batch_p2p_comm,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing cross-stream synchronization for NVSHMEM gets

In the forward loop, the NVSHMEM get for p2p_comm_buffers[i+1] is issued on stream flash_attn_streams[i % 2] (line 765). In the next iteration, the computation consumes p2p_comm_buffers[i+1] on stream flash_attn_streams[(i+1) % 2] — a different stream — with no event-based synchronization.

The existing P2P fallback updates send_recv_reqs with requests that are waited on (line 754-755), but when the NVSHMEM path is taken, send_recv_reqs[(i) % 2] remains empty, so the wait is a no-op:

# Line 765: NVSHMEM get on stream[i % 2]
tex.nvshmem_get_on_current_stream(p2p_comm_buffers[i + 1], nvshmem_kv, int(owner_global))
# send_recv_reqs[(i) % 2] is never populated with a request

# Next iteration (i+1):
# Line 754-755: wait is a no-op because send_recv_reqs[(i+1) % 2] is empty
for req in send_recv_reqs[(i + 1) % 2]:
    req.wait()
# Line 779-782: computation accesses p2p_comm_buffers[i+1] on stream[(i+1) % 2]

Without synchronization, the computation kernel on stream [(i+1) % 2] may observe stale data. Add CUDA event-based synchronization after the NVSHMEM get and before subsequent iterations use the buffer.

Comment on lines +429 to +465
def torchrun_uid_init_bcast_object_no_reinit(cp_group=None):
local_rank = torch.cuda.current_device()
dev = Device(local_rank)
dev.set_current()

if cp_group is None:
rank_id = dist.get_rank()
num_ranks = dist.get_world_size()
else:
rank_id = dist.get_rank(group=cp_group)
num_ranks = dist.get_world_size(group=cp_group)

uniqueid = nvshmem.get_unique_id(empty=True)

if rank_id == 0:
uniqueid = nvshmem.get_unique_id()
broadcast_objects = [uniqueid]
else:
broadcast_objects = [None]

dist.broadcast_object_list(
broadcast_objects,
src=0,
group=cp_group
)

dist.barrier(group=cp_group)

nvshmem.init(
device=dev,
uid=broadcast_objects[0],
rank=rank_id,
nranks=num_ranks,
initializer_method="uid"
)

return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVSHMEM initialized unconditionally on every forward pass

Despite its name torchrun_uid_init_bcast_object_no_reinit, this function calls nvshmem.init(...) unconditionally (line 457) with no guard against re-initialization. The function is invoked inside the forward method (line 730) on every forward pass. Calling nvshmem.init() multiple times may cause undefined behavior or errors depending on the NVSHMEM pybinding version.

Move NVSHMEM initialization to module-load time or wrap the call in a one-time initialization guard:

_nvshmem_initialized = False

def _maybe_init_nvshmem(cp_group):
    global _nvshmem_initialized
    if _nvshmem_initialized:
        return
    torchrun_uid_init_bcast_object_no_reinit(cp_group)
    _nvshmem_initialized = True

Then call _maybe_init_nvshmem(cp_group) instead of torchrun_uid_init_bcast_object_no_reinit(cp_group) at line 730.

import torch.distributed as dist
from cuda.core.experimental import Device
from cuda.core.experimental import Stream
from nvshmem.core.interop.torch import tensor_get_buffer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused import tensor_get_buffer

tensor_get_buffer is imported at line 14 but is not referenced anywhere else in the module. This creates an unnecessary dependency and will trigger linting warnings.

Remove this unused import:

Suggested change
from nvshmem.core.interop.torch import tensor_get_buffer
# Removed: from nvshmem.core.interop.torch import tensor_get_buffer

@sbhavani sbhavani added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Mar 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.