Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Generated files
build/
build-*/
cmake-build-*/
generated/

# Prerequisites
Expand Down
74 changes: 71 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ set(PYBIND11_ENABLE_EXTRAS ON)
option(WITH_CPU "Enable CPU backend" OFF)
option(WITH_NVIDIA "Enable CUDA backend" OFF)
option(WITH_ILUVATAR "Enable Iluvatar GPU backend" OFF)
option(WITH_HYGON "Enable Hygon GPU backend" OFF)
option(WITH_METAX "Enable MetaX backend" OFF)
option(WITH_CAMBRICON "Enable Cambricon backend" OFF)
option(WITH_MOORE "Enable Moore backend" OFF)

option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF)
option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF)

set(_DEFAULT_HYGON_DTK_ROOT "/opt/dtk")

if(AUTO_DETECT_DEVICES)
message(STATUS "Auto-detecting available devices...")

Expand All @@ -37,6 +40,13 @@ if(AUTO_DETECT_DEVICES)
message(STATUS "Auto-detected Iluvatar environment.")
endif()

if(DEFINED ENV{DTK_ROOT} OR
EXISTS "${_DEFAULT_HYGON_DTK_ROOT}/cuda/bin/nvcc" OR
EXISTS "${_DEFAULT_HYGON_DTK_ROOT}/cuda/cuda/bin/nvcc")
set(WITH_HYGON ON)
message(STATUS "Auto-detected Hygon environment.")
endif()

if(DEFINED ENV{MACA_PATH})
set(WITH_METAX ON)
message(STATUS "Auto-detected MetaX environment from MACA_PATH")
Expand Down Expand Up @@ -77,14 +87,14 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)

# Only one CUDA-like GPU backend can be enabled at a time.
set(_gpu_backend_count 0)
foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_METAX WITH_MOORE)
foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_HYGON WITH_METAX WITH_MOORE)
if(${_gpu_backend})
math(EXPR _gpu_backend_count "${_gpu_backend_count} + 1")
endif()
endforeach()

if(_gpu_backend_count GREATER 1)
message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, and `WITH_MOORE` are mutually exclusive. Build one GPU backend at a time.")
message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_HYGON`, `WITH_METAX`, and `WITH_MOORE` are mutually exclusive. Build one GPU backend at a time.")
endif()

if(WITH_NVIDIA)
Expand All @@ -111,6 +121,64 @@ if(WITH_ILUVATAR)
find_package(CUDAToolkit REQUIRED)
endif()

if(WITH_HYGON)
add_compile_definitions(WITH_HYGON=1)
set(DTK_ROOT $ENV{DTK_ROOT})
if(NOT DTK_ROOT)
set(DTK_ROOT "${_DEFAULT_HYGON_DTK_ROOT}")
endif()
if(NOT EXISTS "${DTK_ROOT}")
message(FATAL_ERROR "`WITH_HYGON` is `ON` but `DTK_ROOT` (`${DTK_ROOT}`) does not exist.")
endif()

set(_HYGON_ARCH_DEFAULT "gfx906")
if(DEFINED ENV{HYGON_ARCH} AND NOT "$ENV{HYGON_ARCH}" STREQUAL "")
set(_HYGON_ARCH_DEFAULT "$ENV{HYGON_ARCH}")
else()
find_program(HYGON_ROCMINFO_EXECUTABLE NAMES rocminfo HINTS "${DTK_ROOT}/bin")
if(HYGON_ROCMINFO_EXECUTABLE)
execute_process(
COMMAND ${HYGON_ROCMINFO_EXECUTABLE}
OUTPUT_VARIABLE _HYGON_ROCMINFO_OUTPUT
ERROR_QUIET
OUTPUT_STRIP_TRAILING_WHITESPACE
)
string(REGEX MATCH "gfx[0-9]+" _HYGON_ARCH_AUTO "${_HYGON_ROCMINFO_OUTPUT}")
if(_HYGON_ARCH_AUTO)
set(_HYGON_ARCH_DEFAULT "${_HYGON_ARCH_AUTO}")
endif()
endif()
endif()

set(HYGON_ARCH "${_HYGON_ARCH_DEFAULT}" CACHE STRING "Hygon GPU architecture")
set(HYGON_CUDA_ROOT "${DTK_ROOT}/cuda")
if(EXISTS "${DTK_ROOT}/cuda/cuda/bin/nvcc")
set(HYGON_CUDA_ROOT "${DTK_ROOT}/cuda/cuda")
endif()

if(NOT EXISTS "${HYGON_CUDA_ROOT}/bin/nvcc")
message(FATAL_ERROR "`WITH_HYGON` is `ON` but `${HYGON_CUDA_ROOT}/bin/nvcc` was not found. Checked `${DTK_ROOT}/cuda/bin/nvcc` and `${DTK_ROOT}/cuda/cuda/bin/nvcc`.")
endif()

set(CMAKE_CUDA_COMPILER "${HYGON_CUDA_ROOT}/bin/nvcc" CACHE FILEPATH "Hygon CUDA compiler (DTK nvcc)")
set(CUDAToolkit_ROOT "${HYGON_CUDA_ROOT}" CACHE PATH "Hygon CUDA toolkit root")
set(CMAKE_CUDA_ARCHITECTURES OFF CACHE STRING "Disable default CUDA arch flags for Hygon" FORCE)
set(CMAKE_CUDA_FLAGS "-std=c++17 -fPIC -arch=${HYGON_ARCH} -Wno-return-type -Wno-error=unused-private-field" CACHE STRING "Hygon CUDA flags")
set(CMAKE_CUDA_SEPARABLE_COMPILATION OFF CACHE BOOL "Disable RDC for Hygon")

# DTK's nvcc wrapper may invoke `nvcc` by name during compiler checks.
set(ENV{PATH} "${HYGON_CUDA_ROOT}/bin:$ENV{PATH}")

include_directories("${DTK_ROOT}/include")
include_directories("${HYGON_CUDA_ROOT}/include")
link_directories("${DTK_ROOT}/lib")
link_directories("${HYGON_CUDA_ROOT}/lib64")

message(STATUS "Hygon: CUDA compiler ${CMAKE_CUDA_COMPILER}, arch ${HYGON_ARCH}, DTK root ${DTK_ROOT}")
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
endif()

if(WITH_METAX)
add_compile_definitions(WITH_METAX=1)

Expand Down Expand Up @@ -179,7 +247,7 @@ if(WITH_CAMBRICON)
endif()

# If all other platforms are not enabled, CPU is enabled by default.
if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON)
if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_HYGON AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON)
add_compile_definitions(WITH_CPU=1)
endif()

Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,17 @@ For the `<OPTIONS>`:
|----------------------------------------|------------------------------------|:-:
| `-DWITH_CPU=[ON\|OFF]` | Compile the CPU implementation | n
| `-DWITH_NVIDIA=[ON\|OFF]` | Compile the NVIDIA implementation | n
| `-DWITH_ILUVATAR=[ON\|OFF]` | Compile the Iluvatar implementation| n
| `-DWITH_HYGON=[ON\|OFF]` | Compile the Hygon implementation | n
| `-DWITH_METAX=[ON\|OFF]` | Compile the MetaX implementation | n
| `-DGENERATE_PYTHON_BINDINGS=[ON\|OFF]` | Generate Python bindings | n

*Note: If no accelerator options are provided, `WITH_CPU` is enabled by default.*

For Hygon builds, set `DTK_ROOT` to the DTK installation root if it is not
installed at `/opt/dtk`. You can override the default DCU arch with
`-DHYGON_ARCH=<arch>` when configuring CMake.

## 🚀 Running Examples
After a successful build, the executables are located in the `build/examples` directory.

Expand Down
3 changes: 3 additions & 0 deletions examples/gemm/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#if WITH_ILUVATAR
#include "iluvatar/gemm/cublas.h"
#endif
#if WITH_HYGON
#include "hygon/gemm/cublas.h"
#endif
#if WITH_METAX
#include "metax/gemm/mcblas.h"
#endif
Expand Down
9 changes: 9 additions & 0 deletions examples/runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
#define DEVICE_MEMCPY_HOST_TO_DEVICE cudaMemcpyHostToDevice
#define DEVICE_MEMCPY_DEVICE_TO_HOST cudaMemcpyDeviceToHost
#define DEFAULT_DEVICE_TYPE Device::Type::kIluvatar
#elif WITH_HYGON
#include <cuda_runtime.h>
#define DEVICE_MALLOC cudaMalloc
#define DEVICE_FREE cudaFree
#define DEVICE_MEMCPY cudaMemcpy
#define DEVICE_MEMSET cudaMemset
#define DEVICE_MEMCPY_HOST_TO_DEVICE cudaMemcpyHostToDevice
#define DEVICE_MEMCPY_DEVICE_TO_HOST cudaMemcpyDeviceToHost
#define DEFAULT_DEVICE_TYPE Device::Type::kHygon
#elif WITH_METAX
#include <mcr/mc_runtime.h>
#define DEVICE_MALLOC mcMalloc
Expand Down
30 changes: 29 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,34 @@ if(WITH_ILUVATAR)
list(APPEND DEVICE_LIST "iluvatar")
endif()

if(WITH_HYGON)
set(HYGON_PATTERNS
"cuda/*.cc"
"cuda/*.cpp"
"cuda/*.cu"
"hygon/*.cc"
"hygon/*.cpp"
"hygon/*.cu"
)

file(GLOB_RECURSE HYGON_SOURCES CONFIGURE_DEPENDS ${HYGON_PATTERNS})

enable_language(CUDA)

target_compile_definitions(infiniops PUBLIC WITH_HYGON=1)
target_sources(infiniops PRIVATE ${HYGON_SOURCES})

find_package(CUDAToolkit REQUIRED)
target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas)

set_target_properties(infiniops PROPERTIES
CUDA_STANDARD 17
CUDA_STANDARD_REQUIRED ON
)

list(APPEND DEVICE_LIST "hygon")
endif()

if(WITH_METAX)
set(METAX_PATTERNS
"cuda/*.cc"
Expand Down Expand Up @@ -191,7 +219,7 @@ if(GENERATE_PYTHON_BINDINGS)
set(PYBIND11_SOURCES "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc")

# TODO: There might be a better solution.
if(WITH_NVIDIA OR WITH_ILUVATAR)
if(WITH_NVIDIA OR WITH_ILUVATAR OR WITH_HYGON)
set_source_files_properties(${PYBIND11_SOURCES} PROPERTIES LANGUAGE CUDA)
endif()

Expand Down
6 changes: 4 additions & 2 deletions src/cuda/causal_softmax/kernel.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
#define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_

#include <algorithm>
#include <cassert>
#include <cstdint>

Expand Down Expand Up @@ -31,10 +32,11 @@ class CudaCausalSoftmax : public CausalSoftmax {

assert(out.dtype() == input.dtype());

int block_size = Backend::GetOptimalBlockSize();
constexpr int kMaxBlockSize = BackendMaxBlockSize<Backend>::value;
int block_size = std::min(Backend::GetOptimalBlockSize(), kMaxBlockSize);

DispatchFunc<ConcatType<List<DataType::kFloat32>, ReducedFloatTypes>,
AllCudaBlockSizes>(
SupportedCudaBlockSizesType<BackendMaxBlockSize<Backend>::value>>(
// TODO: Output dtype should use the one passed in during construction.
{static_cast<int64_t>(out.dtype()), block_size},
[&](auto list_tag) {
Expand Down
42 changes: 42 additions & 0 deletions src/cuda/kernel_commons.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,54 @@
#ifndef INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_
#define INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_

#include <type_traits>

#include "caster.h"

namespace infini::ops {

using AllCudaBlockSizes = List<128, 256, 512, 1024, 2048>;

template <typename Backend, typename = void>
struct BackendMaxBlockSize : std::integral_constant<int, 2048> {};

template <typename Backend>
struct BackendMaxBlockSize<Backend,
std::void_t<decltype(Backend::max_block_size)>>
: std::integral_constant<int, Backend::max_block_size> {};

template <int max_block_size>
struct SupportedCudaBlockSizes;

template <>
struct SupportedCudaBlockSizes<2048> {
using type = AllCudaBlockSizes;
};

template <>
struct SupportedCudaBlockSizes<1024> {
using type = List<128, 256, 512, 1024>;
};

template <>
struct SupportedCudaBlockSizes<512> {
using type = List<128, 256, 512>;
};

template <>
struct SupportedCudaBlockSizes<256> {
using type = List<128, 256>;
};

template <>
struct SupportedCudaBlockSizes<128> {
using type = List<128>;
};

template <int max_block_size>
using SupportedCudaBlockSizesType =
typename SupportedCudaBlockSizes<max_block_size>::type;

__forceinline__ __device__ __host__ size_t
IndexToOffset(size_t flat_index, size_t ndim, const size_t* shape,
const ptrdiff_t* strides) {
Expand Down
43 changes: 43 additions & 0 deletions src/hygon/add/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef INFINI_OPS_HYGON_ADD_KERNEL_H_
#define INFINI_OPS_HYGON_ADD_KERNEL_H_

#include <utility>

#include "cuda/add/kernel.h"
#include "hygon/device_.h"

namespace infini::ops {

namespace add {

struct HygonBackend {
using stream_t = cudaStream_t;

static constexpr Device::Type kDeviceType = Device::Type::kHygon;

static constexpr auto malloc = [](auto&&... args) {
return cudaMalloc(std::forward<decltype(args)>(args)...);
};

static constexpr auto memcpy = cudaMemcpy;

static constexpr auto free = cudaFree;

static constexpr auto memcpyH2D = cudaMemcpyHostToDevice;

static int GetOptimalBlockSize() {
return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock());
}
};

} // namespace add

template <>
class Operator<Add, Device::Type::kHygon> : public CudaAdd<add::HygonBackend> {
public:
using CudaAdd<add::HygonBackend>::CudaAdd;
};

} // namespace infini::ops

#endif
43 changes: 43 additions & 0 deletions src/hygon/causal_softmax/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef INFINI_OPS_HYGON_CAUSAL_SOFTMAX_KERNEL_H_
#define INFINI_OPS_HYGON_CAUSAL_SOFTMAX_KERNEL_H_

#include <utility>

// clang-format off
#include <cuda_runtime.h>
// clang-format on

// clang-format off
#include "hygon/device_.h"
// clang-format on

#include "cuda/causal_softmax/kernel.h"

namespace infini::ops {

namespace causal_softmax {

struct HygonBackend {
using stream_t = cudaStream_t;

static constexpr Device::Type kDeviceType = Device::Type::kHygon;

static constexpr int max_block_size = 256;

static int GetOptimalBlockSize() {
return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock());
}
};

} // namespace causal_softmax

template <>
class Operator<CausalSoftmax, Device::Type::kHygon>
: public CudaCausalSoftmax<causal_softmax::HygonBackend> {
public:
using CudaCausalSoftmax<causal_softmax::HygonBackend>::CudaCausalSoftmax;
};

} // namespace infini::ops

#endif
Loading