[CK_TILE] fix(fmha): clamp paged KV lookups in batch prefill#3733
[CK_TILE] fix(fmha): clamp paged KV lookups in batch prefill#3733zhenhantech wants to merge 1 commit intoROCm:developfrom
Conversation
## Motivation FMHA batch prefill can prefetch V pages beyond the valid KV sequence length. For SGLang-style page tables, `kv_indptr` defines the logical page-table range, so reading `page_idx[128]` is invalid when only `page_idx[0..127]` exists. In the failing repro, V prefetch reaches this one-past-end lookup. It may read adjacent memory silently or fault at an allocation boundary, depending on the runtime allocation layout. ## Technical Details Expose the mask's KV sequence length through `GetXTotal()` and pass it to `load_physical_pages()`. Clamp each page-table lookup token to the last valid KV token before reading `page_idx`. ## Test Plan - Rebuild the JIT kernel and rerun the ROCm MHA batch prefill repro. - Validate the debug repro that confirmed the one-past-end `kv_page_indices` read. ## Test Result The original and debug repros pass after the fix. A pre-fix diagnostic trap on `page_id == 128` confirmed the illegal V prefetch lookup.
|
Thanks for catching this and putting up a fix — the OOB read in We ended up landing the V2 fix via a slightly different shape to stay consistent with the V3 batch prefill pipeline currently under development (ROCm/rocm-libraries#6054), which already threads Functionally equivalent to your fix; the parameter just comes from |
Motivation
FMHA batch prefill can prefetch V pages beyond the valid KV sequence length. For SGLang-style page tables,
kv_indptrdefines the logical page-table range, so readingpage_idx[128]is invalid when onlypage_idx[0..127]exists.In the failing repro, V prefetch reaches this one-past-end lookup. It may read adjacent memory silently or fault at an allocation boundary, depending on the runtime allocation layout.
Technical Details
Expose the mask's KV sequence length through
GetXTotal()and pass it toload_physical_pages(). Clamp each page-table lookup token to the last valid KV token before readingpage_idx.Test Plan
kv_page_indicesread.Test Result
The original and debug repros pass after the fix. A pre-fix diagnostic trap on
page_id == 128confirmed the illegal V prefetch lookup.Proposed changes
Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered