Bug Description
Compilation of multiple Qwen models consistently fails with internal error NCC_IXCG856 when the target is inf2. The error fires inside the XLA tensorizer during token-generation graph compilation:
[INTERNAL_ERROR] [NCC_IXCG856] MATCH_REPLACE8 Instruction expects at least 8 input elements per partition
The bug affects both Qwen2ForCausalLM (Qwen2.5-14B) and Qwen3ForCausalLM (Qwen3-8B, DeepSeek-R1-0528-Qwen3-8B) but does not affect Qwen2ForCausalLM at the 7B size (same architecture, same container, same TP degree). Qwen2.5-7B compiles cleanly.
Environment
| Item |
Value |
| Container image |
public.ecr.aws/neuron/pytorch-inference-vllm-neuronx:0.13.0-neuronx-py312-sdk2.28.0-ubuntu24.04 |
| NxDI version |
0.8.16251 (Neuron SDK 2.28.0) |
| libneuronxla |
2.2.15515.0+50c26cbd |
| vLLM Neuron Plugin |
0.4.1 |
| Compilation |
CPU cross-compilation (NEURON_PLATFORM_TARGET_OVERRIDE=inf2, no physical Inf2 device) |
| --target |
inf2 |
Affected Models
| Model |
HF architecture |
num_attention_heads |
num_key_value_heads |
TP degree |
Failing tensor |
Status |
| Qwen/Qwen2.5-7B-Instruct |
Qwen2ForCausalLM |
28 |
4 |
2 |
N/A |
✅ Compiles |
| Qwen/Qwen2.5-14B-Instruct |
Qwen2ForCausalLM |
40 |
8 |
2 |
float32<1 x 6> |
❌ NCC_IXCG856 |
| Qwen/Qwen3-8B |
Qwen3ForCausalLM |
32 |
8 |
2 |
float32<1 x 4> |
❌ NCC_IXCG856 |
| deepseek-ai/DeepSeek-R1-0528-Qwen3-8B |
Qwen3ForCausalLM |
32 |
8 |
2 |
float32<1 x 4> |
❌ NCC_IXCG856 |
Full error — Qwen2.5-14B-Instruct
2026-03-30 13:32:48 [ERROR]: Failed compilation with ['neuronx-cc', 'compile',
'--framework=XLA',
'/tmp/nxd_model/token_generation_model/_tp0_bk0/model.MODULE_1c428201b5d29e196ecb+511d3580.hlo_module.pb',
'--output', '...neff',
'--target=inf2',
'--enable-saturate-infinity', '--enable-mixed-precision-accumulation',
'--auto-cast=none', '--model-type', 'transformer', '-O1',
'--tensorizer-options=--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2 --vectorize-strided-dma',
'--internal-hlo2tensorizer-options=--verify-hlo=true',
'--verbose=35', '--logfile=...log-neuron-cc.txt', '--enable-internal-neff-wrapper']:
(I-93885-0), tensorizer(
output tensor: float32<2 x 15201> $93884, id: 68326
output tensor: float32<1 x 6> $93885, id: 68326
) [INTERNAL_ERROR] [NCC_IXCG856]
MATCH_REPLACE8 Instruction expects at least 8 input elements per partition
[libneuronxla 2.2.15515.0+50c26cbd]
subprocess.CalledProcessError: ... returned non-zero exit status 70.
Compile parameters: BATCH_SIZE=2, SEQUENCE_LENGTH=2048, NUM_CORES=2
Full error — Qwen3-8B (and DeepSeek-R1-0528-Qwen3-8B, identical)
2026-03-30 12:09:36 [ERROR]: Failed compilation with ['neuronx-cc', 'compile',
'--framework=XLA', ...
'--target=inf2',
'--auto-cast=none', '--model-type=transformer', '-O2',
'--tensorizer-options=--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=1 --vectorize-strided-dma',
'--lnc=1', ...]:
(I-69043-0), tensorizer(
output tensor: float32<2 x 15190> $69042, id: 51394
output tensor: float32<1 x 4> $69043, id: 51394
) [INTERNAL_ERROR] [NCC_IXCG856]
MATCH_REPLACE8 Instruction expects at least 8 input elements per partition
[libneuronxla 2.2.15515.0+50c26cbd]
subprocess.CalledProcessError: ... returned non-zero exit status 70.
Compile parameters: BATCH_SIZE=2, SEQUENCE_LENGTH=2048, NUM_CORES=2
Analysis / pattern
The MATCH_REPLACE8 hardware instruction requires the vals lookup tensor to have at least 8 elements per NeuronCore partition. In both failing cases, a small tensor (6 or 4 elements) ends up as the vals operand after the tensorizer has sharded the graph. The tensor likely represents a small set of special token IDs (e.g. eos_token_id, bos_token_id, think-start/end tokens) used in post-processing logits.
The working model (Qwen2.5-7B-Instruct, num_key_value_heads=4) does not trigger this lowering, while models with num_key_value_heads=8 do — suggesting the graph pattern is shaped differently once the number of KV heads changes. Lowering to a different instruction (e.g. a scalar loop or a padded MATCH_REPLACE8 with a dummy fill) for small vals tensors would resolve the issue.
Expected behaviour
Compilation completes successfully, producing a .neff artifact for all listed models.
Actual behaviour
neuronx-cc exits with code 70 and internal error NCC_IXCG856 for all models with num_key_value_heads=8 in the Qwen2/Qwen3 family, exclusively when --target=inf2.
Additional info
While the compilation of Qwen/Qwen2.5-7B-Instruct works - it seems that the eos is ignored when serving the model with vLLM. This issue is not present with Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2 which compiles and stops iterating when reaching the eos character - this bug might require another issue - please advise.
Only inf2 was used for the tests as the models are deployed to EKS in eu-west-1 region.
Model Name
Qwen/Qwen2.5-14B-Instruct
Qwen/Qwen3-8B
deepseek-ai/DeepSeek-R1-0528-Qwen3-8B
Describe the workload type
Inferences - using vLLM to serve some exotic models accessed through LiteLLM.
Instance Type
inf2.*
Release version
2.28.0
Reproduction Steps
compile_model.py:
import hashlib
import importlib.metadata
import logging
import os
import sys
from pathlib import Path
def _force_inf2_neuron_compiler_env() -> None:
target = os.environ.get("NEURON_COMPILE_TARGET", "inf2").strip().lower()
os.environ["NEURON_PLATFORM_TARGET_OVERRIDE"] = target
default_flags = (
"--enable-saturate-infinity --enable-mixed-precision-accumulation "
f"--model-type transformer --target {target}"
)
existing = os.environ.get("NEURON_CC_FLAGS")
if existing is None or not str(existing).strip():
os.environ["NEURON_CC_FLAGS"] = default_flags
else:
ex = str(existing).strip()
if "--target" not in ex:
os.environ["NEURON_CC_FLAGS"] = f"{ex} --target {target}"
else:
os.environ["NEURON_CC_FLAGS"] = ex
_force_inf2_neuron_compiler_env()
from huggingface_hub import snapshot_download
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s – %(message)s",
datefmt="%Y-%m-%dT%H:%M:%S",
)
logger = logging.getLogger("compile_model")
def _truthy(name: str, default: str = "false") -> bool:
return os.environ.get(name, default).lower() in ("1", "true", "yes", "on")
def _env_optional_bool(name: str) -> bool | None:
"""If unset, return None (let vLLM apply model defaults). If set, parse bool."""
if name not in os.environ:
return None
return _truthy(name, "false")
def load_config() -> dict:
model_id = os.environ.get("MODEL_ID")
if not model_id:
logger.error("MODEL_ID environment variable is required")
sys.exit(1)
dtype = os.environ.get("VLLM_TORCH_DTYPE", "bfloat16")
if dtype in ("bf16",):
dtype = "bfloat16"
return {
"model_id": model_id,
"batch_size": int(os.environ.get("BATCH_SIZE", 1)),
"sequence_length": int(os.environ.get("SEQUENCE_LENGTH", 2048)),
"num_cores": int(os.environ.get("NUM_CORES", 2)),
"auto_cast_type": os.environ.get("AUTO_CAST_TYPE", "bf16"),
"output_dir": Path(os.environ.get("OUTPUT_DIR", "compiled_model")),
# Match Helm / `vllm serve` (charts/vllm-neuron/values.yaml → vllm.extraArgs).
"block_size": int(os.environ.get("BLOCK_SIZE", "32")),
"dtype": dtype,
"trust_remote_code": _truthy("TRUST_REMOTE_CODE", "false"),
# None = same as omitting CLI flags (vLLM 0.13+ picks model-capability defaults).
"enable_prefix_caching": _env_optional_bool("ENABLE_PREFIX_CACHING"),
"enable_chunked_prefill": _env_optional_bool("ENABLE_CHUNKED_PREFILL"),
"num_gpu_blocks_override_raw": os.environ.get("NUM_GPU_BLOCKS_OVERRIDE"),
}
def get_compiler_version() -> str:
try:
return importlib.metadata.version("neuronx-cc")
except importlib.metadata.PackageNotFoundError:
logger.warning("neuronx-cc not found in installed packages, version unknown")
return "unknown"
def _ensure_neuron_platform_for_vllm_config() -> None:
import vllm.platforms as vllm_platforms
from vllm_neuron.platform import NeuronPlatform
vllm_platforms._current_platform = NeuronPlatform()
logger.info("Forced vLLM current_platform to NeuronPlatform (no /dev/neuron in this environment)")
def build_inference_config_like_vllm_serve(model_dir: Path, cfg: dict):
_ensure_neuron_platform_for_vllm_config()
from vllm.engine.arg_utils import EngineArgs
from vllm.usage.usage_lib import UsageContext
from vllm_neuron.worker.neuronx_distributed_model_loader import (
_get_default_neuron_config,
_get_model_configs,
_get_neuron_config_after_override,
_get_neuron_model_cls,
_handle_pa_num_blocks,
_validate_neuron_config,
_validate_override_neuron_config,
)
ea_kwargs: dict = {
"model": str(model_dir),
"tensor_parallel_size": cfg["num_cores"],
"max_model_len": cfg["sequence_length"],
"max_num_seqs": cfg["batch_size"],
"dtype": cfg["dtype"],
"block_size": cfg["block_size"],
"trust_remote_code": cfg["trust_remote_code"],
}
raw_ov = cfg.get("num_gpu_blocks_override_raw")
if raw_ov is not None and raw_ov != "":
ea_kwargs["num_gpu_blocks_override"] = int(raw_ov)
if cfg["enable_prefix_caching"] is not None:
ea_kwargs["enable_prefix_caching"] = cfg["enable_prefix_caching"]
if cfg["enable_chunked_prefill"] is not None:
ea_kwargs["enable_chunked_prefill"] = cfg["enable_chunked_prefill"]
engine_args = EngineArgs(**ea_kwargs)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.OPENAI_API_SERVER,
)
cc = vllm_config.cache_config
sc = vllm_config.scheduler_config
logger.info(
"Resolved vLLM cache/scheduler for hash: enable_prefix_caching=%s "
"block_size=%s max_model_len=%s max_num_seqs=%s enable_chunked_prefill=%s",
getattr(cc, "enable_prefix_caching", None),
getattr(cc, "block_size", None),
vllm_config.model_config.max_model_len,
sc.max_num_seqs,
sc.enable_chunked_prefill,
)
additional_config = vllm_config.additional_config or {}
lora_serving_config = None
try:
architecture, _, _ = _get_model_configs(vllm_config.model_config.hf_config)
except ValueError as exc:
logger.error(
"vLLM-neuron does not support the architecture for %s "
"(model_type=%r, architectures=%s). "
"Verify the model is in the supported list for the container image being used.",
ea_kwargs["model"],
getattr(vllm_config.model_config.hf_config, "model_type", "?"),
getattr(vllm_config.model_config.hf_config, "architectures", "?"),
)
raise
default_neuron_config_args = _get_default_neuron_config(
vllm_config.model_config,
vllm_config.cache_config,
vllm_config.parallel_config,
vllm_config.scheduler_config,
lora_serving_config,
None,
)
override_neuron_config = additional_config.get("override_neuron_config")
if override_neuron_config is None:
override_neuron_config = {}
override_neuron_config = _validate_override_neuron_config(
override_neuron_config, vllm_config.model_config
)
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, override_neuron_config
)
if neuron_config.get("is_block_kv_layout"):
neuron_config = _handle_pa_num_blocks(
vllm_config.cache_config, neuron_config, override_neuron_config
)
neuron_config = _validate_neuron_config(
vllm_config.cache_config,
vllm_config.scheduler_config,
vllm_config.model_config,
neuron_config,
)
model_cls = _get_neuron_model_cls(architecture)
neuron_cfg_obj = model_cls.get_neuron_config_cls()(**neuron_config)
infer_cfg = model_cls.get_config_cls()(
neuron_cfg_obj,
load_config=load_pretrained_config(str(model_dir)),
)
return infer_cfg, model_cls, architecture
def compile_model(cfg: dict) -> None:
model_id = cfg["model_id"]
output_dir = cfg["output_dir"]
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Downloading weights for {model_id} → {output_dir}")
snapshot_download(
repo_id=model_id,
local_dir=str(output_dir),
ignore_patterns=["*.msgpack", "flax_model*", "tf_model*", "rust_model*"],
)
config_json = output_dir / "config.json"
if not config_json.exists():
raise RuntimeError(
f"snapshot_download completed but no config.json found at {output_dir}. "
"The model repository may not use the standard HF transformers format "
"(check for GGUF-only repos or non-standard layouts)."
)
import json as _json
_hf_raw = _json.loads(config_json.read_text(encoding="utf-8"))
logger.info(
"Downloaded model config: model_type=%r architectures=%s",
_hf_raw.get("model_type"),
_hf_raw.get("architectures"),
)
infer_cfg, model_cls, architecture = build_inference_config_like_vllm_serve(
output_dir, cfg
)
logger.info(f"HF architecture {architecture} → {model_cls.__name__}")
target_hash = hashlib.md5(
infer_cfg.to_json_string().encode("utf-8"),
usedforsecurity=False,
).hexdigest()
logger.info(
"vLLM compilation hash: %s (model_dir=%s — must match `vllm serve` path)",
target_hash,
output_dir.resolve(),
)
compiled_model_path = output_dir / "neuron-compiled-artifacts" / target_hash
compiled_model_path.mkdir(parents=True, exist_ok=True)
# vLLM / torch-neuronx may mutate Neuron env during EngineArgs import; re-apply
# before neuronx-cc runs so NEFFs are always built for the intended silicon.
_force_inf2_neuron_compiler_env()
logger.info(
f"Starting AOT compilation for {model_id} "
f"(NEURON_PLATFORM_TARGET_OVERRIDE={os.environ['NEURON_PLATFORM_TARGET_OVERRIDE']!r}, "
f"tp={cfg['num_cores']}, max_num_seqs={cfg['batch_size']}, "
f"max_model_len={cfg['sequence_length']}, block_size={cfg['block_size']})"
)
model = model_cls(str(output_dir), infer_cfg)
model.compile(str(compiled_model_path))
(output_dir / "neuron_artifact_md5.txt").write_text(
target_hash + "\n", encoding="utf-8"
)
logger.info(f"AOT compilation complete. Artifacts at: {output_dir}")
logger.info(f" Compiler version : {get_compiler_version()}")
logger.info(f" Compiled path : {compiled_model_path}")
def main() -> None:
cfg = load_config()
compile_model(cfg)
if __name__ == "__main__":
main()
Executed the script the following command:
docker run --rm \
--entrypoint /opt/conda/bin/python3 \
-e MODEL_ID=Qwen/Qwen2.5-14B-Instruct \
-e BATCH_SIZE=2 \
-e SEQUENCE_LENGTH=2048 \
-e NUM_CORES=2 \
-e AUTO_CAST_TYPE=bf16 \
-e OUTPUT_DIR=/models/Qwen/Qwen2.5-14B-Instruct/test/seq2048 \
-e NEURON_PLATFORM_TARGET_OVERRIDE=inf2 \
-e ENABLE_PREFIX_CACHING=false \
-e BLOCK_SIZE=2048 \
-v /tmp/models:/models \
-v /path/to/compile_model.py:/workspace/scripts/compile_model.py \
public.ecr.aws/neuron/pytorch-inference-vllm-neuronx:0.13.0-neuronx-py312-sdk2.28.0-ubuntu24.04 \
/workspace/scripts/compile_model.py
Regression Issue
Possible Solution
The tensorizer in neuronx-cc / libneuronxla selects the nc_match_replace8 instruction for a small constant vals tensor (typically special token IDs in the logits post-processing / sampling path of the token-generation graph).
After sharding (common with NUM_CORES=2 for TP=2 on --target=inf2), this vals tensor results in fewer than 8 elements per partition (e.g., 4 or 6 elements). The MATCH_REPLACE8 ISA requires exactly 8 elements per partition, triggering the internal error NCC_IXCG856.
This affects models such as Qwen2.5-14B, Qwen3-8B, and DeepSeek-R1-Qwen3-8B, but not Qwen2.5-7B (different num_key_value_heads leads to a safe size or different lowering).
Recommended Fix (in tensorizer / HLO-to-Neuron IR lowering pass)
Add a check before emitting MATCH_REPLACE8:
- Compute effective elements per partition for the vals operand (accounting for current partitioning and flattening of free dimensions).
- If the count == 8 → proceed with existing MATCH_REPLACE8 path (unchanged behavior).
- If the count < 8 (and > 0):
• Preferred fast path (when vals is a compile-time constant): Insert an HLO-level Pad operation to bring the vals tensor to exactly 8 elements per partition using a safe dummy value (e.g., -1.0, which cannot match real token IDs). Then emit MATCH_REPLACE8 as usual. This preserves the optimized hardware instruction with negligible overhead.
• Fallback path (for non-constant or other cases): Emit a general vectorized replace implementation (unrolled compares or small loop using standard vector engine operations). For N ≤ 8 this has near-zero performance impact.
- Improve diagnostics:
• Emit a clear WARNING (not error) when padding or fallback is used:
"MATCH_REPLACE8 vals tensor has X elements per partition (expected 8). Applying padding/fallback."
• Optionally expose a tensorizer flag (e.g., --tensorizer-options=--force-match-replace8-fallback) for debugging.
Benefits
• Makes compilation robust for any small constant lookup tables (common in modern LLMs with thinking tokens, tool-use delimiters, etc.).
• No behavior change for models that already produce exactly 8 elements per partition.
• Negligible performance cost.
• Covers edge cases: different TP degrees, batch sizes, sequence lengths, and future models with arbitrary small vals sizes (5, 7, 9, …).
Verification:
• Recompile failing models (Qwen2.5-14B, Qwen3-8B) → successful .neff generation.
• Regression test: Qwen2.5-7B and other common models compile and perform identically.
• End-to-end inference test: confirm correct handling of EOS / special tokens.
This change should be implemented in the instruction selection / legalization logic inside libneuronxla. A quick HLO-level padding workaround in the vLLM Neuron plugin could serve as a temporary unblock if needed.
Logs/Context/Additional Information
No response
Bug Description
Compilation of multiple Qwen models consistently fails with internal error NCC_IXCG856 when the target is inf2. The error fires inside the XLA tensorizer during token-generation graph compilation:
[INTERNAL_ERROR] [NCC_IXCG856] MATCH_REPLACE8 Instruction expects at least 8 input elements per partitionThe bug affects both Qwen2ForCausalLM (Qwen2.5-14B) and Qwen3ForCausalLM (Qwen3-8B, DeepSeek-R1-0528-Qwen3-8B) but does not affect Qwen2ForCausalLM at the 7B size (same architecture, same container, same TP degree). Qwen2.5-7B compiles cleanly.
Environment
Affected Models
Full error — Qwen2.5-14B-Instruct
Compile parameters: BATCH_SIZE=2, SEQUENCE_LENGTH=2048, NUM_CORES=2
Full error — Qwen3-8B (and DeepSeek-R1-0528-Qwen3-8B, identical)
Compile parameters: BATCH_SIZE=2, SEQUENCE_LENGTH=2048, NUM_CORES=2
Analysis / pattern
The MATCH_REPLACE8 hardware instruction requires the vals lookup tensor to have at least 8 elements per NeuronCore partition. In both failing cases, a small tensor (6 or 4 elements) ends up as the vals operand after the tensorizer has sharded the graph. The tensor likely represents a small set of special token IDs (e.g. eos_token_id, bos_token_id, think-start/end tokens) used in post-processing logits.
The working model (Qwen2.5-7B-Instruct, num_key_value_heads=4) does not trigger this lowering, while models with num_key_value_heads=8 do — suggesting the graph pattern is shaped differently once the number of KV heads changes. Lowering to a different instruction (e.g. a scalar loop or a padded MATCH_REPLACE8 with a dummy fill) for small vals tensors would resolve the issue.
Expected behaviour
Compilation completes successfully, producing a .neff artifact for all listed models.
Actual behaviour
neuronx-ccexits with code 70 and internal errorNCC_IXCG856for all models withnum_key_value_heads=8in theQwen2/Qwen3family, exclusively when--target=inf2.Additional info
While the compilation of
Qwen/Qwen2.5-7B-Instructworks - it seems that the eos is ignored when serving the model with vLLM. This issue is not present withOrenguteng/Llama-3.1-8B-Lexi-Uncensored-V2which compiles and stops iterating when reaching theeoscharacter - this bug might require another issue - please advise.Only
inf2was used for the tests as the models are deployed to EKS in eu-west-1 region.Model Name
Qwen/Qwen2.5-14B-Instruct
Qwen/Qwen3-8B
deepseek-ai/DeepSeek-R1-0528-Qwen3-8B
Describe the workload type
Inferences - using vLLM to serve some exotic models accessed through LiteLLM.
Instance Type
inf2.*Release version
2.28.0
Reproduction Steps
compile_model.py:
Executed the script the following command:
Regression Issue
Possible Solution
The tensorizer in neuronx-cc / libneuronxla selects the nc_match_replace8 instruction for a small constant vals tensor (typically special token IDs in the logits post-processing / sampling path of the token-generation graph).
After sharding (common with NUM_CORES=2 for TP=2 on --target=inf2), this vals tensor results in fewer than 8 elements per partition (e.g., 4 or 6 elements). The MATCH_REPLACE8 ISA requires exactly 8 elements per partition, triggering the internal error NCC_IXCG856.
This affects models such as Qwen2.5-14B, Qwen3-8B, and DeepSeek-R1-Qwen3-8B, but not Qwen2.5-7B (different num_key_value_heads leads to a safe size or different lowering).
Recommended Fix (in tensorizer / HLO-to-Neuron IR lowering pass)
Add a check before emitting MATCH_REPLACE8:
• Preferred fast path (when vals is a compile-time constant): Insert an HLO-level Pad operation to bring the vals tensor to exactly 8 elements per partition using a safe dummy value (e.g., -1.0, which cannot match real token IDs). Then emit MATCH_REPLACE8 as usual. This preserves the optimized hardware instruction with negligible overhead.
• Fallback path (for non-constant or other cases): Emit a general vectorized replace implementation (unrolled compares or small loop using standard vector engine operations). For N ≤ 8 this has near-zero performance impact.
• Emit a clear WARNING (not error) when padding or fallback is used: "MATCH_REPLACE8 vals tensor has X elements per partition (expected 8). Applying padding/fallback."
• Optionally expose a tensorizer flag (e.g., --tensorizer-options=--force-match-replace8-fallback) for debugging.
Benefits
• Makes compilation robust for any small constant lookup tables (common in modern LLMs with thinking tokens, tool-use delimiters, etc.).
• No behavior change for models that already produce exactly 8 elements per partition.
• Negligible performance cost.
• Covers edge cases: different TP degrees, batch sizes, sequence lengths, and future models with arbitrary small vals sizes (5, 7, 9, …).
Verification:
• Recompile failing models (Qwen2.5-14B, Qwen3-8B) → successful .neff generation.
• Regression test: Qwen2.5-7B and other common models compile and perform identically.
• End-to-end inference test: confirm correct handling of EOS / special tokens.
This change should be implemented in the instruction selection / legalization logic inside libneuronxla. A quick HLO-level padding workaround in the vLLM Neuron plugin could serve as a temporary unblock if needed.
Logs/Context/Additional Information
No response