diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/mix/paged_attention_parallel.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/mix/paged_attention_parallel.cpp index 32e9e3302..3acb28a1f 100644 --- a/tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/mix/paged_attention_parallel.cpp +++ b/tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention/kernels/mix/paged_attention_parallel.cpp @@ -106,6 +106,11 @@ static constexpr uint16_t SIJ_FLAG_ID = 0; static constexpr uint16_t PIJ_FLAG_ID = 2; static constexpr uint16_t OI_FLAG_ID = 4; static constexpr uint8_t FIFO_DEPTH = 2; +static constexpr int QK_L1_EVENT_BASE = 0; +static constexpr int PV_L1_EVENT_BASE = 2; +static constexpr int QK_L0_EVENT_BASE = 0; +static constexpr int PV_L0_EVENT_BASE = 2; +static constexpr int PV_PIJ_EVENT = 4; // Per-q_tile compile-time configuration: pipe types, slot sizes, UB/L1 layouts. // QT must be 16 or 64. SUB_QT = QT / 2 (each of AIV0/AIV1 handles half the rows). @@ -159,51 +164,62 @@ template < typename LeftTile_QK, typename RightTile_QK, typename AccTile_QK> static __aicore__ void aic_qk_step( __gm__ bfloat16_t *key_base, uint64_t kv_block_id, uint64_t i, TileMatA_QK &aMatTile_QK, TileMatB_QK &bMatTile_QK_A, - TileMatB_QK &bMatTile_QK_B, LeftTile_QK &aTile_QK, RightTile_QK &bTile_QK, AccTile_QK &cTile_QK, SijPipeT &sij_pipe, - bool current_loaded = false, bool has_next = false, uint64_t next_kv_block_id = 0 + TileMatB_QK &bMatTile_QK_B, LeftTile_QK (&aTile_QK)[2], RightTile_QK (&bTile_QK)[2], AccTile_QK &cTile_QK, + SijPipeT &sij_pipe, bool current_loaded = false, bool has_next = false, uint64_t next_kv_block_id = 0 ) { + int cur = static_cast(i % 2); + event_t cur_l1_event = static_cast(QK_L1_EVENT_BASE + cur); + event_t cur_l0_event = static_cast(QK_L0_EVENT_BASE + cur); if (!current_loaded) { GlobalB_QK kjGlobal(key_base + kv_block_id * N * K); - if (i % 2 == 0) { + wait_flag(PIPE_MTE1, PIPE_MTE2, cur_l1_event); + if (cur == 0) { TLOAD(bMatTile_QK_A, kjGlobal); } else { TLOAD(bMatTile_QK_B, kjGlobal); } + set_flag(PIPE_MTE2, PIPE_MTE1, cur_l1_event); } - set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, cur_l0_event); + wait_flag(PIPE_MTE2, PIPE_MTE1, cur_l1_event); - TMOV(aTile_QK, aMatTile_QK); - if (i % 2 == 0) { - TMOV(bTile_QK, bMatTile_QK_A); + TMOV(aTile_QK[cur], aMatTile_QK); + if (cur == 0) { + TMOV(bTile_QK[cur], bMatTile_QK_A); } else { - TMOV(bTile_QK, bMatTile_QK_B); + TMOV(bTile_QK[cur], bMatTile_QK_B); } + set_flag(PIPE_MTE1, PIPE_MTE2, cur_l1_event); + set_flag(PIPE_MTE1, PIPE_M, cur_l0_event); - if (has_next) { - GlobalB_QK kjGlobalNext(key_base + next_kv_block_id * N * K); - if ((i + 1) % 2 == 0) { - TLOAD(bMatTile_QK_A, kjGlobalNext); - } else { - TLOAD(bMatTile_QK_B, kjGlobalNext); - } - } + wait_flag(PIPE_MTE1, PIPE_M, cur_l0_event); - set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); - wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); - - TMATMUL(cTile_QK, aTile_QK, bTile_QK); + TMATMUL(cTile_QK, aTile_QK[cur], bTile_QK[cur]); + set_flag(PIPE_M, PIPE_MTE1, cur_l0_event); set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); - // TPUSH sij (C2V): AccTile L0C -> GM. Ensure prior MTE3 is done, - // then push, then wait for MTE3 DMA to complete before signaling consumer. + // TPUSH sij (C2V): AccTile L0C -> GM. The C2V ready signal must + // be emitted only after the GM write is complete and visible. TPUSH(sij_pipe, cTile_QK); set_flag(PIPE_FIX, PIPE_S, EVENT_ID7); wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7); sij_pipe.prod.record(); + + if (has_next) { + int next = static_cast((i + 1) % 2); + event_t next_l1_event = static_cast(QK_L1_EVENT_BASE + next); + GlobalB_QK kjGlobalNext(key_base + next_kv_block_id * N * K); + wait_flag(PIPE_MTE1, PIPE_MTE2, next_l1_event); + if (next == 0) { + TLOAD(bMatTile_QK_A, kjGlobalNext); + } else { + TLOAD(bMatTile_QK_B, kjGlobalNext); + } + set_flag(PIPE_MTE2, PIPE_MTE1, next_l1_event); + } } // Helper: PV matmul for block i — TPOP pij, load value, move to L0, matmul, TPUSH oi @@ -212,45 +228,57 @@ template < typename TileMatB_PV, typename LeftTile_PV, typename RightTile_PV, typename AccTile_PV> static __aicore__ void aic_pv_step( __gm__ bfloat16_t *val_base, uint64_t kv_block_id, uint64_t i, PijMatTile &pijMatTile, TileMatB_PV &bMatTile_PV_A, - TileMatB_PV &bMatTile_PV_B, LeftTile_PV &aTile_PV, RightTile_PV &bTile_PV, AccTile_PV &cTile_PV, PijPipeT &pij_pipe, - OiPipeT &oi_pipe, bool current_loaded = false, bool has_next = false, uint64_t next_kv_block_id = 0 + TileMatB_PV &bMatTile_PV_B, LeftTile_PV (&aTile_PV)[2], RightTile_PV (&bTile_PV)[2], AccTile_PV &cTile_PV, + PijPipeT &pij_pipe, OiPipeT &oi_pipe, bool current_loaded = false, bool has_next = false, + uint64_t next_kv_block_id = 0 ) { + int cur = static_cast(i % 2); + event_t cur_l1_event = static_cast(PV_L1_EVENT_BASE + cur); + event_t cur_l0_event = static_cast(PV_L0_EVENT_BASE + cur); if (!current_loaded) { GlobalB_PV vjGlobal(val_base + kv_block_id * N * K); - if (i % 2 == 0) { + wait_flag(PIPE_MTE1, PIPE_MTE2, cur_l1_event); + if (cur == 0) { TLOAD(bMatTile_PV_A, vjGlobal); } else { TLOAD(bMatTile_PV_B, vjGlobal); } + set_flag(PIPE_MTE2, PIPE_MTE1, cur_l1_event); } TPOP(pij_pipe, pijMatTile); - // PV step uses EVENT_ID1 (QK step uses EVENT_ID0) to avoid flag aliasing - // when pipe_barrier(PIPE_ALL) is removed between steps. - set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); - wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(PV_PIJ_EVENT)); + wait_flag(PIPE_M, PIPE_MTE1, cur_l0_event); + wait_flag(PIPE_MTE2, PIPE_MTE1, cur_l1_event); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(PV_PIJ_EVENT)); - TMOV(aTile_PV, pijMatTile); - if (i % 2 == 0) { - TMOV(bTile_PV, bMatTile_PV_A); + TMOV(aTile_PV[cur], pijMatTile); + if (cur == 0) { + TMOV(bTile_PV[cur], bMatTile_PV_A); } else { - TMOV(bTile_PV, bMatTile_PV_B); + TMOV(bTile_PV[cur], bMatTile_PV_B); } + set_flag(PIPE_MTE1, PIPE_MTE2, cur_l1_event); + set_flag(PIPE_MTE1, PIPE_M, cur_l0_event); if (has_next) { + int next = static_cast((i + 1) % 2); + event_t next_l1_event = static_cast(PV_L1_EVENT_BASE + next); GlobalB_PV vjGlobalNext(val_base + next_kv_block_id * N * K); - if ((i + 1) % 2 == 0) { + wait_flag(PIPE_MTE1, PIPE_MTE2, next_l1_event); + if (next == 0) { TLOAD(bMatTile_PV_A, vjGlobalNext); } else { TLOAD(bMatTile_PV_B, vjGlobalNext); } + set_flag(PIPE_MTE2, PIPE_MTE1, next_l1_event); } - set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); - wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_M, cur_l0_event); - TMATMUL(cTile_PV, aTile_PV, bTile_PV); + TMATMUL(cTile_PV, aTile_PV[cur], bTile_PV[cur]); + set_flag(PIPE_M, PIPE_MTE1, cur_l0_event); set_flag(PIPE_M, PIPE_FIX, EVENT_ID1); wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1); @@ -297,11 +325,13 @@ static __aicore__ void aic_process_blocks( TASSIGN(bMatTile_QK_A, 0x20000); TASSIGN(bMatTile_QK_B, 0x20000 + kQKBBytes); - LeftTile_QK aTile_QK; - RightTile_QK bTile_QK; + LeftTile_QK aTile_QK[2]; + RightTile_QK bTile_QK[2]; AccTile_QK cTile_QK; - TASSIGN(aTile_QK, 0x0); - TASSIGN(bTile_QK, 0x0); + TASSIGN(aTile_QK[0], 0x0); + TASSIGN(aTile_QK[1], M * K * static_cast(sizeof(bfloat16_t))); + TASSIGN(bTile_QK[0], 0x0); + TASSIGN(bTile_QK[1], kQKBBytes); TASSIGN(cTile_QK, 0x0); PijMatTile pijMatTile; @@ -309,16 +339,27 @@ static __aicore__ void aic_process_blocks( TASSIGN(bMatTile_PV_A, Cfg::PIJ_L1_BASE + Cfg::PIJ_L1_SIZE); TASSIGN(bMatTile_PV_B, Cfg::PIJ_L1_BASE + Cfg::PIJ_L1_SIZE + kPVBBytes); - LeftTile_PV aTile_PV; - RightTile_PV bTile_PV; + LeftTile_PV aTile_PV[2]; + RightTile_PV bTile_PV[2]; AccTile_PV cTile_PV; - TASSIGN(aTile_PV, 0x0); - TASSIGN(bTile_PV, 0x0); + TASSIGN(aTile_PV[0], 0x0); + TASSIGN(aTile_PV[1], M * N * static_cast(sizeof(bfloat16_t))); + TASSIGN(bTile_PV[0], 0x0); + TASSIGN(bTile_PV[1], kPVBBytes); TASSIGN(cTile_PV, 0x0); GlobalA_QK qiGlobal(qi_base); TLOAD(aMatTile_QK, qiGlobal); + set_flag(PIPE_MTE1, PIPE_MTE2, static_cast(QK_L1_EVENT_BASE)); + set_flag(PIPE_MTE1, PIPE_MTE2, static_cast(QK_L1_EVENT_BASE + 1)); + set_flag(PIPE_MTE1, PIPE_MTE2, static_cast(PV_L1_EVENT_BASE)); + set_flag(PIPE_MTE1, PIPE_MTE2, static_cast(PV_L1_EVENT_BASE + 1)); + set_flag(PIPE_M, PIPE_MTE1, static_cast(QK_L0_EVENT_BASE)); + set_flag(PIPE_M, PIPE_MTE1, static_cast(QK_L0_EVENT_BASE + 1)); + set_flag(PIPE_M, PIPE_MTE1, static_cast(PV_L0_EVENT_BASE)); + set_flag(PIPE_M, PIPE_MTE1, static_cast(PV_L0_EVENT_BASE + 1)); + if (n_blocks == 1) { // Degenerate case: no pipeline overlap possible uint64_t block_id = static_cast(bt[bt_offset]); @@ -358,6 +399,15 @@ static __aicore__ void aic_process_blocks( cTile_PV, pij_pipe, oi_pipe, n_blocks > 1 ); } + + wait_flag(PIPE_MTE1, PIPE_MTE2, static_cast(QK_L1_EVENT_BASE)); + wait_flag(PIPE_MTE1, PIPE_MTE2, static_cast(QK_L1_EVENT_BASE + 1)); + wait_flag(PIPE_MTE1, PIPE_MTE2, static_cast(PV_L1_EVENT_BASE)); + wait_flag(PIPE_MTE1, PIPE_MTE2, static_cast(PV_L1_EVENT_BASE + 1)); + wait_flag(PIPE_M, PIPE_MTE1, static_cast(QK_L0_EVENT_BASE)); + wait_flag(PIPE_M, PIPE_MTE1, static_cast(QK_L0_EVENT_BASE + 1)); + wait_flag(PIPE_M, PIPE_MTE1, static_cast(PV_L0_EVENT_BASE)); + wait_flag(PIPE_M, PIPE_MTE1, static_cast(PV_L0_EVENT_BASE + 1)); } // ============================================================================