How do you launch an NKI kernel across several LNCs? #119
-
|
Hi! I am working on the AWS MoE challenge for MLSys 2026 (https://github.com/aws-neuron/nki-moe/) and was wondering how we create / call NKI kernels across multiple LNCs. The current SPMD example in the AWS Neuron SDK docs (https://awsdocs-neuron.readthedocs-hosted.com/en/latest/nki/guides/tutorials/spmd_multiple_nc_tensor_addition.html, https://awsdocs-neuron.readthedocs-hosted.com/en/latest/nki/guides/tutorials/spmd_multiple_nc_tensor_addition.html) either only go across an LNC=2 pair or use unsupported functions from NKI 1 Here's what I have so far, documented with what I've done and some issues: import os
import numpy as np
import torch
import torch_xla.runtime as xr
from torch_xla.distributed.spmd.debugging import (
visualize_tensor_sharding
)
from torch_xla.distributed.spmd import Mesh
import torch_xla.distributed.spmd as xs
import nki
import nki.isa as nisa
import nki.language as nl
# Enable SPMD
xr.use_spmd()
@nki.jit(mode='torchxla')
def f(input_hbm: torch.Tensor):
# Debug print statements
print(input_hbm.shape)
print('dims', nl.program_ndim())
print('program_id(0)', nl.program_id(0))
input_sbuf = nl.ndarray(
input_hbm.shape, input_hbm.dtype, buffer=nl.sbuf
)
nisa.dma_copy(input_sbuf, input_hbm)
output_hbm = nl.ndarray(
input_hbm.shape, input_hbm.dtype, buffer=nl.shared_hbm
)
return output_hbm
if __name__ == '__main__':
# Good luck charms
os.environ['PJRT_DEVICE'] = 'NEURON'
os.environ['NEURON_PLATFORM_TARGET_OVERRIDE'] = 'trn2'
# This doesn't work, I think because we're not going through NxDI
os.environ['NEURON_RT_DEBUG_OUTPUT_DIR'] = 'debug_output'
num_devices = xr.global_runtime_device_count() # 4
device_ids = np.arange(num_devices) # [0, 1, 2, 3]
# Row-wise data split
mesh = Mesh(device_ids, (num_devices, 1), ('data', 'model'))
t = torch.arange(8).reshape(4, 2).to(torch.float32).to('xla')
xs.mark_sharding(t, mesh, ('data', 'model'))
# [LNC-0 LNC-0]
# [LNC-1 LNC-1]
# [LNC-2 LNC-2]
# [LNC-3 LNC-3]
visualize_tensor_sharding(t, use_color=False)
# I expect each of the four devices to have a row of the tensor of shape (1, 2),
# but it gets the entire (4, 2) tensor even though they're on different HBM banks
print(f[num_devices](t))
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
Hello, thanks for your question! Today the best way to run collectives with NKI kernels is using the newly released NKI collectives API here. For some inspiration and samples, I'd suggest checking out the NKI Kernel Library here. A nice example is this all_gather for sbuf2sbuf. You should not go through the SPMD grid for this use case. When you want to push data across multiple Neuron cores, however you have set their logical config, it's better to go through collectives. Also - when you are using the XLA stack it's probably easier to keep things scoped within the NxD Model Builder API. You can use that to set the distributed process group. |
Beta Was this translation helpful? Give feedback.

Hello, thanks for your question! Today the best way to run collectives with NKI kernels is using the newly released NKI collectives API here.
For some inspiration and samples, I'd suggest checking out the NKI Kernel Library here.
A nice example is this all_gather for sbuf2sbuf.
You should not go through the SPMD grid for this use case. When you want to push data across multiple Neuron cores, however you have set their logical config, it's better to go through collectives.
Also - when you are using the XLA stack it's probably easier to keep things scoped within the NxD Model Builder API. You can use that to set the distributed process group.