Skip to content

[ROCm][DSv4] AITER-accelerated MLA decode for DeepSeek V4 on MI355X (rebased on tj/dsv4prrebase)#901

Draft
ChuanLi1101 wants to merge 4 commits intoROCm:hexwang/dsv4_adapt_upstreamfrom
ChuanLi1101:chuali/aiter-mla-dsv4-decode-rebased
Draft

[ROCm][DSv4] AITER-accelerated MLA decode for DeepSeek V4 on MI355X (rebased on tj/dsv4prrebase)#901
ChuanLi1101 wants to merge 4 commits intoROCm:hexwang/dsv4_adapt_upstreamfrom
ChuanLi1101:chuali/aiter-mla-dsv4-decode-rebased

Conversation

@ChuanLi1101
Copy link
Copy Markdown

@ChuanLi1101 ChuanLi1101 commented Apr 27, 2026

Summary

Rebased version of vllm-project#40889 for review against the tj/dsv4prrebase integration branch.

Per Tun Jian's note that vLLM upstream is now planning to merge DSv4 via the rebased PR vllm-project#40860 instead of vllm-project#40760, the AITER decode stack needs to land on top of that newer base. This PR cherry-picks the 2 commits from vllm-project#40889 onto ROCm/vllm:tj/dsv4prrebase.

This does not duplicate vllm-project#40889 — that PR will stay open as a draft against vllm-project/vllm:main until vllm-project#40860 lands; this PR exists so reviewers can look at the AITER decode code on top of the new base today. Once vllm-project#40860 lands and tj/dsv4prrebase merges, only one of the two will graduate to a real upstream PR; the other will be closed.

What changed during rebase

  • 5 commits across the original 3 PRs cherry-picked cleanly with 0 manual conflict resolution.
  • One small drop: mxfp4.py already has the RoutingMethodType import on tj/dsv4prrebase, so the duplicate import that [ROCm] Add AITER-accelerated MLA decode for DeepSeek V4 on MI355X vllm-project/vllm#40889 added was removed during cherry-pick.
  • All cross-file references verified: _forward_decode_aiteraiter_sparse_attn_decode signatures, _dequantize_blocked_k_cache(out=), current_workspace_manager, AiterSparseScratch.

Original PR description

This PR adds an AITER-accelerated sparse MLA decode path for DeepSeek V4 on AMD MI355X (gfx950).

The existing ROCm decode path uses a PyTorch reference implementation. This PR replaces it with AITER's persistent-mode ASM kernel (aiter.mla.mla_decode_fwd), achieving ~2-3x decode speedup at high batch sizes while maintaining numerical correctness.

Changes

New file: vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py

  • AiterSparseScratch: Lazy-initialized persistent-mode metadata buffers. Keyed by (batch_size, nhead, topk, dtype, kvtype) so all 61 DSv4 attention layers share one allocation per decode step, eliminating per-layer metadata rebuild overhead.
  • aiter_sparse_attn_decode(): Drop-in replacement for _ref_sparse_attn_decode, handling:
    • Dual-scope attention (SWA + extra blocked K) with LSE-based output merging
    • Attention sink correction using LSE values from the kernel
    • FP8/FP8 input casting (required by gfx950 persistent-mode + return_lse=True)
    • Fixed-stride kv_indices layout with -1 sentinels (required by AITER persistent-mode kernels)

Modified: vllm/model_executor/layers/deepseek_v4_attention.py

  • Added _aiter_scratch / _aiter_extra_scratch fields to __init__
  • On ROCm, _forward_decode now unconditionally dispatches to the new _forward_decode_aiter() (DeepSeek sparse attention is AITER-only on ROCm; no env-var flag).

Test plan

  • Cherry-pick verified clean (0 manual conflicts, all symbols cross-link)
  • Cosine-vs-PyTorch-ref parity on MI355X (TP2/TP4/TP8, SWA-only and SWA+extra) — running now on tj/dsv4prrebase
  • E2E generation smoke test — pending baseline validation on tj/dsv4prrebase (Hexiang)
  • lm_eval (gsm8k / mmlu_stem) — pending baseline validation

Parity numbers from the original vllm-project#40889 base were cosine > 0.999 across all configs. Re-running on the new base.

AI assistance disclosure

AI assistance (Claude / Cursor agent) was used for this rebase. The submitter has reviewed every cherry-picked diff line and verified cross-file integration. Test results will be posted as a follow-up comment once parity finishes.

Related

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@ChuanLi1101
Copy link
Copy Markdown
Author

Parity check on MI355X (cherry-picked branch)

Re-ran cosine-vs-PyTorch-reference parity on the cherry-picked branch
(tj/dsv4prrebase + 5 commits, HEAD 4564d3b) using
bench_remote/_parity_test_aiter_decode.py on a single MI355X / gfx950
GPU inside chuali_glm51. Synthetic inputs: b=8, topk=2048,
d_qk=576, d_v=512.

Config h_q cosine min / mean / max max abs diff result
TP4 swa_only 32 0.998532 / 0.999131 / 0.999474 1.66e-02 PASS
TP4 swa+extra 32 0.998528 / 0.999280 / 0.999494 1.56e-02 PASS
TP8 swa_only 16 0.998682 / 0.999152 / 0.999455 1.46e-02 PASS
TP8 swa+extra 16 0.998927 / 0.999292 / 0.999505 1.01e-02 PASS
TP2 swa_only 64 0.994775 / 0.998978 / 0.999511 2.55e-02 borderline
TP2 swa+extra 64 NaN NaN FAIL

(Threshold: cosine_mean > 0.999 AND cosine_min > 0.99)

TP2 NaN — not from the cherry-picked code

A drill-down (_parity_tp2_diag.py) shows the NaN comes from the AITER
persistent ASM kernel mla_a8w8_qh64_qseqlen4_gqaratio16_lse_ps itself,
not from the integration code:

TP2 swa+extra b=8:
  out_swa  : nans=0     infs=0   ok
  out_ext  : nans=8192  infs=0   <-- NaN appears in scope #2's output
  lse_ext  : nans=0     infs=16  <-- kernel returns +inf for 16 rows

(lse_ext - lse_total).exp() with both operands +inf yields NaN, which
then propagates through the LSE merge. Reproducible across batch sizes
(1/2/4/8/16) and topk (256/512/1024/2048) at TP2; completely clean
on TP4 and TP8 (which use the smaller qh16_qseqlen1 kernel). Worth a
ping to the AITER team but doesn't block landing the decode path on
TP4/TP8 (the production deployment shapes on MI355X).

Cherry-pick verification

  • 5 commits applied with 0 manual conflict resolution.
  • One small drop during cherry-pick: the duplicate RoutingMethodType
    import in mxfp4.py already exists on tj/dsv4prrebase.
  • All cross-file references verified: _forward_decode_aiter
    aiter_sparse_attn_decode signatures, _dequantize_blocked_k_cache(out=),
    current_workspace_manager, AiterSparseScratch.

@ChuanLi1101
Copy link
Copy Markdown
Author

Parity check rerun on MI355X — all six configs PASS

After landing the head-split workaround for the AITER qh64_qseqlen4_gqaratio16_lse_ps
kernel lse=+inf issue (commit f51f25f on this stack's bottom branch), re-ran
bench_remote/_parity_test_aiter_decode.py against the top of the stack
(b790bdb):

Config h_q cosine min / mean / max max abs diff result
TP2 swa_only 64 0.998151 / 0.999135 / 0.999492 2.15e-02 PASS
TP2 swa+extra 64 0.998568 / 0.999278 / 0.999538 1.46e-02 PASS
TP4 swa_only 32 0.998532 / 0.999131 / 0.999474 1.66e-02 PASS
TP4 swa+extra 32 0.998528 / 0.999280 / 0.999494 1.56e-02 PASS
TP8 swa_only 16 0.998682 / 0.999152 / 0.999455 1.46e-02 PASS
TP8 swa+extra 16 0.998927 / 0.999292 / 0.999505 1.01e-02 PASS

Threshold: cosine_mean > 0.999 AND cosine_min > 0.99. All six PASS.
OVERALL: PASS (all configs cosine > 0.999 mean and > 0.99 min).

Runtime kernel trace confirms only the working kernel is loaded:

[aiter] hipModuleLoad: ... mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps.co

The buggy qh64_qseqlen4_gqaratio16_lse_ps kernel is no longer invoked.

Test setup: MI355X gfx950, GPU 0 inside chuali_glm51,
b=8 s_q=1 d_qk=576 d_v=512 n_blk=4096 blk_sz=1, topk=2048 per scope,
valid lengths sampled in [topk//2, topk). Reference is
_ref_sparse_attn_decode ported verbatim from DeepseekV4MLAAttention
(bf16). Full log + summary in bench_remote/logs/parity_run_v2.log and
bench_remote/logs/PARITY_SUMMARY.md.

@ChuanLi1101 ChuanLi1101 force-pushed the chuali/aiter-mla-dsv4-decode-rebased branch from f51f25f to 779a9f5 Compare April 28, 2026 06:23
@ChuanLi1101
Copy link
Copy Markdown
Author

Rebased onto latest tj/dsv4prrebase (HEAD 7b2a8671)

tj/dsv4prrebase was force-updated since the original cherry-pick (vllm-project#40860 base

  • subsequent fixes including _old_use_hadamard, silu clamp shared expert,
    topk_softplus, and an upstream main merge). Rebased the stack on top of the
    new HEAD — 0 manual conflicts (the new base does not touch
    vllm/model_executor/layers/deepseek_v4_attention.py or our new
    rocm_aiter_dsv4_decode.py).

New commit hashes:

PR New head Old head
#901 779a9f5c f51f25f
#902 2833da39 ea14ac0
#903 eb39ee63 b790bdb

Re-ran _parity_test_aiter_decode.py on top of the rebased stack (HEAD
eb39ee63 on base 7b2a8671) — all six configs PASS, numbers match the
previous run within float-roundoff (purely a sanity-check confirming the
rebase did not break anything):

Config h_q cosine min / mean / max result
TP2 swa_only 64 0.998151 / 0.999135 / 0.999492 PASS
TP2 swa+extra 64 0.998568 / 0.999278 / 0.999538 PASS
TP4 swa_only 32 0.998532 / 0.999131 / 0.999474 PASS
TP4 swa+extra 32 0.998528 / 0.999280 / 0.999494 PASS
TP8 swa_only 16 0.998682 / 0.999152 / 0.999455 PASS
TP8 swa+extra 16 0.998927 / 0.999292 / 0.999505 PASS

Threshold: cosine_mean > 0.999 AND cosine_min > 0.99.
Only mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps is loaded at runtime; the
buggy qh64_qseqlen4 kernel path is bypassed by the head-split workaround.

End-to-end gsm8k on top of the rebased stack is queued — will follow up.

Replace the PyTorch reference sparse MLA decode with AITER's
persistent-mode ASM kernel (aiter.mla.mla_decode_fwd) on gfx950.
This gives ~2-3x decode speedup at high batch sizes.

Key changes:
- New module: vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py
  - AiterSparseScratch: lazy-init persistent-mode metadata buffers,
    keyed by (batch, nhead, topk, dtype) so 61 layers share one
    allocation per decode step
  - aiter_sparse_attn_decode: drop-in replacement handling dual-scope
    attention (SWA + extra), LSE-based merging, and attn_sink correction
  - Uses FP8/FP8 path only (gfx950 persistent-mode + return_lse
    requires FP8)
  - Fixed-stride kv_indices layout with -1 sentinels (required by
    AITER persistent-mode kernels)

- deepseek_v4_attention.py:
  - Add _aiter_scratch / _aiter_extra_scratch fields to __init__
  - Gate ROCm decode path: VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE=1
    routes to _forward_decode_aiter, otherwise falls back to the
    existing PyTorch reference

- Fix missing RoutingMethodType import in fused_moe/oracle/mxfp4.py

Validated numerically (cosine > 0.999) across TP2/TP4/TP8 configs
on MI355X. Micro-benchmarked at 2.4x speedup (b=128, dual-scope).

Co-authored-by: Claude
Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
Address review feedback (tjtanaa, vllm-project#40889): on ROCm,
DeepSeek sparse attention can only run through AITER, so gating the
op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value.

- Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from
  vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py.
- Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode
  to dispatch unconditionally to _forward_decode_aiter.
- Drop the now-unused os import and the env-var mention in the
  _forward_decode_aiter docstring.

Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
@ChuanLi1101 ChuanLi1101 force-pushed the chuali/aiter-mla-dsv4-decode-rebased branch from 779a9f5 to 77515cd Compare April 28, 2026 08:49
@ChuanLi1101 ChuanLi1101 changed the base branch from tj/dsv4prrebase to hexwang/dsv4_adapt_upstream April 28, 2026 08:50
The AITER persistent ASM kernel mla_a8w8_qh64_qseqlen4_gqaratio16_lse_ps
(selected when h_q==64, e.g. TP=2 on a 128-head model or TP=4 on a
256-head model) returns lse=+inf for some (batch, head) rows when called
with return_lse=True on gfx950, which propagates NaN through the
dual-scope LSE merge.

The qh16_qseqlen1 kernel selected for h_q<=32 is numerically clean
(cosine > 0.999 vs PyTorch ref on TP4/TP8). Recurse at the outer level
in aiter_sparse_attn_decode — split q + attn_sink in half, run a
complete sparse_attn_decode per half, concatenate the final (b, h_q, d_v)
outputs.

Doing the split at the outer level (instead of inside
_aiter_decode_one_scope) keeps each recursive call self-contained:
LSE merging and sink correction never cross the h_q boundary, so the
result does not depend on the kernel's lse shape convention (which has
differed across aiter versions). Earlier inner-level split surfaced as
a shape mismatch on aiter v0.1.10.post3 in the upstream nightly docker.

2x kernel launches per decode call in the split case (TP=2 with 128
heads or TP=4 with 256 heads), but those configs are rarely used in
production; correctness first.

TODO: drop once AITER fixes qh64_qseqlen4_gqaratio16_lse_ps.

Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
@ChuanLi1101 ChuanLi1101 force-pushed the chuali/aiter-mla-dsv4-decode-rebased branch from 77515cd to e960500 Compare April 28, 2026 09:06
@ChuanLi1101
Copy link
Copy Markdown
Author

@tjtanaavllm thanks for trying it on tj/dsv4prrebase — that error is a real bug in our head-split workaround, fixed with the latest force-push.

Root cause

Our previous head-split lived inside _aiter_decode_one_scope: when h_q > 32 it recursed with q split in half, ran the kernel twice, and concatenated the two lse tensors back at dim=-1. That assumes lse from aiter.mla.mla_decode_fwd is shape (total_q, h_q) — which is what we get on our gfx950 / aiter build, so the parity sweep on our box passed. On your aiter v0.1.10.post3 (upstream nightly docker) the kernel returns lse with a different shape, so the post-recursion (sink - lse).exp() failed to broadcast. Hence the 64 vs 32 mismatch you saw.

Fix (commit e9605007ea on chuali/aiter-mla-dsv4-decode-rebased)

Moved the head-split up to the outer aiter_sparse_attn_decode level. When h_q > 32 we now split q and attn_sink in half, call aiter_sparse_attn_decode recursively (a complete sparse attention per half — its own LSE merge, its own sink correction), and concatenate the final (b, h_q, d_v) outputs at dim=-2.

Each recursive call is fully self-contained: LSE merging and sink correction never cross the h_q boundary, so the result no longer depends on the kernel's lse shape convention.

Verified

Re-ran bench_remote/_parity_test_aiter_decode.py on top of the rebased stack (HEAD 29ef9be512 on hexwang/dsv4_adapt_upstream 66ed64f5):

Config h_q cosine min / mean / max result
TP2 swa_only 64 0.998151 / 0.999135 / 0.999492 PASS
TP2 swa+extra 64 0.998568 / 0.999278 / 0.999538 PASS
TP4 swa_only 32 0.998532 / 0.999131 / 0.999474 PASS
TP4 swa+extra 32 0.998528 / 0.999280 / 0.999494 PASS
TP8 swa_only 16 0.998682 / 0.999152 / 0.999455 PASS
TP8 swa+extra 16 0.998927 / 0.999292 / 0.999505 PASS

Only mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps is loaded at runtime; the buggy qh64_qseqlen4 path is bypassed.

Note on PR base

We also re-pointed all three PRs (#901 / #902 / #903) at hexwang/dsv4_adapt_upstream per the team's decision that this is the accuracy baseline going forward. Re-rebase had 0 manual conflicts since neither hexwang/dsv4_adapt_upstream nor tj/dsv4prrebase touches vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py or our _forward_decode_aiter additions.

Could you re-pull chuali/aiter-mla-dsv4-decode-rebased (or whichever you're cherry-picking) and try the V4-Flash + TP=4 + --tokenizer-mode "deepseek_v4" recipe again? If it still trips, please grab the q.shape and lse.shape at the failure point — that pins down whether your aiter build returns lse in a fundamentally different layout, and we can guard for that explicitly.

select_mxfp4_moe_backend in mxfp4.py references
RoutingMethodType.DeepseekV4 but the symbol is not imported on the
hexwang/dsv4_adapt_upstream base, raising NameError: name
'RoutingMethodType' is not defined during model init. Hexiang's recipe
happens to skip this code path via --moe-backend triton_unfused CLI
flag, but the LLM offline API takes the same path and trips on it.

Minimal one-line fix: pull RoutingMethodType into the same multi-import
block that already imports FusedMoEQuantConfig / FusedMoEQuantDesc /
mxfp4_*_moe_quant_config from fused_moe.config.

This fix was originally in vllm-project#40889 but got auto-merged-out during the
cherry-pick onto tj/dsv4prrebase (which already had it); reintroducing
when rebasing onto hexwang/dsv4_adapt_upstream which does not.

Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant