[ROCm][DSv4] Make AITER sparse decode cudagraph-clean (rebased, stacked on #901)#902
Conversation
Parity check on MI355X (cherry-picked branch)Re-ran cosine-vs-PyTorch-reference parity on the cherry-picked branch
(Threshold: cosine_mean > 0.999 AND cosine_min > 0.99) TP2 NaN — not from the cherry-picked codeA drill-down (
Cherry-pick verification
|
bb248d3 to
ea14ac0
Compare
Parity check rerun on MI355X — all six configs PASSAfter landing the head-split workaround for the AITER
Threshold: Runtime kernel trace confirms only the working kernel is loaded: The buggy Test setup: MI355X gfx950, GPU 0 inside |
|
@ChuanLi1101 I cherry picked this PR into I am using aiter v0.1.10.post3 from upstream docker My server command is Command is: |
Signed-off-by: ganyi <ygan@amd.com> Made-with: Cursor
ea14ac0 to
2833da3
Compare
Rebased onto latest
|
| PR | New head | Old head |
|---|---|---|
| #901 | 779a9f5c |
f51f25f |
| #902 | 2833da39 |
ea14ac0 |
| #903 | eb39ee63 |
b790bdb |
Re-ran _parity_test_aiter_decode.py on top of the rebased stack (HEAD
eb39ee63 on base 7b2a8671) — all six configs PASS, numbers match the
previous run within float-roundoff (purely a sanity-check confirming the
rebase did not break anything):
| Config | h_q | cosine min / mean / max | result |
|---|---|---|---|
| TP2 swa_only | 64 | 0.998151 / 0.999135 / 0.999492 | PASS |
| TP2 swa+extra | 64 | 0.998568 / 0.999278 / 0.999538 | PASS |
| TP4 swa_only | 32 | 0.998532 / 0.999131 / 0.999474 | PASS |
| TP4 swa+extra | 32 | 0.998528 / 0.999280 / 0.999494 | PASS |
| TP8 swa_only | 16 | 0.998682 / 0.999152 / 0.999455 | PASS |
| TP8 swa+extra | 16 | 0.998927 / 0.999292 / 0.999505 | PASS |
Threshold: cosine_mean > 0.999 AND cosine_min > 0.99.
Only mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps is loaded at runtime; the
buggy qh64_qseqlen4 kernel path is bypassed by the head-split workaround.
End-to-end gsm8k on top of the rebased stack is queued — will follow up.
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Replace the PyTorch reference sparse MLA decode with AITER's
persistent-mode ASM kernel (aiter.mla.mla_decode_fwd) on gfx950.
This gives ~2-3x decode speedup at high batch sizes.
Key changes:
- New module: vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py
- AiterSparseScratch: lazy-init persistent-mode metadata buffers,
keyed by (batch, nhead, topk, dtype) so 61 layers share one
allocation per decode step
- aiter_sparse_attn_decode: drop-in replacement handling dual-scope
attention (SWA + extra), LSE-based merging, and attn_sink correction
- Uses FP8/FP8 path only (gfx950 persistent-mode + return_lse
requires FP8)
- Fixed-stride kv_indices layout with -1 sentinels (required by
AITER persistent-mode kernels)
- deepseek_v4_attention.py:
- Add _aiter_scratch / _aiter_extra_scratch fields to __init__
- Gate ROCm decode path: VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE=1
routes to _forward_decode_aiter, otherwise falls back to the
existing PyTorch reference
- Fix missing RoutingMethodType import in fused_moe/oracle/mxfp4.py
Validated numerically (cosine > 0.999) across TP2/TP4/TP8 configs
on MI355X. Micro-benchmarked at 2.4x speedup (b=128, dual-scope).
Co-authored-by: Claude
Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
Address review feedback (tjtanaa, vllm-project#40889): on ROCm, DeepSeek sparse attention can only run through AITER, so gating the op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value. - Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py. - Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode to dispatch unconditionally to _forward_decode_aiter. - Drop the now-unused os import and the env-var mention in the _forward_decode_aiter docstring. Signed-off-by: Chuan Li <chuan.li@amd.com> Made-with: Cursor
2833da3 to
143c695
Compare
The AITER persistent ASM kernel mla_a8w8_qh64_qseqlen4_gqaratio16_lse_ps (selected when h_q==64, e.g. TP=2 on a 128-head model or TP=4 on a 256-head model) returns lse=+inf for some (batch, head) rows when called with return_lse=True on gfx950, which propagates NaN through the dual-scope LSE merge. The qh16_qseqlen1 kernel selected for h_q<=32 is numerically clean (cosine > 0.999 vs PyTorch ref on TP4/TP8). Recurse at the outer level in aiter_sparse_attn_decode — split q + attn_sink in half, run a complete sparse_attn_decode per half, concatenate the final (b, h_q, d_v) outputs. Doing the split at the outer level (instead of inside _aiter_decode_one_scope) keeps each recursive call self-contained: LSE merging and sink correction never cross the h_q boundary, so the result does not depend on the kernel's lse shape convention (which has differed across aiter versions). Earlier inner-level split surfaced as a shape mismatch on aiter v0.1.10.post3 in the upstream nightly docker. 2x kernel launches per decode call in the split case (TP=2 with 128 heads or TP=4 with 256 heads), but those configs are rarely used in production; correctness first. TODO: drop once AITER fixes qh64_qseqlen4_gqaratio16_lse_ps. Signed-off-by: Chuan Li <chuanli@amd.com> Made-with: Cursor
143c695 to
450b838
Compare
|
@tjtanaavllm thanks for trying it on Root causeOur previous head-split lived inside Fix (commit
|
| Config | h_q | cosine min / mean / max | result |
|---|---|---|---|
| TP2 swa_only | 64 | 0.998151 / 0.999135 / 0.999492 | PASS |
| TP2 swa+extra | 64 | 0.998568 / 0.999278 / 0.999538 | PASS |
| TP4 swa_only | 32 | 0.998532 / 0.999131 / 0.999474 | PASS |
| TP4 swa+extra | 32 | 0.998528 / 0.999280 / 0.999494 | PASS |
| TP8 swa_only | 16 | 0.998682 / 0.999152 / 0.999455 | PASS |
| TP8 swa+extra | 16 | 0.998927 / 0.999292 / 0.999505 | PASS |
Only mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps is loaded at runtime; the buggy qh64_qseqlen4 path is bypassed.
Note on PR base
We also re-pointed all three PRs (#901 / #902 / #903) at hexwang/dsv4_adapt_upstream per the team's decision that this is the accuracy baseline going forward. Re-rebase had 0 manual conflicts since neither hexwang/dsv4_adapt_upstream nor tj/dsv4prrebase touches vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py or our _forward_decode_aiter additions.
Could you re-pull chuali/aiter-mla-dsv4-decode-rebased (or whichever you're cherry-picking) and try the V4-Flash + TP=4 + --tokenizer-mode "deepseek_v4" recipe again? If it still trips, please grab the q.shape and lse.shape at the failure point — that pins down whether your aiter build returns lse in a fundamentally different layout, and we can guard for that explicitly.
select_mxfp4_moe_backend in mxfp4.py references RoutingMethodType.DeepseekV4 but the symbol is not imported on the hexwang/dsv4_adapt_upstream base, raising NameError: name 'RoutingMethodType' is not defined during model init. Hexiang's recipe happens to skip this code path via --moe-backend triton_unfused CLI flag, but the LLM offline API takes the same path and trips on it. Minimal one-line fix: pull RoutingMethodType into the same multi-import block that already imports FusedMoEQuantConfig / FusedMoEQuantDesc / mxfp4_*_moe_quant_config from fused_moe.config. This fix was originally in vllm-project#40889 but got auto-merged-out during the cherry-pick onto tj/dsv4prrebase (which already had it); reintroducing when rebasing onto hexwang/dsv4_adapt_upstream which does not. Signed-off-by: Chuan Li <chuanli@amd.com> Made-with: Cursor
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.
Changes
-------
* `AiterSparseScratch` now caches:
* Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
`col_arange`, `q_scale`, `kv_scale`.
* Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
`valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
`aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
`torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
`torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
`DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.
Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.
Test plan
---------
* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
cudagraph-mode (non-eager) startup log + decode parity vs eager.
AI assistance: drafted with Cursor agent; human-reviewed before
submission.
Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
`aiter.get_mla_metadata_v1` produces a `work_*`/`reduce_*` plan that is
keyed on the *actual* per-batch kv lengths, not just on shapes. The
persistent ASM `mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps` kernel reads
out of bounds (causing a GPU memory access fault) if those buffers are
left stale across steps with different kv lengths.
Fix the cudagraph-clean refactor so the metadata is rewritten in-place
on every per-step call against the current `kv_indptr`. The buffer
sizes returned by `get_mla_metadata_info_v1` are determined by shapes
+ `max_split_per_batch` only, so they remain large enough for any kv
length distribution and the data pointers stay stable for graph capture.
* `AiterSparseScratch.rebuild()` now only allocates buffers and stores
the static gqa/topk/dtype parameters; it no longer requires a
`kv_indptr_seed` and no longer runs the metadata builder itself.
* New `AiterSparseScratch.refresh_metadata()` reruns
`get_mla_metadata_v1` writing into the same `work_*`/`reduce_*` slots.
* `_aiter_decode_one_scope` writes `valid_mask`/`valid_lens`/
`kv_indptr`/`kv_indices_2d`/`q_fp8` directly into scratch every
step, then calls `refresh_metadata()` and `mla.mla_decode_fwd`.
Validated with the standalone `bench_remote/_unit_test_cudagraph.py`
harness on MI355X:
- Call 1 (lens=[3,2]): success, scratch key set.
- Call 2 (same lens): rebuild skipped, all data_ptrs stable, output
bit-identical to call 1.
- Call 3 (lens=[4,1]): all data_ptrs still stable, output differs as
expected (max abs diff = 2.39 vs identical-input call), no fault.
- Parity check vs the original non-cudagraph implementation:
max abs diff = 0.000000.
Signed-off-by: Chuan Li <chuanli1101@gmail.com>
Co-authored-by: Cursor
Signed-off-by: Li <chuali@amd.com>
450b838 to
f4778f8
Compare
9a8d252 to
14a3f64
Compare
Summary
Rebased version of vllm-project#40892 for review against the
tj/dsv4prrebaseintegration branch. Cudagraph-clean follow-up to vllm-project#40889 (which is rebased in the parallel chuali/aiter-mla-dsv4-decode-rebased PR against this same branch).Per Tun Jian's note that vLLM upstream is now planning to merge DSv4 via the rebased PR vllm-project#40860, the AITER decode stack needs to land on top of that newer base. This PR cherry-picks the 2 commits from vllm-project#40892 on top of the rebased vllm-project#40889 commits, all on
ROCm/vllm:tj/dsv4prrebase.This does not duplicate vllm-project#40892 — that PR will stay open as a draft against
vllm-project/vllm:mainuntil vllm-project#40860 lands; this PR exists so reviewers can look at the cudagraph-clean refactor on top of the new base today.Stacking
This PR includes 4 commits in its diff: 2 from rebased vllm-project#40889 + 2 from rebased vllm-project#40892. Once the prior PR lands, this one will reduce to just the 2 cudagraph-clean commits. Happy to rebase on request.
What changed during rebase
rocm_aiter_dsv4_decode.py(a new file from [ROCm] Add AITER-accelerated MLA decode for DeepSeek V4 on MI355X vllm-project/vllm#40889), so they apply cleanly on the new base.Original PR description
The persistent-mode AITER sparse MLA decode kernel (
mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps) is the only LSE-returning sparse attention path on MI355X / gfx950, but the original integration in vllm-project#40889 allocates fresh per-step indexing tensors (qo_indptr,kv_indptr,kv_indices,kv_last_page_lens,q_scale,kv_scale) and a fresh FP8 query / output buffer on every layer call. That is incompatible with HIP/CUDA-graph capture, so DSv4 currently has to run in eager mode on ROCm.This PR makes the per-step decode path completely allocation-free and pointer-stable so the decode loop can be wrapped in a HIP graph once the model wires it up.
Changes
AiterSparseScratchnow owns every per-step buffer in addition to the AITER work-plan / reduce buffers:qo_indptr,kv_indptr,kv_indices_2d,kv_last_page_lensvalid_mask,valid_lens,col_arangeq_fp8(FP8 query buffer),out_buf(BF16 output)q_scale,kv_scale(constant 1.0 tensors)_aiter_decode_one_scoperewrites all of those in-place every step (torch.lt(out=),copy_,masked_fill_,torch.cumsum(out=)).AiterSparseScratch.refresh_metadata()re-runsaiter.get_mla_metadata_v1against the currentkv_indptrevery step and writes the new work plan into the samework_*/reduce_*slots. The persistent ASM kernel encodes per-batch lengths into that plan, so leaving it stale across steps with different kv lengths causes a GPU memory-access fault.rebuild()now only allocates buffers and stores static gqa/topk/dtype parameters; it no longer re-runs the metadata builder itself.Net effect: across decode steps with the same
(total_q, nhead, topk, d_qk, d_v, dtype, kvtype)key, everydata_ptr()is stable, so a graph captured on step N can be replayed for any step N+k.Test plan
bench_remote/_unit_test_cudagraph.py): all 9 trackeddata_ptr()s stable across 3 decode steps with differentlens; output bit-identical whenlensrepeats; max abs diff vs original [ROCm] Add AITER-accelerated MLA decode for DeepSeek V4 on MI355X vllm-project/vllm#40889 = 0.0tj/dsv4prrebase— running now alongside the rebased [ROCm] Add AITER-accelerated MLA decode for DeepSeek V4 on MI355X vllm-project/vllm#40889AI assistance disclosure
AI assistance (Cursor agent) was used for this rebase. The submitter has reviewed every cherry-picked diff line. The microbench above was run on MI355X against the original base; will re-run on
tj/dsv4prrebaseonce the parity sweep completes.Related