Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/ck_tile/ops/fmha/block/block_masking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ struct GenericAttentionMask
}
}

CK_TILE_HOST_DEVICE constexpr auto GetXTotal() const { return x_total; }

private:
index_t y, x, sink;
index_t y_total, x_total;
Expand Down Expand Up @@ -536,6 +538,8 @@ struct SimplifiedGenericAttentionMask
}
}

CK_TILE_HOST_DEVICE constexpr auto GetXTotal() const { return x_total; }

private:
index_t y, x, sink;
index_t y_total, x_total;
Expand Down Expand Up @@ -722,6 +726,8 @@ struct SimplifiedRatioAttentionMask
}
}

CK_TILE_HOST_DEVICE constexpr auto GetXTotal() const { return x_total; }

private:
index_t y, x;
index_t y_total, x_total;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ template <typename IndexArrayType,
CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
const CoordVecType& coord_vec,
index_t global_seq_offset,
index_t valid_seqlen_k,
IndexArrayType& physical_pages)
{
static constexpr index_t kLog2PageSize = [] {
Expand All @@ -49,14 +50,18 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
}();

const index_t& thread_coord_start = coord_vec[kCoordAxis];
const index_t last_valid_token_idx = valid_seqlen_k > 0 ? valid_seqlen_k - 1 : 0;
const auto clamp_token_idx = [&](index_t token_idx) {
return ck_tile::min(token_idx, last_valid_token_idx);
};

if constexpr(kIsKcache)
{
// K cache: per-token lookup (all tokens may be on different pages)
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t page_id = global_token_idx >> kLog2PageSize;
const index_t page_id = clamp_token_idx(global_token_idx) >> kLog2PageSize;
physical_pages[k0] = page_idx[page_id];
});
}
Expand All @@ -75,7 +80,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
physical_pages[k0] = page_idx[global_token_idx];
physical_pages[k0] = page_idx[clamp_token_idx(global_token_idx)];
});
}
else if constexpr(kVTileCrossesPages)
Expand All @@ -85,16 +90,16 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t page_id = global_token_idx >> kLog2PageSize;
const index_t page_id = clamp_token_idx(global_token_idx) >> kLog2PageSize;
physical_pages[k0] = page_idx[page_id];
});
}
else
{
// V tile fully contained in one page: lane0 lookup, broadcast to all
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
const index_t lane0_page_id =
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
const index_t lane0_token_idx = global_seq_offset + lane0_start + kLoopStart;
const index_t lane0_page_id = clamp_token_idx(lane0_token_idx) >> kLog2PageSize;
const index_t shared_physical_page = page_idx[lane0_page_id];

static_for<0, kLoopCount, 1>{}(
Expand Down Expand Up @@ -559,6 +564,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const index_t valid_seqlen_k = mask.GetXTotal();
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
const auto num_total_loop =
Expand Down Expand Up @@ -611,7 +617,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kN0 / NRepeat,
kKVMemoryLayout,
true,
kN0>(page_idx, k_coord, current_seq_k, k_physical_pages);
kN0>(
page_idx, k_coord, current_seq_k, valid_seqlen_k, k_physical_pages);

kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
Expand Down Expand Up @@ -839,7 +846,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
1,
kKVMemoryLayout,
false,
kN0>(page_idx, v_coord, current_seq_k, v_physical_pages_k2);
kN0>(page_idx,
v_coord,
current_seq_k,
valid_seqlen_k,
v_physical_pages_k2);

// Copy to merged array
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
Expand All @@ -859,7 +870,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
1,
kKVMemoryLayout,
false,
kN0>(page_idx, v_coord, current_seq_k, v_physical_pages);
kN0>(
page_idx, v_coord, current_seq_k, valid_seqlen_k, v_physical_pages);
}
};

Expand Down Expand Up @@ -1516,7 +1528,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kN0 / NRepeat,
kKVMemoryLayout,
true,
kN0>(page_idx, k_coord, current_seq_k, k_physical_pages);
kN0>(
page_idx, k_coord, current_seq_k, valid_seqlen_k, k_physical_pages);

kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
Expand Down