Skip to content

NCC_ITEN404: 7x7 Conv2d backward fails on trn1 (BirCodeGenLoop partition alignment) #1295

@JunjieTang-D1

Description

@JunjieTang-D1

Environment

  • Instance: trn1.32xlarge
  • Neuron SDK: 2.23 (neuronx-cc 2.23.6484.0+3b612583)
  • PyTorch: 2.8 (torch-neuronx, /opt/aws_neuronx_venv_pytorch_2_8)
  • OS: Ubuntu 22.04 (Deep Learning AMI)

Summary

The backward pass of a 7x7 Conv2d layer (ResNet-34 stem) fails with NCC_ITEN404 internal compiler error during NKI kernel code generation. All 3x3, 1x1, and strided convolutions compile successfully — only the 7x7 kernel triggers this error.

Minimal Reproducer

import torch, torch.nn as nn
import torch_xla.core.xla_model as xm

device = xm.xla_device()

# This FAILS:
model = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(),
).to(device).train()

x = torch.randn(1, 3, 256, 1024, device=device, requires_grad=True)
out = model(x)
loss = out.sum()
loss.backward()  # <-- NCC_ITEN404 here
xm.mark_step()
# This PASSES (same structure, 3x3 kernel):
model = nn.Sequential(
    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(),
).to(device).train()

x = torch.randn(1, 64, 128, 512, device=device, requires_grad=True)
out = model(x)
loss = out.sum()
loss.backward()  # <-- PASSES
xm.mark_step()

Error Message

[INTERNAL_ERROR] [NCC_ITEN404] Internal tensorizer error: BirCodeGenLoop:
tensorcopy src start_partition(0) or dst start_partition(i_3) is not multiple of partitions_per_bank (32).
tensorcopy: float32<1 x 12288> TongaSB partitions[1] float32 (2, 128, 21609)
%'conv2d_dw_fb01_io01_01bf_rep_nhwc_Pcinh..._a0_img_local_prefetch'

The error occurs in the NKI kernel Conv2d_dw_fb01_io01_01bf_rep_nhwc_Pcinh during backward graph compilation. The specific failure is a tensorcopy partition alignment check (partitions_per_bank=32).

Systematic Test Results

We tested all conv layer types individually with NEURON_CC_FLAGS="--optlevel=1":

Conv Type Input Shape Backward Compile Time
7x7, stride=2, 3→64 (1, 3, 256, 1024) FAIL (NCC_ITEN404)
7x7, stride=2, 3→64 (1, 3, 64, 64) FAIL (NCC_ITEN404)
7x7, stride=2, 1→64 (1, 1, 256, 256) FAIL (NCC_ITEN404)
3x3, stride=1, 64→64 (1, 64, 128, 512) PASS 62.4s
3x3, stride=1, 256→256 (1, 256, 16, 64) PASS 30.6s
3x3, stride=1, 512→512 (1, 512, 8, 32) PASS 4.6s
3x3, stride=2, 128→256 (1, 128, 32, 128) PASS 23.5s
1x1, 64→128 (1, 64, 8, 32) PASS 1.8s
Full ResNet-34 layer1-4 (no stem) (1, 64, 128, 512) PASS 60.1s

The 7x7 conv fails regardless of input size and number of input channels.

Impact

This blocks full end-to-end training of any model using a 7x7 conv (ResNet-18/34/50/101/152, most vision backbones). Our specific use case is training DiffusionDrive (CVPR 2025, end-to-end autonomous driving) with two ResNet-34 encoders.

Workaround: Freeze conv1 and bn1 of each ResNet encoder (~9.5K params) and run them under torch.no_grad(). All subsequent layers compile and train successfully, giving 99.98% trainable parameters (60.7M out of 60.7M).

Question

Has this been fixed in newer neuronx-cc versions (2.24+)? We were unable to upgrade within the PyTorch 2.8 DLAMI venv (pinned dependencies). If a newer version fixes this, please let us know which AMI/venv to use.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions