[ROCm][DSv4] AITER-accelerated MLA decode for DeepSeek V4 on MI355X (rebased on tj/dsv4prrebase)#901
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
Parity check on MI355X (cherry-picked branch)Re-ran cosine-vs-PyTorch-reference parity on the cherry-picked branch
(Threshold: cosine_mean > 0.999 AND cosine_min > 0.99) TP2 NaN — not from the cherry-picked codeA drill-down (
Cherry-pick verification
|
Parity check rerun on MI355X — all six configs PASSAfter landing the head-split workaround for the AITER
Threshold: Runtime kernel trace confirms only the working kernel is loaded: The buggy Test setup: MI355X gfx950, GPU 0 inside |
f51f25f to
779a9f5
Compare
Rebased onto latest
|
| PR | New head | Old head |
|---|---|---|
| #901 | 779a9f5c |
f51f25f |
| #902 | 2833da39 |
ea14ac0 |
| #903 | eb39ee63 |
b790bdb |
Re-ran _parity_test_aiter_decode.py on top of the rebased stack (HEAD
eb39ee63 on base 7b2a8671) — all six configs PASS, numbers match the
previous run within float-roundoff (purely a sanity-check confirming the
rebase did not break anything):
| Config | h_q | cosine min / mean / max | result |
|---|---|---|---|
| TP2 swa_only | 64 | 0.998151 / 0.999135 / 0.999492 | PASS |
| TP2 swa+extra | 64 | 0.998568 / 0.999278 / 0.999538 | PASS |
| TP4 swa_only | 32 | 0.998532 / 0.999131 / 0.999474 | PASS |
| TP4 swa+extra | 32 | 0.998528 / 0.999280 / 0.999494 | PASS |
| TP8 swa_only | 16 | 0.998682 / 0.999152 / 0.999455 | PASS |
| TP8 swa+extra | 16 | 0.998927 / 0.999292 / 0.999505 | PASS |
Threshold: cosine_mean > 0.999 AND cosine_min > 0.99.
Only mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps is loaded at runtime; the
buggy qh64_qseqlen4 kernel path is bypassed by the head-split workaround.
End-to-end gsm8k on top of the rebased stack is queued — will follow up.
Replace the PyTorch reference sparse MLA decode with AITER's
persistent-mode ASM kernel (aiter.mla.mla_decode_fwd) on gfx950.
This gives ~2-3x decode speedup at high batch sizes.
Key changes:
- New module: vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py
- AiterSparseScratch: lazy-init persistent-mode metadata buffers,
keyed by (batch, nhead, topk, dtype) so 61 layers share one
allocation per decode step
- aiter_sparse_attn_decode: drop-in replacement handling dual-scope
attention (SWA + extra), LSE-based merging, and attn_sink correction
- Uses FP8/FP8 path only (gfx950 persistent-mode + return_lse
requires FP8)
- Fixed-stride kv_indices layout with -1 sentinels (required by
AITER persistent-mode kernels)
- deepseek_v4_attention.py:
- Add _aiter_scratch / _aiter_extra_scratch fields to __init__
- Gate ROCm decode path: VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE=1
routes to _forward_decode_aiter, otherwise falls back to the
existing PyTorch reference
- Fix missing RoutingMethodType import in fused_moe/oracle/mxfp4.py
Validated numerically (cosine > 0.999) across TP2/TP4/TP8 configs
on MI355X. Micro-benchmarked at 2.4x speedup (b=128, dual-scope).
Co-authored-by: Claude
Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
Address review feedback (tjtanaa, vllm-project#40889): on ROCm, DeepSeek sparse attention can only run through AITER, so gating the op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value. - Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py. - Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode to dispatch unconditionally to _forward_decode_aiter. - Drop the now-unused os import and the env-var mention in the _forward_decode_aiter docstring. Signed-off-by: Chuan Li <chuan.li@amd.com> Made-with: Cursor
779a9f5 to
77515cd
Compare
The AITER persistent ASM kernel mla_a8w8_qh64_qseqlen4_gqaratio16_lse_ps (selected when h_q==64, e.g. TP=2 on a 128-head model or TP=4 on a 256-head model) returns lse=+inf for some (batch, head) rows when called with return_lse=True on gfx950, which propagates NaN through the dual-scope LSE merge. The qh16_qseqlen1 kernel selected for h_q<=32 is numerically clean (cosine > 0.999 vs PyTorch ref on TP4/TP8). Recurse at the outer level in aiter_sparse_attn_decode — split q + attn_sink in half, run a complete sparse_attn_decode per half, concatenate the final (b, h_q, d_v) outputs. Doing the split at the outer level (instead of inside _aiter_decode_one_scope) keeps each recursive call self-contained: LSE merging and sink correction never cross the h_q boundary, so the result does not depend on the kernel's lse shape convention (which has differed across aiter versions). Earlier inner-level split surfaced as a shape mismatch on aiter v0.1.10.post3 in the upstream nightly docker. 2x kernel launches per decode call in the split case (TP=2 with 128 heads or TP=4 with 256 heads), but those configs are rarely used in production; correctness first. TODO: drop once AITER fixes qh64_qseqlen4_gqaratio16_lse_ps. Signed-off-by: Chuan Li <chuanli@amd.com> Made-with: Cursor
77515cd to
e960500
Compare
|
@tjtanaavllm thanks for trying it on Root causeOur previous head-split lived inside Fix (commit
|
| Config | h_q | cosine min / mean / max | result |
|---|---|---|---|
| TP2 swa_only | 64 | 0.998151 / 0.999135 / 0.999492 | PASS |
| TP2 swa+extra | 64 | 0.998568 / 0.999278 / 0.999538 | PASS |
| TP4 swa_only | 32 | 0.998532 / 0.999131 / 0.999474 | PASS |
| TP4 swa+extra | 32 | 0.998528 / 0.999280 / 0.999494 | PASS |
| TP8 swa_only | 16 | 0.998682 / 0.999152 / 0.999455 | PASS |
| TP8 swa+extra | 16 | 0.998927 / 0.999292 / 0.999505 | PASS |
Only mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps is loaded at runtime; the buggy qh64_qseqlen4 path is bypassed.
Note on PR base
We also re-pointed all three PRs (#901 / #902 / #903) at hexwang/dsv4_adapt_upstream per the team's decision that this is the accuracy baseline going forward. Re-rebase had 0 manual conflicts since neither hexwang/dsv4_adapt_upstream nor tj/dsv4prrebase touches vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py or our _forward_decode_aiter additions.
Could you re-pull chuali/aiter-mla-dsv4-decode-rebased (or whichever you're cherry-picking) and try the V4-Flash + TP=4 + --tokenizer-mode "deepseek_v4" recipe again? If it still trips, please grab the q.shape and lse.shape at the failure point — that pins down whether your aiter build returns lse in a fundamentally different layout, and we can guard for that explicitly.
select_mxfp4_moe_backend in mxfp4.py references RoutingMethodType.DeepseekV4 but the symbol is not imported on the hexwang/dsv4_adapt_upstream base, raising NameError: name 'RoutingMethodType' is not defined during model init. Hexiang's recipe happens to skip this code path via --moe-backend triton_unfused CLI flag, but the LLM offline API takes the same path and trips on it. Minimal one-line fix: pull RoutingMethodType into the same multi-import block that already imports FusedMoEQuantConfig / FusedMoEQuantDesc / mxfp4_*_moe_quant_config from fused_moe.config. This fix was originally in vllm-project#40889 but got auto-merged-out during the cherry-pick onto tj/dsv4prrebase (which already had it); reintroducing when rebasing onto hexwang/dsv4_adapt_upstream which does not. Signed-off-by: Chuan Li <chuanli@amd.com> Made-with: Cursor
Summary
Rebased version of vllm-project#40889 for review against the
tj/dsv4prrebaseintegration branch.Per Tun Jian's note that vLLM upstream is now planning to merge DSv4 via the rebased PR vllm-project#40860 instead of vllm-project#40760, the AITER decode stack needs to land on top of that newer base. This PR cherry-picks the 2 commits from vllm-project#40889 onto
ROCm/vllm:tj/dsv4prrebase.This does not duplicate vllm-project#40889 — that PR will stay open as a draft against
vllm-project/vllm:mainuntil vllm-project#40860 lands; this PR exists so reviewers can look at the AITER decode code on top of the new base today. Once vllm-project#40860 lands andtj/dsv4prrebasemerges, only one of the two will graduate to a real upstream PR; the other will be closed.What changed during rebase
mxfp4.pyalready has theRoutingMethodTypeimport ontj/dsv4prrebase, so the duplicate import that [ROCm] Add AITER-accelerated MLA decode for DeepSeek V4 on MI355X vllm-project/vllm#40889 added was removed during cherry-pick._forward_decode_aiter↔aiter_sparse_attn_decodesignatures,_dequantize_blocked_k_cache(out=),current_workspace_manager,AiterSparseScratch.Original PR description
This PR adds an AITER-accelerated sparse MLA decode path for DeepSeek V4 on AMD MI355X (gfx950).
The existing ROCm decode path uses a PyTorch reference implementation. This PR replaces it with AITER's persistent-mode ASM kernel (
aiter.mla.mla_decode_fwd), achieving ~2-3x decode speedup at high batch sizes while maintaining numerical correctness.Changes
New file:
vllm/v1/attention/ops/rocm_aiter_dsv4_decode.pyAiterSparseScratch: Lazy-initialized persistent-mode metadata buffers. Keyed by(batch_size, nhead, topk, dtype, kvtype)so all 61 DSv4 attention layers share one allocation per decode step, eliminating per-layer metadata rebuild overhead.aiter_sparse_attn_decode(): Drop-in replacement for_ref_sparse_attn_decode, handling:return_lse=True)kv_indiceslayout with-1sentinels (required by AITER persistent-mode kernels)Modified:
vllm/model_executor/layers/deepseek_v4_attention.py_aiter_scratch/_aiter_extra_scratchfields to__init___forward_decodenow unconditionally dispatches to the new_forward_decode_aiter()(DeepSeek sparse attention is AITER-only on ROCm; no env-var flag).Test plan
tj/dsv4prrebasetj/dsv4prrebase(Hexiang)Parity numbers from the original vllm-project#40889 base were cosine > 0.999 across all configs. Re-running on the new base.
AI assistance disclosure
AI assistance (Claude / Cursor agent) was used for this rebase. The submitter has reviewed every cherry-picked diff line and verified cross-file integration. Test results will be posted as a follow-up comment once parity finishes.
Related