From 1973307c2c7e73452e37872580f46d6f83da0e76 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 3 Mar 2026 09:48:47 +0000 Subject: [PATCH 01/11] Implement pre-sorting, caching and contigous warp processing in group_index_select kernel --- fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h | 7 +- .../utils/rocm/sparse_group_utils.h | 108 +++++ .../src/sparse_ops/sparse_group_index.cu | 379 +++++++++++++----- fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp | 26 +- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 123 +++++- 5 files changed, 508 insertions(+), 135 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 0e7bd37234..edf9c2ee57 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -1078,6 +1078,8 @@ void group_index_select_or_add_cuda( const int64_t* input_ptrs, const int64_t* output_ptrs, const int64_t* indices_ptrs, + const int64_t* sorted_indices_ptrs, + const int64_t* reverse_indices_ptrs, const int64_t* warp_offsets_group, const int32_t* num_cols_group, const c10::ScalarType& input_scalar_type, @@ -1087,7 +1089,10 @@ void group_index_select_or_add_cuda( const int64_t total_num_warps, const int group_size, const bool use_index_select, - const bool use_var_cols); + const bool use_var_cols, + const bool use_contiguous_warps, + const bool use_cache, + const bool use_packed_rows); int get_group_index_select_cols_per_warp(); diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h index e32d346897..5ab1d1a607 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h @@ -11,6 +11,12 @@ #include #include +#include +#include + +#include +#include + #include "fbgemm_gpu/utils/cuda_prelude.cuh" namespace fbgemm_gpu::rocm { @@ -68,5 +74,107 @@ __device__ __forceinline__ void warp_upper_bound( *cached_boundary = cached_result; } +std::tuple sort_indices_with_rocprim(const at::Tensor& indices) { + TORCH_CHECK( + indices.dim() == 1, + "sort_indices_with_rocprim expects a 1D tensor, got ", + indices.dim()); + TORCH_CHECK( + indices.is_cuda(), + "sort_indices_with_rocprim expects a CUDA tensor for indices"); + + CUDA_DEVICE_GUARD(indices); + auto contiguous_indices = indices.contiguous(); + auto sorted_indices = at::empty_like(contiguous_indices); + auto reverse_indices = at::empty( + contiguous_indices.sizes(), + contiguous_indices.options().dtype(at::kLong)); + auto original_positions = at::arange( + contiguous_indices.numel(), + contiguous_indices.options().dtype(at::kLong)); + + const auto numel = contiguous_indices.numel(); + if (numel == 0) { + return {sorted_indices, reverse_indices}; + } + + TORCH_CHECK( + numel <= static_cast(std::numeric_limits::max()), + "sort_indices_with_rocprim only supports up to INT_MAX elements"); + + const int num_items = static_cast(numel); + auto stream = at::cuda::getCurrentCUDAStream(); + + const auto scalar_type = contiguous_indices.scalar_type(); + auto dispatch = [&](auto index_value_placeholder) { + using index_t = decltype(index_value_placeholder); + auto keys_in = contiguous_indices.data_ptr(); + auto keys_out = sorted_indices.data_ptr(); + auto values_in = original_positions.data_ptr(); + auto values_out = reverse_indices.data_ptr(); + + size_t temp_storage_bytes = 0; + // Selected empirically + constexpr int k_merge_sort_threshold = 400'000; + + using sort_config = rocprim::radix_sort_config< + rocprim::default_config, + rocprim::default_config, + rocprim::default_config, + k_merge_sort_threshold>; + AT_CUDA_CHECK(rocprim::radix_sort_pairs( + nullptr, + temp_storage_bytes, + keys_in, + keys_out, + values_in, + values_out, + num_items, + 0, + sizeof(index_t) * 8, + stream, + false)); + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, + contiguous_indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(rocprim::radix_sort_pairs( + temp_storage.data_ptr(), + temp_storage_bytes, + keys_in, + keys_out, + values_in, + values_out, + num_items, + 0, + sizeof(index_t) * 8, + stream, + false)); + }; + + switch (scalar_type) { + case at::ScalarType::Byte: + dispatch(uint8_t{}); + break; + case at::ScalarType::Char: + dispatch(int8_t{}); + break; + case at::ScalarType::Short: + dispatch(int16_t{}); + break; + case at::ScalarType::Int: + dispatch(int32_t{}); + break; + case at::ScalarType::Long: + dispatch(int64_t{}); + break; + default: + TORCH_CHECK( + false, + "sort_indices_with_rocprim only supports integral index dtypes, got ", + scalar_type); + } + + return {sorted_indices, reverse_indices}; +} } // namespace } // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index ddec5d0e01..1a1f122f77 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -5,6 +5,10 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ +#include +#include +#include +#include #include "common.cuh" #ifdef USE_ROCM @@ -45,6 +49,10 @@ template < typename scalar_t, bool USE_INDEX_SELECT, bool USE_VAR_COLS, + bool USE_CONTIGUOUS_WARPS, + bool USE_SORTED_INDICES, + bool USE_CACHE, + bool USE_PACKED_ROWS, int UNROLL_FACTOR, int COLS_PER_WARP, int LOG_COLS_PER_WARP> @@ -53,10 +61,17 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int64_t* input_ptrs, const int64_t* output_ptrs, const int64_t* indices_ptrs, + const int64_t* reverse_indices_ptrs, const int64_t* warp_offsets_group, const int32_t* num_cols_group, const int64_t num_work_rows, // number of rows to work on per member const int64_t group_size) { + static_assert( + !USE_CACHE || UNROLL_FACTOR == 1, + "Cache path only supports UNROLL_FACTOR == 1"); + + constexpr index_t kInvalidIdx = std::numeric_limits::max(); + const auto total_num_warps = warp_offsets_group[group_size]; int32_t num_cols = 0; int32_t warps_per_row = 0; @@ -66,14 +81,50 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; } -#ifdef USE_ROCM - int cached_member_id = -1; - int64_t cached_upper_bound = -1; -#endif + [[maybe_unused]] int cached_member_id = -1; + [[maybe_unused]] int64_t cached_upper_bound = -1; + [[maybe_unused]] int32_t last_member_id_for_accum = -1; + [[maybe_unused]] int32_t last_member_num_cols = 0; + [[maybe_unused]] scalar_t* last_member_output_tile = nullptr; + + int64_t start_warp_id = 0; + int64_t warp_end = 0; + int64_t warp_stride = 0; + + if constexpr (USE_CONTIGUOUS_WARPS) { + const int64_t linear_warp_id = threadIdx.y * gridDim.x + blockIdx.x; + const int64_t warps_per_launch = gridDim.x * blockDim.y; + const int64_t chunk_size = + (total_num_warps + warps_per_launch - 1) / warps_per_launch; + start_warp_id = linear_warp_id * chunk_size; + warp_end = start_warp_id + chunk_size < total_num_warps + ? start_warp_id + chunk_size + : total_num_warps; + warp_stride = 1; + } else { + start_warp_id = threadIdx.y * gridDim.x + blockIdx.x; + warp_end = total_num_warps; + warp_stride = gridDim.x * blockDim.y; + } - for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; - warp_id < total_num_warps; - warp_id += gridDim.x * blockDim.y) { + auto storage = scalar_t(0); + auto cached_idx = kInvalidIdx; + // TODO: Account for UNROLL_FACTOR + auto flush_cache_accumulator = [&](scalar_t* target_output, + int32_t target_num_cols) { + if constexpr (!USE_INDEX_SELECT && USE_CACHE) { + if (target_output && cached_idx != kInvalidIdx) { + gpuAtomicAddNoReturn( + &target_output[cached_idx * target_num_cols], + storage); + cached_idx = kInvalidIdx; + } + } + }; + + for (int64_t warp_id = start_warp_id; warp_id < warp_end; warp_id += warp_stride) { + bool use_small_dim_path = false; + int rows_per_warp_small = 0; int32_t member_id = 0; int32_t member_warp_id = 0; if constexpr (USE_VAR_COLS) { @@ -108,92 +159,149 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( // All columns are the same member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); -#ifdef USE_ROCM + } + + if constexpr (USE_PACKED_ROWS) { if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { // Need to ensure that [member_id] and [member_warp_id] are calculated // correctly for the small embedding dimension path below - const auto rows_per_warp = COLS_PER_WARP / num_cols; - const auto warps_per_member = - DIV_ROUND_UP(num_work_rows, rows_per_warp); - member_id = warp_id / warps_per_member; - member_warp_id = warp_id % warps_per_member; + rows_per_warp_small = COLS_PER_WARP / num_cols; + if constexpr (!USE_VAR_COLS) { + const auto warps_per_member = + (num_work_rows + rows_per_warp_small - 1) / rows_per_warp_small; + member_id = warp_id / warps_per_member; + member_warp_id = warp_id % warps_per_member; + } + use_small_dim_path = true; } -#endif // USE_ROCM } -#ifdef USE_ROCM - if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { - // Optimized path for small embedding dimensions - // Each warp processes 'rows_per_warp' rows - const auto rows_per_warp = COLS_PER_WARP / num_cols; - const int64_t start_row = member_warp_id * rows_per_warp; - - // Since we are processing multiple rows within the warp, we need to - // map each lane to a specific row, in addition to the column - const auto local_row = (threadIdx.x * UNROLL_FACTOR) / - num_cols; // the row ID within the set of rows handled by this warp - const auto col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; - const int64_t current_row = start_row + - local_row; // the actual row within the table processed by this lane - - // local_row may be out of bounds for the last few lanes in the warp if - // [COLS_PER_WARP % num_cols != 0] and we also need to confirm that we are - // within num_work_rows - if (local_row < rows_per_warp && current_row < num_work_rows) { - scalar_t* input = - reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = - reinterpret_cast(output_ptrs[member_id]) + col_offset; - - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[current_row]; -#pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { - // Compile time conditional - if constexpr (USE_INDEX_SELECT) { - output[current_row * num_cols + i] = - LDG(&input[idx * num_cols + i]); - } else { - gpuAtomicAddNoReturn( - &output[idx * num_cols + i], input[current_row * num_cols + i]); - } + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const int64_t* reverse_indices = USE_SORTED_INDICES + ? reinterpret_cast(reverse_indices_ptrs[member_id]) + : nullptr; + + int64_t logical_row = 0; + int64_t row = 0; + int64_t col_offset = 0; + bool handled_small_dim_path = false; + + if constexpr (USE_PACKED_ROWS) { + if (use_small_dim_path) { + // Optimized path for small embedding dimensions + // Each warp processes 'rows_per_warp' rows + const int rows_per_warp = rows_per_warp_small; + const int64_t start_row = member_warp_id * rows_per_warp; + // Since we are processing multiple rows within the warp, we need to + // map each lane to a specific row, in addition to the column + const int local_row = (threadIdx.x * UNROLL_FACTOR) / + num_cols; // the row ID within the set of rows handled by this warp + const int64_t current_row = start_row + + local_row; // the actual row within the table processed by this lane + const int col_offset_small = (threadIdx.x * UNROLL_FACTOR) % num_cols; + // local_row may be out of bounds for the last few lanes in the warp if + // [COLS_PER_WARP % num_cols != 0] and we also need to confirm that we are + // within num_work_rows + if (local_row < rows_per_warp && current_row < num_work_rows) { + logical_row = current_row; + row = USE_SORTED_INDICES ? reverse_indices[current_row] : current_row; + col_offset = col_offset_small; + handled_small_dim_path = true; + } else { + flush_cache_accumulator(last_member_output_tile, last_member_num_cols); + continue; } } - } else { - // Large embedding dimensions use >= 1 warp per row - // which is the default codepath for non-ROCm as well -#endif // USE_ROCM - const auto row = member_warp_id / warps_per_row; - const auto col_offset = - ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + + } + + if (!handled_small_dim_path) { + int64_t row_in_member = 0; + int64_t col_tile = 0; + if constexpr (USE_CONTIGUOUS_WARPS) { + // Contiguous warp traversal: iterate rows sequentially while column tiles + // remain strided so each warp processes a different tile for successive rows. + row_in_member = member_warp_id % num_work_rows; + col_tile = member_warp_id / num_work_rows; + } else { + // Original strided mapping: each warp walks tiles first, distributing rows round-robin. + row_in_member = member_warp_id / warps_per_row; + col_tile = member_warp_id % warps_per_row; + } + + logical_row = row_in_member; + row = USE_SORTED_INDICES ? reverse_indices[row_in_member] : row_in_member; + col_offset = + (static_cast(col_tile) << LOG_COLS_PER_WARP) + (threadIdx.x * UNROLL_FACTOR); - scalar_t* input = - reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = - reinterpret_cast(output_ptrs[member_id]) + col_offset; + } + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[row]; + const index_t idx = indices[logical_row]; + // TODO: Account for UNROLL_FACTOR #pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { - // Compile time conditional - if constexpr (USE_INDEX_SELECT) { + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + // Compile time conditional + if constexpr (USE_INDEX_SELECT) { + if constexpr (USE_CACHE) { + if (cached_idx != idx) { + storage = LDG(&input[idx * num_cols + i]); + cached_idx = idx; + } + + output[row * num_cols + i] = storage; + } else { output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); + } + } else { + if constexpr (USE_CACHE) { + const bool member_changed = (last_member_id_for_accum != -1 && + member_id != last_member_id_for_accum); + // Probably might be merged into following if-else cascade + if (member_changed) { + flush_cache_accumulator(last_member_output_tile, last_member_num_cols); + } + + const bool is_first_warp = member_changed || (warp_id == start_warp_id); + const bool is_last_warp = (warp_id + warp_stride >= warp_end); + if (is_first_warp) { + storage = input[row * num_cols + i]; + cached_idx = idx; + } else if (cached_idx != idx) { + // Flush using the output tile that owns the cached accumulator. + flush_cache_accumulator(last_member_output_tile, last_member_num_cols); + storage = input[row * num_cols + i]; + cached_idx = idx; + } else { + storage += input[row * num_cols + i]; + } + + if (is_last_warp) { + flush_cache_accumulator(output, num_cols); + } + + last_member_output_tile = output; + last_member_num_cols = num_cols; + last_member_id_for_accum = member_id; } else { gpuAtomicAddNoReturn( &output[idx * num_cols + i], input[row * num_cols + i]); } } -#ifdef USE_ROCM } -#endif // USE_ROCM } + + flush_cache_accumulator(last_member_output_tile, last_member_num_cols); } DLL_PUBLIC void group_index_select_or_add_cuda( const int64_t* input_ptrs, const int64_t* output_ptrs, const int64_t* indices_ptrs, + const int64_t* sorted_indices_ptrs, + const int64_t* reverse_indices_ptrs, const int64_t* warp_offsets_group, const int32_t* num_cols_group, const c10::ScalarType& input_scalar_type, @@ -203,12 +311,16 @@ DLL_PUBLIC void group_index_select_or_add_cuda( const int64_t total_num_warps, const int group_size, const bool use_index_select, - const bool use_var_cols) { + const bool use_var_cols, + const bool use_contiguous_warps, + const bool use_cache, + const bool use_packed_rows) { if (group_size == 0) { return; } at::cuda::OptionalCUDAGuard device_guard(device); + const bool use_sorted_indices = (sorted_indices_ptrs && reverse_indices_ptrs); // Partition work based on num_work_rows uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE; @@ -219,49 +331,106 @@ DLL_PUBLIC void group_index_select_or_add_cuda( max_grid_size); dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1); -#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \ - FBGEMM_LAUNCH_KERNEL( \ - (group_index_select_or_add_2d_kernel< \ - index_t, \ - scalar_t, \ - USE_INDEX_SELECT, \ - USE_VAR_COLS, \ - GROUP_INDEX_SELECT_UNROLL_FACTOR, \ - GROUP_INDEX_SELECT_COLS_PER_WARP, \ - GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \ - grid_size, \ - block_size, \ - 0, \ - at::cuda::getCurrentCUDAStream(), \ - input_ptrs, \ - output_ptrs, \ - indices_ptrs, \ - warp_offsets_group, \ - num_cols_group, \ - num_work_rows, \ - group_size) + auto invoke_group_index_select_or_add = [&]() { + FBGEMM_LAUNCH_KERNEL( + (group_index_select_or_add_2d_kernel< + index_t, + scalar_t, + USE_INDEX_SELECT, + USE_VAR_COLS, + USE_CONTIGUOUS_WARPS, + USE_SORTED_INDICES, + USE_CACHE, + USE_PACKED_ROWS, + GROUP_INDEX_SELECT_UNROLL_FACTOR, + GROUP_INDEX_SELECT_COLS_PER_WARP, + GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), + grid_size, + block_size, + 0, + at::cuda::getCurrentCUDAStream(), + input_ptrs, + output_ptrs, + use_sorted_indices ? sorted_indices_ptrs : indices_ptrs, + reverse_indices_ptrs, + warp_offsets_group, + num_cols_group, + num_work_rows, + group_size); + }; + + using bool_variant_t = std::variant; + +// Split is needed to avoid additional code generation for unsupported +// algorithms on CUDA +#ifdef USE_ROCM + using platform_bool_variant_t = std::variant; +#else + using platform_bool_variant_t = std::variant; +#endif + + auto get_bool_type = [](const bool var) -> bool_variant_t { + if (var) { + return std::true_type{}; + } else { + return std::false_type{}; + } + }; + + const bool_variant_t use_index_select_variant = + get_bool_type(use_index_select); + const bool_variant_t use_var_cols_variant = get_bool_type(use_var_cols); +#ifdef USE_ROCM + const platform_bool_variant_t use_contiguous_warps_variant = + get_bool_type(use_contiguous_warps); + const platform_bool_variant_t use_sorted_indices_variant = + get_bool_type(use_sorted_indices); + const platform_bool_variant_t use_cache_variant = get_bool_type(use_cache); + const platform_bool_variant_t use_packed_rows_variant = + get_bool_type(use_packed_rows); +#else + const platform_bool_variant_t use_contiguous_warps_variant{std::false_type{}}; + const platform_bool_variant_t use_sorted_indices_variant{std::false_type{}}; + const platform_bool_variant_t use_cache_variant{std::false_type{}}; + const platform_bool_variant_t use_packed_rows_variant{std::false_type{}}; +#endif AT_DISPATCH_INDEX_TYPES( indices_scalar_type, "group_index_select_2d_wrapper_1", [&] { FBGEMM_DISPATCH_FLOATING_TYPES( input_scalar_type, "group_index_select_2d_wrapper_2", [&] { - if (use_index_select) { - if (use_var_cols) { - INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true); - } else { - INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, false); - } - } else { - if (use_var_cols) { - INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, true); - } else { - INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false); - } - } + std::visit( + [&](auto use_index_select_arg, + auto use_var_cols_arg, + auto use_contiguous_warps_arg, + auto use_sorted_indices_arg, + auto use_cache_arg, + auto use_packed_rows_arg) { + invoke_group_index_select_or_add.template operator()< + index_t, + scalar_t, + use_index_select_arg.value, + use_var_cols_arg.value, + use_contiguous_warps_arg.value, + use_sorted_indices_arg.value, + use_cache_arg.value, + use_packed_rows_arg.value>(); + }, + use_index_select_variant, + use_var_cols_variant, + use_contiguous_warps_variant, + use_sorted_indices_variant, + use_cache_variant, + use_packed_rows_variant); }); }); - -#undef INVOKE_GROUP_INDEX_SELECT_OR_ADD } } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 411da4fcc1..3792940bf5 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -3523,7 +3523,7 @@ torch::autograd::variable_list group_index_select_dim0( .typed(); auto res = forward_op.call( all_indices_input_tensor, static_cast(group_size)); - TORCH_CHECK(res.size() == group_size + 2); + TORCH_CHECK(res.size() == group_size + 4); // only return the outputs (the first group_size elements) res.resize(group_size); return res; @@ -3537,14 +3537,16 @@ torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( group_index_select_dim0_unpack(all_indices_input, group_size); std::vector output_group; - output_group.reserve(group_size + 2); + output_group.reserve(group_size + 4); for (const auto i : c10::irange(group_size)) { output_group.push_back( at::index_select(input_group[i], 0, indices_group[i])); } // to match return format in CUDA implementation - // (group_size outputs, 1 args_tensor, 1 saved_data) + // (group_size outputs, 1 args_tensor, 1 saved_data, 1 sorted tensor, 1 reverse tensor) + output_group.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); + output_group.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); output_group.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); output_group.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); return output_group; @@ -3553,10 +3555,10 @@ torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( at::TensorList all_inputs, c10::SymIntArrayRef output_shape_group_ref) { - TORCH_CHECK(all_inputs.size() > 2); + TORCH_CHECK(all_inputs.size() > 4); // all input size = group_size * 2 (from grads, indices) // + 1 args_tensor + 1 saved_data + 1 first output - const int64_t group_size = static_cast((all_inputs.size() - 3) / 2); + const int64_t group_size = static_cast((all_inputs.size() - 5) / 2); auto grad_output_group = std::vector( all_inputs.cbegin(), all_inputs.cbegin() + group_size); @@ -3564,7 +3566,7 @@ torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( auto indices_group = std::vector( all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); - const Tensor& fwd_input = all_inputs[2 * group_size + 2]; + const Tensor& fwd_input = all_inputs[2 * group_size + 4]; const int64_t output_dim = fwd_input.dim(); std::vector output_shape_group; @@ -3635,7 +3637,7 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") .typed(); auto result = forward_op.call(all_indices_input, group_size); - TORCH_CHECK(static_cast(result.size()) == group_size + 2); + TORCH_CHECK(static_cast(result.size()) == group_size + 4); auto [input_group, indices_group] = group_index_select_dim0_unpack(all_indices_input, group_size); @@ -3655,7 +3657,7 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( ctx->saved_data["input_shape_group_" + std::to_string(i)] = input_shape_group[i]; } - // save indices, args_tensor, saved_data + // save indices, args_tensor, saved_data, sorted tensor, reverse tensor auto saved_tensors = std::vector(indices_group); saved_tensors.insert( saved_tensors.end(), result.cbegin() + group_size, result.cend()); @@ -3668,17 +3670,17 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( torch::autograd::variable_list GroupIndexSelectDim0Op::backward( torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output_group) { - TORCH_CHECK(grad_output_group.size() >= 2); - if (grad_output_group.size() == 2) { + TORCH_CHECK(grad_output_group.size() >= 4); + if (grad_output_group.size() == 4) { // empty outputs return torch::autograd::variable_list(1); } // remove redundant grads - auto group_size = grad_output_group.size() - 2; + auto group_size = grad_output_group.size() - 4; grad_output_group.resize(group_size); const auto saved_tensors = ctx->get_saved_variables(); - TORCH_CHECK(saved_tensors.size() == group_size + 3); + TORCH_CHECK(saved_tensors.size() == group_size + 5); std::vector output_shape_group; int i = 0; while (true) { diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index d85eb3cba7..5db4042e9a 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -12,14 +12,19 @@ #include "fbgemm_gpu/sparse_ops.h" #include "fbgemm_gpu/utils/ops_utils.h" #include "fbgemm_gpu/utils/tensor_utils.h" +#ifdef USE_ROCM +#include "fbgemm_gpu/utils/rocm/sparse_group_utils.h" +#endif #include +#include #include #include #include #include #include #include // for logic_error +#include using Tensor = at::Tensor; @@ -27,13 +32,15 @@ namespace fbgemm_gpu { namespace { -constexpr int32_t NUM_ARGS = 5; +constexpr int32_t NUM_ARGS = 7; enum args_pos { P_input_ptrs = 0, P_output_ptrs = 1, P_indices_ptrs = 2, - P_warp_offsets_group_ptrs = 3, - P_num_cols_group_ptrs = 4 + P_sorted_indices_ptrs = 3, + P_reverse_indices_ptrs = 4, + P_warp_offsets_group_ptrs = 5, + P_num_cols_group_ptrs = 6, }; template @@ -47,6 +54,8 @@ void offset_args( int64_t** input_ptrs, int64_t** output_ptrs, int64_t** indices_ptrs, + int64_t** sorted_indices_ptrs, + int64_t** reverse_indices_ptrs, int64_t** warp_offsets_group, int32_t** num_cols_group, int64_t* base_addr, @@ -54,6 +63,8 @@ void offset_args( *input_ptrs = base_addr + ptr_offsets[P_input_ptrs]; *output_ptrs = base_addr + ptr_offsets[P_output_ptrs]; *indices_ptrs = base_addr + ptr_offsets[P_indices_ptrs]; + *sorted_indices_ptrs = base_addr + ptr_offsets[P_sorted_indices_ptrs]; + *reverse_indices_ptrs = base_addr + ptr_offsets[P_reverse_indices_ptrs]; *warp_offsets_group = base_addr + ptr_offsets[P_warp_offsets_group_ptrs]; *num_cols_group = reinterpret_cast( base_addr + ptr_offsets[P_num_cols_group_ptrs]); @@ -213,6 +224,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( // input_ptrs (group_size int64_t elements) // output_ptrs (group_size int64_t elements) // indices_ptrs (group_size int64_t elements) + // sorted_indices_ptrs (group_size int64_t elements) + // reverse_indices_ptrs (group_size int64_t elements) // warp_offsets_group (group_size + 1 int64_t elements) // num_cols_group (group_size int32_t elements) int64_t args_ptrs_offsets[NUM_ARGS + 1]; @@ -224,6 +237,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( args_ptrs_offsets[P_input_ptrs] = group_size; args_ptrs_offsets[P_output_ptrs] = group_size; args_ptrs_offsets[P_indices_ptrs] = group_size; + args_ptrs_offsets[P_sorted_indices_ptrs] = group_size; + args_ptrs_offsets[P_reverse_indices_ptrs] = group_size; args_ptrs_offsets[P_warp_offsets_group_ptrs] = group_size + 1; args_ptrs_offsets[P_num_cols_group_ptrs] = numels_num_cols_group_64; @@ -251,6 +266,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( int64_t* input_ptrs = nullptr; int64_t* output_ptrs = nullptr; int64_t* indices_ptrs = nullptr; + int64_t* sorted_indices_ptrs = nullptr; + int64_t* reverse_indices_ptrs = nullptr; int64_t* warp_offsets_group = nullptr; int32_t* num_cols_group = nullptr; @@ -259,6 +276,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( &input_ptrs, &output_ptrs, &indices_ptrs, + &sorted_indices_ptrs, + &reverse_indices_ptrs, &warp_offsets_group, &num_cols_group, reinterpret_cast(args_tensor.mutable_data_ptr()), @@ -278,9 +297,14 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( int64_t warp_offset = 0; bool use_var_cols = false; + Tensor sorted_indices_storage = + at::empty({0}, first_indices.options()); + Tensor reverse_indices_storage = + at::empty({0}, first_indices.options()); + // Allocate memory for output_group std::vector output_group; - output_group.reserve(group_size + 2); + output_group.reserve(group_size + 4); // We need to store contiguous inputs and indices outside the for-loop to // guarantee that the contiguous tensors will outlive the kernel @@ -290,6 +314,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( input_contigs.reserve(group_size); index_contigs.reserve(group_size); + bool use_packed_rows = false; + size_t num_total_indices = 0; // For each group, copy input to output for (const auto i : c10::irange(group_size)) { const auto& input = input_group[i]; @@ -335,6 +361,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input, indices); auto num_output_rows_ = indices.size(0); + num_total_indices += num_output_rows_; // Verify that all input tensors have the same shape[0] TORCH_CHECK_VALUE( @@ -358,6 +385,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( // Optimization: Pack multiple rows into one warp int rows_per_warp = cols_per_warp / num_cols_; warps_needed = (num_output_rows_ + rows_per_warp - 1) / rows_per_warp; + use_packed_rows = true; } else { // Standard: One or more warps per row int warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; @@ -400,6 +428,17 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( warp_offset += warps_per_row * num_output_rows; #endif // USE_ROCM } + +#ifdef USE_ROCM + // The value is selected empirically. Potential + // place for optimization. + constexpr size_t kSortIndicesThreshold = 15'000'000; + const bool use_sorted_indices_for_bwd = + (num_total_indices < kSortIndicesThreshold); +#else + const bool use_sorted_indices_for_bwd = false; + (void)num_total_indices; +#endif // Store the last offset warp_offsets_group[group_size] = warp_offset; @@ -414,6 +453,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( &input_ptrs, &output_ptrs, &indices_ptrs, + &sorted_indices_ptrs, + &reverse_indices_ptrs, &warp_offsets_group, &num_cols_group, reinterpret_cast(args_tensor.mutable_data_ptr()), @@ -422,6 +463,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( int64_t saved_data[] = { static_cast(group_size), use_var_cols, + use_packed_rows, + use_sorted_indices_for_bwd, reinterpret_cast(warp_offsets_group), reinterpret_cast(num_cols_group), warp_offset, @@ -438,6 +481,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( input_ptrs, output_ptrs, indices_ptrs, + /*sorted_indices_ptrs=*/nullptr, + /*reverse_indices_ptrs=*/nullptr, warp_offsets_group, num_cols_group, first_input.scalar_type(), @@ -447,13 +492,18 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( /*total_num_warps=*/warp_offset, group_size, /*use_index_select=*/true, - use_var_cols); + use_var_cols, + /*use_contiguous_warps=*/false, + /*use_cache=*/false, + use_packed_rows); output_group.push_back(args_tensor); output_group.push_back(saved_data_t); + output_group.push_back(sorted_indices_storage); + output_group.push_back(reverse_indices_storage); // return format: - // (group_size outputs, 1 args_tensor, 1 saved_data) + // (group_size outputs, 1 args_tensor, 1 saved_data, 1 sorted tensor, 1 reverse tensor) return output_group; } @@ -461,17 +511,19 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( at::TensorList all_inputs, c10::SymIntArrayRef output_shape_group_ref) { TORCH_CHECK_VALUE( - all_inputs.size() > 2, - "all_inputs size must be larger than 2, but got ", + all_inputs.size() > 4, + "all_inputs size must be larger than 4, but got ", all_inputs.size()); // all_input size = group_size * 2 (from grads, indices) // + 1 args_tensor + 1 saved_data + 1 first input - const int64_t group_size = (all_inputs.size() - 3) / 2; + const int64_t group_size = (all_inputs.size() - 5) / 2; - const Tensor& fwd_input = all_inputs[2 * group_size + 2]; + const Tensor& fwd_input = all_inputs[2 * group_size + 4]; const int64_t output_dim = fwd_input.dim(); const Tensor& saved_data = all_inputs[2 * group_size + 1]; + const Tensor& sorted_indices_storage = all_inputs[2 * group_size + 2]; + const Tensor& reverse_indices_storage = all_inputs[2 * group_size + 3]; const Tensor& first_indices = all_inputs[group_size]; auto grad_output_group = std::vector( @@ -499,11 +551,11 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( " but got ", saved_data_ptr[0]); const bool use_var_cols = saved_data_ptr[1]; - const int64_t* warp_offsets_group = - reinterpret_cast(saved_data_ptr[2]); - const int32_t* num_cols_group = - reinterpret_cast(saved_data_ptr[3]); - int64_t total_num_warps = saved_data_ptr[4]; + const bool use_packed_rows = saved_data_ptr[2]; + const bool use_sorted_indices = saved_data_ptr[3]; + int64_t* warp_offsets_group = reinterpret_cast(saved_data_ptr[4]); + int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[5]); + int64_t total_num_warps = saved_data_ptr[6]; // We checked in forward that all output rows are the same for all member // in the group @@ -528,7 +580,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( // Allocate Tensor for ptrs of grad output and input, and indices Tensor args_tensor = at::empty( - {group_size * 3}, + {group_size * 5}, at::TensorOptions().dtype(at::kLong).pinned_memory(true)); // Ensure that args_tensor is contiguous TORCH_CHECK( @@ -538,6 +590,8 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( args_tensor.mutable_data_ptr() + group_size; int64_t* indices_ptrs = args_tensor.mutable_data_ptr() + 2 * group_size; + int64_t* sorted_indices_ptrs = args_tensor.data_ptr() + 3 * group_size; + int64_t* reverse_indices_ptrs = args_tensor.data_ptr() + 4 * group_size; int64_t group_grad_input_numel = 0; std::vector grad_input_numels; @@ -607,21 +661,53 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( // Calculate indices_ptrs std::vector> index_contigs; + std::vector sorted_indices_contigs; + std::vector reverse_indices_contigs; index_contigs.reserve(group_size); + if (use_sorted_indices) { + sorted_indices_contigs.reserve(group_size); + reverse_indices_contigs.reserve(group_size); + } for (const auto i : c10::irange(group_size)) { const auto& indices = indices_group[i]; index_contigs.push_back(indices.expect_contiguous()); indices_ptrs[i] = reinterpret_cast(index_contigs[i]->const_data_ptr()); + reverse_indices_ptrs[i] = 0; +#ifdef USE_ROCM + if (use_sorted_indices) { + auto [sorted_tensor, reverse_tensor] = + rocm::sort_indices_with_rocprim(*index_contigs[i]); + const auto stream = at::cuda::getCurrentCUDAStream(); + sorted_tensor.record_stream(stream); + reverse_tensor.record_stream(stream); + sorted_indices_contigs.push_back(std::move(sorted_tensor)); + reverse_indices_contigs.push_back(std::move(reverse_tensor)); + sorted_indices_ptrs[i] = reinterpret_cast( + sorted_indices_contigs.back().data_ptr()); + reverse_indices_ptrs[i] = reinterpret_cast( + reverse_indices_contigs.back().data_ptr()); + } +#endif } // Transfer grad output pointers to GPU args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); +#ifdef USE_ROCM + constexpr bool use_contiguous_warps = true; + constexpr bool use_cache = true; +#else + constexpr bool use_contiguous_warps = false; + constexpr bool use_cache = false; +#endif + group_index_select_or_add_cuda( args_tensor.const_data_ptr(), args_tensor.const_data_ptr() + group_size, args_tensor.const_data_ptr() + 2 * group_size, + use_sorted_indices ? args_tensor.data_ptr() + 3 * group_size : nullptr, + use_sorted_indices ? args_tensor.data_ptr() + 4 * group_size : nullptr, warp_offsets_group, num_cols_group, fwd_input.scalar_type(), @@ -631,7 +717,10 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( total_num_warps, group_size, /*use_index_select=*/false, - use_var_cols); + use_var_cols, + use_contiguous_warps, + use_cache, + use_packed_rows); return outputs; } From 37b4078b683945473f41588d047ec9175a173c8f Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 11 Mar 2026 11:37:21 +0000 Subject: [PATCH 02/11] Remove redundant int32_t numel restriction --- .../include/fbgemm_gpu/utils/rocm/sparse_group_utils.h | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h index 5ab1d1a607..bea20c9160 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h @@ -98,11 +98,7 @@ std::tuple sort_indices_with_rocprim(const at::Tensor& i return {sorted_indices, reverse_indices}; } - TORCH_CHECK( - numel <= static_cast(std::numeric_limits::max()), - "sort_indices_with_rocprim only supports up to INT_MAX elements"); - - const int num_items = static_cast(numel); + const auto num_items = static_cast(numel); auto stream = at::cuda::getCurrentCUDAStream(); const auto scalar_type = contiguous_indices.scalar_type(); From 4a004b9eeca23b9989f561509b4a7eaffdcd1966 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 11 Mar 2026 11:47:09 +0000 Subject: [PATCH 03/11] Use AT_DISPATCH_INTEGRAL_TYPES macro for sort dispatch --- .../utils/rocm/sparse_group_utils.h | 112 +++++++----------- 1 file changed, 45 insertions(+), 67 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h index bea20c9160..961f87cd95 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h @@ -102,73 +102,51 @@ std::tuple sort_indices_with_rocprim(const at::Tensor& i auto stream = at::cuda::getCurrentCUDAStream(); const auto scalar_type = contiguous_indices.scalar_type(); - auto dispatch = [&](auto index_value_placeholder) { - using index_t = decltype(index_value_placeholder); - auto keys_in = contiguous_indices.data_ptr(); - auto keys_out = sorted_indices.data_ptr(); - auto values_in = original_positions.data_ptr(); - auto values_out = reverse_indices.data_ptr(); - - size_t temp_storage_bytes = 0; - // Selected empirically - constexpr int k_merge_sort_threshold = 400'000; - - using sort_config = rocprim::radix_sort_config< - rocprim::default_config, - rocprim::default_config, - rocprim::default_config, - k_merge_sort_threshold>; - AT_CUDA_CHECK(rocprim::radix_sort_pairs( - nullptr, - temp_storage_bytes, - keys_in, - keys_out, - values_in, - values_out, - num_items, - 0, - sizeof(index_t) * 8, - stream, - false)); - auto temp_storage = at::empty( - {static_cast(temp_storage_bytes)}, - contiguous_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(rocprim::radix_sort_pairs( - temp_storage.data_ptr(), - temp_storage_bytes, - keys_in, - keys_out, - values_in, - values_out, - num_items, - 0, - sizeof(index_t) * 8, - stream, - false)); - }; - - switch (scalar_type) { - case at::ScalarType::Byte: - dispatch(uint8_t{}); - break; - case at::ScalarType::Char: - dispatch(int8_t{}); - break; - case at::ScalarType::Short: - dispatch(int16_t{}); - break; - case at::ScalarType::Int: - dispatch(int32_t{}); - break; - case at::ScalarType::Long: - dispatch(int64_t{}); - break; - default: - TORCH_CHECK( - false, - "sort_indices_with_rocprim only supports integral index dtypes, got ", - scalar_type); - } + AT_DISPATCH_INTEGRAL_TYPES( + scalar_type, "sort_indices_with_rocprim", [&] { + using index_t = scalar_t; + auto keys_in = contiguous_indices.data_ptr(); + auto keys_out = sorted_indices.data_ptr(); + auto values_in = original_positions.data_ptr(); + auto values_out = reverse_indices.data_ptr(); + + size_t temp_storage_bytes = 0; + // Selected empirically + constexpr int k_merge_sort_threshold = 400'000; + + using sort_config = rocprim::radix_sort_config< + rocprim::default_config, + rocprim::default_config, + rocprim::default_config, + k_merge_sort_threshold>; + AT_CUDA_CHECK(rocprim::radix_sort_pairs( + nullptr, + temp_storage_bytes, + keys_in, + keys_out, + values_in, + values_out, + num_items, + 0, + sizeof(index_t) * 8, + stream, + false)); + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, + contiguous_indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(rocprim::radix_sort_pairs( + temp_storage.data_ptr(), + temp_storage_bytes, + keys_in, + keys_out, + values_in, + values_out, + num_items, + 0, + sizeof(index_t) * 8, + stream, + false)); + }); return {sorted_indices, reverse_indices}; } From 48ffb6c57d4fd6682a762a282cd9841d2293adda Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 11 Mar 2026 12:11:41 +0000 Subject: [PATCH 04/11] Revert returning reverse_indices from forward, save group_size through AutogradContext --- fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp | 27 ++++++++++---------- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 21 +++++---------- 2 files changed, 19 insertions(+), 29 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 3792940bf5..79f9ba04ee 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -3523,7 +3523,7 @@ torch::autograd::variable_list group_index_select_dim0( .typed(); auto res = forward_op.call( all_indices_input_tensor, static_cast(group_size)); - TORCH_CHECK(res.size() == group_size + 4); + TORCH_CHECK(res.size() >= group_size + 2); // only return the outputs (the first group_size elements) res.resize(group_size); return res; @@ -3537,16 +3537,14 @@ torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( group_index_select_dim0_unpack(all_indices_input, group_size); std::vector output_group; - output_group.reserve(group_size + 4); + output_group.reserve(group_size + 2); for (const auto i : c10::irange(group_size)) { output_group.push_back( at::index_select(input_group[i], 0, indices_group[i])); } // to match return format in CUDA implementation - // (group_size outputs, 1 args_tensor, 1 saved_data, 1 sorted tensor, 1 reverse tensor) - output_group.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); - output_group.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); + // (group_size outputs, 1 args_tensor, 1 saved_data) output_group.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); output_group.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); return output_group; @@ -3555,10 +3553,10 @@ torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( at::TensorList all_inputs, c10::SymIntArrayRef output_shape_group_ref) { - TORCH_CHECK(all_inputs.size() > 4); + TORCH_CHECK(all_inputs.size() > 2); // all input size = group_size * 2 (from grads, indices) // + 1 args_tensor + 1 saved_data + 1 first output - const int64_t group_size = static_cast((all_inputs.size() - 5) / 2); + const int64_t group_size = static_cast((all_inputs.size() - 3) / 2); auto grad_output_group = std::vector( all_inputs.cbegin(), all_inputs.cbegin() + group_size); @@ -3566,7 +3564,7 @@ torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( auto indices_group = std::vector( all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); - const Tensor& fwd_input = all_inputs[2 * group_size + 4]; + const Tensor& fwd_input = all_inputs[2 * group_size + 2]; const int64_t output_dim = fwd_input.dim(); std::vector output_shape_group; @@ -3637,7 +3635,8 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") .typed(); auto result = forward_op.call(all_indices_input, group_size); - TORCH_CHECK(static_cast(result.size()) == group_size + 4); + TORCH_CHECK(static_cast(result.size()) >= group_size + 2); + ctx->saved_data["group_size"] = group_size; auto [input_group, indices_group] = group_index_select_dim0_unpack(all_indices_input, group_size); @@ -3657,7 +3656,7 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( ctx->saved_data["input_shape_group_" + std::to_string(i)] = input_shape_group[i]; } - // save indices, args_tensor, saved_data, sorted tensor, reverse tensor + // save indices, args_tensor, saved_data auto saved_tensors = std::vector(indices_group); saved_tensors.insert( saved_tensors.end(), result.cbegin() + group_size, result.cend()); @@ -3670,17 +3669,17 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( torch::autograd::variable_list GroupIndexSelectDim0Op::backward( torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output_group) { - TORCH_CHECK(grad_output_group.size() >= 4); - if (grad_output_group.size() == 4) { + TORCH_CHECK(grad_output_group.size() >= 2); + if (grad_output_group.size() == 2) { // empty outputs return torch::autograd::variable_list(1); } // remove redundant grads - auto group_size = grad_output_group.size() - 4; + const auto group_size = ctx->saved_data["group_size"].toInt(); grad_output_group.resize(group_size); const auto saved_tensors = ctx->get_saved_variables(); - TORCH_CHECK(saved_tensors.size() == group_size + 5); + TORCH_CHECK(saved_tensors.size() == group_size + 3); std::vector output_shape_group; int i = 0; while (true) { diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 5db4042e9a..8708044055 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -297,14 +297,9 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( int64_t warp_offset = 0; bool use_var_cols = false; - Tensor sorted_indices_storage = - at::empty({0}, first_indices.options()); - Tensor reverse_indices_storage = - at::empty({0}, first_indices.options()); - // Allocate memory for output_group std::vector output_group; - output_group.reserve(group_size + 4); + output_group.reserve(group_size + 2); // We need to store contiguous inputs and indices outside the for-loop to // guarantee that the contiguous tensors will outlive the kernel @@ -499,11 +494,9 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( output_group.push_back(args_tensor); output_group.push_back(saved_data_t); - output_group.push_back(sorted_indices_storage); - output_group.push_back(reverse_indices_storage); // return format: - // (group_size outputs, 1 args_tensor, 1 saved_data, 1 sorted tensor, 1 reverse tensor) + // (group_size outputs, 1 args_tensor, 1 saved_data) return output_group; } @@ -511,19 +504,17 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( at::TensorList all_inputs, c10::SymIntArrayRef output_shape_group_ref) { TORCH_CHECK_VALUE( - all_inputs.size() > 4, - "all_inputs size must be larger than 4, but got ", + all_inputs.size() > 2, + "all_inputs size must be larger than 2, but got ", all_inputs.size()); // all_input size = group_size * 2 (from grads, indices) // + 1 args_tensor + 1 saved_data + 1 first input - const int64_t group_size = (all_inputs.size() - 5) / 2; + const int64_t group_size = (all_inputs.size() - 3) / 2; - const Tensor& fwd_input = all_inputs[2 * group_size + 4]; + const Tensor& fwd_input = all_inputs[2 * group_size + 2]; const int64_t output_dim = fwd_input.dim(); const Tensor& saved_data = all_inputs[2 * group_size + 1]; - const Tensor& sorted_indices_storage = all_inputs[2 * group_size + 2]; - const Tensor& reverse_indices_storage = all_inputs[2 * group_size + 3]; const Tensor& first_indices = all_inputs[group_size]; auto grad_output_group = std::vector( From c468070ccb6b026293430a13908d7b83fa586675 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 11 Mar 2026 12:12:35 +0000 Subject: [PATCH 05/11] Add const qualifier to packed helper buffers --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 8708044055..1f0ffd4453 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -544,8 +544,10 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( const bool use_var_cols = saved_data_ptr[1]; const bool use_packed_rows = saved_data_ptr[2]; const bool use_sorted_indices = saved_data_ptr[3]; - int64_t* warp_offsets_group = reinterpret_cast(saved_data_ptr[4]); - int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[5]); + const int64_t* warp_offsets_group = + reinterpret_cast(saved_data_ptr[4]); + const int32_t* num_cols_group = + reinterpret_cast(saved_data_ptr[5]); int64_t total_num_warps = saved_data_ptr[6]; // We checked in forward that all output rows are the same for all member From 40a3593a2f3679e385943b075d724d0916283cf8 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 11 Mar 2026 12:58:54 +0000 Subject: [PATCH 06/11] Revert to stronger result size check --- fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 79f9ba04ee..da3a717025 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -3523,7 +3523,7 @@ torch::autograd::variable_list group_index_select_dim0( .typed(); auto res = forward_op.call( all_indices_input_tensor, static_cast(group_size)); - TORCH_CHECK(res.size() >= group_size + 2); + TORCH_CHECK(res.size() == group_size + 2); // only return the outputs (the first group_size elements) res.resize(group_size); return res; @@ -3635,7 +3635,7 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") .typed(); auto result = forward_op.call(all_indices_input, group_size); - TORCH_CHECK(static_cast(result.size()) >= group_size + 2); + TORCH_CHECK(static_cast(result.size()) == group_size + 2); ctx->saved_data["group_size"] = group_size; auto [input_group, indices_group] = From 75a2c5b3038cb0e1600737b9fdf1b6307da44848 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 17 Mar 2026 10:19:37 +0000 Subject: [PATCH 07/11] Fix missing sorted_indices_ptrs initialization --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 1f0ffd4453..d219f21ce5 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -666,6 +666,7 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( index_contigs.push_back(indices.expect_contiguous()); indices_ptrs[i] = reinterpret_cast(index_contigs[i]->const_data_ptr()); + sorted_indices_ptrs[i] = 0; reverse_indices_ptrs[i] = 0; #ifdef USE_ROCM if (use_sorted_indices) { From 6496d73cd25646fcd6f95b723c4c41bd41b1309f Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Fri, 20 Mar 2026 11:30:24 +0000 Subject: [PATCH 08/11] Split sort_indices_with_rocprim definition and implementation --- fbgemm_gpu/FbgemmGpu.cmake | 1 + .../utils/rocm/sparse_group_utils.h | 79 +--------------- .../utils/rocm/sparse_group_utils.cu | 91 +++++++++++++++++++ 3 files changed, 95 insertions(+), 76 deletions(-) create mode 100644 fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index 73a17572ef..da84ce3edc 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -110,6 +110,7 @@ if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_CPU) src/quantize_ops/quantize_msfp.cu src/quantize_ops/quantize_padded_fp8_rowwise.cu src/quantize_ops/quantize_mx.cu + src/sparse_ops/utils/rocm/sparse_group_utils.cu src/sparse_ops/sparse_async_batched_cumsum.cu src/sparse_ops/sparse_block_bucketize_features.cu src/sparse_ops/sparse_bucketize_features.cu diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h index 961f87cd95..57e791aa2d 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h @@ -18,6 +18,7 @@ #include #include "fbgemm_gpu/utils/cuda_prelude.cuh" +#include "fbgemm_gpu/utils/function_types.h" namespace fbgemm_gpu::rocm { namespace { @@ -73,82 +74,8 @@ __device__ __forceinline__ void warp_upper_bound( *found = result; *cached_boundary = cached_result; } +} // namespace -std::tuple sort_indices_with_rocprim(const at::Tensor& indices) { - TORCH_CHECK( - indices.dim() == 1, - "sort_indices_with_rocprim expects a 1D tensor, got ", - indices.dim()); - TORCH_CHECK( - indices.is_cuda(), - "sort_indices_with_rocprim expects a CUDA tensor for indices"); - - CUDA_DEVICE_GUARD(indices); - auto contiguous_indices = indices.contiguous(); - auto sorted_indices = at::empty_like(contiguous_indices); - auto reverse_indices = at::empty( - contiguous_indices.sizes(), - contiguous_indices.options().dtype(at::kLong)); - auto original_positions = at::arange( - contiguous_indices.numel(), - contiguous_indices.options().dtype(at::kLong)); - - const auto numel = contiguous_indices.numel(); - if (numel == 0) { - return {sorted_indices, reverse_indices}; - } - - const auto num_items = static_cast(numel); - auto stream = at::cuda::getCurrentCUDAStream(); - - const auto scalar_type = contiguous_indices.scalar_type(); - AT_DISPATCH_INTEGRAL_TYPES( - scalar_type, "sort_indices_with_rocprim", [&] { - using index_t = scalar_t; - auto keys_in = contiguous_indices.data_ptr(); - auto keys_out = sorted_indices.data_ptr(); - auto values_in = original_positions.data_ptr(); - auto values_out = reverse_indices.data_ptr(); - - size_t temp_storage_bytes = 0; - // Selected empirically - constexpr int k_merge_sort_threshold = 400'000; - - using sort_config = rocprim::radix_sort_config< - rocprim::default_config, - rocprim::default_config, - rocprim::default_config, - k_merge_sort_threshold>; - AT_CUDA_CHECK(rocprim::radix_sort_pairs( - nullptr, - temp_storage_bytes, - keys_in, - keys_out, - values_in, - values_out, - num_items, - 0, - sizeof(index_t) * 8, - stream, - false)); - auto temp_storage = at::empty( - {static_cast(temp_storage_bytes)}, - contiguous_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(rocprim::radix_sort_pairs( - temp_storage.data_ptr(), - temp_storage_bytes, - keys_in, - keys_out, - values_in, - values_out, - num_items, - 0, - sizeof(index_t) * 8, - stream, - false)); - }); +std::tuple sort_indices_with_rocprim(const at::Tensor& indices); - return {sorted_indices, reverse_indices}; -} -} // namespace } // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu b/fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu new file mode 100644 index 0000000000..c0754b46ed --- /dev/null +++ b/fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef USE_ROCM +#include "fbgemm_gpu/utils/rocm/sparse_group_utils.h" + +namespace fbgemm_gpu::rocm { +DLL_PUBLIC std::tuple sort_indices_with_rocprim(const at::Tensor& indices) { + TORCH_CHECK( + indices.dim() == 1, + "sort_indices_with_rocprim expects a 1D tensor, got ", + indices.dim()); + TORCH_CHECK( + indices.is_cuda(), + "sort_indices_with_rocprim expects a CUDA tensor for indices"); + + CUDA_DEVICE_GUARD(indices); + auto contiguous_indices = indices.contiguous(); + auto sorted_indices = at::empty_like(contiguous_indices); + auto reverse_indices = at::empty( + contiguous_indices.sizes(), + contiguous_indices.options().dtype(at::kLong)); + auto original_positions = at::arange( + contiguous_indices.numel(), + contiguous_indices.options().dtype(at::kLong)); + + const auto numel = contiguous_indices.numel(); + if (numel == 0) { + return {sorted_indices, reverse_indices}; + } + + const auto num_items = static_cast(numel); + auto stream = at::cuda::getCurrentCUDAStream(); + + const auto scalar_type = contiguous_indices.scalar_type(); + AT_DISPATCH_INTEGRAL_TYPES( + scalar_type, "sort_indices_with_rocprim", [&] { + using index_t = scalar_t; + auto keys_in = contiguous_indices.data_ptr(); + auto keys_out = sorted_indices.data_ptr(); + auto values_in = original_positions.data_ptr(); + auto values_out = reverse_indices.data_ptr(); + + size_t temp_storage_bytes = 0; + // Selected empirically + constexpr int k_merge_sort_threshold = 400'000; + + using sort_config = rocprim::radix_sort_config< + rocprim::default_config, + rocprim::default_config, + rocprim::default_config, + k_merge_sort_threshold>; + AT_CUDA_CHECK(rocprim::radix_sort_pairs( + nullptr, + temp_storage_bytes, + keys_in, + keys_out, + values_in, + values_out, + num_items, + 0, + sizeof(index_t) * 8, + stream, + false)); + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, + contiguous_indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(rocprim::radix_sort_pairs( + temp_storage.data_ptr(), + temp_storage_bytes, + keys_in, + keys_out, + values_in, + values_out, + num_items, + 0, + sizeof(index_t) * 8, + stream, + false)); + }); + + return {sorted_indices, reverse_indices}; +} +} // namespace fbgemm::rocm + +#endif From f5f9e70b4cb17fd64f81cfa9d167f149b081b851 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 1 Apr 2026 13:12:08 +0000 Subject: [PATCH 09/11] Implement segmented_sort and reduce CPU overhead --- .../utils/rocm/sparse_group_utils.h | 56 +++++- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 119 +++++++++--- .../utils/rocm/sparse_group_utils.cu | 175 ++++++++++++------ 3 files changed, 263 insertions(+), 87 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h index 57e791aa2d..6cb740911e 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h @@ -16,11 +16,21 @@ #include #include +#include #include "fbgemm_gpu/utils/cuda_prelude.cuh" #include "fbgemm_gpu/utils/function_types.h" namespace fbgemm_gpu::rocm { +// Selected empirically: rocprim uses merge sort when num_items < this threshold, +// which is faster for small inputs. Must match across sizing and sort calls. +constexpr unsigned int k_sort_merge_threshold = 400'000; +using sort_config = rocprim::radix_sort_config< + rocprim::default_config, + rocprim::default_config, + rocprim::default_config, + k_sort_merge_threshold>; + namespace { template @@ -76,6 +86,48 @@ __device__ __forceinline__ void warp_upper_bound( } } // namespace -std::tuple sort_indices_with_rocprim(const at::Tensor& indices); - +// Returns temp storage size for a single-segment sort of num_items elements. +size_t get_sort_temp_storage_bytes( + const size_t num_items, + const c10::ScalarType scalar_type, + const at::cuda::CUDAStream& stream); +// Returns temp storage size for segmented sort of num_groups segments each +// with num_items_per_segment elements. +size_t get_segmented_sort_temp_storage_bytes( + const size_t num_items_per_segment, + const int64_t num_groups, + const c10::ScalarType scalar_type, + const at::cuda::CUDAStream& stream); +// Sort all groups' indices with one rocprim::segmented_radix_sort_pairs call, +// eliminating all per-group CPU launch overhead. +// +// Inputs must be contiguous across groups: +// all_keys_in : [num_groups * num_items_per_segment] — packed input indices +// all_values_in : [num_groups * num_items_per_segment] — tiled 0..N-1 per segment +// segment_offsets: [num_groups + 1] device tensor — [0, N, 2N, ..., K*N] +// all_keys_out / all_values_out: pre-allocated output buffers (same shape) +// temp_storage : pre-allocated via get_segmented_sort_temp_storage_bytes() +void sort_indices_segmented_rocprim( + const at::Tensor& all_keys_in, + at::Tensor& all_keys_out, + const at::Tensor& all_values_in, + at::Tensor& all_values_out, + const at::Tensor& segment_offsets, + const size_t num_items_per_segment, + const int64_t num_groups, + at::Tensor& temp_storage, + const at::cuda::CUDAStream& stream); +// Sort all groups in a batch with one AT_DISPATCH and one stream lookup. +// Uses radix_sort_pairs per group, preserving the merge sort +// fallback for small segment sizes (num_items < k_sort_merge_threshold). +void sort_indices_batch_rocprim( + const int64_t* keys_in_ptrs, + void* keys_out_base, + int64_t* values_out_base, + const int64_t* values_in, + const size_t num_items, + const int64_t num_groups, + at::Tensor& temp_storage, + const c10::ScalarType scalar_type, + const at::cuda::CUDAStream& stream); } // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index d219f21ce5..e31d5a3e01 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -562,13 +562,16 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( outputs.reserve(group_size * 2 + 1); // 1) Add group_size Variable()'s for indices - // c10::irange cannot be used in here as it - // triggers a build error of i being an unused variable. // Add empty tensor with zero size here to make __torch_dispatch__ work for // the backward op. Those empty tensors will be replaced with // torch::autograd::Variable() outside of the op call. - for (auto i = 0; i < group_size; i++) { - outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); + // Reuse a single placeholder tensor to avoid N separate allocations. + { + const auto placeholder = + at::empty({0}, at::TensorOptions().dtype(at::kLong)); + for (auto i = 0; i < group_size; i++) { + outputs.push_back(placeholder); + } } // Allocate Tensor for ptrs of grad output and input, and indices @@ -634,10 +637,8 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( // Reshape grad inputs and obtain their pointers for (int i = 0; i < group_size; i++) { - const auto grad_input_shape = std::vector( - output_shape_group.begin() + i * output_dim, - output_shape_group.begin() + (i + 1) * output_dim); - output_group[i] = output_group[i].reshape(grad_input_shape); + output_group[i] = output_group[i].reshape(c10::IntArrayRef( + output_shape_group.data() + i * output_dim, output_dim)); TORCH_CHECK( output_group[i].is_contiguous(), "Tensor output_group ", @@ -654,13 +655,43 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( // Calculate indices_ptrs std::vector> index_contigs; - std::vector sorted_indices_contigs; - std::vector reverse_indices_contigs; index_contigs.reserve(group_size); +#ifdef USE_ROCM + // Pre-allocate sort scratch and output buffers once for all groups. + // All groups share the same index dtype and size (enforced by forward checks). + at::Tensor sort_temp_storage; + at::Tensor sort_original_positions; + at::Tensor all_sorted_indices; + at::Tensor all_reverse_indices; + int64_t sort_num_items = static_cast(indices_group[0].numel()); + const bool use_segmented_sort = + static_cast(sort_num_items) <= rocm::k_sort_merge_threshold; if (use_sorted_indices) { - sorted_indices_contigs.reserve(group_size); - reverse_indices_contigs.reserve(group_size); + const auto stream = at::cuda::getCurrentCUDAStream(); + const size_t temp_bytes = use_segmented_sort + ? rocm::get_segmented_sort_temp_storage_bytes( + static_cast(sort_num_items), group_size, + first_indices.scalar_type(), stream) + : rocm::get_sort_temp_storage_bytes( + static_cast(sort_num_items), + first_indices.scalar_type(), stream); + sort_temp_storage = at::empty( + {static_cast(temp_bytes)}, + first_indices.options().dtype(at::kByte)); + // original_positions is always 0..N-1 and read-only, shared across groups. + sort_original_positions = at::arange( + sort_num_items, first_indices.options().dtype(at::kLong)); + // Single contiguous allocation for all groups' sort outputs. + // record_stream called twice total instead of 2*group_size. + all_sorted_indices = at::empty( + {group_size * sort_num_items}, first_indices.options()); + all_reverse_indices = at::empty( + {group_size * sort_num_items}, + first_indices.options().dtype(at::kLong)); + all_sorted_indices.record_stream(stream); + all_reverse_indices.record_stream(stream); } +#endif for (const auto i : c10::irange(group_size)) { const auto& indices = indices_group[i]; index_contigs.push_back(indices.expect_contiguous()); @@ -668,22 +699,60 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( reinterpret_cast(index_contigs[i]->const_data_ptr()); sorted_indices_ptrs[i] = 0; reverse_indices_ptrs[i] = 0; + } #ifdef USE_ROCM - if (use_sorted_indices) { - auto [sorted_tensor, reverse_tensor] = - rocm::sort_indices_with_rocprim(*index_contigs[i]); - const auto stream = at::cuda::getCurrentCUDAStream(); - sorted_tensor.record_stream(stream); - reverse_tensor.record_stream(stream); - sorted_indices_contigs.push_back(std::move(sorted_tensor)); - reverse_indices_contigs.push_back(std::move(reverse_tensor)); - sorted_indices_ptrs[i] = reinterpret_cast( - sorted_indices_contigs.back().data_ptr()); - reverse_indices_ptrs[i] = reinterpret_cast( - reverse_indices_contigs.back().data_ptr()); + if (use_sorted_indices) { + // Fill sorted/reverse ptr tables via direct pointer arithmetic. + const int64_t idx_elem_bytes = first_indices.element_size(); + auto* sorted_base = static_cast(all_sorted_indices.data_ptr()); + auto* reverse_base = static_cast(all_reverse_indices.data_ptr()); + for (int64_t i = 0; i < group_size; ++i) { + sorted_indices_ptrs[i] = + reinterpret_cast(sorted_base + i * sort_num_items * idx_elem_bytes); + reverse_indices_ptrs[i] = + reinterpret_cast(reverse_base + i * sort_num_items * sizeof(int64_t)); + } + + const auto stream = at::cuda::getCurrentCUDAStream(); + if (use_segmented_sort) { + std::vector index_tensors; + index_tensors.reserve(group_size); + for (const auto& ic : index_contigs) { + index_tensors.push_back(*ic); + } + const auto all_keys_in = at::cat(index_tensors, 0); + const auto all_values_in = + sort_original_positions.unsqueeze(0) + .expand({group_size, sort_num_items}) + .contiguous() + .view({-1}); + const auto segment_offsets = + at::arange(group_size + 1, first_indices.options().dtype(at::kLong)) * + sort_num_items; + rocm::sort_indices_segmented_rocprim( + all_keys_in, + all_sorted_indices, + all_values_in, + all_reverse_indices, + segment_offsets, + static_cast(sort_num_items), + group_size, + sort_temp_storage, + stream); + } else { + rocm::sort_indices_batch_rocprim( + indices_ptrs, + all_sorted_indices.data_ptr(), + all_reverse_indices.data_ptr(), + sort_original_positions.data_ptr(), + static_cast(sort_num_items), + group_size, + sort_temp_storage, + first_indices.scalar_type(), + stream); } -#endif } +#endif // Transfer grad output pointers to GPU args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); diff --git a/fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu b/fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu index c0754b46ed..f1a962de94 100644 --- a/fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu +++ b/fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu @@ -10,81 +10,136 @@ #include "fbgemm_gpu/utils/rocm/sparse_group_utils.h" namespace fbgemm_gpu::rocm { -DLL_PUBLIC std::tuple sort_indices_with_rocprim(const at::Tensor& indices) { - TORCH_CHECK( - indices.dim() == 1, - "sort_indices_with_rocprim expects a 1D tensor, got ", - indices.dim()); - TORCH_CHECK( - indices.is_cuda(), - "sort_indices_with_rocprim expects a CUDA tensor for indices"); +DLL_PUBLIC size_t get_sort_temp_storage_bytes( + const size_t num_items, + const c10::ScalarType scalar_type, + const at::cuda::CUDAStream& stream) { + size_t temp_storage_bytes = 0; + AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "get_sort_temp_storage_bytes", [&] { + using index_t = scalar_t; + AT_CUDA_CHECK(rocprim::radix_sort_pairs( + nullptr, + temp_storage_bytes, + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + num_items, + 0, + sizeof(index_t) * 8, + stream, + false)); + }); + return temp_storage_bytes; +} - CUDA_DEVICE_GUARD(indices); - auto contiguous_indices = indices.contiguous(); - auto sorted_indices = at::empty_like(contiguous_indices); - auto reverse_indices = at::empty( - contiguous_indices.sizes(), - contiguous_indices.options().dtype(at::kLong)); - auto original_positions = at::arange( - contiguous_indices.numel(), - contiguous_indices.options().dtype(at::kLong)); +DLL_PUBLIC size_t get_segmented_sort_temp_storage_bytes( + const size_t num_items_per_segment, + const int64_t num_groups, + const c10::ScalarType scalar_type, + const at::cuda::CUDAStream& stream) { + size_t temp_storage_bytes = 0; + const size_t total_items = num_items_per_segment * static_cast(num_groups); + AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "get_segmented_sort_temp_storage_bytes", [&] { + using index_t = scalar_t; + // segmented_radix_sort_pairs requires segmented_radix_sort_config, not + // radix_sort_config — use default config (radix sort, no merge fallback). + AT_CUDA_CHECK(rocprim::segmented_radix_sort_pairs( + nullptr, + temp_storage_bytes, + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + total_items, + static_cast(num_groups), + static_cast(nullptr), + static_cast(nullptr), + 0, + sizeof(index_t) * 8, + stream, + false)); + }); + return temp_storage_bytes; +} - const auto numel = contiguous_indices.numel(); - if (numel == 0) { - return {sorted_indices, reverse_indices}; +DLL_PUBLIC void sort_indices_segmented_rocprim( + const at::Tensor& all_keys_in, + at::Tensor& all_keys_out, + const at::Tensor& all_values_in, + at::Tensor& all_values_out, + const at::Tensor& segment_offsets, + const size_t num_items_per_segment, + const int64_t num_groups, + at::Tensor& temp_storage, + const at::cuda::CUDAStream& stream) { + if (num_items_per_segment == 0 || num_groups == 0) { + return; } - const auto num_items = static_cast(numel); - auto stream = at::cuda::getCurrentCUDAStream(); - - const auto scalar_type = contiguous_indices.scalar_type(); - AT_DISPATCH_INTEGRAL_TYPES( - scalar_type, "sort_indices_with_rocprim", [&] { - using index_t = scalar_t; - auto keys_in = contiguous_indices.data_ptr(); - auto keys_out = sorted_indices.data_ptr(); - auto values_in = original_positions.data_ptr(); - auto values_out = reverse_indices.data_ptr(); + size_t temp_storage_bytes = static_cast(temp_storage.numel()); + const size_t total_items = num_items_per_segment * static_cast(num_groups); + // segment_offsets is [0, N, 2N, ..., K*N]: begin[i] = ptr[i], end[i] = ptr[i+1] + const auto* begin_offsets = segment_offsets.const_data_ptr(); + const auto* end_offsets = begin_offsets + 1; - size_t temp_storage_bytes = 0; - // Selected empirically - constexpr int k_merge_sort_threshold = 400'000; + AT_DISPATCH_INTEGRAL_TYPES(all_keys_in.scalar_type(), "sort_indices_segmented_rocprim", [&] { + using index_t = scalar_t; + // segmented_radix_sort_pairs requires segmented_radix_sort_config — + // radix_sort_config is not accepted here, so default config is used. + // Only call this path when num_items_per_segment >= k_sort_merge_threshold + // so there is no regression vs the per-group merge sort path. + AT_CUDA_CHECK(rocprim::segmented_radix_sort_pairs( + temp_storage.data_ptr(), + temp_storage_bytes, + all_keys_in.const_data_ptr(), + all_keys_out.data_ptr(), + all_values_in.const_data_ptr(), + all_values_out.data_ptr(), + total_items, + static_cast(num_groups), + begin_offsets, + end_offsets, + 0, + sizeof(index_t) * 8, + stream, + false)); + }); +} - using sort_config = rocprim::radix_sort_config< - rocprim::default_config, - rocprim::default_config, - rocprim::default_config, - k_merge_sort_threshold>; - AT_CUDA_CHECK(rocprim::radix_sort_pairs( - nullptr, - temp_storage_bytes, - keys_in, - keys_out, - values_in, - values_out, - num_items, - 0, - sizeof(index_t) * 8, - stream, - false)); - auto temp_storage = at::empty( - {static_cast(temp_storage_bytes)}, - contiguous_indices.options().dtype(at::kByte)); +DLL_PUBLIC void sort_indices_batch_rocprim( + const int64_t* keys_in_ptrs, + void* keys_out_base, + int64_t* values_out_base, + const int64_t* values_in, + const size_t num_items, + const int64_t num_groups, + at::Tensor& temp_storage, + const c10::ScalarType scalar_type, + const at::cuda::CUDAStream& stream) { + if (num_items == 0 || num_groups == 0) { + return; + } + size_t temp_storage_bytes = static_cast(temp_storage.numel()); + void* temp_ptr = temp_storage.data_ptr(); + AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "sort_indices_batch_rocprim", [&] { + using index_t = scalar_t; + auto* keys_out = static_cast(keys_out_base); + for (int64_t i = 0; i < num_groups; ++i) { AT_CUDA_CHECK(rocprim::radix_sort_pairs( - temp_storage.data_ptr(), + temp_ptr, temp_storage_bytes, - keys_in, - keys_out, + reinterpret_cast(keys_in_ptrs[i]), + keys_out + i * num_items, values_in, - values_out, + values_out_base + i * num_items, num_items, 0, sizeof(index_t) * 8, stream, false)); + } }); - - return {sorted_indices, reverse_indices}; } } // namespace fbgemm::rocm From dc9abe4467963d8419360eba1f6f67d2facbb9d3 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 8 Apr 2026 11:45:49 +0000 Subject: [PATCH 10/11] Switch to AT_DISPATCH_INDEX_TYPES macro --- .../src/sparse_ops/utils/rocm/sparse_group_utils.cu | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu b/fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu index f1a962de94..01a6030b15 100644 --- a/fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu +++ b/fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu @@ -15,8 +15,7 @@ DLL_PUBLIC size_t get_sort_temp_storage_bytes( const c10::ScalarType scalar_type, const at::cuda::CUDAStream& stream) { size_t temp_storage_bytes = 0; - AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "get_sort_temp_storage_bytes", [&] { - using index_t = scalar_t; + AT_DISPATCH_INDEX_TYPES(scalar_type, "get_sort_temp_storage_bytes", [&] { AT_CUDA_CHECK(rocprim::radix_sort_pairs( nullptr, temp_storage_bytes, @@ -40,8 +39,7 @@ DLL_PUBLIC size_t get_segmented_sort_temp_storage_bytes( const at::cuda::CUDAStream& stream) { size_t temp_storage_bytes = 0; const size_t total_items = num_items_per_segment * static_cast(num_groups); - AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "get_segmented_sort_temp_storage_bytes", [&] { - using index_t = scalar_t; + AT_DISPATCH_INDEX_TYPES(scalar_type, "get_segmented_sort_temp_storage_bytes", [&] { // segmented_radix_sort_pairs requires segmented_radix_sort_config, not // radix_sort_config — use default config (radix sort, no merge fallback). AT_CUDA_CHECK(rocprim::segmented_radix_sort_pairs( @@ -83,8 +81,7 @@ DLL_PUBLIC void sort_indices_segmented_rocprim( const auto* begin_offsets = segment_offsets.const_data_ptr(); const auto* end_offsets = begin_offsets + 1; - AT_DISPATCH_INTEGRAL_TYPES(all_keys_in.scalar_type(), "sort_indices_segmented_rocprim", [&] { - using index_t = scalar_t; + AT_DISPATCH_INDEX_TYPES(all_keys_in.scalar_type(), "sort_indices_segmented_rocprim", [&] { // segmented_radix_sort_pairs requires segmented_radix_sort_config — // radix_sort_config is not accepted here, so default config is used. // Only call this path when num_items_per_segment >= k_sort_merge_threshold @@ -122,8 +119,7 @@ DLL_PUBLIC void sort_indices_batch_rocprim( } size_t temp_storage_bytes = static_cast(temp_storage.numel()); void* temp_ptr = temp_storage.data_ptr(); - AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "sort_indices_batch_rocprim", [&] { - using index_t = scalar_t; + AT_DISPATCH_INDEX_TYPES(scalar_type, "sort_indices_batch_rocprim", [&] { auto* keys_out = static_cast(keys_out_base); for (int64_t i = 0; i < num_groups; ++i) { AT_CUDA_CHECK(rocprim::radix_sort_pairs( From 948e39a711f82b5698ea6506b3ba74d6280a22c5 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Wed, 22 Apr 2026 14:21:12 +0000 Subject: [PATCH 11/11] sparse_ops_gpu.cpp: switches off sorting for small batch sizes --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index e31d5a3e01..da46b95100 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -425,11 +425,22 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( } #ifdef USE_ROCM - // The value is selected empirically. Potential + // The values are selected empirically. Potential // place for optimization. - constexpr size_t kSortIndicesThreshold = 15'000'000; + constexpr size_t kSortIndicesUpperThreshold = 15'000'000; + + // Sorting only pays off when there are enough indices to amortize + // the sorting cost, and the crossover point depends on the dtype. + constexpr size_t kSortIndicesLowerThresholdLowPrec = 1'000'000; + constexpr size_t kSortIndicesLowerThresholdFullPrec = 2'000'000; + + const bool is_low_precision = first_input.dtype().itemsize() <= 2; + const size_t kSortIndicesLowerThreshold = is_low_precision + ? kSortIndicesLowerThresholdLowPrec + : kSortIndicesLowerThresholdFullPrec; const bool use_sorted_indices_for_bwd = - (num_total_indices < kSortIndicesThreshold); + (num_total_indices >= kSortIndicesLowerThreshold) && + (num_total_indices < kSortIndicesUpperThreshold); #else const bool use_sorted_indices_for_bwd = false; (void)num_total_indices;