Fix(spmd_paged_attention): per-slot sync events for K/V prefetch#708
Open
chenshengxin2026 wants to merge 1 commit intohw-native-sys:mainfrom
Open
Fix(spmd_paged_attention): per-slot sync events for K/V prefetch#708chenshengxin2026 wants to merge 1 commit intohw-native-sys:mainfrom
chenshengxin2026 wants to merge 1 commit intohw-native-sys:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements ping-pong buffering and per-slot synchronization for the QK and PV steps in the paged attention kernel. Key changes include the introduction of event base constants, the use of tile arrays for double buffering, and the addition of explicit flag initialization and cleanup to manage pipeline dependencies. I have no feedback to provide as the existing review comments were purely explanatory or validating.
d2fba00 to
8b61747
Compare
- Replace shared EVENT_ID0/EVENT_ID1 with per-slot events for QK/PV
L1 (MTE1<->MTE2) and L0 (MTE1<->M) so each ping-pong slot has its
own RAW/WAR sync.
- Split QK/PV L0 left/right-tile addresses into two-entry arrays
with disjoint offsets so the slot index selects an independent
L0 region per iteration.
- Add a dedicated PV_PIJ_EVENT for the TPOP(pij) -> TMOV(aTile_PV)
path, decoupling pij synchronization from the V-load ping-pong.
- Move the next-block K TLOAD in the QK step to after sij record()
so the prefetch stays outside the C2V notification critical path.
- Set and drain all eight per-slot events at function entry/exit to
keep AIC pipeline state consistent across calls.
Verification:
task-submit --device auto --run "python -m pytest \
tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention \
--platform a2a3 --rounds 20"
8b61747 to
2f57f06
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes the intermittent precision failure in
tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention.MTE1<->MTE2) andL0 (
MTE1<->M) so each ping-pong slot has its own RAW/WAR sync.TPUSH(sij)viaPIPE_FIX -> PIPE_Sbeforerecord()soAIV
TPOP(sij)only observes a fully written GM FIFO.TLOADin the QK step to afterrecord()toavoid reintroducing a coarse
PIPE_ALLbarrier.Verification
task-submit --device auto --run "python -m pytest tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention --platform a2a3 --rounds 20"Performance
After fix:
Before fix:
Related: #704