Conversation
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
…e global_mesh_resource().fsdp_resource (NVIDIA#2088) * Enforce global MeshResource is set Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Use global_mesh_resource().fsdp_resource in gemm primitive Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Update tests Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Update gemm.py Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Update test_layer.py Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…A#2092) Avoid garbage collection when capturing a CUDA Graph Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Fix incorrect version checks for atomic GEMM Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fix typo Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* added cp strategy arg to DPA api Signed-off-by: Md Fahim Faysal Khan <mdfahimfaysa@nvidia.com> * converted DPA cp_strategy to string Signed-off-by: Md Fahim Faysal Khan <mdfahimfaysa@nvidia.com> --------- Signed-off-by: Md Fahim Faysal Khan <mdfahimfaysa@nvidia.com>
* Return dummy wgrad tensors when requested by Mcore Signed-off-by: Tim Moon <tmoon@nvidia.com> * Apply suggestions from code review Co-authored-by: Jan Bielak <janekb04@icloud.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Jan Bielak <janekb04@icloud.com>
* added shardy warning Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Revert "[Common] PDL for Quantization Kernels (NVIDIA#2001)" This reverts commit bfab8c6. Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
* Bump cuDNN FE to 1.14.0 Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Change submodule hash Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Pick up a cuDNN FE fix Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * New model configs in tests Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Exclude cuDNN backend for some configs Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> --------- Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Revert "[Common] PDL for Blockwise Quantization (NVIDIA#2066)" This reverts commit ebca615. Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…#2083) * code drop Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> --------- Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Saving... Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Saving... Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Test fixure Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Saving... Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Saving... Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Fix axes Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Refactor Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Refactor & fixes Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Saving... Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Gemm-RS Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Fixes Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Tweak tolerance Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * First shot at fp8 Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * More test configs Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Support comm_sm_count Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Tweak scaling Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Amax ptr Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Cleanup Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Bias tests Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Fix bias test Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Aux, saving... Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * aux_ld Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * A fix Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Use test::Tensor Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Set scale inv Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Tweak tests Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * More test config Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Fix merge fallout Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Fix nvshmem build Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Remove leftover code Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Remove now unused argument Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com> Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> --------- Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov <vcherepanov@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com>
…kv caching (NVIDIA#2121) Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
* disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
dev-base initialize: with version change and log verify
feat: pure nvshmem comm based cp
… replace natural cpp nvshmem code
Greptile SummaryThis PR introduces two ~4000-line Python modules that replace point-to-point distributed collectives with NVSHMEM-based one-sided get operations for KV exchange, adds Critical issues identified:
Confidence Score: 1/5
Last reviewed commit: f63d3c3 |
| def _store_fa_nvshmem(out_tensor, softmax_tensor): | ||
| # allocate symmetric tensors lazily and copy | ||
| if not causal: | ||
| return | ||
| try: | ||
| nvshmem_fa_out = [tex.create_nvshmem_tensor(list(out_tensor.shape), out_tensor.dtype) for _ in range(cp_size)] | ||
| nvshmem_fa_softmax_lse = [tex.create_nvshmem_tensor( | ||
| list(softmax_tensor.shape), softmax_tensor.dtype | ||
| ) for _ in range(cp_size)] | ||
| except Exception: | ||
| nvshmem_fa_out = [None for _ in range(cp_size)] | ||
| nvshmem_fa_softmax_lse = [None for _ in range(cp_size)] | ||
| # if nvshmem_fa_out is not None: | ||
| # for idx in range(cp_size): | ||
| # if nvshmem_fa_out[idx] is not None: | ||
| # nvshmem_fa_out[idx].copy_(out_tensor) | ||
| # nvshmem_fa_softmax_lse[idx].copy_(softmax_tensor) | ||
|
|
||
| # create two streams to resolve wave quantization issue of Flash Attn in each step | ||
| flash_attn_streams = [torch.cuda.current_stream(), cp_stream] | ||
| # synchronize fwd results correction across steps | ||
| fwd_results_correction_done = torch.cuda.Event() |
There was a problem hiding this comment.
Function signature mismatch causes TypeError at runtime
The function _store_fa_nvshmem is defined with 2 parameters but called with 3 positional arguments at lines 932 and 2999:
# Definition (line 699): 2 params
def _store_fa_nvshmem(out_tensor, softmax_tensor):
...
# Call sites: 3 args
_store_fa_nvshmem(i, out_per_step[i], softmax_lse_per_step[i]) # lines 932, 2999This will raise TypeError: _store_fa_nvshmem() takes 2 positional arguments but 3 were given at runtime.
The commented-out version at line 682 had the correct 3-parameter signature. Restore it with the idx parameter:
| def _store_fa_nvshmem(out_tensor, softmax_tensor): | |
| # allocate symmetric tensors lazily and copy | |
| if not causal: | |
| return | |
| try: | |
| nvshmem_fa_out = [tex.create_nvshmem_tensor(list(out_tensor.shape), out_tensor.dtype) for _ in range(cp_size)] | |
| nvshmem_fa_softmax_lse = [tex.create_nvshmem_tensor( | |
| list(softmax_tensor.shape), softmax_tensor.dtype | |
| ) for _ in range(cp_size)] | |
| except Exception: | |
| nvshmem_fa_out = [None for _ in range(cp_size)] | |
| nvshmem_fa_softmax_lse = [None for _ in range(cp_size)] | |
| # if nvshmem_fa_out is not None: | |
| # for idx in range(cp_size): | |
| # if nvshmem_fa_out[idx] is not None: | |
| # nvshmem_fa_out[idx].copy_(out_tensor) | |
| # nvshmem_fa_softmax_lse[idx].copy_(softmax_tensor) | |
| # create two streams to resolve wave quantization issue of Flash Attn in each step | |
| flash_attn_streams = [torch.cuda.current_stream(), cp_stream] | |
| # synchronize fwd results correction across steps | |
| fwd_results_correction_done = torch.cuda.Event() | |
| def _store_fa_nvshmem(idx, out_tensor, softmax_tensor): | |
| # allocate symmetric tensors lazily and copy | |
| if not causal: | |
| return | |
| try: | |
| nvshmem_fa_out[idx] = tex.create_nvshmem_tensor(list(out_tensor.shape), out_tensor.dtype) | |
| nvshmem_fa_softmax_lse[idx] = tex.create_nvshmem_tensor( | |
| list(softmax_tensor.shape), softmax_tensor.dtype | |
| ) | |
| except Exception: | |
| nvshmem_fa_out[idx] = None | |
| nvshmem_fa_softmax_lse[idx] = None | |
| if nvshmem_fa_out[idx] is not None: | |
| nvshmem_fa_out[idx].copy_(out_tensor) | |
| nvshmem_fa_softmax_lse[idx].copy_(softmax_tensor) |
| def backward(ctx, dout): | ||
| # pylint: disable=missing-function-docstring | ||
| nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward") | ||
| cp_size_a2a = ctx.cp_size_a2a | ||
| rank_a2a = ctx.rank_a2a | ||
|
|
||
| cp_size = get_distributed_world_size(ctx.cp_group) | ||
| rank = get_distributed_rank(ctx.cp_group) | ||
| send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] | ||
| recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] | ||
| device_compute_capability = get_device_compute_capability() | ||
| batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or ( | ||
| device_compute_capability < (10, 0) and cp_size == 2 | ||
| ) | ||
|
|
||
| q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( | ||
| restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) | ||
| ) | ||
| cu_seqlens_q_per_step = other_tensors[:cp_size] | ||
| cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] | ||
| rng_states = other_tensors[cp_size * 2 : cp_size * 3] |
There was a problem hiding this comment.
Undefined variable cp_global_ranks in backward
The backward function accesses cp_global_ranks at lines 1805 and 1835 as a bare local variable, but this variable is never assigned locally. The code only saves and retrieves ctx.cp_global_ranks. At runtime, this will raise NameError: name 'cp_global_ranks' is not defined.
# Line 1598-1599: correct usage with ctx prefix
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
# Line 1805, 1835: incorrect usage without ctx prefix
owner_global = cp_global_ranks[owner_idx * cp_size_a2a + rank_a2a] # NameErrorAdd this assignment near the top of the backward method alongside the other context unpacking:
| def backward(ctx, dout): | |
| # pylint: disable=missing-function-docstring | |
| nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward") | |
| cp_size_a2a = ctx.cp_size_a2a | |
| rank_a2a = ctx.rank_a2a | |
| cp_size = get_distributed_world_size(ctx.cp_group) | |
| rank = get_distributed_rank(ctx.cp_group) | |
| send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] | |
| recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] | |
| device_compute_capability = get_device_compute_capability() | |
| batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or ( | |
| device_compute_capability < (10, 0) and cp_size == 2 | |
| ) | |
| q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( | |
| restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) | |
| ) | |
| cu_seqlens_q_per_step = other_tensors[:cp_size] | |
| cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] | |
| rng_states = other_tensors[cp_size * 2 : cp_size * 3] | |
| cp_size_a2a = ctx.cp_size_a2a | |
| rank_a2a = ctx.rank_a2a | |
| cp_global_ranks = ctx.cp_global_ranks |
| #else | ||
| NVTE_ERROR( | ||
| "Internal TE error: nvshmem_get_on_current_stream cannot be initialized with valid PyTorch ", | ||
| "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); | ||
| #endif |
There was a problem hiding this comment.
Inverted error message in #else branch
The #else branch executes when NVTE_ENABLE_NVSHMEM is not defined, but the error message incorrectly states the function "cannot be initialized… when TE is compiled with NVTE_ENABLE_NVSHMEM=1" — the exact opposite condition.
| #else | |
| NVTE_ERROR( | |
| "Internal TE error: nvshmem_get_on_current_stream cannot be initialized with valid PyTorch ", | |
| "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); | |
| #endif | |
| NVTE_ERROR( | |
| "Internal TE error: nvshmem_get_on_current_stream requires TE to be compiled with ", | |
| "NVTE_ENABLE_NVSHMEM=1!"); |
| with torch.cuda.stream(flash_attn_streams[i % 2]): | ||
| # wait until KV is received | ||
| for req in send_recv_reqs[(i + 1) % 2]: | ||
| req.wait() | ||
|
|
||
| if i < (cp_size - 1): | ||
| p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i]) | ||
| if nvshmem_kv is not None: | ||
| # Use NVSHMEM get: compute owner of the (i+1)-th step KV block | ||
| owner_idx = (rank - (i + 1)) % cp_size | ||
| # Map owner idx to global rank (accounting for a2a groups) | ||
| owner_global = cp_global_ranks[owner_idx * cp_size_a2a + rank_a2a] | ||
| # nvshmem_get: dst (local buffer), src (symmetric address), peer=owner_global | ||
| tex.nvshmem_get_on_current_stream(p2p_comm_buffers[i + 1], nvshmem_kv, int(owner_global)) | ||
| else: | ||
| # fallback to P2P if NVSHMEM not available | ||
| send_recv_reqs[i % 2] = flash_attn_p2p_communicate( | ||
| rank, | ||
| p2p_comm_buffers[i], | ||
| send_dst, | ||
| p2p_comm_buffers[i + 1], | ||
| recv_src, | ||
| cp_group, | ||
| batch_p2p_comm, | ||
| ) |
There was a problem hiding this comment.
Missing cross-stream synchronization for NVSHMEM gets
In the forward loop, the NVSHMEM get for p2p_comm_buffers[i+1] is issued on stream flash_attn_streams[i % 2] (line 765). In the next iteration, the computation consumes p2p_comm_buffers[i+1] on stream flash_attn_streams[(i+1) % 2] — a different stream — with no event-based synchronization.
The existing P2P fallback updates send_recv_reqs with requests that are waited on (line 754-755), but when the NVSHMEM path is taken, send_recv_reqs[(i) % 2] remains empty, so the wait is a no-op:
# Line 765: NVSHMEM get on stream[i % 2]
tex.nvshmem_get_on_current_stream(p2p_comm_buffers[i + 1], nvshmem_kv, int(owner_global))
# send_recv_reqs[(i) % 2] is never populated with a request
# Next iteration (i+1):
# Line 754-755: wait is a no-op because send_recv_reqs[(i+1) % 2] is empty
for req in send_recv_reqs[(i + 1) % 2]:
req.wait()
# Line 779-782: computation accesses p2p_comm_buffers[i+1] on stream[(i+1) % 2]Without synchronization, the computation kernel on stream [(i+1) % 2] may observe stale data. Add CUDA event-based synchronization after the NVSHMEM get and before subsequent iterations use the buffer.
| def torchrun_uid_init_bcast_object_no_reinit(cp_group=None): | ||
| local_rank = torch.cuda.current_device() | ||
| dev = Device(local_rank) | ||
| dev.set_current() | ||
|
|
||
| if cp_group is None: | ||
| rank_id = dist.get_rank() | ||
| num_ranks = dist.get_world_size() | ||
| else: | ||
| rank_id = dist.get_rank(group=cp_group) | ||
| num_ranks = dist.get_world_size(group=cp_group) | ||
|
|
||
| uniqueid = nvshmem.get_unique_id(empty=True) | ||
|
|
||
| if rank_id == 0: | ||
| uniqueid = nvshmem.get_unique_id() | ||
| broadcast_objects = [uniqueid] | ||
| else: | ||
| broadcast_objects = [None] | ||
|
|
||
| dist.broadcast_object_list( | ||
| broadcast_objects, | ||
| src=0, | ||
| group=cp_group | ||
| ) | ||
|
|
||
| dist.barrier(group=cp_group) | ||
|
|
||
| nvshmem.init( | ||
| device=dev, | ||
| uid=broadcast_objects[0], | ||
| rank=rank_id, | ||
| nranks=num_ranks, | ||
| initializer_method="uid" | ||
| ) | ||
|
|
||
| return True |
There was a problem hiding this comment.
NVSHMEM initialized unconditionally on every forward pass
Despite its name torchrun_uid_init_bcast_object_no_reinit, this function calls nvshmem.init(...) unconditionally (line 457) with no guard against re-initialization. The function is invoked inside the forward method (line 730) on every forward pass. Calling nvshmem.init() multiple times may cause undefined behavior or errors depending on the NVSHMEM pybinding version.
Move NVSHMEM initialization to module-load time or wrap the call in a one-time initialization guard:
_nvshmem_initialized = False
def _maybe_init_nvshmem(cp_group):
global _nvshmem_initialized
if _nvshmem_initialized:
return
torchrun_uid_init_bcast_object_no_reinit(cp_group)
_nvshmem_initialized = TrueThen call _maybe_init_nvshmem(cp_group) instead of torchrun_uid_init_bcast_object_no_reinit(cp_group) at line 730.
| import torch.distributed as dist | ||
| from cuda.core.experimental import Device | ||
| from cuda.core.experimental import Stream | ||
| from nvshmem.core.interop.torch import tensor_get_buffer |
There was a problem hiding this comment.
Unused import tensor_get_buffer
tensor_get_buffer is imported at line 14 but is not referenced anywhere else in the module. This creates an unnecessary dependency and will trigger linting warnings.
Remove this unused import:
| from nvshmem.core.interop.torch import tensor_get_buffer | |
| # Removed: from nvshmem.core.interop.torch import tensor_get_buffer |
Description
To make the computation and communication become overlap, we create a new stream for communication.
To use NVSHMEM APIs easier, and make the code cleaner, we use NVSHMEM pybindings to replace original cpp code.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: