diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index 73a17572ef..da84ce3edc 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -110,6 +110,7 @@ if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_CPU) src/quantize_ops/quantize_msfp.cu src/quantize_ops/quantize_padded_fp8_rowwise.cu src/quantize_ops/quantize_mx.cu + src/sparse_ops/utils/rocm/sparse_group_utils.cu src/sparse_ops/sparse_async_batched_cumsum.cu src/sparse_ops/sparse_block_bucketize_features.cu src/sparse_ops/sparse_bucketize_features.cu diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 0e7bd37234..edf9c2ee57 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -1078,6 +1078,8 @@ void group_index_select_or_add_cuda( const int64_t* input_ptrs, const int64_t* output_ptrs, const int64_t* indices_ptrs, + const int64_t* sorted_indices_ptrs, + const int64_t* reverse_indices_ptrs, const int64_t* warp_offsets_group, const int32_t* num_cols_group, const c10::ScalarType& input_scalar_type, @@ -1087,7 +1089,10 @@ void group_index_select_or_add_cuda( const int64_t total_num_warps, const int group_size, const bool use_index_select, - const bool use_var_cols); + const bool use_var_cols, + const bool use_contiguous_warps, + const bool use_cache, + const bool use_packed_rows); int get_group_index_select_cols_per_warp(); diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h index e32d346897..6cb740911e 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h @@ -11,9 +11,26 @@ #include #include +#include +#include + +#include +#include +#include + #include "fbgemm_gpu/utils/cuda_prelude.cuh" +#include "fbgemm_gpu/utils/function_types.h" namespace fbgemm_gpu::rocm { +// Selected empirically: rocprim uses merge sort when num_items < this threshold, +// which is faster for small inputs. Must match across sizing and sort calls. +constexpr unsigned int k_sort_merge_threshold = 400'000; +using sort_config = rocprim::radix_sort_config< + rocprim::default_config, + rocprim::default_config, + rocprim::default_config, + k_sort_merge_threshold>; + namespace { template @@ -67,6 +84,50 @@ __device__ __forceinline__ void warp_upper_bound( *found = result; *cached_boundary = cached_result; } - } // namespace + +// Returns temp storage size for a single-segment sort of num_items elements. +size_t get_sort_temp_storage_bytes( + const size_t num_items, + const c10::ScalarType scalar_type, + const at::cuda::CUDAStream& stream); +// Returns temp storage size for segmented sort of num_groups segments each +// with num_items_per_segment elements. +size_t get_segmented_sort_temp_storage_bytes( + const size_t num_items_per_segment, + const int64_t num_groups, + const c10::ScalarType scalar_type, + const at::cuda::CUDAStream& stream); +// Sort all groups' indices with one rocprim::segmented_radix_sort_pairs call, +// eliminating all per-group CPU launch overhead. +// +// Inputs must be contiguous across groups: +// all_keys_in : [num_groups * num_items_per_segment] — packed input indices +// all_values_in : [num_groups * num_items_per_segment] — tiled 0..N-1 per segment +// segment_offsets: [num_groups + 1] device tensor — [0, N, 2N, ..., K*N] +// all_keys_out / all_values_out: pre-allocated output buffers (same shape) +// temp_storage : pre-allocated via get_segmented_sort_temp_storage_bytes() +void sort_indices_segmented_rocprim( + const at::Tensor& all_keys_in, + at::Tensor& all_keys_out, + const at::Tensor& all_values_in, + at::Tensor& all_values_out, + const at::Tensor& segment_offsets, + const size_t num_items_per_segment, + const int64_t num_groups, + at::Tensor& temp_storage, + const at::cuda::CUDAStream& stream); +// Sort all groups in a batch with one AT_DISPATCH and one stream lookup. +// Uses radix_sort_pairs per group, preserving the merge sort +// fallback for small segment sizes (num_items < k_sort_merge_threshold). +void sort_indices_batch_rocprim( + const int64_t* keys_in_ptrs, + void* keys_out_base, + int64_t* values_out_base, + const int64_t* values_in, + const size_t num_items, + const int64_t num_groups, + at::Tensor& temp_storage, + const c10::ScalarType scalar_type, + const at::cuda::CUDAStream& stream); } // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index ddec5d0e01..1a1f122f77 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -5,6 +5,10 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ +#include +#include +#include +#include #include "common.cuh" #ifdef USE_ROCM @@ -45,6 +49,10 @@ template < typename scalar_t, bool USE_INDEX_SELECT, bool USE_VAR_COLS, + bool USE_CONTIGUOUS_WARPS, + bool USE_SORTED_INDICES, + bool USE_CACHE, + bool USE_PACKED_ROWS, int UNROLL_FACTOR, int COLS_PER_WARP, int LOG_COLS_PER_WARP> @@ -53,10 +61,17 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int64_t* input_ptrs, const int64_t* output_ptrs, const int64_t* indices_ptrs, + const int64_t* reverse_indices_ptrs, const int64_t* warp_offsets_group, const int32_t* num_cols_group, const int64_t num_work_rows, // number of rows to work on per member const int64_t group_size) { + static_assert( + !USE_CACHE || UNROLL_FACTOR == 1, + "Cache path only supports UNROLL_FACTOR == 1"); + + constexpr index_t kInvalidIdx = std::numeric_limits::max(); + const auto total_num_warps = warp_offsets_group[group_size]; int32_t num_cols = 0; int32_t warps_per_row = 0; @@ -66,14 +81,50 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; } -#ifdef USE_ROCM - int cached_member_id = -1; - int64_t cached_upper_bound = -1; -#endif + [[maybe_unused]] int cached_member_id = -1; + [[maybe_unused]] int64_t cached_upper_bound = -1; + [[maybe_unused]] int32_t last_member_id_for_accum = -1; + [[maybe_unused]] int32_t last_member_num_cols = 0; + [[maybe_unused]] scalar_t* last_member_output_tile = nullptr; + + int64_t start_warp_id = 0; + int64_t warp_end = 0; + int64_t warp_stride = 0; + + if constexpr (USE_CONTIGUOUS_WARPS) { + const int64_t linear_warp_id = threadIdx.y * gridDim.x + blockIdx.x; + const int64_t warps_per_launch = gridDim.x * blockDim.y; + const int64_t chunk_size = + (total_num_warps + warps_per_launch - 1) / warps_per_launch; + start_warp_id = linear_warp_id * chunk_size; + warp_end = start_warp_id + chunk_size < total_num_warps + ? start_warp_id + chunk_size + : total_num_warps; + warp_stride = 1; + } else { + start_warp_id = threadIdx.y * gridDim.x + blockIdx.x; + warp_end = total_num_warps; + warp_stride = gridDim.x * blockDim.y; + } - for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; - warp_id < total_num_warps; - warp_id += gridDim.x * blockDim.y) { + auto storage = scalar_t(0); + auto cached_idx = kInvalidIdx; + // TODO: Account for UNROLL_FACTOR + auto flush_cache_accumulator = [&](scalar_t* target_output, + int32_t target_num_cols) { + if constexpr (!USE_INDEX_SELECT && USE_CACHE) { + if (target_output && cached_idx != kInvalidIdx) { + gpuAtomicAddNoReturn( + &target_output[cached_idx * target_num_cols], + storage); + cached_idx = kInvalidIdx; + } + } + }; + + for (int64_t warp_id = start_warp_id; warp_id < warp_end; warp_id += warp_stride) { + bool use_small_dim_path = false; + int rows_per_warp_small = 0; int32_t member_id = 0; int32_t member_warp_id = 0; if constexpr (USE_VAR_COLS) { @@ -108,92 +159,149 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( // All columns are the same member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); -#ifdef USE_ROCM + } + + if constexpr (USE_PACKED_ROWS) { if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { // Need to ensure that [member_id] and [member_warp_id] are calculated // correctly for the small embedding dimension path below - const auto rows_per_warp = COLS_PER_WARP / num_cols; - const auto warps_per_member = - DIV_ROUND_UP(num_work_rows, rows_per_warp); - member_id = warp_id / warps_per_member; - member_warp_id = warp_id % warps_per_member; + rows_per_warp_small = COLS_PER_WARP / num_cols; + if constexpr (!USE_VAR_COLS) { + const auto warps_per_member = + (num_work_rows + rows_per_warp_small - 1) / rows_per_warp_small; + member_id = warp_id / warps_per_member; + member_warp_id = warp_id % warps_per_member; + } + use_small_dim_path = true; } -#endif // USE_ROCM } -#ifdef USE_ROCM - if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { - // Optimized path for small embedding dimensions - // Each warp processes 'rows_per_warp' rows - const auto rows_per_warp = COLS_PER_WARP / num_cols; - const int64_t start_row = member_warp_id * rows_per_warp; - - // Since we are processing multiple rows within the warp, we need to - // map each lane to a specific row, in addition to the column - const auto local_row = (threadIdx.x * UNROLL_FACTOR) / - num_cols; // the row ID within the set of rows handled by this warp - const auto col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; - const int64_t current_row = start_row + - local_row; // the actual row within the table processed by this lane - - // local_row may be out of bounds for the last few lanes in the warp if - // [COLS_PER_WARP % num_cols != 0] and we also need to confirm that we are - // within num_work_rows - if (local_row < rows_per_warp && current_row < num_work_rows) { - scalar_t* input = - reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = - reinterpret_cast(output_ptrs[member_id]) + col_offset; - - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[current_row]; -#pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { - // Compile time conditional - if constexpr (USE_INDEX_SELECT) { - output[current_row * num_cols + i] = - LDG(&input[idx * num_cols + i]); - } else { - gpuAtomicAddNoReturn( - &output[idx * num_cols + i], input[current_row * num_cols + i]); - } + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const int64_t* reverse_indices = USE_SORTED_INDICES + ? reinterpret_cast(reverse_indices_ptrs[member_id]) + : nullptr; + + int64_t logical_row = 0; + int64_t row = 0; + int64_t col_offset = 0; + bool handled_small_dim_path = false; + + if constexpr (USE_PACKED_ROWS) { + if (use_small_dim_path) { + // Optimized path for small embedding dimensions + // Each warp processes 'rows_per_warp' rows + const int rows_per_warp = rows_per_warp_small; + const int64_t start_row = member_warp_id * rows_per_warp; + // Since we are processing multiple rows within the warp, we need to + // map each lane to a specific row, in addition to the column + const int local_row = (threadIdx.x * UNROLL_FACTOR) / + num_cols; // the row ID within the set of rows handled by this warp + const int64_t current_row = start_row + + local_row; // the actual row within the table processed by this lane + const int col_offset_small = (threadIdx.x * UNROLL_FACTOR) % num_cols; + // local_row may be out of bounds for the last few lanes in the warp if + // [COLS_PER_WARP % num_cols != 0] and we also need to confirm that we are + // within num_work_rows + if (local_row < rows_per_warp && current_row < num_work_rows) { + logical_row = current_row; + row = USE_SORTED_INDICES ? reverse_indices[current_row] : current_row; + col_offset = col_offset_small; + handled_small_dim_path = true; + } else { + flush_cache_accumulator(last_member_output_tile, last_member_num_cols); + continue; } } - } else { - // Large embedding dimensions use >= 1 warp per row - // which is the default codepath for non-ROCm as well -#endif // USE_ROCM - const auto row = member_warp_id / warps_per_row; - const auto col_offset = - ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + + } + + if (!handled_small_dim_path) { + int64_t row_in_member = 0; + int64_t col_tile = 0; + if constexpr (USE_CONTIGUOUS_WARPS) { + // Contiguous warp traversal: iterate rows sequentially while column tiles + // remain strided so each warp processes a different tile for successive rows. + row_in_member = member_warp_id % num_work_rows; + col_tile = member_warp_id / num_work_rows; + } else { + // Original strided mapping: each warp walks tiles first, distributing rows round-robin. + row_in_member = member_warp_id / warps_per_row; + col_tile = member_warp_id % warps_per_row; + } + + logical_row = row_in_member; + row = USE_SORTED_INDICES ? reverse_indices[row_in_member] : row_in_member; + col_offset = + (static_cast(col_tile) << LOG_COLS_PER_WARP) + (threadIdx.x * UNROLL_FACTOR); - scalar_t* input = - reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = - reinterpret_cast(output_ptrs[member_id]) + col_offset; + } + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[row]; + const index_t idx = indices[logical_row]; + // TODO: Account for UNROLL_FACTOR #pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { - // Compile time conditional - if constexpr (USE_INDEX_SELECT) { + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + // Compile time conditional + if constexpr (USE_INDEX_SELECT) { + if constexpr (USE_CACHE) { + if (cached_idx != idx) { + storage = LDG(&input[idx * num_cols + i]); + cached_idx = idx; + } + + output[row * num_cols + i] = storage; + } else { output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); + } + } else { + if constexpr (USE_CACHE) { + const bool member_changed = (last_member_id_for_accum != -1 && + member_id != last_member_id_for_accum); + // Probably might be merged into following if-else cascade + if (member_changed) { + flush_cache_accumulator(last_member_output_tile, last_member_num_cols); + } + + const bool is_first_warp = member_changed || (warp_id == start_warp_id); + const bool is_last_warp = (warp_id + warp_stride >= warp_end); + if (is_first_warp) { + storage = input[row * num_cols + i]; + cached_idx = idx; + } else if (cached_idx != idx) { + // Flush using the output tile that owns the cached accumulator. + flush_cache_accumulator(last_member_output_tile, last_member_num_cols); + storage = input[row * num_cols + i]; + cached_idx = idx; + } else { + storage += input[row * num_cols + i]; + } + + if (is_last_warp) { + flush_cache_accumulator(output, num_cols); + } + + last_member_output_tile = output; + last_member_num_cols = num_cols; + last_member_id_for_accum = member_id; } else { gpuAtomicAddNoReturn( &output[idx * num_cols + i], input[row * num_cols + i]); } } -#ifdef USE_ROCM } -#endif // USE_ROCM } + + flush_cache_accumulator(last_member_output_tile, last_member_num_cols); } DLL_PUBLIC void group_index_select_or_add_cuda( const int64_t* input_ptrs, const int64_t* output_ptrs, const int64_t* indices_ptrs, + const int64_t* sorted_indices_ptrs, + const int64_t* reverse_indices_ptrs, const int64_t* warp_offsets_group, const int32_t* num_cols_group, const c10::ScalarType& input_scalar_type, @@ -203,12 +311,16 @@ DLL_PUBLIC void group_index_select_or_add_cuda( const int64_t total_num_warps, const int group_size, const bool use_index_select, - const bool use_var_cols) { + const bool use_var_cols, + const bool use_contiguous_warps, + const bool use_cache, + const bool use_packed_rows) { if (group_size == 0) { return; } at::cuda::OptionalCUDAGuard device_guard(device); + const bool use_sorted_indices = (sorted_indices_ptrs && reverse_indices_ptrs); // Partition work based on num_work_rows uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE; @@ -219,49 +331,106 @@ DLL_PUBLIC void group_index_select_or_add_cuda( max_grid_size); dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1); -#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \ - FBGEMM_LAUNCH_KERNEL( \ - (group_index_select_or_add_2d_kernel< \ - index_t, \ - scalar_t, \ - USE_INDEX_SELECT, \ - USE_VAR_COLS, \ - GROUP_INDEX_SELECT_UNROLL_FACTOR, \ - GROUP_INDEX_SELECT_COLS_PER_WARP, \ - GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \ - grid_size, \ - block_size, \ - 0, \ - at::cuda::getCurrentCUDAStream(), \ - input_ptrs, \ - output_ptrs, \ - indices_ptrs, \ - warp_offsets_group, \ - num_cols_group, \ - num_work_rows, \ - group_size) + auto invoke_group_index_select_or_add = [&]() { + FBGEMM_LAUNCH_KERNEL( + (group_index_select_or_add_2d_kernel< + index_t, + scalar_t, + USE_INDEX_SELECT, + USE_VAR_COLS, + USE_CONTIGUOUS_WARPS, + USE_SORTED_INDICES, + USE_CACHE, + USE_PACKED_ROWS, + GROUP_INDEX_SELECT_UNROLL_FACTOR, + GROUP_INDEX_SELECT_COLS_PER_WARP, + GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), + grid_size, + block_size, + 0, + at::cuda::getCurrentCUDAStream(), + input_ptrs, + output_ptrs, + use_sorted_indices ? sorted_indices_ptrs : indices_ptrs, + reverse_indices_ptrs, + warp_offsets_group, + num_cols_group, + num_work_rows, + group_size); + }; + + using bool_variant_t = std::variant; + +// Split is needed to avoid additional code generation for unsupported +// algorithms on CUDA +#ifdef USE_ROCM + using platform_bool_variant_t = std::variant; +#else + using platform_bool_variant_t = std::variant; +#endif + + auto get_bool_type = [](const bool var) -> bool_variant_t { + if (var) { + return std::true_type{}; + } else { + return std::false_type{}; + } + }; + + const bool_variant_t use_index_select_variant = + get_bool_type(use_index_select); + const bool_variant_t use_var_cols_variant = get_bool_type(use_var_cols); +#ifdef USE_ROCM + const platform_bool_variant_t use_contiguous_warps_variant = + get_bool_type(use_contiguous_warps); + const platform_bool_variant_t use_sorted_indices_variant = + get_bool_type(use_sorted_indices); + const platform_bool_variant_t use_cache_variant = get_bool_type(use_cache); + const platform_bool_variant_t use_packed_rows_variant = + get_bool_type(use_packed_rows); +#else + const platform_bool_variant_t use_contiguous_warps_variant{std::false_type{}}; + const platform_bool_variant_t use_sorted_indices_variant{std::false_type{}}; + const platform_bool_variant_t use_cache_variant{std::false_type{}}; + const platform_bool_variant_t use_packed_rows_variant{std::false_type{}}; +#endif AT_DISPATCH_INDEX_TYPES( indices_scalar_type, "group_index_select_2d_wrapper_1", [&] { FBGEMM_DISPATCH_FLOATING_TYPES( input_scalar_type, "group_index_select_2d_wrapper_2", [&] { - if (use_index_select) { - if (use_var_cols) { - INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true); - } else { - INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, false); - } - } else { - if (use_var_cols) { - INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, true); - } else { - INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false); - } - } + std::visit( + [&](auto use_index_select_arg, + auto use_var_cols_arg, + auto use_contiguous_warps_arg, + auto use_sorted_indices_arg, + auto use_cache_arg, + auto use_packed_rows_arg) { + invoke_group_index_select_or_add.template operator()< + index_t, + scalar_t, + use_index_select_arg.value, + use_var_cols_arg.value, + use_contiguous_warps_arg.value, + use_sorted_indices_arg.value, + use_cache_arg.value, + use_packed_rows_arg.value>(); + }, + use_index_select_variant, + use_var_cols_variant, + use_contiguous_warps_variant, + use_sorted_indices_variant, + use_cache_variant, + use_packed_rows_variant); }); }); - -#undef INVOKE_GROUP_INDEX_SELECT_OR_ADD } } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 411da4fcc1..da3a717025 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -3636,6 +3636,7 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::forward( .typed(); auto result = forward_op.call(all_indices_input, group_size); TORCH_CHECK(static_cast(result.size()) == group_size + 2); + ctx->saved_data["group_size"] = group_size; auto [input_group, indices_group] = group_index_select_dim0_unpack(all_indices_input, group_size); @@ -3674,7 +3675,7 @@ torch::autograd::variable_list GroupIndexSelectDim0Op::backward( return torch::autograd::variable_list(1); } // remove redundant grads - auto group_size = grad_output_group.size() - 2; + const auto group_size = ctx->saved_data["group_size"].toInt(); grad_output_group.resize(group_size); const auto saved_tensors = ctx->get_saved_variables(); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index d85eb3cba7..da46b95100 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -12,14 +12,19 @@ #include "fbgemm_gpu/sparse_ops.h" #include "fbgemm_gpu/utils/ops_utils.h" #include "fbgemm_gpu/utils/tensor_utils.h" +#ifdef USE_ROCM +#include "fbgemm_gpu/utils/rocm/sparse_group_utils.h" +#endif #include +#include #include #include #include #include #include #include // for logic_error +#include using Tensor = at::Tensor; @@ -27,13 +32,15 @@ namespace fbgemm_gpu { namespace { -constexpr int32_t NUM_ARGS = 5; +constexpr int32_t NUM_ARGS = 7; enum args_pos { P_input_ptrs = 0, P_output_ptrs = 1, P_indices_ptrs = 2, - P_warp_offsets_group_ptrs = 3, - P_num_cols_group_ptrs = 4 + P_sorted_indices_ptrs = 3, + P_reverse_indices_ptrs = 4, + P_warp_offsets_group_ptrs = 5, + P_num_cols_group_ptrs = 6, }; template @@ -47,6 +54,8 @@ void offset_args( int64_t** input_ptrs, int64_t** output_ptrs, int64_t** indices_ptrs, + int64_t** sorted_indices_ptrs, + int64_t** reverse_indices_ptrs, int64_t** warp_offsets_group, int32_t** num_cols_group, int64_t* base_addr, @@ -54,6 +63,8 @@ void offset_args( *input_ptrs = base_addr + ptr_offsets[P_input_ptrs]; *output_ptrs = base_addr + ptr_offsets[P_output_ptrs]; *indices_ptrs = base_addr + ptr_offsets[P_indices_ptrs]; + *sorted_indices_ptrs = base_addr + ptr_offsets[P_sorted_indices_ptrs]; + *reverse_indices_ptrs = base_addr + ptr_offsets[P_reverse_indices_ptrs]; *warp_offsets_group = base_addr + ptr_offsets[P_warp_offsets_group_ptrs]; *num_cols_group = reinterpret_cast( base_addr + ptr_offsets[P_num_cols_group_ptrs]); @@ -213,6 +224,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( // input_ptrs (group_size int64_t elements) // output_ptrs (group_size int64_t elements) // indices_ptrs (group_size int64_t elements) + // sorted_indices_ptrs (group_size int64_t elements) + // reverse_indices_ptrs (group_size int64_t elements) // warp_offsets_group (group_size + 1 int64_t elements) // num_cols_group (group_size int32_t elements) int64_t args_ptrs_offsets[NUM_ARGS + 1]; @@ -224,6 +237,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( args_ptrs_offsets[P_input_ptrs] = group_size; args_ptrs_offsets[P_output_ptrs] = group_size; args_ptrs_offsets[P_indices_ptrs] = group_size; + args_ptrs_offsets[P_sorted_indices_ptrs] = group_size; + args_ptrs_offsets[P_reverse_indices_ptrs] = group_size; args_ptrs_offsets[P_warp_offsets_group_ptrs] = group_size + 1; args_ptrs_offsets[P_num_cols_group_ptrs] = numels_num_cols_group_64; @@ -251,6 +266,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( int64_t* input_ptrs = nullptr; int64_t* output_ptrs = nullptr; int64_t* indices_ptrs = nullptr; + int64_t* sorted_indices_ptrs = nullptr; + int64_t* reverse_indices_ptrs = nullptr; int64_t* warp_offsets_group = nullptr; int32_t* num_cols_group = nullptr; @@ -259,6 +276,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( &input_ptrs, &output_ptrs, &indices_ptrs, + &sorted_indices_ptrs, + &reverse_indices_ptrs, &warp_offsets_group, &num_cols_group, reinterpret_cast(args_tensor.mutable_data_ptr()), @@ -290,6 +309,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( input_contigs.reserve(group_size); index_contigs.reserve(group_size); + bool use_packed_rows = false; + size_t num_total_indices = 0; // For each group, copy input to output for (const auto i : c10::irange(group_size)) { const auto& input = input_group[i]; @@ -335,6 +356,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input, indices); auto num_output_rows_ = indices.size(0); + num_total_indices += num_output_rows_; // Verify that all input tensors have the same shape[0] TORCH_CHECK_VALUE( @@ -358,6 +380,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( // Optimization: Pack multiple rows into one warp int rows_per_warp = cols_per_warp / num_cols_; warps_needed = (num_output_rows_ + rows_per_warp - 1) / rows_per_warp; + use_packed_rows = true; } else { // Standard: One or more warps per row int warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; @@ -400,6 +423,28 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( warp_offset += warps_per_row * num_output_rows; #endif // USE_ROCM } + +#ifdef USE_ROCM + // The values are selected empirically. Potential + // place for optimization. + constexpr size_t kSortIndicesUpperThreshold = 15'000'000; + + // Sorting only pays off when there are enough indices to amortize + // the sorting cost, and the crossover point depends on the dtype. + constexpr size_t kSortIndicesLowerThresholdLowPrec = 1'000'000; + constexpr size_t kSortIndicesLowerThresholdFullPrec = 2'000'000; + + const bool is_low_precision = first_input.dtype().itemsize() <= 2; + const size_t kSortIndicesLowerThreshold = is_low_precision + ? kSortIndicesLowerThresholdLowPrec + : kSortIndicesLowerThresholdFullPrec; + const bool use_sorted_indices_for_bwd = + (num_total_indices >= kSortIndicesLowerThreshold) && + (num_total_indices < kSortIndicesUpperThreshold); +#else + const bool use_sorted_indices_for_bwd = false; + (void)num_total_indices; +#endif // Store the last offset warp_offsets_group[group_size] = warp_offset; @@ -414,6 +459,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( &input_ptrs, &output_ptrs, &indices_ptrs, + &sorted_indices_ptrs, + &reverse_indices_ptrs, &warp_offsets_group, &num_cols_group, reinterpret_cast(args_tensor.mutable_data_ptr()), @@ -422,6 +469,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( int64_t saved_data[] = { static_cast(group_size), use_var_cols, + use_packed_rows, + use_sorted_indices_for_bwd, reinterpret_cast(warp_offsets_group), reinterpret_cast(num_cols_group), warp_offset, @@ -438,6 +487,8 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( input_ptrs, output_ptrs, indices_ptrs, + /*sorted_indices_ptrs=*/nullptr, + /*reverse_indices_ptrs=*/nullptr, warp_offsets_group, num_cols_group, first_input.scalar_type(), @@ -447,7 +498,10 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( /*total_num_warps=*/warp_offset, group_size, /*use_index_select=*/true, - use_var_cols); + use_var_cols, + /*use_contiguous_warps=*/false, + /*use_cache=*/false, + use_packed_rows); output_group.push_back(args_tensor); output_group.push_back(saved_data_t); @@ -499,11 +553,13 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( " but got ", saved_data_ptr[0]); const bool use_var_cols = saved_data_ptr[1]; + const bool use_packed_rows = saved_data_ptr[2]; + const bool use_sorted_indices = saved_data_ptr[3]; const int64_t* warp_offsets_group = - reinterpret_cast(saved_data_ptr[2]); + reinterpret_cast(saved_data_ptr[4]); const int32_t* num_cols_group = - reinterpret_cast(saved_data_ptr[3]); - int64_t total_num_warps = saved_data_ptr[4]; + reinterpret_cast(saved_data_ptr[5]); + int64_t total_num_warps = saved_data_ptr[6]; // We checked in forward that all output rows are the same for all member // in the group @@ -517,18 +573,21 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( outputs.reserve(group_size * 2 + 1); // 1) Add group_size Variable()'s for indices - // c10::irange cannot be used in here as it - // triggers a build error of i being an unused variable. // Add empty tensor with zero size here to make __torch_dispatch__ work for // the backward op. Those empty tensors will be replaced with // torch::autograd::Variable() outside of the op call. - for (auto i = 0; i < group_size; i++) { - outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); + // Reuse a single placeholder tensor to avoid N separate allocations. + { + const auto placeholder = + at::empty({0}, at::TensorOptions().dtype(at::kLong)); + for (auto i = 0; i < group_size; i++) { + outputs.push_back(placeholder); + } } // Allocate Tensor for ptrs of grad output and input, and indices Tensor args_tensor = at::empty( - {group_size * 3}, + {group_size * 5}, at::TensorOptions().dtype(at::kLong).pinned_memory(true)); // Ensure that args_tensor is contiguous TORCH_CHECK( @@ -538,6 +597,8 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( args_tensor.mutable_data_ptr() + group_size; int64_t* indices_ptrs = args_tensor.mutable_data_ptr() + 2 * group_size; + int64_t* sorted_indices_ptrs = args_tensor.data_ptr() + 3 * group_size; + int64_t* reverse_indices_ptrs = args_tensor.data_ptr() + 4 * group_size; int64_t group_grad_input_numel = 0; std::vector grad_input_numels; @@ -587,10 +648,8 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( // Reshape grad inputs and obtain their pointers for (int i = 0; i < group_size; i++) { - const auto grad_input_shape = std::vector( - output_shape_group.begin() + i * output_dim, - output_shape_group.begin() + (i + 1) * output_dim); - output_group[i] = output_group[i].reshape(grad_input_shape); + output_group[i] = output_group[i].reshape(c10::IntArrayRef( + output_shape_group.data() + i * output_dim, output_dim)); TORCH_CHECK( output_group[i].is_contiguous(), "Tensor output_group ", @@ -608,20 +667,121 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( // Calculate indices_ptrs std::vector> index_contigs; index_contigs.reserve(group_size); +#ifdef USE_ROCM + // Pre-allocate sort scratch and output buffers once for all groups. + // All groups share the same index dtype and size (enforced by forward checks). + at::Tensor sort_temp_storage; + at::Tensor sort_original_positions; + at::Tensor all_sorted_indices; + at::Tensor all_reverse_indices; + int64_t sort_num_items = static_cast(indices_group[0].numel()); + const bool use_segmented_sort = + static_cast(sort_num_items) <= rocm::k_sort_merge_threshold; + if (use_sorted_indices) { + const auto stream = at::cuda::getCurrentCUDAStream(); + const size_t temp_bytes = use_segmented_sort + ? rocm::get_segmented_sort_temp_storage_bytes( + static_cast(sort_num_items), group_size, + first_indices.scalar_type(), stream) + : rocm::get_sort_temp_storage_bytes( + static_cast(sort_num_items), + first_indices.scalar_type(), stream); + sort_temp_storage = at::empty( + {static_cast(temp_bytes)}, + first_indices.options().dtype(at::kByte)); + // original_positions is always 0..N-1 and read-only, shared across groups. + sort_original_positions = at::arange( + sort_num_items, first_indices.options().dtype(at::kLong)); + // Single contiguous allocation for all groups' sort outputs. + // record_stream called twice total instead of 2*group_size. + all_sorted_indices = at::empty( + {group_size * sort_num_items}, first_indices.options()); + all_reverse_indices = at::empty( + {group_size * sort_num_items}, + first_indices.options().dtype(at::kLong)); + all_sorted_indices.record_stream(stream); + all_reverse_indices.record_stream(stream); + } +#endif for (const auto i : c10::irange(group_size)) { const auto& indices = indices_group[i]; index_contigs.push_back(indices.expect_contiguous()); indices_ptrs[i] = reinterpret_cast(index_contigs[i]->const_data_ptr()); + sorted_indices_ptrs[i] = 0; + reverse_indices_ptrs[i] = 0; } +#ifdef USE_ROCM + if (use_sorted_indices) { + // Fill sorted/reverse ptr tables via direct pointer arithmetic. + const int64_t idx_elem_bytes = first_indices.element_size(); + auto* sorted_base = static_cast(all_sorted_indices.data_ptr()); + auto* reverse_base = static_cast(all_reverse_indices.data_ptr()); + for (int64_t i = 0; i < group_size; ++i) { + sorted_indices_ptrs[i] = + reinterpret_cast(sorted_base + i * sort_num_items * idx_elem_bytes); + reverse_indices_ptrs[i] = + reinterpret_cast(reverse_base + i * sort_num_items * sizeof(int64_t)); + } + + const auto stream = at::cuda::getCurrentCUDAStream(); + if (use_segmented_sort) { + std::vector index_tensors; + index_tensors.reserve(group_size); + for (const auto& ic : index_contigs) { + index_tensors.push_back(*ic); + } + const auto all_keys_in = at::cat(index_tensors, 0); + const auto all_values_in = + sort_original_positions.unsqueeze(0) + .expand({group_size, sort_num_items}) + .contiguous() + .view({-1}); + const auto segment_offsets = + at::arange(group_size + 1, first_indices.options().dtype(at::kLong)) * + sort_num_items; + rocm::sort_indices_segmented_rocprim( + all_keys_in, + all_sorted_indices, + all_values_in, + all_reverse_indices, + segment_offsets, + static_cast(sort_num_items), + group_size, + sort_temp_storage, + stream); + } else { + rocm::sort_indices_batch_rocprim( + indices_ptrs, + all_sorted_indices.data_ptr(), + all_reverse_indices.data_ptr(), + sort_original_positions.data_ptr(), + static_cast(sort_num_items), + group_size, + sort_temp_storage, + first_indices.scalar_type(), + stream); + } + } +#endif // Transfer grad output pointers to GPU args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); +#ifdef USE_ROCM + constexpr bool use_contiguous_warps = true; + constexpr bool use_cache = true; +#else + constexpr bool use_contiguous_warps = false; + constexpr bool use_cache = false; +#endif + group_index_select_or_add_cuda( args_tensor.const_data_ptr(), args_tensor.const_data_ptr() + group_size, args_tensor.const_data_ptr() + 2 * group_size, + use_sorted_indices ? args_tensor.data_ptr() + 3 * group_size : nullptr, + use_sorted_indices ? args_tensor.data_ptr() + 4 * group_size : nullptr, warp_offsets_group, num_cols_group, fwd_input.scalar_type(), @@ -631,7 +791,10 @@ static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( total_num_warps, group_size, /*use_index_select=*/false, - use_var_cols); + use_var_cols, + use_contiguous_warps, + use_cache, + use_packed_rows); return outputs; } diff --git a/fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu b/fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu new file mode 100644 index 0000000000..01a6030b15 --- /dev/null +++ b/fbgemm_gpu/src/sparse_ops/utils/rocm/sparse_group_utils.cu @@ -0,0 +1,142 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef USE_ROCM +#include "fbgemm_gpu/utils/rocm/sparse_group_utils.h" + +namespace fbgemm_gpu::rocm { +DLL_PUBLIC size_t get_sort_temp_storage_bytes( + const size_t num_items, + const c10::ScalarType scalar_type, + const at::cuda::CUDAStream& stream) { + size_t temp_storage_bytes = 0; + AT_DISPATCH_INDEX_TYPES(scalar_type, "get_sort_temp_storage_bytes", [&] { + AT_CUDA_CHECK(rocprim::radix_sort_pairs( + nullptr, + temp_storage_bytes, + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + num_items, + 0, + sizeof(index_t) * 8, + stream, + false)); + }); + return temp_storage_bytes; +} + +DLL_PUBLIC size_t get_segmented_sort_temp_storage_bytes( + const size_t num_items_per_segment, + const int64_t num_groups, + const c10::ScalarType scalar_type, + const at::cuda::CUDAStream& stream) { + size_t temp_storage_bytes = 0; + const size_t total_items = num_items_per_segment * static_cast(num_groups); + AT_DISPATCH_INDEX_TYPES(scalar_type, "get_segmented_sort_temp_storage_bytes", [&] { + // segmented_radix_sort_pairs requires segmented_radix_sort_config, not + // radix_sort_config — use default config (radix sort, no merge fallback). + AT_CUDA_CHECK(rocprim::segmented_radix_sort_pairs( + nullptr, + temp_storage_bytes, + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + total_items, + static_cast(num_groups), + static_cast(nullptr), + static_cast(nullptr), + 0, + sizeof(index_t) * 8, + stream, + false)); + }); + return temp_storage_bytes; +} + +DLL_PUBLIC void sort_indices_segmented_rocprim( + const at::Tensor& all_keys_in, + at::Tensor& all_keys_out, + const at::Tensor& all_values_in, + at::Tensor& all_values_out, + const at::Tensor& segment_offsets, + const size_t num_items_per_segment, + const int64_t num_groups, + at::Tensor& temp_storage, + const at::cuda::CUDAStream& stream) { + if (num_items_per_segment == 0 || num_groups == 0) { + return; + } + + size_t temp_storage_bytes = static_cast(temp_storage.numel()); + const size_t total_items = num_items_per_segment * static_cast(num_groups); + // segment_offsets is [0, N, 2N, ..., K*N]: begin[i] = ptr[i], end[i] = ptr[i+1] + const auto* begin_offsets = segment_offsets.const_data_ptr(); + const auto* end_offsets = begin_offsets + 1; + + AT_DISPATCH_INDEX_TYPES(all_keys_in.scalar_type(), "sort_indices_segmented_rocprim", [&] { + // segmented_radix_sort_pairs requires segmented_radix_sort_config — + // radix_sort_config is not accepted here, so default config is used. + // Only call this path when num_items_per_segment >= k_sort_merge_threshold + // so there is no regression vs the per-group merge sort path. + AT_CUDA_CHECK(rocprim::segmented_radix_sort_pairs( + temp_storage.data_ptr(), + temp_storage_bytes, + all_keys_in.const_data_ptr(), + all_keys_out.data_ptr(), + all_values_in.const_data_ptr(), + all_values_out.data_ptr(), + total_items, + static_cast(num_groups), + begin_offsets, + end_offsets, + 0, + sizeof(index_t) * 8, + stream, + false)); + }); +} + +DLL_PUBLIC void sort_indices_batch_rocprim( + const int64_t* keys_in_ptrs, + void* keys_out_base, + int64_t* values_out_base, + const int64_t* values_in, + const size_t num_items, + const int64_t num_groups, + at::Tensor& temp_storage, + const c10::ScalarType scalar_type, + const at::cuda::CUDAStream& stream) { + if (num_items == 0 || num_groups == 0) { + return; + } + size_t temp_storage_bytes = static_cast(temp_storage.numel()); + void* temp_ptr = temp_storage.data_ptr(); + AT_DISPATCH_INDEX_TYPES(scalar_type, "sort_indices_batch_rocprim", [&] { + auto* keys_out = static_cast(keys_out_base); + for (int64_t i = 0; i < num_groups; ++i) { + AT_CUDA_CHECK(rocprim::radix_sort_pairs( + temp_ptr, + temp_storage_bytes, + reinterpret_cast(keys_in_ptrs[i]), + keys_out + i * num_items, + values_in, + values_out_base + i * num_items, + num_items, + 0, + sizeof(index_t) * 8, + stream, + false)); + } + }); +} +} // namespace fbgemm::rocm + +#endif