Skip to content

[CK_TILE] fix(fmha): clamp paged KV lookups in batch prefill#3733

Open
zhenhantech wants to merge 1 commit intoROCm:developfrom
zhenhantech:develop
Open

[CK_TILE] fix(fmha): clamp paged KV lookups in batch prefill#3733
zhenhantech wants to merge 1 commit intoROCm:developfrom
zhenhantech:develop

Conversation

@zhenhantech
Copy link
Copy Markdown

@zhenhantech zhenhantech commented Apr 29, 2026

Motivation

FMHA batch prefill can prefetch V pages beyond the valid KV sequence length. For SGLang-style page tables, kv_indptr defines the logical page-table range, so reading page_idx[128] is invalid when only page_idx[0..127] exists.

In the failing repro, V prefetch reaches this one-past-end lookup. It may read adjacent memory silently or fault at an allocation boundary, depending on the runtime allocation layout.

Technical Details

Expose the mask's KV sequence length through GetXTotal() and pass it to load_physical_pages(). Clamp each page-table lookup token to the last valid KV token before reading page_idx.

Test Plan

  • Rebuild the JIT kernel and rerun the ROCm MHA batch prefill repro.
  • Validate the debug repro that confirmed the one-past-end kv_page_indices read.

Test Result

The original and debug repros pass after the fix. A pre-fix diagnostic trap on page_id == 128 confirmed the illegal V prefetch lookup.

Proposed changes

Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

## Motivation

FMHA batch prefill can prefetch V pages beyond the valid KV sequence length.
For SGLang-style page tables, `kv_indptr` defines the logical page-table range,
so reading `page_idx[128]` is invalid when only `page_idx[0..127]` exists.

In the failing repro, V prefetch reaches this one-past-end lookup. It may read
adjacent memory silently or fault at an allocation boundary, depending on the
runtime allocation layout.

## Technical Details

Expose the mask's KV sequence length through `GetXTotal()` and pass it to
`load_physical_pages()`. Clamp each page-table lookup token to the last valid
KV token before reading `page_idx`.

## Test Plan

- Rebuild the JIT kernel and rerun the ROCm MHA batch prefill repro.
- Validate the debug repro that confirmed the one-past-end `kv_page_indices`
  read.

## Test Result

The original and debug repros pass after the fix. A pre-fix diagnostic trap on
`page_id == 128` confirmed the illegal V prefetch lookup.
@Jeff-Huang
Copy link
Copy Markdown
Contributor

Thanks for catching this and putting up a fix — the OOB read in load_physical_pages lookahead is exactly the AICK-1171 issue we were debugging on MI-308X.

We ended up landing the V2 fix via a slightly different shape to stay consistent with the V3 batch prefill pipeline currently under development (ROCm/rocm-libraries#6054), which already threads max_page_table_idx through load_physical_pages (instead of querying mask.GetXTotal()). To keep the V2 and V3 APIs aligned, the V2 patch we opened — ROCm/rocm-libraries#6932 — uses the same max_page_table_idx parameter and clamps with ck_tile::min(page_id, max_page_table_idx) in all four branches.

Functionally equivalent to your fix; the parameter just comes from kargs.seqlen_k in the kernel layer rather than the mask. Really appreciate you isolating the bug — happy to credit your investigation in the commit if you'd like.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants