Describe the bug
When calling nkc.collectives.all_reduce (on either nl.sbuf or nl.shared_hbm tensors) inside an @nki.jit kernel traced via neuronx_distributed.trace.parallel_model_trace(tp_degree=2) with NEURON_LOGICAL_NC_CONFIG=2 (LNC=2, the default for trn2), the compiler crashes with:
[INTERNAL_ERROR] [NCC_ILLC059] Could not find MemoryLocation named
inst____mp_main__.<kernel_name>-0:<buffer_name>.N on core 1
Rank 0's NEFF compiles successfully; the crash occurs while compiling Rank 1's NEFF. The pattern is identical for nl.sbuf sources and nl.shared_hbm sources, and reproduces at tp_degree=2 and tp_degree=4.
Model Name
First surfaced when optimizing Qwen3-MoE (Qwen3-30B-A3B) for the MLSys competition.
Describe the workload type
Inference (Token Generation). Attempting to fuse the TP all-reduce CC operation directly into a NKI kernel to avoid HBM materialization before the CC op. The target integration is NeuronQwen3MoeForCausalLM with tp_degree=4, NEURON_LOGICAL_NC_CONFIG=2.
Instance Type
trn2.3xlarge
Release version
aws-neuronx-collectives/unknown,now 2.30.59.0-f5cdefb39 amd64 [installed]
aws-neuronx-dkms/unknown,now 2.26.10.0 all [installed]
aws-neuronx-oci-hook/unknown,now 2.14.102.0 amd64 [installed]
aws-neuronx-runtime-lib/unknown,now 2.30.51.0-faafe26f0 amd64 [installed]
aws-neuronx-tools/unknown,now 2.28.23.0-f1c114a9d amd64 [installed]
libneuronxla 2.2.15515.0+50c26cbd
neuronx-cc 2.23.6484.0+3b612583
neuronx-distributed 0.17.26814+4b18de63
neuronx-distributed-inference 0.8.16251+f3ca5575
torch 2.9.0
torch-neuronx 2.9.0.2.12.22436+0f1dac25
torch-xla 2.9.0
torchvision 0.24.0
transformers 4.57.6
Reproduction Steps
Minimal self-contained reproducer — save as test_ncc_illc059.py and run on a trn2 instance:
"""
Minimal reproducer for NCC_ILLC059 with nkc.collectives.all_reduce + LNC=2 + TP=2.
Expected: neuronx-cc fails with NCC_ILLC059 when compiling rank 1.
"""
import os, sys
os.environ["NEURON_PLATFORM_TARGET_OVERRIDE"] = "trn2"
os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2"
import torch
import torch.nn as nn
import nki
import nki.isa as nisa
import nki.language as nl
import nki.collectives as nkc
from neuronx_distributed.trace import parallel_model_trace
import tempfile
HIDDEN = 2048
PAR, FREE = 128, 16 # HIDDEN = PAR * FREE
@nki.jit
def sbuf_allreduce_tp2(x):
"""SBUF all-reduce across TP=2 ranks."""
P, F = x.shape
sbuf_in = nl.ndarray((P, F), dtype=x.dtype, buffer=nl.sbuf)
sbuf_out = nl.ndarray((P, F), dtype=x.dtype, buffer=nl.sbuf)
nisa.dma_copy(dst=sbuf_in, src=x)
nkc.all_reduce(
srcs=[sbuf_in], dsts=[sbuf_out],
replica_group=nkc.ReplicaGroup([[0, 1]]),
op=nl.add,
)
out = nl.ndarray((P, F), dtype=x.dtype, buffer=nl.shared_hbm)
nisa.dma_copy(dst=out, src=sbuf_out)
return out
class AllReduceModule(nn.Module):
def forward(self, x):
return sbuf_allreduce_tp2(x.reshape(PAR, FREE)).reshape(1, HIDDEN)
def factory():
m = AllReduceModule()
m.eval()
return m, {}
x = torch.randn(1, HIDDEN, dtype=torch.bfloat16)
with tempfile.TemporaryDirectory() as d:
try:
parallel_model_trace(
factory, (x,),
tp_degree=2,
compiler_workdir=d,
compiler_args="--target=trn2 --model-type transformer -O1",
inline_weights_to_neff=True,
)
print("PASS (unexpected)")
except Exception as e:
err = str(e)
if "NCC_ILLC059" in err or "neuronx-cc failed with 70" in err:
print("FAIL: NCC_ILLC059 reproduced")
else:
print(f"FAIL (different error): {err[:300]}")
Run with:
source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate
python test_ncc_illc059.py
Expected output:
.
Compiler status PASS # rank 0 compiles OK
[INTERNAL_ERROR] [NCC_ILLC059] Could not find MemoryLocation named
inst____mp_main__.sbuf_allreduce_tp2-0:sbuf_in.23 on core 1
FAIL: NCC_ILLC059 reproduced
All three variants reproduce the same NCC_ILLC059:
| Variant |
Buffer type |
tp_degree |
Result |
| SBUF all-reduce |
nl.sbuf |
2 |
NCC_ILLC059 on rank 1 |
| HBM all-reduce |
nl.shared_hbm |
2 |
NCC_ILLC059 on rank 1 |
| SBUF all-reduce |
nl.sbuf |
4 |
NCC_ILLC059 on ranks 1–3 |
In all cases rank 0 compiles successfully and ranks 1+ fail.
Regression Issue
Possible Solution
No response
Logs/Context/Additional Information
Full compiler error (appears for each failing rank):
[INTERNAL_ERROR] [NCC_ILLC059] Could not find MemoryLocation named
inst____mp_main__.sbuf_allreduce_tp2-0:sbuf_in.23 on core 1 -
Please open a support ticket at https://github.com/aws-neuron/aws-neuron-sdk/issues/new.
You may also be able to obtain more information using the 'XLA_IR_DEBUG'
and 'XLA_HLO_DEBUG' environment variables.
Non-signal exit. Backend exited with code 1
- Rank 0 always compiles successfully; failure is always on Rank 1 (or ranks 1–N-1 for TP=N).
NCC_ILLC059 is flagged as an internal compiler error, not a user input validation error.
- Reproduces with both
-O1 and -O3 compiler flags.
- Reproduces with both
nl.sbuf and nl.shared_hbm buffer types.
- The
nki.collectives module documentation states SBUF tensors are supported for all_reduce (single tensor, non-coalesced), so this appears to be an unimplemented path for the LNC=2 multi-rank case.
Describe the bug
When calling
nkc.collectives.all_reduce(on eithernl.sbufornl.shared_hbmtensors) inside an@nki.jitkernel traced vianeuronx_distributed.trace.parallel_model_trace(tp_degree=2)withNEURON_LOGICAL_NC_CONFIG=2(LNC=2, the default for trn2), the compiler crashes with:Rank 0's NEFF compiles successfully; the crash occurs while compiling Rank 1's NEFF. The pattern is identical for
nl.sbufsources andnl.shared_hbmsources, and reproduces attp_degree=2andtp_degree=4.Model Name
First surfaced when optimizing Qwen3-MoE (Qwen3-30B-A3B) for the MLSys competition.
Describe the workload type
Inference (Token Generation). Attempting to fuse the TP all-reduce CC operation directly into a NKI kernel to avoid HBM materialization before the CC op. The target integration is
NeuronQwen3MoeForCausalLMwithtp_degree=4,NEURON_LOGICAL_NC_CONFIG=2.Instance Type
trn2.3xlargeRelease version
Reproduction Steps
Minimal self-contained reproducer — save as
test_ncc_illc059.pyand run on a trn2 instance:Run with:
source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate python test_ncc_illc059.pyExpected output:
All three variants reproduce the same NCC_ILLC059:
nl.sbufNCC_ILLC059on rank 1nl.shared_hbmNCC_ILLC059on rank 1nl.sbufNCC_ILLC059on ranks 1–3In all cases rank 0 compiles successfully and ranks 1+ fail.
Regression Issue
Possible Solution
No response
Logs/Context/Additional Information
Full compiler error (appears for each failing rank):
NCC_ILLC059is flagged as an internal compiler error, not a user input validation error.-O1and-O3compiler flags.nl.sbufandnl.shared_hbmbuffer types.nki.collectivesmodule documentation states SBUF tensors are supported forall_reduce(single tensor, non-coalesced), so this appears to be an unimplemented path for the LNC=2 multi-rank case.