Skip to content

Add TKG-optimized contribs for 4 dense Mistral models#132

Open
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft:contrib/mistral-dense-tkg
Open

Add TKG-optimized contribs for 4 dense Mistral models#132
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft:contrib/mistral-dense-tkg

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

Summary

  • Add NKI TKG fused attention kernel optimization for 4 dense Mistral-family models
  • Each model directory is self-contained with setup scripts, TKG kernel, NKI 0.3.0 compat fixes, and integration tests
  • All benchmarked on SDK 2.29 (DLAMI 20260410), vLLM 0.16.0, trn2.3xlarge

Models

Model tok/s (TKG) tok/s (baseline) Improvement TKG Mode GPU Comparison
Mistral-7B-Instruct-v0.3 115 97 +19% multi-KV 1.47x GPU, 15% cheaper
Ministral-8B-Instruct-2410 99 82 +21-29% multi-KV 1.32x GPU, 24% cheaper
Ministral-3-3B-Instruct-2512-BF16 158 134 +18-27% multi-KV GPU blocked
Mistral-Small-3.1-24B-Instruct-2503 48 45 +6-14% stock 1.27x GPU, 26% cheaper

Key Technical Details

TKG kernel: Fused attention NKI kernel derived from Leanstral (Ministral-3-14B contrib). Two variants:

  • Multi-KV: For models with >1 KV head per TP rank (7B, 8B, 3B at TP=4). Uses virtual batch approach.
  • Stock: For models with 1 KV head per TP rank (24B at TP=8). No custom kernel needed, just config flags.

Limitations:

  • TKG requires max_num_seqs=1 (BIR verification failure at BS>1 on neuronx-cc 2.24)
  • --no-enable-prefix-caching required (block KV cache corruption)
  • Ministral 3B and 24B require text extraction from multimodal checkpoint (LlamaForCausalLM path)
  • Ministral 8B requires config patch (sliding_window=null, layer_types removed)

Files Changed

  • New: Mistral-7B-Instruct-v0.3/ (10 files)
  • New: Ministral-8B-Instruct-2410/ (12 files, includes patch_config.py)
  • New: Ministral-3-3B-Instruct-2512-BF16/ (11 files, includes extract_text_model.py)
  • Updated: Mistral-Small-3.1-24B-Instruct-2503/ (8 new files + README update, includes extract_text_model.py)

Add NKI TKG fused attention kernel optimization for the Mistral dense
model family, achieving 6-29% throughput improvement over baseline:

- Mistral-7B-Instruct-v0.3: 115 tok/s (+19%), multi-KV TKG kernel
- Ministral-8B-Instruct-2410: 99 tok/s (+21-29%), multi-KV TKG + config fix
- Ministral-3-3B-Instruct-2512-BF16: 158 tok/s (+18-27%), fastest in family
- Mistral-Small-3.1-24B-Instruct-2503: 48 tok/s (+6-14%), stock TKG kernel

Each model directory is self-contained with setup scripts, TKG kernel
files, NKI 0.3.0 compatibility fixes, and integration tests. Models
using the Llama path (3B, 24B) include text extraction from multimodal
checkpoints. All models are cost-competitive with GPU at spot pricing.

Tested on SDK 2.29 (DLAMI 20260410), vLLM 0.16.0, trn2.3xlarge.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant