Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_CPU)
src/quantize_ops/quantize_msfp.cu
src/quantize_ops/quantize_padded_fp8_rowwise.cu
src/quantize_ops/quantize_mx.cu
src/sparse_ops/utils/rocm/sparse_group_utils.cu
src/sparse_ops/sparse_async_batched_cumsum.cu
src/sparse_ops/sparse_block_bucketize_features.cu
src/sparse_ops/sparse_bucketize_features.cu
Expand Down
7 changes: 6 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,8 @@ void group_index_select_or_add_cuda(
const int64_t* input_ptrs,
const int64_t* output_ptrs,
const int64_t* indices_ptrs,
const int64_t* sorted_indices_ptrs,
const int64_t* reverse_indices_ptrs,
const int64_t* warp_offsets_group,
const int32_t* num_cols_group,
const c10::ScalarType& input_scalar_type,
Expand All @@ -1087,7 +1089,10 @@ void group_index_select_or_add_cuda(
const int64_t total_num_warps,
const int group_size,
const bool use_index_select,
const bool use_var_cols);
const bool use_var_cols,
const bool use_contiguous_warps,
const bool use_cache,
const bool use_packed_rows);

int get_group_index_select_cols_per_warp();

Expand Down
63 changes: 62 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,26 @@
#include <cstdint>
#include <limits>

#include <ATen/Dispatch.h>
#include <ATen/ATen.h>

#include <hip/hip_runtime.h>
#include <rocprim/device/device_radix_sort.hpp>
#include <rocprim/device/device_segmented_radix_sort.hpp>

#include "fbgemm_gpu/utils/cuda_prelude.cuh"
#include "fbgemm_gpu/utils/function_types.h"

namespace fbgemm_gpu::rocm {
// Selected empirically: rocprim uses merge sort when num_items < this threshold,
// which is faster for small inputs. Must match across sizing and sort calls.
constexpr unsigned int k_sort_merge_threshold = 400'000;
using sort_config = rocprim::radix_sort_config<
rocprim::default_config,
rocprim::default_config,
rocprim::default_config,
k_sort_merge_threshold>;

namespace {

template <typename scalar_t, int kLogicalWarpSize = kWarpSize>
Expand Down Expand Up @@ -67,6 +84,50 @@ __device__ __forceinline__ void warp_upper_bound(
*found = result;
*cached_boundary = cached_result;
}

} // namespace

// Returns temp storage size for a single-segment sort of num_items elements.
size_t get_sort_temp_storage_bytes(
const size_t num_items,
const c10::ScalarType scalar_type,
const at::cuda::CUDAStream& stream);
// Returns temp storage size for segmented sort of num_groups segments each
// with num_items_per_segment elements.
size_t get_segmented_sort_temp_storage_bytes(
const size_t num_items_per_segment,
const int64_t num_groups,
const c10::ScalarType scalar_type,
const at::cuda::CUDAStream& stream);
// Sort all groups' indices with one rocprim::segmented_radix_sort_pairs call,
// eliminating all per-group CPU launch overhead.
//
// Inputs must be contiguous across groups:
// all_keys_in : [num_groups * num_items_per_segment] — packed input indices
// all_values_in : [num_groups * num_items_per_segment] — tiled 0..N-1 per segment
// segment_offsets: [num_groups + 1] device tensor — [0, N, 2N, ..., K*N]
// all_keys_out / all_values_out: pre-allocated output buffers (same shape)
// temp_storage : pre-allocated via get_segmented_sort_temp_storage_bytes()
void sort_indices_segmented_rocprim(
const at::Tensor& all_keys_in,
at::Tensor& all_keys_out,
const at::Tensor& all_values_in,
at::Tensor& all_values_out,
const at::Tensor& segment_offsets,
const size_t num_items_per_segment,
const int64_t num_groups,
at::Tensor& temp_storage,
const at::cuda::CUDAStream& stream);
// Sort all groups in a batch with one AT_DISPATCH and one stream lookup.
// Uses radix_sort_pairs<sort_config> per group, preserving the merge sort
// fallback for small segment sizes (num_items < k_sort_merge_threshold).
void sort_indices_batch_rocprim(
const int64_t* keys_in_ptrs,
void* keys_out_base,
int64_t* values_out_base,
const int64_t* values_in,
const size_t num_items,
const int64_t num_groups,
at::Tensor& temp_storage,
const c10::ScalarType scalar_type,
const at::cuda::CUDAStream& stream);
} // namespace fbgemm_gpu::rocm
Loading