From 5e5098b77549f4768fa4ad31291306dd0502e189 Mon Sep 17 00:00:00 2001 From: gongchensu Date: Fri, 20 Mar 2026 16:52:17 +0800 Subject: [PATCH 1/3] feat(hygon-add): add Hygon backend support for `Add` - Add `WITH_HYGON` build support and a Hygon `Add` backend that reuses the shared CUDA implementation. - Detect DTK `nvcc` from the Hygon toolkit layout and auto-detect the GPU arch from `rocminfo`. - Treat Hygon as a CUDA-like backend in shared data type, cast, and kernel helper headers. - Skip the Hygon `gemm` example for now and ignore `build-*` temporary directories. - Verified with `pip install -e .[dev]` and `pytest tests/test_add.py`. --- .gitignore | 2 ++ CMakeLists.txt | 74 +++++++++++++++++++++++++++++++++++++++-- README.md | 6 ++++ examples/CMakeLists.txt | 4 +++ examples/runtime_api.h | 9 +++++ src/CMakeLists.txt | 30 ++++++++++++++++- src/hygon/add/kernel.h | 43 ++++++++++++++++++++++++ src/hygon/device_.h | 68 +++++++++++++++++++++++++++++++++++++ 8 files changed, 232 insertions(+), 4 deletions(-) create mode 100644 src/hygon/add/kernel.h create mode 100644 src/hygon/device_.h diff --git a/.gitignore b/.gitignore index 2effaff..4a5d4b7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ # Generated files build/ +build-*/ +cmake-build-*/ generated/ # Prerequisites diff --git a/CMakeLists.txt b/CMakeLists.txt index b9e2deb..9741887 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,7 @@ set(PYBIND11_ENABLE_EXTRAS ON) option(WITH_CPU "Enable CPU backend" OFF) option(WITH_NVIDIA "Enable CUDA backend" OFF) option(WITH_ILUVATAR "Enable Iluvatar GPU backend" OFF) +option(WITH_HYGON "Enable Hygon GPU backend" OFF) option(WITH_METAX "Enable MetaX backend" OFF) option(WITH_CAMBRICON "Enable Cambricon backend" OFF) option(WITH_MOORE "Enable Moore backend" OFF) @@ -18,6 +19,8 @@ option(WITH_MOORE "Enable Moore backend" OFF) option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF) +set(_DEFAULT_HYGON_DTK_ROOT "/opt/dtk") + if(AUTO_DETECT_DEVICES) message(STATUS "Auto-detecting available devices...") @@ -37,6 +40,13 @@ if(AUTO_DETECT_DEVICES) message(STATUS "Auto-detected Iluvatar environment.") endif() + if(DEFINED ENV{DTK_ROOT} OR + EXISTS "${_DEFAULT_HYGON_DTK_ROOT}/cuda/bin/nvcc" OR + EXISTS "${_DEFAULT_HYGON_DTK_ROOT}/cuda/cuda/bin/nvcc") + set(WITH_HYGON ON) + message(STATUS "Auto-detected Hygon environment.") + endif() + if(DEFINED ENV{MACA_PATH}) set(WITH_METAX ON) message(STATUS "Auto-detected MetaX environment from MACA_PATH") @@ -77,14 +87,14 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) # Only one CUDA-like GPU backend can be enabled at a time. set(_gpu_backend_count 0) -foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_METAX WITH_MOORE) +foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_HYGON WITH_METAX WITH_MOORE) if(${_gpu_backend}) math(EXPR _gpu_backend_count "${_gpu_backend_count} + 1") endif() endforeach() if(_gpu_backend_count GREATER 1) - message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, and `WITH_MOORE` are mutually exclusive. Build one GPU backend at a time.") + message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_HYGON`, `WITH_METAX`, and `WITH_MOORE` are mutually exclusive. Build one GPU backend at a time.") endif() if(WITH_NVIDIA) @@ -111,6 +121,64 @@ if(WITH_ILUVATAR) find_package(CUDAToolkit REQUIRED) endif() +if(WITH_HYGON) + add_compile_definitions(WITH_HYGON=1) + set(DTK_ROOT $ENV{DTK_ROOT}) + if(NOT DTK_ROOT) + set(DTK_ROOT "${_DEFAULT_HYGON_DTK_ROOT}") + endif() + if(NOT EXISTS "${DTK_ROOT}") + message(FATAL_ERROR "`WITH_HYGON` is `ON` but `DTK_ROOT` (`${DTK_ROOT}`) does not exist.") + endif() + + set(_HYGON_ARCH_DEFAULT "gfx906") + if(DEFINED ENV{HYGON_ARCH} AND NOT "$ENV{HYGON_ARCH}" STREQUAL "") + set(_HYGON_ARCH_DEFAULT "$ENV{HYGON_ARCH}") + else() + find_program(HYGON_ROCMINFO_EXECUTABLE NAMES rocminfo HINTS "${DTK_ROOT}/bin") + if(HYGON_ROCMINFO_EXECUTABLE) + execute_process( + COMMAND ${HYGON_ROCMINFO_EXECUTABLE} + OUTPUT_VARIABLE _HYGON_ROCMINFO_OUTPUT + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + string(REGEX MATCH "gfx[0-9]+" _HYGON_ARCH_AUTO "${_HYGON_ROCMINFO_OUTPUT}") + if(_HYGON_ARCH_AUTO) + set(_HYGON_ARCH_DEFAULT "${_HYGON_ARCH_AUTO}") + endif() + endif() + endif() + + set(HYGON_ARCH "${_HYGON_ARCH_DEFAULT}" CACHE STRING "Hygon GPU architecture") + set(HYGON_CUDA_ROOT "${DTK_ROOT}/cuda") + if(EXISTS "${DTK_ROOT}/cuda/cuda/bin/nvcc") + set(HYGON_CUDA_ROOT "${DTK_ROOT}/cuda/cuda") + endif() + + if(NOT EXISTS "${HYGON_CUDA_ROOT}/bin/nvcc") + message(FATAL_ERROR "`WITH_HYGON` is `ON` but `${HYGON_CUDA_ROOT}/bin/nvcc` was not found. Checked `${DTK_ROOT}/cuda/bin/nvcc` and `${DTK_ROOT}/cuda/cuda/bin/nvcc`.") + endif() + + set(CMAKE_CUDA_COMPILER "${HYGON_CUDA_ROOT}/bin/nvcc" CACHE FILEPATH "Hygon CUDA compiler (DTK nvcc)") + set(CUDAToolkit_ROOT "${HYGON_CUDA_ROOT}" CACHE PATH "Hygon CUDA toolkit root") + set(CMAKE_CUDA_ARCHITECTURES OFF CACHE STRING "Disable default CUDA arch flags for Hygon" FORCE) + set(CMAKE_CUDA_FLAGS "-std=c++17 -fPIC -arch=${HYGON_ARCH} -Wno-return-type -Wno-error=unused-private-field" CACHE STRING "Hygon CUDA flags") + set(CMAKE_CUDA_SEPARABLE_COMPILATION OFF CACHE BOOL "Disable RDC for Hygon") + + # DTK's nvcc wrapper may invoke `nvcc` by name during compiler checks. + set(ENV{PATH} "${HYGON_CUDA_ROOT}/bin:$ENV{PATH}") + + include_directories("${DTK_ROOT}/include") + include_directories("${HYGON_CUDA_ROOT}/include") + link_directories("${DTK_ROOT}/lib") + link_directories("${HYGON_CUDA_ROOT}/lib64") + + message(STATUS "Hygon: CUDA compiler ${CMAKE_CUDA_COMPILER}, arch ${HYGON_ARCH}, DTK root ${DTK_ROOT}") + enable_language(CUDA) + find_package(CUDAToolkit REQUIRED) +endif() + if(WITH_METAX) add_compile_definitions(WITH_METAX=1) @@ -179,7 +247,7 @@ if(WITH_CAMBRICON) endif() # If all other platforms are not enabled, CPU is enabled by default. -if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON) +if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_HYGON AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON) add_compile_definitions(WITH_CPU=1) endif() diff --git a/README.md b/README.md index 875a936..c61bc02 100644 --- a/README.md +++ b/README.md @@ -38,11 +38,17 @@ For the ``: |----------------------------------------|------------------------------------|:-: | `-DWITH_CPU=[ON\|OFF]` | Compile the CPU implementation | n | `-DWITH_NVIDIA=[ON\|OFF]` | Compile the NVIDIA implementation | n +| `-DWITH_ILUVATAR=[ON\|OFF]` | Compile the Iluvatar implementation| n +| `-DWITH_HYGON=[ON\|OFF]` | Compile the Hygon implementation | n | `-DWITH_METAX=[ON\|OFF]` | Compile the MetaX implementation | n | `-DGENERATE_PYTHON_BINDINGS=[ON\|OFF]` | Generate Python bindings | n *Note: If no accelerator options are provided, `WITH_CPU` is enabled by default.* +For Hygon builds, set `DTK_ROOT` to the DTK installation root if it is not +installed at `/opt/dtk`. You can override the default DCU arch with +`-DHYGON_ARCH=` when configuring CMake. + ## 🚀 Running Examples After a successful build, the executables are located in the `build/examples` directory. diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 68ebc1b..c4039c4 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -2,6 +2,10 @@ file(GLOB_RECURSE EXAMPLE_SOURCES CONFIGURE_DEPENDS "*.cc") # Iterate through each file and create an executable. foreach(source_file ${EXAMPLE_SOURCES}) + if(WITH_HYGON AND source_file MATCHES "/gemm\\.cc$") + continue() + endif() + get_filename_component(example_name ${source_file} NAME_WE) add_executable(${example_name} ${source_file}) diff --git a/examples/runtime_api.h b/examples/runtime_api.h index c5b7597..bc7398f 100644 --- a/examples/runtime_api.h +++ b/examples/runtime_api.h @@ -19,6 +19,15 @@ #define DEVICE_MEMCPY_HOST_TO_DEVICE cudaMemcpyHostToDevice #define DEVICE_MEMCPY_DEVICE_TO_HOST cudaMemcpyDeviceToHost #define DEFAULT_DEVICE_TYPE Device::Type::kIluvatar +#elif WITH_HYGON +#include +#define DEVICE_MALLOC cudaMalloc +#define DEVICE_FREE cudaFree +#define DEVICE_MEMCPY cudaMemcpy +#define DEVICE_MEMSET cudaMemset +#define DEVICE_MEMCPY_HOST_TO_DEVICE cudaMemcpyHostToDevice +#define DEVICE_MEMCPY_DEVICE_TO_HOST cudaMemcpyDeviceToHost +#define DEFAULT_DEVICE_TYPE Device::Type::kHygon #elif WITH_METAX #include #define DEVICE_MALLOC mcMalloc diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 585e3ab..130a7fc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -77,6 +77,34 @@ if(WITH_ILUVATAR) list(APPEND DEVICE_LIST "iluvatar") endif() +if(WITH_HYGON) + set(HYGON_PATTERNS + "cuda/*.cc" + "cuda/*.cpp" + "cuda/*.cu" + "hygon/*.cc" + "hygon/*.cpp" + "hygon/*.cu" + ) + + file(GLOB_RECURSE HYGON_SOURCES CONFIGURE_DEPENDS ${HYGON_PATTERNS}) + + enable_language(CUDA) + + target_compile_definitions(infiniops PUBLIC WITH_HYGON=1) + target_sources(infiniops PRIVATE ${HYGON_SOURCES}) + + find_package(CUDAToolkit REQUIRED) + target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas) + + set_target_properties(infiniops PROPERTIES + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + ) + + list(APPEND DEVICE_LIST "hygon") +endif() + if(WITH_METAX) set(METAX_PATTERNS "cuda/*.cc" @@ -191,7 +219,7 @@ if(GENERATE_PYTHON_BINDINGS) set(PYBIND11_SOURCES "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc") # TODO: There might be a better solution. - if(WITH_NVIDIA OR WITH_ILUVATAR) + if(WITH_NVIDIA OR WITH_ILUVATAR OR WITH_HYGON) set_source_files_properties(${PYBIND11_SOURCES} PROPERTIES LANGUAGE CUDA) endif() diff --git a/src/hygon/add/kernel.h b/src/hygon/add/kernel.h new file mode 100644 index 0000000..808c68c --- /dev/null +++ b/src/hygon/add/kernel.h @@ -0,0 +1,43 @@ +#ifndef INFINI_OPS_HYGON_ADD_KERNEL_H_ +#define INFINI_OPS_HYGON_ADD_KERNEL_H_ + +#include + +#include "cuda/add/kernel.h" +#include "hygon/device_.h" + +namespace infini::ops { + +namespace add { + +struct HygonBackend { + using stream_t = cudaStream_t; + + static constexpr Device::Type kDeviceType = Device::Type::kHygon; + + static constexpr auto malloc = [](auto&&... args) { + return cudaMalloc(std::forward(args)...); + }; + + static constexpr auto memcpy = cudaMemcpy; + + static constexpr auto free = cudaFree; + + static constexpr auto memcpyH2D = cudaMemcpyHostToDevice; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } +}; + +} // namespace add + +template <> +class Operator : public CudaAdd { + public: + using CudaAdd::CudaAdd; +}; + +} // namespace infini::ops + +#endif diff --git a/src/hygon/device_.h b/src/hygon/device_.h new file mode 100644 index 0000000..839cc36 --- /dev/null +++ b/src/hygon/device_.h @@ -0,0 +1,68 @@ +#ifndef INFINI_OPS_HYGON_DEVICE__H_ +#define INFINI_OPS_HYGON_DEVICE__H_ + +#include +#include + +// clang-format off +#include +#include +#include +// clang-format on + +#include "cuda/caster_.h" +#include "data_type.h" +#include "device.h" + +namespace infini::ops { + +using cuda_bfloat16 = nv_bfloat16; + +using cuda_bfloat162 = nv_bfloat162; + +template <> +struct TypeMap { + using type = half; +}; + +template <> +struct TypeMap { + using type = __nv_bfloat16; +}; + +// Caches `cudaDeviceProp` per device, initialized once at first access. +class DevicePropertyCache { + public: + static const cudaDeviceProp& GetCurrentDeviceProps() { + int device_id = 0; + cudaGetDevice(&device_id); + return GetDeviceProps(device_id); + } + + static const cudaDeviceProp& GetDeviceProps(int device_id) { + static std::vector cache = []() { + int count = 0; + cudaGetDeviceCount(&count); + if (count == 0) return std::vector{}; + std::vector props(count); + for (int i = 0; i < count; ++i) { + cudaGetDeviceProperties(&props[i], i); + } + return props; + }(); + + assert(device_id >= 0 && device_id < static_cast(cache.size())); + return cache[device_id]; + } +}; + +inline int QueryMaxThreadsPerBlock() { + return DevicePropertyCache::GetCurrentDeviceProps().maxThreadsPerBlock; +} + +template <> +struct Caster : CudaCasterImpl {}; + +} // namespace infini::ops + +#endif From f57a5c692cd59d8290b9c9a2ce1006990d171e1b Mon Sep 17 00:00:00 2001 From: gongchensu Date: Mon, 23 Mar 2026 18:13:23 +0800 Subject: [PATCH 2/3] feat(hygon-gemm): add Hygon backend support for `Gemm` - add a Hygon `Gemm` backend on top of the shared CUDA BLAS path - use DTK-friendly compute and algo settings for fp32/fp16 gemm - fall back to `cublasGemmEx` for single-batch Hygon gemm to avoid DTK crashes - release Hygon cublas handles after each call and re-enable the `gemm` example - verified with `pip install -e .[dev]`, `pytest tests/test_gemm.py -k cuda`, and `pytest tests/test_gemm.py` --- examples/CMakeLists.txt | 4 -- examples/gemm/gemm.cc | 3 + src/hygon/gemm/cublas.h | 154 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 157 insertions(+), 4 deletions(-) create mode 100644 src/hygon/gemm/cublas.h diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index c4039c4..68ebc1b 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -2,10 +2,6 @@ file(GLOB_RECURSE EXAMPLE_SOURCES CONFIGURE_DEPENDS "*.cc") # Iterate through each file and create an executable. foreach(source_file ${EXAMPLE_SOURCES}) - if(WITH_HYGON AND source_file MATCHES "/gemm\\.cc$") - continue() - endif() - get_filename_component(example_name ${source_file} NAME_WE) add_executable(${example_name} ${source_file}) diff --git a/examples/gemm/gemm.cc b/examples/gemm/gemm.cc index 4664740..9e1593c 100644 --- a/examples/gemm/gemm.cc +++ b/examples/gemm/gemm.cc @@ -11,6 +11,9 @@ #if WITH_ILUVATAR #include "iluvatar/gemm/cublas.h" #endif +#if WITH_HYGON +#include "hygon/gemm/cublas.h" +#endif #if WITH_METAX #include "metax/gemm/mcblas.h" #endif diff --git a/src/hygon/gemm/cublas.h b/src/hygon/gemm/cublas.h new file mode 100644 index 0000000..60eb2c5 --- /dev/null +++ b/src/hygon/gemm/cublas.h @@ -0,0 +1,154 @@ +#ifndef INFINI_OPS_HYGON_GEMM_CUBLAS_H_ +#define INFINI_OPS_HYGON_GEMM_CUBLAS_H_ + +#include + +// clang-format off +#include "cublas_v2.h" +// clang-format on + +#include "cuda/gemm/blas.h" + +namespace infini::ops { + +namespace gemm { + +struct HygonBackend { + using blasHandle_t = cublasHandle_t; + + using stream_t = cudaStream_t; + + static constexpr auto BLAS_OP_N = CUBLAS_OP_N; + + static constexpr auto BLAS_OP_T = CUBLAS_OP_T; + + static constexpr auto R_16F = CUDA_R_16F; + + static constexpr auto R_16BF = CUDA_R_16BF; + + static constexpr auto R_32F = CUDA_R_32F; + + static constexpr auto BLAS_COMPUTE_32F = CUBLAS_COMPUTE_32F; + + // DTK exposes the TF32 enum for compatibility, but BW/GFX9-class Hygon + // devices do not provide a working TF32 GEMM fast path. + static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = CUBLAS_COMPUTE_32F; + + static constexpr auto BLAS_GEMM_DEFAULT = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + + static constexpr auto blasCreate = cublasCreate; + + static constexpr auto blasSetStream = cublasSetStream; + + static constexpr auto blasDestroy = cublasDestroy; + + static constexpr auto blasGemmEx = [](auto&&... args) { + return cublasGemmEx(std::forward(args)...); + }; + + static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) { + return cublasGemmStridedBatchedEx(std::forward(args)...); + }; + + static auto GetDataType(DataType dtype) { + if (dtype == DataType::kFloat16) return R_16F; + if (dtype == DataType::kBFloat16) return R_16BF; + return R_32F; + } + + static auto GetComputeType(DataType dtype) { + if (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16) + return BLAS_COMPUTE_32F; + return BLAS_COMPUTE_32F_FAST_TF32; + } +}; + +} // namespace gemm + +template <> +class Operator : public Blas { + public: + using Blas::Blas; + + void operator()(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) const override { + const bool a_is_col_major = a.stride(-1) == 1; + const bool b_is_col_major = b.stride(-1) == 1; + const bool swap_a_and_b = c.stride(-1) == 1; + + auto get_op_a = [&](int trans_a_value, int trans_b_value) { + if (swap_a_and_b) { + return (b_is_col_major == trans_b_value) ? gemm::HygonBackend::BLAS_OP_T + : gemm::HygonBackend::BLAS_OP_N; + } + return (a_is_col_major != trans_a_value) ? gemm::HygonBackend::BLAS_OP_T + : gemm::HygonBackend::BLAS_OP_N; + }; + + auto get_op_b = [&](int trans_a_value, int trans_b_value) { + if (swap_a_and_b) { + return (a_is_col_major == trans_a_value) ? gemm::HygonBackend::BLAS_OP_T + : gemm::HygonBackend::BLAS_OP_N; + } + return (b_is_col_major != trans_b_value) ? gemm::HygonBackend::BLAS_OP_T + : gemm::HygonBackend::BLAS_OP_N; + }; + + gemm::HygonBackend::blasHandle_t handle{}; + gemm::HygonBackend::blasCreate(&handle); + gemm::HygonBackend::blasSetStream( + handle, static_cast(this->stream_)); + + const auto& alpha_value{alpha.value_or(this->alpha_)}; + const auto& beta_value{beta.value_or(this->beta_)}; + + const auto& trans_a_value{trans_a.value_or(this->trans_a_)}; + const auto& trans_b_value{trans_b.value_or(this->trans_b_)}; + auto op_a{get_op_a(trans_a_value, trans_b_value)}; + auto op_b{get_op_b(trans_a_value, trans_b_value)}; + const void* alpha_ptr{this->GetAlphaPtr(alpha_value, c.dtype())}; + const void* beta_ptr{this->GetBetaPtr(beta_value, c.dtype())}; + + if (this->batch_count_ == 1) { + gemm::HygonBackend::blasGemmEx( + handle, op_a, op_b, swap_a_and_b ? this->n_ : this->m_, + swap_a_and_b ? this->m_ : this->n_, this->k_, alpha_ptr, + swap_a_and_b ? b.data() : a.data(), + gemm::HygonBackend::GetDataType(swap_a_and_b ? b.dtype() + : a.dtype()), + swap_a_and_b ? this->ldb_ : this->lda_, + swap_a_and_b ? a.data() : b.data(), + gemm::HygonBackend::GetDataType(swap_a_and_b ? a.dtype() + : b.dtype()), + swap_a_and_b ? this->lda_ : this->ldb_, beta_ptr, c.data(), + gemm::HygonBackend::GetDataType(c.dtype()), this->ldc_, + gemm::HygonBackend::GetComputeType(c.dtype()), + gemm::HygonBackend::BLAS_GEMM_DEFAULT); + } else { + gemm::HygonBackend::blasGemmStridedBatchedEx( + handle, op_a, op_b, swap_a_and_b ? this->n_ : this->m_, + swap_a_and_b ? this->m_ : this->n_, this->k_, alpha_ptr, + swap_a_and_b ? b.data() : a.data(), + gemm::HygonBackend::GetDataType(swap_a_and_b ? b.dtype() + : a.dtype()), + swap_a_and_b ? this->ldb_ : this->lda_, + swap_a_and_b ? this->batch_stride_b_ : this->batch_stride_a_, + swap_a_and_b ? a.data() : b.data(), + gemm::HygonBackend::GetDataType(swap_a_and_b ? a.dtype() + : b.dtype()), + swap_a_and_b ? this->lda_ : this->ldb_, + swap_a_and_b ? this->batch_stride_a_ : this->batch_stride_b_, + beta_ptr, c.data(), gemm::HygonBackend::GetDataType(c.dtype()), + this->ldc_, this->batch_stride_c_, this->batch_count_, + gemm::HygonBackend::GetComputeType(c.dtype()), + gemm::HygonBackend::BLAS_GEMM_DEFAULT); + } + + gemm::HygonBackend::blasDestroy(handle); + } +}; + +} // namespace infini::ops + +#endif From 43809e7014b3cef70b6d681f6e9d7f3569cf2c92 Mon Sep 17 00:00:00 2001 From: gongchensu Date: Wed, 25 Mar 2026 16:37:27 +0800 Subject: [PATCH 3/3] feat(ops): implement CausalSoftmax operator with Hygon backend. --- src/cuda/causal_softmax/kernel.h | 6 +++-- src/cuda/kernel_commons.h | 42 ++++++++++++++++++++++++++++++ src/hygon/causal_softmax/kernel.h | 43 +++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 2 deletions(-) create mode 100644 src/hygon/causal_softmax/kernel.h diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h index 7ca0135..f59914a 100644 --- a/src/cuda/causal_softmax/kernel.h +++ b/src/cuda/causal_softmax/kernel.h @@ -1,6 +1,7 @@ #ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_ #define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_ +#include #include #include @@ -31,10 +32,11 @@ class CudaCausalSoftmax : public CausalSoftmax { assert(out.dtype() == input.dtype()); - int block_size = Backend::GetOptimalBlockSize(); + constexpr int kMaxBlockSize = BackendMaxBlockSize::value; + int block_size = std::min(Backend::GetOptimalBlockSize(), kMaxBlockSize); DispatchFunc, ReducedFloatTypes>, - AllCudaBlockSizes>( + SupportedCudaBlockSizesType::value>>( // TODO: Output dtype should use the one passed in during construction. {static_cast(out.dtype()), block_size}, [&](auto list_tag) { diff --git a/src/cuda/kernel_commons.h b/src/cuda/kernel_commons.h index bb25fad..8f7c1a7 100644 --- a/src/cuda/kernel_commons.h +++ b/src/cuda/kernel_commons.h @@ -1,12 +1,54 @@ #ifndef INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_ #define INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_ +#include + #include "caster.h" namespace infini::ops { using AllCudaBlockSizes = List<128, 256, 512, 1024, 2048>; +template +struct BackendMaxBlockSize : std::integral_constant {}; + +template +struct BackendMaxBlockSize> + : std::integral_constant {}; + +template +struct SupportedCudaBlockSizes; + +template <> +struct SupportedCudaBlockSizes<2048> { + using type = AllCudaBlockSizes; +}; + +template <> +struct SupportedCudaBlockSizes<1024> { + using type = List<128, 256, 512, 1024>; +}; + +template <> +struct SupportedCudaBlockSizes<512> { + using type = List<128, 256, 512>; +}; + +template <> +struct SupportedCudaBlockSizes<256> { + using type = List<128, 256>; +}; + +template <> +struct SupportedCudaBlockSizes<128> { + using type = List<128>; +}; + +template +using SupportedCudaBlockSizesType = + typename SupportedCudaBlockSizes::type; + __forceinline__ __device__ __host__ size_t IndexToOffset(size_t flat_index, size_t ndim, const size_t* shape, const ptrdiff_t* strides) { diff --git a/src/hygon/causal_softmax/kernel.h b/src/hygon/causal_softmax/kernel.h new file mode 100644 index 0000000..c9e054a --- /dev/null +++ b/src/hygon/causal_softmax/kernel.h @@ -0,0 +1,43 @@ +#ifndef INFINI_OPS_HYGON_CAUSAL_SOFTMAX_KERNEL_H_ +#define INFINI_OPS_HYGON_CAUSAL_SOFTMAX_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +// clang-format off +#include "hygon/device_.h" +// clang-format on + +#include "cuda/causal_softmax/kernel.h" + +namespace infini::ops { + +namespace causal_softmax { + +struct HygonBackend { + using stream_t = cudaStream_t; + + static constexpr Device::Type kDeviceType = Device::Type::kHygon; + + static constexpr int max_block_size = 256; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } +}; + +} // namespace causal_softmax + +template <> +class Operator + : public CudaCausalSoftmax { + public: + using CudaCausalSoftmax::CudaCausalSoftmax; +}; + +} // namespace infini::ops + +#endif