Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ static constexpr uint16_t SIJ_FLAG_ID = 0;
static constexpr uint16_t PIJ_FLAG_ID = 2;
static constexpr uint16_t OI_FLAG_ID = 4;
static constexpr uint8_t FIFO_DEPTH = 2;
static constexpr int QK_L1_EVENT_BASE = 0;
static constexpr int PV_L1_EVENT_BASE = 2;
static constexpr int QK_L0_EVENT_BASE = 0;
static constexpr int PV_L0_EVENT_BASE = 2;
static constexpr int PV_PIJ_EVENT = 4;

// Per-q_tile compile-time configuration: pipe types, slot sizes, UB/L1 layouts.
// QT must be 16 or 64. SUB_QT = QT / 2 (each of AIV0/AIV1 handles half the rows).
Expand Down Expand Up @@ -159,51 +164,62 @@ template <
typename LeftTile_QK, typename RightTile_QK, typename AccTile_QK>
static __aicore__ void aic_qk_step(
__gm__ bfloat16_t *key_base, uint64_t kv_block_id, uint64_t i, TileMatA_QK &aMatTile_QK, TileMatB_QK &bMatTile_QK_A,
TileMatB_QK &bMatTile_QK_B, LeftTile_QK &aTile_QK, RightTile_QK &bTile_QK, AccTile_QK &cTile_QK, SijPipeT &sij_pipe,
bool current_loaded = false, bool has_next = false, uint64_t next_kv_block_id = 0
TileMatB_QK &bMatTile_QK_B, LeftTile_QK (&aTile_QK)[2], RightTile_QK (&bTile_QK)[2], AccTile_QK &cTile_QK,
SijPipeT &sij_pipe, bool current_loaded = false, bool has_next = false, uint64_t next_kv_block_id = 0
) {
int cur = static_cast<int>(i % 2);
event_t cur_l1_event = static_cast<event_t>(QK_L1_EVENT_BASE + cur);
event_t cur_l0_event = static_cast<event_t>(QK_L0_EVENT_BASE + cur);
if (!current_loaded) {
GlobalB_QK kjGlobal(key_base + kv_block_id * N * K);
if (i % 2 == 0) {
wait_flag(PIPE_MTE1, PIPE_MTE2, cur_l1_event);
if (cur == 0) {
TLOAD(bMatTile_QK_A, kjGlobal);
} else {
TLOAD(bMatTile_QK_B, kjGlobal);
}
set_flag(PIPE_MTE2, PIPE_MTE1, cur_l1_event);
}

set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_M, PIPE_MTE1, cur_l0_event);
wait_flag(PIPE_MTE2, PIPE_MTE1, cur_l1_event);

TMOV(aTile_QK, aMatTile_QK);
if (i % 2 == 0) {
TMOV(bTile_QK, bMatTile_QK_A);
TMOV(aTile_QK[cur], aMatTile_QK);
if (cur == 0) {
TMOV(bTile_QK[cur], bMatTile_QK_A);
} else {
TMOV(bTile_QK, bMatTile_QK_B);
TMOV(bTile_QK[cur], bMatTile_QK_B);
}
set_flag(PIPE_MTE1, PIPE_MTE2, cur_l1_event);
set_flag(PIPE_MTE1, PIPE_M, cur_l0_event);

if (has_next) {
GlobalB_QK kjGlobalNext(key_base + next_kv_block_id * N * K);
if ((i + 1) % 2 == 0) {
TLOAD(bMatTile_QK_A, kjGlobalNext);
} else {
TLOAD(bMatTile_QK_B, kjGlobalNext);
}
}
wait_flag(PIPE_MTE1, PIPE_M, cur_l0_event);

set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);

TMATMUL(cTile_QK, aTile_QK, bTile_QK);
TMATMUL(cTile_QK, aTile_QK[cur], bTile_QK[cur]);
set_flag(PIPE_M, PIPE_MTE1, cur_l0_event);

set_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0);

// TPUSH sij (C2V): AccTile L0C -> GM. Ensure prior MTE3 is done,
// then push, then wait for MTE3 DMA to complete before signaling consumer.
// TPUSH sij (C2V): AccTile L0C -> GM. The C2V ready signal must
// be emitted only after the GM write is complete and visible.
TPUSH<SijPipeT, AccTile_QK, TileSplitAxis::TILE_UP_DOWN>(sij_pipe, cTile_QK);
set_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7);
sij_pipe.prod.record();

if (has_next) {
int next = static_cast<int>((i + 1) % 2);
event_t next_l1_event = static_cast<event_t>(QK_L1_EVENT_BASE + next);
GlobalB_QK kjGlobalNext(key_base + next_kv_block_id * N * K);
wait_flag(PIPE_MTE1, PIPE_MTE2, next_l1_event);
if (next == 0) {
TLOAD(bMatTile_QK_A, kjGlobalNext);
} else {
TLOAD(bMatTile_QK_B, kjGlobalNext);
}
set_flag(PIPE_MTE2, PIPE_MTE1, next_l1_event);
}
}

// Helper: PV matmul for block i — TPOP pij, load value, move to L0, matmul, TPUSH oi
Expand All @@ -212,45 +228,57 @@ template <
typename TileMatB_PV, typename LeftTile_PV, typename RightTile_PV, typename AccTile_PV>
static __aicore__ void aic_pv_step(
__gm__ bfloat16_t *val_base, uint64_t kv_block_id, uint64_t i, PijMatTile &pijMatTile, TileMatB_PV &bMatTile_PV_A,
TileMatB_PV &bMatTile_PV_B, LeftTile_PV &aTile_PV, RightTile_PV &bTile_PV, AccTile_PV &cTile_PV, PijPipeT &pij_pipe,
OiPipeT &oi_pipe, bool current_loaded = false, bool has_next = false, uint64_t next_kv_block_id = 0
TileMatB_PV &bMatTile_PV_B, LeftTile_PV (&aTile_PV)[2], RightTile_PV (&bTile_PV)[2], AccTile_PV &cTile_PV,
PijPipeT &pij_pipe, OiPipeT &oi_pipe, bool current_loaded = false, bool has_next = false,
uint64_t next_kv_block_id = 0
) {
int cur = static_cast<int>(i % 2);
event_t cur_l1_event = static_cast<event_t>(PV_L1_EVENT_BASE + cur);
event_t cur_l0_event = static_cast<event_t>(PV_L0_EVENT_BASE + cur);
if (!current_loaded) {
GlobalB_PV vjGlobal(val_base + kv_block_id * N * K);
if (i % 2 == 0) {
wait_flag(PIPE_MTE1, PIPE_MTE2, cur_l1_event);
if (cur == 0) {
TLOAD(bMatTile_PV_A, vjGlobal);
} else {
TLOAD(bMatTile_PV_B, vjGlobal);
}
set_flag(PIPE_MTE2, PIPE_MTE1, cur_l1_event);
}

TPOP<PijPipeT, PijMatTile, TileSplitAxis::TILE_NO_SPLIT>(pij_pipe, pijMatTile);

// PV step uses EVENT_ID1 (QK step uses EVENT_ID0) to avoid flag aliasing
// when pipe_barrier(PIPE_ALL) is removed between steps.
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
set_flag(PIPE_MTE2, PIPE_MTE1, static_cast<event_t>(PV_PIJ_EVENT));
wait_flag(PIPE_M, PIPE_MTE1, cur_l0_event);
wait_flag(PIPE_MTE2, PIPE_MTE1, cur_l1_event);
wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast<event_t>(PV_PIJ_EVENT));

TMOV(aTile_PV, pijMatTile);
if (i % 2 == 0) {
TMOV(bTile_PV, bMatTile_PV_A);
TMOV(aTile_PV[cur], pijMatTile);
if (cur == 0) {
TMOV(bTile_PV[cur], bMatTile_PV_A);
} else {
TMOV(bTile_PV, bMatTile_PV_B);
TMOV(bTile_PV[cur], bMatTile_PV_B);
}
set_flag(PIPE_MTE1, PIPE_MTE2, cur_l1_event);
set_flag(PIPE_MTE1, PIPE_M, cur_l0_event);

if (has_next) {
int next = static_cast<int>((i + 1) % 2);
event_t next_l1_event = static_cast<event_t>(PV_L1_EVENT_BASE + next);
GlobalB_PV vjGlobalNext(val_base + next_kv_block_id * N * K);
if ((i + 1) % 2 == 0) {
wait_flag(PIPE_MTE1, PIPE_MTE2, next_l1_event);
if (next == 0) {
TLOAD(bMatTile_PV_A, vjGlobalNext);
} else {
TLOAD(bMatTile_PV_B, vjGlobalNext);
}
set_flag(PIPE_MTE2, PIPE_MTE1, next_l1_event);
}

set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
wait_flag(PIPE_MTE1, PIPE_M, cur_l0_event);

TMATMUL(cTile_PV, aTile_PV, bTile_PV);
TMATMUL(cTile_PV, aTile_PV[cur], bTile_PV[cur]);
set_flag(PIPE_M, PIPE_MTE1, cur_l0_event);

set_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
Expand Down Expand Up @@ -297,28 +325,41 @@ static __aicore__ void aic_process_blocks(
TASSIGN(bMatTile_QK_A, 0x20000);
TASSIGN(bMatTile_QK_B, 0x20000 + kQKBBytes);

LeftTile_QK aTile_QK;
RightTile_QK bTile_QK;
LeftTile_QK aTile_QK[2];
RightTile_QK bTile_QK[2];
AccTile_QK cTile_QK;
TASSIGN(aTile_QK, 0x0);
TASSIGN(bTile_QK, 0x0);
TASSIGN(aTile_QK[0], 0x0);
TASSIGN(aTile_QK[1], M * K * static_cast<int>(sizeof(bfloat16_t)));
TASSIGN(bTile_QK[0], 0x0);
TASSIGN(bTile_QK[1], kQKBBytes);
TASSIGN(cTile_QK, 0x0);

PijMatTile pijMatTile;
TileMatB_PV bMatTile_PV_A, bMatTile_PV_B;
TASSIGN(bMatTile_PV_A, Cfg::PIJ_L1_BASE + Cfg::PIJ_L1_SIZE);
TASSIGN(bMatTile_PV_B, Cfg::PIJ_L1_BASE + Cfg::PIJ_L1_SIZE + kPVBBytes);

LeftTile_PV aTile_PV;
RightTile_PV bTile_PV;
LeftTile_PV aTile_PV[2];
RightTile_PV bTile_PV[2];
AccTile_PV cTile_PV;
TASSIGN(aTile_PV, 0x0);
TASSIGN(bTile_PV, 0x0);
TASSIGN(aTile_PV[0], 0x0);
TASSIGN(aTile_PV[1], M * N * static_cast<int>(sizeof(bfloat16_t)));
TASSIGN(bTile_PV[0], 0x0);
TASSIGN(bTile_PV[1], kPVBBytes);
TASSIGN(cTile_PV, 0x0);

GlobalA_QK qiGlobal(qi_base);
TLOAD(aMatTile_QK, qiGlobal);

set_flag(PIPE_MTE1, PIPE_MTE2, static_cast<event_t>(QK_L1_EVENT_BASE));
set_flag(PIPE_MTE1, PIPE_MTE2, static_cast<event_t>(QK_L1_EVENT_BASE + 1));
set_flag(PIPE_MTE1, PIPE_MTE2, static_cast<event_t>(PV_L1_EVENT_BASE));
set_flag(PIPE_MTE1, PIPE_MTE2, static_cast<event_t>(PV_L1_EVENT_BASE + 1));
set_flag(PIPE_M, PIPE_MTE1, static_cast<event_t>(QK_L0_EVENT_BASE));
set_flag(PIPE_M, PIPE_MTE1, static_cast<event_t>(QK_L0_EVENT_BASE + 1));
set_flag(PIPE_M, PIPE_MTE1, static_cast<event_t>(PV_L0_EVENT_BASE));
set_flag(PIPE_M, PIPE_MTE1, static_cast<event_t>(PV_L0_EVENT_BASE + 1));

if (n_blocks == 1) {
// Degenerate case: no pipeline overlap possible
uint64_t block_id = static_cast<uint64_t>(bt[bt_offset]);
Expand Down Expand Up @@ -358,6 +399,15 @@ static __aicore__ void aic_process_blocks(
cTile_PV, pij_pipe, oi_pipe, n_blocks > 1
);
}

wait_flag(PIPE_MTE1, PIPE_MTE2, static_cast<event_t>(QK_L1_EVENT_BASE));
wait_flag(PIPE_MTE1, PIPE_MTE2, static_cast<event_t>(QK_L1_EVENT_BASE + 1));
wait_flag(PIPE_MTE1, PIPE_MTE2, static_cast<event_t>(PV_L1_EVENT_BASE));
wait_flag(PIPE_MTE1, PIPE_MTE2, static_cast<event_t>(PV_L1_EVENT_BASE + 1));
wait_flag(PIPE_M, PIPE_MTE1, static_cast<event_t>(QK_L0_EVENT_BASE));
wait_flag(PIPE_M, PIPE_MTE1, static_cast<event_t>(QK_L0_EVENT_BASE + 1));
wait_flag(PIPE_M, PIPE_MTE1, static_cast<event_t>(PV_L0_EVENT_BASE));
wait_flag(PIPE_M, PIPE_MTE1, static_cast<event_t>(PV_L0_EVENT_BASE + 1));
}

// ============================================================================
Expand Down
Loading