diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 4ffb303812..06a0539ebf 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -299,6 +299,8 @@ struct GenericAttentionMask } } + CK_TILE_HOST_DEVICE constexpr auto GetXTotal() const { return x_total; } + private: index_t y, x, sink; index_t y_total, x_total; @@ -536,6 +538,8 @@ struct SimplifiedGenericAttentionMask } } + CK_TILE_HOST_DEVICE constexpr auto GetXTotal() const { return x_total; } + private: index_t y, x, sink; index_t y_total, x_total; @@ -722,6 +726,8 @@ struct SimplifiedRatioAttentionMask } } + CK_TILE_HOST_DEVICE constexpr auto GetXTotal() const { return x_total; } + private: index_t y, x; index_t y_total, x_total; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 8aa6d17dc3..09a8354e7d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -35,6 +35,7 @@ template 0 ? valid_seqlen_k - 1 : 0; + const auto clamp_token_idx = [&](index_t token_idx) { + return ck_tile::min(token_idx, last_valid_token_idx); + }; if constexpr(kIsKcache) { @@ -56,7 +61,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; + const index_t page_id = clamp_token_idx(global_token_idx) >> kLog2PageSize; physical_pages[k0] = page_idx[page_id]; }); } @@ -75,7 +80,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - physical_pages[k0] = page_idx[global_token_idx]; + physical_pages[k0] = page_idx[clamp_token_idx(global_token_idx)]; }); } else if constexpr(kVTileCrossesPages) @@ -85,7 +90,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; + const index_t page_id = clamp_token_idx(global_token_idx) >> kLog2PageSize; physical_pages[k0] = page_idx[page_id]; }); } @@ -93,8 +98,8 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, { // V tile fully contained in one page: lane0 lookup, broadcast to all const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start); - const index_t lane0_page_id = - (global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize; + const index_t lane0_token_idx = global_seq_offset + lane0_start + kLoopStart; + const index_t lane0_page_id = clamp_token_idx(lane0_token_idx) >> kLog2PageSize; const index_t shared_physical_page = page_idx[lane0_page_id]; static_for<0, kLoopCount, 1>{}( @@ -559,6 +564,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); + const index_t valid_seqlen_k = mask.GetXTotal(); const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0; const auto num_total_loop = @@ -611,7 +617,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kN0 / NRepeat, kKVMemoryLayout, true, - kN0>(page_idx, k_coord, current_seq_k, k_physical_pages); + kN0>( + page_idx, k_coord, current_seq_k, valid_seqlen_k, k_physical_pages); kv_offset_array_transform, decltype(k_coord), @@ -839,7 +846,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync 1, kKVMemoryLayout, false, - kN0>(page_idx, v_coord, current_seq_k, v_physical_pages_k2); + kN0>(page_idx, + v_coord, + current_seq_k, + valid_seqlen_k, + v_physical_pages_k2); // Copy to merged array static_for<0, V_KIterInner, 1>{}([&](auto k1) { @@ -859,7 +870,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync 1, kKVMemoryLayout, false, - kN0>(page_idx, v_coord, current_seq_k, v_physical_pages); + kN0>( + page_idx, v_coord, current_seq_k, valid_seqlen_k, v_physical_pages); } }; @@ -1516,7 +1528,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kN0 / NRepeat, kKVMemoryLayout, true, - kN0>(page_idx, k_coord, current_seq_k, k_physical_pages); + kN0>( + page_idx, k_coord, current_seq_k, valid_seqlen_k, k_physical_pages); kv_offset_array_transform, decltype(k_coord),