From 11a496045b1931688db0a10a28bd0d3146a3eea7 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 4 Feb 2026 03:31:36 +0000 Subject: [PATCH 01/93] chore: add `.clang-format` --- .clang-format | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .clang-format diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..2296f7d --- /dev/null +++ b/.clang-format @@ -0,0 +1,3 @@ +--- +BasedOnStyle: Google +... From 25de6c83327ef4c4df820bf97b926a652d5b6ad4 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 4 Feb 2026 08:51:29 +0000 Subject: [PATCH 02/93] feat: add `DataType` --- src/data_type.h | 57 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 src/data_type.h diff --git a/src/data_type.h b/src/data_type.h new file mode 100644 index 0000000..60cd0db --- /dev/null +++ b/src/data_type.h @@ -0,0 +1,57 @@ +#ifndef INFINI_OPS_DATA_TYPE_H_ +#define INFINI_OPS_DATA_TYPE_H_ + +#include +#include +#include + +namespace infini::ops { + +class DataType { + public: + constexpr DataType(int index, std::size_t element_size, const char* name) + : index_{index}, element_size_{element_size}, name_{name} {} + + constexpr bool operator==(const DataType& other) const { + return index_ == other.index_; + } + + constexpr std::size_t element_size() const { return element_size_; } + + constexpr const char* name() const { return name_; } + + private: + int index_{0}; + + std::size_t element_size_{0}; + + const char* name_{nullptr}; +}; + +constexpr DataType kInt8{0, sizeof(int8_t), "int8"}; + +constexpr DataType kInt16{1, sizeof(int16_t), "int16"}; + +constexpr DataType kInt32{2, sizeof(int32_t), "int32"}; + +constexpr DataType kInt64{3, sizeof(int64_t), "int64"}; + +constexpr DataType kUInt8{4, sizeof(uint8_t), "uint8"}; + +constexpr DataType kUInt16{5, sizeof(uint16_t), "uint16"}; + +constexpr DataType kUInt32{6, sizeof(uint32_t), "uint32"}; + +constexpr DataType kUInt64{7, sizeof(uint64_t), "uint64"}; + +constexpr DataType kFloat16{8, 2, "float16"}; + +constexpr DataType kBFloat16{9, 2, "bfloat16"}; + +constexpr DataType kFloat32{10, sizeof(float), "float32"}; + +constexpr DataType kFloat64{11, sizeof(double), "float64"}; + +} // namespace infini::ops + +#endif From 3e1bb6f0f83f825e5a1d166eb40514a16dd72ad9 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 5 Feb 2026 13:18:31 +0800 Subject: [PATCH 03/93] test: add an example for `DataType` --- examples/data_type.cc | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 examples/data_type.cc diff --git a/examples/data_type.cc b/examples/data_type.cc new file mode 100644 index 0000000..0b9d010 --- /dev/null +++ b/examples/data_type.cc @@ -0,0 +1,25 @@ +#include "data_type.h" + +#include +#include +#include + +static void PrintDataTypeInfo(const infini::ops::DataType& dtype) {} + +int main() { + using namespace infini::ops; + + static const std::vector kDataTypes{ + kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, + kUInt32, kUInt64, kFloat16, kBFloat16, kFloat32, kFloat64}; + + std::cout << std::left << std::setw(10) << "Name" << std::left + << std::setw(10) << "Element Size\n"; + + for (const auto& dtype : kDataTypes) { + std::cout << std::left << std::setw(10) << dtype.name() << std::left + << std::setw(10) << dtype.element_size() << '\n'; + } + + return 0; +} From 96127b379f72ecc7fec94c2e25c47748ffe9a61d Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 4 Feb 2026 06:00:59 +0000 Subject: [PATCH 04/93] feat: add `Tensor` --- src/tensor.cc | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/tensor.h | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+) create mode 100644 src/tensor.cc create mode 100644 src/tensor.h diff --git a/src/tensor.cc b/src/tensor.cc new file mode 100644 index 0000000..9634c77 --- /dev/null +++ b/src/tensor.cc @@ -0,0 +1,89 @@ +#include "tensor.h" + +#include + +namespace infini::ops { + +Tensor::Tensor(void* data, std::initializer_list shape, + const DataType& dtype, std::initializer_list strides) + : Tensor{data, decltype(shape_){shape}, dtype, + decltype(strides_){strides}} {} + +Tensor Tensor::operator[](const Index& index) const { + return {reinterpret_cast( + reinterpret_cast(data_) + + index * strides_[0] * element_size()), + Shape{shape_.cbegin() + 1, shape_.cend()}, dtype_, + Strides{strides_.cbegin() + 1, strides_.cend()}}; +} + +void*& Tensor::data() { return data_; } + +const void* const& Tensor::data() const { return data_; } + +const Tensor::Shape& Tensor::shape() const { return shape_; } + +const DataType& Tensor::dtype() const { return dtype_; } + +const Tensor::Strides& Tensor::strides() const { return strides_; } + +Tensor::Size Tensor::size(const Index& index) const { return shape_[index]; } + +Tensor::Stride Tensor::stride(const Index& index) const { + return shape_[index]; +} + +Tensor::Size Tensor::ndim() const { return shape_.size(); } + +Tensor::Size Tensor::element_size() const { return dtype_.element_size(); } + +Tensor Tensor::T() const { + return {data_, {shape_[1], shape_[0]}, dtype_, {strides_[1], strides_[0]}}; +} + +std::string Tensor::ToString() const { + return "tensor(" + ToStringHelper() + ", dtype=" + dtype_.name() + ")"; +} + +const DataType& Tensor::DefaultDataType() { return kFloat32; } + +Tensor::Strides Tensor::DefaultStrides(const Shape& shape) { + if (shape.empty()) { + return {}; + } + + Strides strides(shape.size()); + + strides.back() = 1; + + for (auto i{shape.size() - 2}; i != -1; --i) { + strides[i] = strides[i + 1] * shape[i + 1]; + } + + return strides; +} + +std::string Tensor::ToStringHelper() const { + if (ndim() == 0) { + if (dtype_ == kFloat32) { + return std::to_string(*static_cast(data_)); + } + + // TODO: Handle more data types here. + + assert(false && "string conversion not implemented for this data type"); + } + + std::string result{"["}; + + for (auto i{Index{0}}; i < shape_[0]; ++i) { + result += operator[](i).ToStringHelper() + ", "; + } + + result.pop_back(); + result.back() = ']'; + + return result; +} + +} // namespace infini::ops diff --git a/src/tensor.h b/src/tensor.h new file mode 100644 index 0000000..2c03a8f --- /dev/null +++ b/src/tensor.h @@ -0,0 +1,88 @@ +#ifndef INFINI_OPS_TENSOR_H_ +#define INFINI_OPS_TENSOR_H_ + +#include +#include +#include + +#include "data_type.h" + +namespace infini::ops { + +class Tensor { + public: + using Size = std::uint64_t; + + using Stride = std::int64_t; + + using Index = Stride; + + using Shape = std::vector; + + using Strides = std::vector; + + template + Tensor(void* data, const Shape& shape) + : data_{data}, + shape_{shape}, + dtype_{DefaultDataType()}, + strides_{DefaultStrides(shape)} {} + + template + Tensor(void* data, const Shape& shape, const DataType& dtype) + : data_{data}, + shape_{shape}, + dtype_{dtype}, + strides_{DefaultStrides(shape)} {} + + template + Tensor(void* data, const Shape& shape, const DataType& dtype, + const Strides& strides) + : data_{data}, shape_{shape}, dtype_{dtype}, strides_{strides} {} + + Tensor(void* data, std::initializer_list shape, const DataType& dtype, + std::initializer_list strides); + + Tensor operator[](const Index& index) const; + + void*& data(); + + const void* const& data() const; + + const DataType& dtype() const; + + const Shape& shape() const; + + const Strides& strides() const; + + Size size(const Index& index) const; + + Stride stride(const Index& index) const; + + Size ndim() const; + + Size element_size() const; + + Tensor T() const; + + std::string ToString() const; + + private: + static const DataType& DefaultDataType(); + + static Strides DefaultStrides(const Shape& shape); + + std::string ToStringHelper() const; + + void* data_{nullptr}; + + Shape shape_; + + const DataType& dtype_; + + Strides strides_; +}; + +} // namespace infini::ops + +#endif From 9a5077b95e5959e7f1345343da14b574e0fa41f9 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 6 Feb 2026 00:05:29 +0800 Subject: [PATCH 05/93] test: add an example for `Tensor` --- examples/tensor.cc | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 examples/tensor.cc diff --git a/examples/tensor.cc b/examples/tensor.cc new file mode 100644 index 0000000..ff768bd --- /dev/null +++ b/examples/tensor.cc @@ -0,0 +1,25 @@ +#include "tensor.h" + +#include +#include +#include +#include + +int main() { + using namespace infini::ops; + + const Tensor::Shape shape{2, 3, 4}; + + const auto num_elements{ + std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies())}; + + std::vector elems(num_elements); + + std::iota(elems.begin(), elems.end(), 0); + + Tensor x{elems.data(), shape}; + + std::cout << x.ToString() << '\n'; + + return 0; +} From 91daa44cc463ec6605c7683efd7504aefe07fd91 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 6 Feb 2026 00:09:56 +0800 Subject: [PATCH 06/93] feat: add `Device` --- src/device.h | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 src/device.h diff --git a/src/device.h b/src/device.h new file mode 100644 index 0000000..24b3086 --- /dev/null +++ b/src/device.h @@ -0,0 +1,10 @@ +#ifndef INFINI_OPS_DEVICE_H_ +#define INFINI_OPS_DEVICE_H_ + +namespace infini::ops { + +enum class Device { kCpu, kNvidia, kCount }; + +} // namespace infini::ops + +#endif From fb016f2d66fb2159339887da7ba0f33997a4cb21 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 6 Feb 2026 00:10:11 +0800 Subject: [PATCH 07/93] feat: add `Handle` --- src/handle.h | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 src/handle.h diff --git a/src/handle.h b/src/handle.h new file mode 100644 index 0000000..37a91b6 --- /dev/null +++ b/src/handle.h @@ -0,0 +1,20 @@ +#ifndef INFINI_OPS_HANDLE_H_ +#define INFINI_OPS_HANDLE_H_ + +#include "device.h" + +namespace infini::ops { + +class Handle { + public: + Handle(Device device) : device_{device} {} + + const Device& device() const { return device_; } + + private: + Device device_; +}; + +} // namespace infini::ops + +#endif From 128e7398387af4171eff1d36b98ef9b3fed2808d Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 6 Feb 2026 00:44:15 +0800 Subject: [PATCH 08/93] feat: add `Operator` --- src/operator.h | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 src/operator.h diff --git a/src/operator.h b/src/operator.h new file mode 100644 index 0000000..87d1716 --- /dev/null +++ b/src/operator.h @@ -0,0 +1,50 @@ +#ifndef INFINI_OPS_OPERATOR_H_ +#define INFINI_OPS_OPERATOR_H_ + +#include +#include + +#include "handle.h" + +namespace infini::ops { + +template +class Operator { + public: + template + static auto make(const Handle& handle, Args&&... args) { + std::unique_ptr op_ptr; + + switch (handle.device()) { + case Device::kNvidia: + op_ptr = std::make_unique>( + std::forward(args)...); + break; + default: + assert(false && + "constructor dispatching not implemented for this device"); + } + + op_ptr->device_ = handle.device(); + + return op_ptr; + } + + template + auto operator()(Args&&... args) const { + switch (device_) { + case Device::kNvidia: + return (*static_cast*>(this))( + std::forward(args)...); + } + + assert(false && "`operator()` dispatching not implemented for this device"); + } + + private: + Device device_; +}; + +} // namespace infini::ops + +#endif From 92a02b08aeaa6d230912101ed88b9ee9cfc18537 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 6 Feb 2026 13:43:45 +0800 Subject: [PATCH 09/93] feat: add `Gemm` --- src/base/gemm.h | 62 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 src/base/gemm.h diff --git a/src/base/gemm.h b/src/base/gemm.h new file mode 100644 index 0000000..9990612 --- /dev/null +++ b/src/base/gemm.h @@ -0,0 +1,62 @@ +#ifndef INFINI_OPS_BASE_GEMM_H_ +#define INFINI_OPS_BASE_GEMM_H_ + +#include + +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class Gemm : public Operator { + public: + Gemm(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) + : alpha_{alpha.value_or(1.0)}, + beta_{beta.value_or(1.0)}, + trans_a_{static_cast(trans_a.value_or(false))}, + trans_b_{static_cast(trans_b.value_or(false))}, + m_{c.size(0)}, + n_{c.size(1)}, + k_{trans_a_ ? a.size(0) : a.size(1)}, + a_type_{a.dtype()}, + b_type_{b.dtype()}, + c_type_{c.dtype()}, + a_strides_{a.strides()}, + b_strides_{b.strides()}, + c_strides_{c.strides()} { + // TODO: Check constraints. + } + + protected: + float alpha_{1.0}; + + float beta_{1.0}; + + bool trans_a_{false}; + + bool trans_b_{false}; + + Tensor::Size m_{0}; + + Tensor::Size n_{0}; + + Tensor::Size k_{0}; + + const DataType& a_type_; + + const DataType& b_type_; + + const DataType& c_type_; + + Tensor::Strides a_strides_; + + Tensor::Strides b_strides_; + + Tensor::Strides c_strides_; +}; + +} // namespace infini::ops + +#endif From 75b4d99bebbb2987886c01ac6e399a1315e627b3 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 6 Feb 2026 13:59:43 +0800 Subject: [PATCH 10/93] feat: add `Operator` --- src/nvidia/gemm/cublas.h | 61 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 src/nvidia/gemm/cublas.h diff --git a/src/nvidia/gemm/cublas.h b/src/nvidia/gemm/cublas.h new file mode 100644 index 0000000..80f881c --- /dev/null +++ b/src/nvidia/gemm/cublas.h @@ -0,0 +1,61 @@ +#ifndef INFINI_OPS_NVIDIA_GEMM_CUBLAS_H_ +#define INFINI_OPS_NVIDIA_GEMM_CUBLAS_H_ + +#include + +// clang-format off +#include "cublas_v2.h" +// clang-format on + +#include "base/gemm.h" + +namespace infini::ops { + +template <> +class Operator : public Gemm { + public: + Operator(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) + : Gemm{a.stride(0) == 1 ? a : b.T(), + a.stride(0) == 1 ? b : a.T(), + alpha, + beta, + trans_a, + trans_b, + a.stride(0) == 1 ? c : c.T()}, + lda_{a_strides_[1]}, + ldb_{b_strides_[1]}, + ldc_{c_strides_[1]} { + // TODO: Check constraints. + } + + void operator()(void* stream, const void* a, const void* b, void* c) const { + cublasHandle_t handle; + cublasCreate(&handle); + + cublasSetStream(handle, static_cast(stream)); + + // TODO: Add support for more data types. + assert(a_type_ == kFloat32 && b_type_ == kFloat32 && c_type_ == kFloat32 && + "`operator()` not implemented for this data type"); + + cublasGemmEx(handle, trans_a_ ? CUBLAS_OP_T : CUBLAS_OP_N, + trans_b_ ? CUBLAS_OP_T : CUBLAS_OP_N, m_, n_, k_, &alpha_, b, + CUDA_R_32F, lda_, a, CUDA_R_32F, ldb_, &beta_, c, CUDA_R_32F, + ldc_, CUBLAS_COMPUTE_32F_FAST_TF32, CUBLAS_GEMM_DEFAULT); + + cublasDestroy(handle); + } + + private: + Tensor::Stride lda_{0}; + + Tensor::Stride ldb_{0}; + + Tensor::Stride ldc_{0}; +}; + +} // namespace infini::ops + +#endif From 08fc18933ee6f517ff5350eb4da5533cff9ade3c Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 6 Feb 2026 14:28:20 +0800 Subject: [PATCH 11/93] test: add an example for `Operator` --- examples/nvidia/gemm/cublas.cc | 76 ++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 examples/nvidia/gemm/cublas.cc diff --git a/examples/nvidia/gemm/cublas.cc b/examples/nvidia/gemm/cublas.cc new file mode 100644 index 0000000..94c084f --- /dev/null +++ b/examples/nvidia/gemm/cublas.cc @@ -0,0 +1,76 @@ +#include "nvidia/gemm/cublas.h" + +#include + +#include +#include + +#include "tensor.h" + +int main() { + using namespace infini::ops; + + constexpr auto m{2}; + constexpr auto k{3}; + constexpr auto n{4}; + + std::vector a_shape{m, k}; + std::vector b_shape{k, n}; + std::vector c_shape{m, n}; + + const auto a_num_elements{std::accumulate(a_shape.cbegin(), a_shape.cend(), 1, + std::multiplies())}; + const auto b_num_elements{std::accumulate(b_shape.cbegin(), b_shape.cend(), 1, + std::multiplies())}; + const auto c_num_elements{std::accumulate(c_shape.cbegin(), c_shape.cend(), 1, + std::multiplies())}; + + std::vector a_vec(a_num_elements); + std::vector b_vec(b_num_elements); + std::vector c_vec(c_num_elements); + + std::iota(a_vec.begin(), a_vec.end(), 0); + std::iota(b_vec.begin(), b_vec.end(), 0); + + Tensor a_host{a_vec.data(), a_shape}; + Tensor b_host{b_vec.data(), b_shape}; + Tensor c_host{c_vec.data(), c_shape}; + + Tensor a_device{nullptr, a_host.shape(), a_host.dtype(), a_host.strides()}; + Tensor b_device{nullptr, b_host.shape(), b_host.dtype(), b_host.strides()}; + Tensor c_device{nullptr, c_host.shape(), c_host.dtype(), c_host.strides()}; + + const auto a_size{a_num_elements * a_device.dtype().element_size()}; + const auto b_size{b_num_elements * b_device.dtype().element_size()}; + const auto c_size{c_num_elements * c_device.dtype().element_size()}; + + cudaMalloc(&a_device.data(), a_size); + cudaMalloc(&b_device.data(), b_size); + cudaMalloc(&c_device.data(), c_size); + + cudaMemcpy(a_device.data(), a_vec.data(), a_size, cudaMemcpyHostToDevice); + cudaMemcpy(b_device.data(), b_vec.data(), b_size, cudaMemcpyHostToDevice); + cudaMemset(c_device.data(), 0, c_size); + + const Handle handle{Device::kNvidia}; + auto gemm_ptr{Operator::make(handle, a_device, b_device, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, + c_device)}; + const auto& gemm{*gemm_ptr}; + + gemm(nullptr, a_device.data(), b_device.data(), c_device.data()); + + cudaMemcpy(a_host.data(), a_device.data(), a_size, cudaMemcpyDeviceToHost); + cudaMemcpy(b_host.data(), b_device.data(), b_size, cudaMemcpyDeviceToHost); + cudaMemcpy(c_host.data(), c_device.data(), c_size, cudaMemcpyDeviceToHost); + + cudaFree(a_device.data()); + cudaFree(b_device.data()); + cudaFree(c_device.data()); + + std::cout << "A: " << a_host.ToString() << "\n"; + std::cout << "B: " << b_host.ToString() << "\n"; + std::cout << "C: " << c_host.ToString() << "\n"; + + return 0; +} From c590c2a2bb4e87dcc0365d8f63e4557dd539ebdd Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 10 Feb 2026 17:48:17 +0800 Subject: [PATCH 12/93] feat: add `DataType::FromString` for string-to-dtype conversion --- src/data_type.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/data_type.h b/src/data_type.h index 60cd0db..7cf8045 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -1,9 +1,11 @@ #ifndef INFINI_OPS_DATA_TYPE_H_ #define INFINI_OPS_DATA_TYPE_H_ +#include #include #include #include +#include namespace infini::ops { @@ -12,6 +14,8 @@ class DataType { constexpr DataType(int index, std::size_t element_size, const char* name) : index_{index}, element_size_{element_size}, name_{name} {} + static const DataType& FromString(const std::string& name); + constexpr bool operator==(const DataType& other) const { return index_ == other.index_; } @@ -52,6 +56,18 @@ constexpr DataType kFloat32{10, sizeof(float), "float32"}; constexpr DataType kFloat64{11, sizeof(double), "float64"}; +inline const DataType& DataType::FromString(const std::string& name) { + static std::unordered_map name_to_dtype{ + {kInt8.name(), kInt8}, {kInt16.name(), kInt16}, + {kInt32.name(), kInt32}, {kInt64.name(), kInt64}, + {kUInt8.name(), kUInt8}, {kUInt16.name(), kUInt16}, + {kUInt32.name(), kUInt32}, {kUInt64.name(), kUInt64}, + {kFloat16.name(), kFloat16}, {kBFloat16.name(), kBFloat16}, + {kFloat32.name(), kFloat32}, {kFloat64.name(), kFloat64}}; + + return name_to_dtype.at(name); +} + } // namespace infini::ops #endif From fb434a1387dafa6202e4066977e9c3b05963c0c1 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 10 Feb 2026 17:55:39 +0800 Subject: [PATCH 13/93] refactor: make `Device` a `class` and integrate it into `Tensor` --- examples/nvidia/gemm/cublas.cc | 18 +++++++++------- src/base/gemm.h | 1 - src/device.h | 19 ++++++++++++++++- src/nvidia/gemm/cublas.h | 2 +- src/operator.h | 24 ++++++++++----------- src/tensor.cc | 17 +++++++++++---- src/tensor.h | 38 ++++++++++++++++++++++++++++++---- 7 files changed, 88 insertions(+), 31 deletions(-) diff --git a/examples/nvidia/gemm/cublas.cc b/examples/nvidia/gemm/cublas.cc index 94c084f..8b40bf0 100644 --- a/examples/nvidia/gemm/cublas.cc +++ b/examples/nvidia/gemm/cublas.cc @@ -32,13 +32,16 @@ int main() { std::iota(a_vec.begin(), a_vec.end(), 0); std::iota(b_vec.begin(), b_vec.end(), 0); - Tensor a_host{a_vec.data(), a_shape}; - Tensor b_host{b_vec.data(), b_shape}; - Tensor c_host{c_vec.data(), c_shape}; + Tensor a_host{a_vec.data(), a_shape, Device{Device::Type::kNvidia}}; + Tensor b_host{b_vec.data(), b_shape, Device{Device::Type::kNvidia}}; + Tensor c_host{c_vec.data(), c_shape, Device{Device::Type::kNvidia}}; - Tensor a_device{nullptr, a_host.shape(), a_host.dtype(), a_host.strides()}; - Tensor b_device{nullptr, b_host.shape(), b_host.dtype(), b_host.strides()}; - Tensor c_device{nullptr, c_host.shape(), c_host.dtype(), c_host.strides()}; + Tensor a_device{nullptr, a_host.shape(), a_host.dtype(), a_host.device(), + a_host.strides()}; + Tensor b_device{nullptr, b_host.shape(), b_host.dtype(), a_host.device(), + b_host.strides()}; + Tensor c_device{nullptr, c_host.shape(), c_host.dtype(), a_host.device(), + c_host.strides()}; const auto a_size{a_num_elements * a_device.dtype().element_size()}; const auto b_size{b_num_elements * b_device.dtype().element_size()}; @@ -52,8 +55,7 @@ int main() { cudaMemcpy(b_device.data(), b_vec.data(), b_size, cudaMemcpyHostToDevice); cudaMemset(c_device.data(), 0, c_size); - const Handle handle{Device::kNvidia}; - auto gemm_ptr{Operator::make(handle, a_device, b_device, std::nullopt, + auto gemm_ptr{Operator::make(a_device, b_device, std::nullopt, std::nullopt, std::nullopt, std::nullopt, c_device)}; const auto& gemm{*gemm_ptr}; diff --git a/src/base/gemm.h b/src/base/gemm.h index 9990612..fc90752 100644 --- a/src/base/gemm.h +++ b/src/base/gemm.h @@ -4,7 +4,6 @@ #include #include "operator.h" -#include "tensor.h" namespace infini::ops { diff --git a/src/device.h b/src/device.h index 24b3086..46464e4 100644 --- a/src/device.h +++ b/src/device.h @@ -3,7 +3,24 @@ namespace infini::ops { -enum class Device { kCpu, kNvidia, kCount }; +class Device { + public: + // TODO: Complete the list. + enum class Type { kCpu, kNvidia, kCount }; + + Device() = default; + + Device(const Type& type, const int& index = 0) : type_{type}, index_{index} {} + + const Type& type() const { return type_; } + + const int& index() const { return index_; } + + private: + Type type_{Type::kCpu}; + + int index_{0}; +}; } // namespace infini::ops diff --git a/src/nvidia/gemm/cublas.h b/src/nvidia/gemm/cublas.h index 80f881c..6949afe 100644 --- a/src/nvidia/gemm/cublas.h +++ b/src/nvidia/gemm/cublas.h @@ -12,7 +12,7 @@ namespace infini::ops { template <> -class Operator : public Gemm { +class Operator : public Gemm { public: Operator(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, std::optional trans_a, diff --git a/src/operator.h b/src/operator.h index 87d1716..b322fb6 100644 --- a/src/operator.h +++ b/src/operator.h @@ -4,38 +4,38 @@ #include #include -#include "handle.h" +#include "tensor.h" namespace infini::ops { -template +template class Operator { public: template - static auto make(const Handle& handle, Args&&... args) { + static auto make(const Tensor tensor, Args&&... args) { std::unique_ptr op_ptr; - switch (handle.device()) { - case Device::kNvidia: - op_ptr = std::make_unique>( - std::forward(args)...); + switch (tensor.device().type()) { + case Device::Type::kNvidia: + op_ptr = std::make_unique>( + tensor, std::forward(args)...); break; default: assert(false && "constructor dispatching not implemented for this device"); } - op_ptr->device_ = handle.device(); + op_ptr->device_ = tensor.device(); return op_ptr; } template auto operator()(Args&&... args) const { - switch (device_) { - case Device::kNvidia: - return (*static_cast*>(this))( - std::forward(args)...); + switch (device_.type()) { + case Device::Type::kNvidia: + return (*static_cast*>( + this))(std::forward(args)...); } assert(false && "`operator()` dispatching not implemented for this device"); diff --git a/src/tensor.cc b/src/tensor.cc index 9634c77..b4c4269 100644 --- a/src/tensor.cc +++ b/src/tensor.cc @@ -5,15 +5,16 @@ namespace infini::ops { Tensor::Tensor(void* data, std::initializer_list shape, - const DataType& dtype, std::initializer_list strides) - : Tensor{data, decltype(shape_){shape}, dtype, + const DataType& dtype, const Device& device, + std::initializer_list strides) + : Tensor{data, decltype(shape_){shape}, dtype, device, decltype(strides_){strides}} {} Tensor Tensor::operator[](const Index& index) const { return {reinterpret_cast( reinterpret_cast(data_) + index * strides_[0] * element_size()), - Shape{shape_.cbegin() + 1, shape_.cend()}, dtype_, + Shape{shape_.cbegin() + 1, shape_.cend()}, dtype_, device_, Strides{strides_.cbegin() + 1, strides_.cend()}}; } @@ -25,6 +26,8 @@ const Tensor::Shape& Tensor::shape() const { return shape_; } const DataType& Tensor::dtype() const { return dtype_; } +const Device& Tensor::device() const { return device_; } + const Tensor::Strides& Tensor::strides() const { return strides_; } Tensor::Size Tensor::size(const Index& index) const { return shape_[index]; } @@ -38,7 +41,11 @@ Tensor::Size Tensor::ndim() const { return shape_.size(); } Tensor::Size Tensor::element_size() const { return dtype_.element_size(); } Tensor Tensor::T() const { - return {data_, {shape_[1], shape_[0]}, dtype_, {strides_[1], strides_[0]}}; + return {data_, + {shape_[1], shape_[0]}, + dtype_, + device_, + {strides_[1], strides_[0]}}; } std::string Tensor::ToString() const { @@ -47,6 +54,8 @@ std::string Tensor::ToString() const { const DataType& Tensor::DefaultDataType() { return kFloat32; } +Device Tensor::DefaultDevice() { return Device{Device::Type::kCpu}; } + Tensor::Strides Tensor::DefaultStrides(const Shape& shape) { if (shape.empty()) { return {}; diff --git a/src/tensor.h b/src/tensor.h index 2c03a8f..66ca60b 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -6,6 +6,7 @@ #include #include "data_type.h" +#include "device.h" namespace infini::ops { @@ -26,22 +27,45 @@ class Tensor { : data_{data}, shape_{shape}, dtype_{DefaultDataType()}, + device_{DefaultDevice()}, strides_{DefaultStrides(shape)} {} - template + template Tensor(void* data, const Shape& shape, const DataType& dtype) : data_{data}, shape_{shape}, dtype_{dtype}, + device_{DefaultDevice()}, + strides_{DefaultStrides(shape)} {} + + template + Tensor(void* data, const Shape& shape, const Device& device) + : data_{data}, + shape_{shape}, + dtype_{DefaultDataType()}, + device_{device}, + strides_{DefaultStrides(shape)} {} + + template + Tensor(void* data, const Shape& shape, const DataType& dtype, + const Device& device) + : data_{data}, + shape_{shape}, + dtype_{dtype}, + device_{device}, strides_{DefaultStrides(shape)} {} template Tensor(void* data, const Shape& shape, const DataType& dtype, - const Strides& strides) - : data_{data}, shape_{shape}, dtype_{dtype}, strides_{strides} {} + const Device& device, const Strides& strides) + : data_{data}, + shape_{shape}, + dtype_{dtype}, + device_{device}, + strides_{strides} {} Tensor(void* data, std::initializer_list shape, const DataType& dtype, - std::initializer_list strides); + const Device& device, std::initializer_list strides); Tensor operator[](const Index& index) const; @@ -51,6 +75,8 @@ class Tensor { const DataType& dtype() const; + const Device& device() const; + const Shape& shape() const; const Strides& strides() const; @@ -70,6 +96,8 @@ class Tensor { private: static const DataType& DefaultDataType(); + static Device DefaultDevice(); + static Strides DefaultStrides(const Shape& shape); std::string ToStringHelper() const; @@ -80,6 +108,8 @@ class Tensor { const DataType& dtype_; + Device device_; + Strides strides_; }; From 6453be96327c7901b152b82dedd85cc6eed69f47 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 11 Feb 2026 11:24:48 +0800 Subject: [PATCH 14/93] feat: add `Device::TypeFromString` for string-to-device-type conversion --- src/device.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/device.h b/src/device.h index 46464e4..0eb4db1 100644 --- a/src/device.h +++ b/src/device.h @@ -12,6 +12,14 @@ class Device { Device(const Type& type, const int& index = 0) : type_{type}, index_{index} {} + static const Type& TypeFromString(const std::string& name) { + // TODO: Handle `"cuda"` dispatching. + static std::unordered_map name_to_type{ + {"cpu", Type::kCpu}, {"cuda", Type::kNvidia}}; + + return name_to_type.at(name); + } + const Type& type() const { return type_; } const int& index() const { return index_; } From 66f1d3c27ccca889742dc66bfb03ab7555db43b5 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 11 Feb 2026 11:29:14 +0800 Subject: [PATCH 15/93] feat: add a script to generate pybind11 bindings --- .gitignore | 3 + requirements.txt | 1 + scripts/generate_wrappers.py | 176 +++++++++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+) create mode 100644 requirements.txt create mode 100644 scripts/generate_wrappers.py diff --git a/.gitignore b/.gitignore index d4fb281..80b21fd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Generated files +generated/ + # Prerequisites *.d diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..add49ef --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +libclang diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py new file mode 100644 index 0000000..2d945f2 --- /dev/null +++ b/scripts/generate_wrappers.py @@ -0,0 +1,176 @@ +import json +import pathlib +import textwrap + +import clang.cindex +from clang.cindex import CursorKind + +_GENERATION_DIR = pathlib.Path("generated") + +_BINDINGS_DIR = _GENERATION_DIR / "bindings" + +_INDENTATION = " " + + +class _OperatorExtractor: + def __init__(self, path): + index = clang.cindex.Index.create() + args = ("-std=c++17", "-x", "c++", "-I", "src") + self._translation_unit = index.parse(path, args=args) + + def __call__(self, op_name): + nodes = tuple(type(self)._find(self._translation_unit.cursor, op_name)) + + constructor = None + call = None + + for node in nodes: + if node.kind == CursorKind.CONSTRUCTOR: + constructor = node + elif node.kind == CursorKind.CXX_METHOD and node.spelling == "operator()": + call = node + + return _Operator(op_name, constructor, call) + + @staticmethod + def _find(node, op_name): + if node.semantic_parent and node.semantic_parent.spelling == "Operator": + for child in node.semantic_parent.get_children(): + if ( + child.kind == CursorKind.CXX_BASE_SPECIFIER + and child.spelling == op_name + ): + yield node + + for child in node.get_children(): + yield from _OperatorExtractor._find(child, op_name) + + +class _Operator: + def __init__(self, name, constructor, call): + self.name = name + + self.constructor = constructor + + self.call = call + + +def _generate_pybind11(operator): + def _generate_params(node): + return ", ".join( + f"{arg.type.spelling} {arg.spelling}" for arg in node.get_arguments() + ) + + def _generate_constructor_arguments(node): + return ", ".join( + _generate_tensor_caster(arg.spelling) + if "Tensor" in arg.type.spelling + else arg.spelling + for arg in node.get_arguments() + ) + + def _generate_call_arguments(node): + return ", ".join( + _generate_data_getter(arg.spelling) + if arg.spelling != "stream" + else "reinterpret_cast(stream)" + for arg in node.get_arguments() + ) + + def _generate_tensor_caster(name): + return f'Tensor{{reinterpret_cast({name}.attr("data_ptr")().cast()), {name}.attr("shape").cast(), DataType::FromString(py::str({name}.attr("dtype")).attr("split")(".").attr("__getitem__")(-1).cast()), Device{{Device::TypeFromString({name}.attr("device").attr("type").cast()), {name}.attr("device").attr("index").is_none() ? 0 : {name}.attr("device").attr("index").cast()}}, {name}.attr("stride")().cast()}}' + + def _generate_data_getter(name): + return ( + f'reinterpret_cast({name}.attr("data_ptr")().cast())' + ) + + op_name = operator.name + + constructor_params = _generate_params(operator.constructor) + constructor_params = constructor_params.replace( + "const Tensor", "py::object" + ).replace("Tensor", "py::object") + + call_params = _generate_params(operator.call) + call_params = ( + call_params.replace("void * stream", "std::uintptr_t stream") + .replace("const void *", "py::object") + .replace("void *", "py::object") + ) + + return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ +#define INFINI_OPS_BINDINGS_{op_name.upper()}_H_ + +#include +#include + +#include "base/{op_name.lower()}.h" + +namespace py = pybind11; + +namespace infini::ops {{ + +void Bind{op_name}(py::module& m) {{ + using Self = {op_name}; + + py::class_(m, "{op_name}") + .def(py::init([]({constructor_params}) {{ + return std::unique_ptr{{static_cast(Self::make({_generate_constructor_arguments(operator.constructor)}).release())}}; + }})) + .def("__call__", [](const Self& self, {call_params}) {{ + return self.operator()({_generate_call_arguments(operator.call)}); + }}); +}} + +}} // namespace infini::ops + +#endif +""" + + +if __name__ == "__main__": + _BINDINGS_DIR.mkdir(parents=True, exist_ok=True) + + with open("ops.json") as f: + ops = json.load(f) + + header_paths = [] + bind_func_names = [] + + for op_name, op_path in ops.items(): + extractor = _OperatorExtractor(op_path) + operator = extractor(op_name) + + header_name = f"{op_name.lower()}.h" + bind_func_name = f"Bind{op_name}" + + (_BINDINGS_DIR / header_name).write_text(_generate_pybind11(operator)) + + header_paths.append(header_name) + bind_func_names.append(bind_func_name) + + impl_includes = "\n".join( + f'#include "{header_path}"' for header_path in ops.values() + ) + op_includes = "\n".join(f'#include "{header_path}"' for header_path in header_paths) + bind_func_calls = "\n".join( + f"{bind_func_name}(m);" for bind_func_name in bind_func_names + ) + + (_BINDINGS_DIR / "ops.cc").write_text(f"""#include + +// clang-format off +{impl_includes} +// clang-format on + +{op_includes} + +namespace infini::ops {{ + +PYBIND11_MODULE(ops, m) {{ +{textwrap.indent(bind_func_calls, _INDENTATION)} +}} + +}} // namespace infini::ops +""") From cb5ccbcf4dcc75499f6ee59a894db4e717c06b53 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 11 Feb 2026 16:53:01 +0800 Subject: [PATCH 16/93] fix: add `virtual ~Operator() = default;` --- src/operator.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/operator.h b/src/operator.h index b322fb6..a81aef4 100644 --- a/src/operator.h +++ b/src/operator.h @@ -11,6 +11,8 @@ namespace infini::ops { template class Operator { public: + virtual ~Operator() = default; + template static auto make(const Tensor tensor, Args&&... args) { std::unique_ptr op_ptr; From d5c106723f67413436eb621ee6dc93d03bec84d4 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 11 Feb 2026 17:51:32 +0800 Subject: [PATCH 17/93] refactor: Simplify `operator()` dispatching --- src/base/gemm.h | 3 +++ src/operator.h | 13 +------------ 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/src/base/gemm.h b/src/base/gemm.h index fc90752..ad67f16 100644 --- a/src/base/gemm.h +++ b/src/base/gemm.h @@ -28,6 +28,9 @@ class Gemm : public Operator { // TODO: Check constraints. } + virtual void operator()(void* stream, const void* a, const void* b, + void* c) const = 0; + protected: float alpha_{1.0}; diff --git a/src/operator.h b/src/operator.h index a81aef4..ca3098e 100644 --- a/src/operator.h +++ b/src/operator.h @@ -27,24 +27,13 @@ class Operator { "constructor dispatching not implemented for this device"); } - op_ptr->device_ = tensor.device(); - return op_ptr; } template auto operator()(Args&&... args) const { - switch (device_.type()) { - case Device::Type::kNvidia: - return (*static_cast*>( - this))(std::forward(args)...); - } - - assert(false && "`operator()` dispatching not implemented for this device"); + return (*static_cast(this))(std::forward(args)...); } - - private: - Device device_; }; } // namespace infini::ops From 45c3e62bd6add3150ef92e3517d3d222e586ee82 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 11 Feb 2026 18:25:03 +0800 Subject: [PATCH 18/93] feat: add naive support for single-stage interfaces --- examples/nvidia/gemm/cublas.cc | 7 +-- scripts/generate_wrappers.py | 92 +++++++++++++++++++--------------- src/base/gemm.h | 12 ++++- src/nvidia/gemm/cublas.h | 23 +++++++-- src/operator.h | 7 +++ 5 files changed, 87 insertions(+), 54 deletions(-) diff --git a/examples/nvidia/gemm/cublas.cc b/examples/nvidia/gemm/cublas.cc index 8b40bf0..7a7723a 100644 --- a/examples/nvidia/gemm/cublas.cc +++ b/examples/nvidia/gemm/cublas.cc @@ -55,12 +55,7 @@ int main() { cudaMemcpy(b_device.data(), b_vec.data(), b_size, cudaMemcpyHostToDevice); cudaMemset(c_device.data(), 0, c_size); - auto gemm_ptr{Operator::make(a_device, b_device, std::nullopt, - std::nullopt, std::nullopt, std::nullopt, - c_device)}; - const auto& gemm{*gemm_ptr}; - - gemm(nullptr, a_device.data(), b_device.data(), c_device.data()); + Gemm::call(nullptr, a_device, b_device, c_device); cudaMemcpy(a_host.data(), a_device.data(), a_size, cudaMemcpyDeviceToHost); cudaMemcpy(b_host.data(), b_device.data(), b_size, cudaMemcpyDeviceToHost); diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 2d945f2..73aff3e 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -13,46 +13,40 @@ class _OperatorExtractor: - def __init__(self, path): + def __call__(self, op_name): index = clang.cindex.Index.create() args = ("-std=c++17", "-x", "c++", "-I", "src") - self._translation_unit = index.parse(path, args=args) + translation_unit = index.parse(f"src/base/{op_name.lower()}.h", args=args) - def __call__(self, op_name): - nodes = tuple(type(self)._find(self._translation_unit.cursor, op_name)) + nodes = tuple(type(self)._find(translation_unit.cursor, op_name)) - constructor = None - call = None + constructors = [] + calls = [] for node in nodes: if node.kind == CursorKind.CONSTRUCTOR: - constructor = node + constructors.append(node) elif node.kind == CursorKind.CXX_METHOD and node.spelling == "operator()": - call = node + calls.append(node) - return _Operator(op_name, constructor, call) + return _Operator(op_name, constructors, calls) @staticmethod def _find(node, op_name): - if node.semantic_parent and node.semantic_parent.spelling == "Operator": - for child in node.semantic_parent.get_children(): - if ( - child.kind == CursorKind.CXX_BASE_SPECIFIER - and child.spelling == op_name - ): - yield node + if node.semantic_parent and node.semantic_parent.spelling == op_name: + yield node for child in node.get_children(): yield from _OperatorExtractor._find(child, op_name) class _Operator: - def __init__(self, name, constructor, call): + def __init__(self, name, constructors, calls): self.name = name - self.constructor = constructor + self.constructors = constructors - self.call = call + self.calls = calls def _generate_pybind11(operator): @@ -71,7 +65,11 @@ def _generate_constructor_arguments(node): def _generate_call_arguments(node): return ", ".join( - _generate_data_getter(arg.spelling) + ( + _generate_tensor_caster(arg.spelling) + if "Tensor" in arg.type.spelling + else arg.spelling + ) if arg.spelling != "stream" else "reinterpret_cast(stream)" for arg in node.get_arguments() @@ -80,24 +78,38 @@ def _generate_call_arguments(node): def _generate_tensor_caster(name): return f'Tensor{{reinterpret_cast({name}.attr("data_ptr")().cast()), {name}.attr("shape").cast(), DataType::FromString(py::str({name}.attr("dtype")).attr("split")(".").attr("__getitem__")(-1).cast()), Device{{Device::TypeFromString({name}.attr("device").attr("type").cast()), {name}.attr("device").attr("index").is_none() ? 0 : {name}.attr("device").attr("index").cast()}}, {name}.attr("stride")().cast()}}' - def _generate_data_getter(name): - return ( - f'reinterpret_cast({name}.attr("data_ptr")().cast())' + op_name = operator.name + + def _generate_init(constructor): + constructor_params = _generate_params(constructor) + constructor_params = constructor_params.replace( + "const Tensor", "py::object" + ).replace("Tensor", "py::object") + + return f""" .def(py::init([]({constructor_params}) {{ + return std::unique_ptr{{static_cast(Self::make({_generate_constructor_arguments(constructor)}).release())}}; + }}))""" + + def _generate_call(call, method=True): + call_params = _generate_params(call) + call_params = ( + call_params.replace("void * stream", "std::uintptr_t stream") + .replace("const Tensor", "py::object") + .replace("Tensor", "py::object") ) - op_name = operator.name + if not method: + return f""" m.def("gemm", []({call_params}) {{ return Self::call({_generate_call_arguments(call)}); }});""" - constructor_params = _generate_params(operator.constructor) - constructor_params = constructor_params.replace( - "const Tensor", "py::object" - ).replace("Tensor", "py::object") + return f""" .def("__call__", [](const Self& self, {call_params}) {{ + return self({_generate_call_arguments(call)}); + }})""" - call_params = _generate_params(operator.call) - call_params = ( - call_params.replace("void * stream", "std::uintptr_t stream") - .replace("const void *", "py::object") - .replace("void *", "py::object") + inits = "\n".join( + _generate_init(constructor) for constructor in operator.constructors ) + calls = "\n".join(_generate_call(call) for call in operator.calls) + callers = "\n".join(_generate_call(call, method=False) for call in operator.calls) return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ #define INFINI_OPS_BINDINGS_{op_name.upper()}_H_ @@ -115,12 +127,10 @@ def _generate_data_getter(name): using Self = {op_name}; py::class_(m, "{op_name}") - .def(py::init([]({constructor_params}) {{ - return std::unique_ptr{{static_cast(Self::make({_generate_constructor_arguments(operator.constructor)}).release())}}; - }})) - .def("__call__", [](const Self& self, {call_params}) {{ - return self.operator()({_generate_call_arguments(operator.call)}); - }}); +{inits} +{calls}; + +{callers} }} }} // namespace infini::ops @@ -138,8 +148,8 @@ def _generate_data_getter(name): header_paths = [] bind_func_names = [] - for op_name, op_path in ops.items(): - extractor = _OperatorExtractor(op_path) + for op_name in ops: + extractor = _OperatorExtractor() operator = extractor(op_name) header_name = f"{op_name.lower()}.h" diff --git a/src/base/gemm.h b/src/base/gemm.h index ad67f16..7602f65 100644 --- a/src/base/gemm.h +++ b/src/base/gemm.h @@ -28,8 +28,16 @@ class Gemm : public Operator { // TODO: Check constraints. } - virtual void operator()(void* stream, const void* a, const void* b, - void* c) const = 0; + virtual void operator()(void* stream, const Tensor a, const Tensor b, + std::optional alpha, std::optional beta, + std::optional trans_a, + std::optional trans_b, Tensor c) const = 0; + + virtual void operator()(void* stream, const Tensor a, const Tensor b, + Tensor c) const { + return operator()(stream, a, b, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, c); + } protected: float alpha_{1.0}; diff --git a/src/nvidia/gemm/cublas.h b/src/nvidia/gemm/cublas.h index 6949afe..aa77772 100644 --- a/src/nvidia/gemm/cublas.h +++ b/src/nvidia/gemm/cublas.h @@ -30,20 +30,33 @@ class Operator : public Gemm { // TODO: Check constraints. } - void operator()(void* stream, const void* a, const void* b, void* c) const { + Operator(const Tensor a, const Tensor b, Tensor c) + : Operator{a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + c} {} + + void operator()(void* stream, const Tensor a, const Tensor b, + std::optional alpha, std::optional beta, + std::optional trans_a, std::optional trans_b, + Tensor c) const override { cublasHandle_t handle; cublasCreate(&handle); cublasSetStream(handle, static_cast(stream)); + const auto& alpha_value{alpha.value_or(alpha_)}; + const auto& beta_value{beta.value_or(beta_)}; + const auto& trans_a_value{alpha.value_or(trans_a_)}; + const auto& trans_b_value{beta.value_or(trans_b_)}; + // TODO: Add support for more data types. assert(a_type_ == kFloat32 && b_type_ == kFloat32 && c_type_ == kFloat32 && "`operator()` not implemented for this data type"); - cublasGemmEx(handle, trans_a_ ? CUBLAS_OP_T : CUBLAS_OP_N, - trans_b_ ? CUBLAS_OP_T : CUBLAS_OP_N, m_, n_, k_, &alpha_, b, - CUDA_R_32F, lda_, a, CUDA_R_32F, ldb_, &beta_, c, CUDA_R_32F, - ldc_, CUBLAS_COMPUTE_32F_FAST_TF32, CUBLAS_GEMM_DEFAULT); + cublasGemmEx(handle, trans_a_value ? CUBLAS_OP_T : CUBLAS_OP_N, + trans_b_value ? CUBLAS_OP_T : CUBLAS_OP_N, m_, n_, k_, + &alpha_value, b.data(), CUDA_R_32F, lda_, a.data(), CUDA_R_32F, + ldb_, &beta_value, c.data(), CUDA_R_32F, ldc_, + CUBLAS_COMPUTE_32F_FAST_TF32, CUBLAS_GEMM_DEFAULT); cublasDestroy(handle); } diff --git a/src/operator.h b/src/operator.h index ca3098e..29ee03e 100644 --- a/src/operator.h +++ b/src/operator.h @@ -30,6 +30,13 @@ class Operator { return op_ptr; } + template + static auto call(void* stream, Args&&... args) { + // TODO: Cache the created `Operator`. + return (*make(std::forward(args)...))(stream, + std::forward(args)...); + } + template auto operator()(Args&&... args) const { return (*static_cast(this))(std::forward(args)...); From dbe3f4c4733c5ee02964bfb1f9e0ae68623f50cc Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 11 Feb 2026 23:15:36 +0800 Subject: [PATCH 19/93] feat: add stream handling --- scripts/generate_wrappers.py | 38 ++++++++++++------------------------ src/operator.h | 29 ++++++++++++++++++++++----- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 73aff3e..9fe5cc9 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -51,28 +51,23 @@ def __init__(self, name, constructors, calls): def _generate_pybind11(operator): def _generate_params(node): - return ", ".join( - f"{arg.type.spelling} {arg.spelling}" for arg in node.get_arguments() + return ( + ", ".join( + f"{arg.type.spelling} {arg.spelling}" + for arg in node.get_arguments() + if arg.spelling != "stream" + ) + .replace("const Tensor", "py::object") + .replace("Tensor", "py::object") ) - def _generate_constructor_arguments(node): + def _generate_arguments(node): return ", ".join( _generate_tensor_caster(arg.spelling) if "Tensor" in arg.type.spelling else arg.spelling for arg in node.get_arguments() - ) - - def _generate_call_arguments(node): - return ", ".join( - ( - _generate_tensor_caster(arg.spelling) - if "Tensor" in arg.type.spelling - else arg.spelling - ) if arg.spelling != "stream" - else "reinterpret_cast(stream)" - for arg in node.get_arguments() ) def _generate_tensor_caster(name): @@ -82,27 +77,19 @@ def _generate_tensor_caster(name): def _generate_init(constructor): constructor_params = _generate_params(constructor) - constructor_params = constructor_params.replace( - "const Tensor", "py::object" - ).replace("Tensor", "py::object") return f""" .def(py::init([]({constructor_params}) {{ - return std::unique_ptr{{static_cast(Self::make({_generate_constructor_arguments(constructor)}).release())}}; + return std::unique_ptr{{static_cast(Self::make({_generate_arguments(constructor)}).release())}}; }}))""" def _generate_call(call, method=True): call_params = _generate_params(call) - call_params = ( - call_params.replace("void * stream", "std::uintptr_t stream") - .replace("const Tensor", "py::object") - .replace("Tensor", "py::object") - ) if not method: - return f""" m.def("gemm", []({call_params}) {{ return Self::call({_generate_call_arguments(call)}); }});""" + return f""" m.def("gemm", []({call_params}) {{ return Self::call({_generate_arguments(call)}); }});""" return f""" .def("__call__", [](const Self& self, {call_params}) {{ - return self({_generate_call_arguments(call)}); + return static_cast&>(self)({_generate_arguments(call)}); }})""" inits = "\n".join( @@ -179,6 +166,7 @@ def _generate_call(call, method=True): namespace infini::ops {{ PYBIND11_MODULE(ops, m) {{ +{_INDENTATION}m.def("set_stream", [](std::uintptr_t stream) {{ OperatorBase::set_stream(reinterpret_cast(stream)); }}); {textwrap.indent(bind_func_calls, _INDENTATION)} }} diff --git a/src/operator.h b/src/operator.h index 29ee03e..5c68fb2 100644 --- a/src/operator.h +++ b/src/operator.h @@ -8,11 +8,19 @@ namespace infini::ops { -template -class Operator { +class OperatorBase { public: - virtual ~Operator() = default; + virtual ~OperatorBase() = default; + + static void set_stream(void* stream) { stream_ = stream; } + + protected: + inline static thread_local void* stream_{nullptr}; +}; +template +class Operator : public OperatorBase { + public: template static auto make(const Tensor tensor, Args&&... args) { std::unique_ptr op_ptr; @@ -38,8 +46,19 @@ class Operator { } template - auto operator()(Args&&... args) const { - return (*static_cast(this))(std::forward(args)...); + static auto call(const Tensor tensor, Args&&... args) { + return call(stream_, tensor, std::forward(args)...); + } + + template + auto operator()(void* stream, Args&&... args) const { + return (*static_cast(this))(stream, + std::forward(args)...); + } + + template + auto operator()(const Tensor tensor, Args&&... args) const { + return operator()(stream_, tensor, std::forward(args)...); } }; From ba533cda4743917c914e4fe53bf66360623509c4 Mon Sep 17 00:00:00 2001 From: Ziminli Date: Tue, 10 Feb 2026 07:25:36 +0000 Subject: [PATCH 20/93] feat: extend Device enum and refactor GEMM support - Add additional entries to the Device enum class to support new hardware targets. - Adapt GEMM mcblas implementation to use MetaX backend and add the test example. - Extract common BLAS interfaces into a new blas.h abstraction for GEMM implementations to share. --- examples/metax/gemm/mcblas.cc | 73 +++++++++++++++++++++++++++++++++++ src/cuda/gemm/blas.h | 65 +++++++++++++++++++++++++++++++ src/device.h | 14 ++++++- src/metax/gemm/mcblas.h | 45 +++++++++++++++++++++ src/nvidia/gemm/cublas.h | 71 ++++++++++------------------------ src/operator.h | 9 +++++ 6 files changed, 226 insertions(+), 51 deletions(-) create mode 100644 examples/metax/gemm/mcblas.cc create mode 100644 src/cuda/gemm/blas.h create mode 100644 src/metax/gemm/mcblas.h diff --git a/examples/metax/gemm/mcblas.cc b/examples/metax/gemm/mcblas.cc new file mode 100644 index 0000000..7348fca --- /dev/null +++ b/examples/metax/gemm/mcblas.cc @@ -0,0 +1,73 @@ +#include "metax/gemm/mcblas.h" + +#include + +#include +#include + +#include "tensor.h" + +int main() { + using namespace infini::ops; + + constexpr auto m{2}; + constexpr auto k{3}; + constexpr auto n{4}; + + std::vector a_shape{m, k}; + std::vector b_shape{k, n}; + std::vector c_shape{m, n}; + + const auto a_num_elements{std::accumulate(a_shape.cbegin(), a_shape.cend(), 1, + std::multiplies())}; + const auto b_num_elements{std::accumulate(b_shape.cbegin(), b_shape.cend(), 1, + std::multiplies())}; + const auto c_num_elements{std::accumulate(c_shape.cbegin(), c_shape.cend(), 1, + std::multiplies())}; + + std::vector a_vec(a_num_elements); + std::vector b_vec(b_num_elements); + std::vector c_vec(c_num_elements); + + std::iota(a_vec.begin(), a_vec.end(), 0); + std::iota(b_vec.begin(), b_vec.end(), 0); + + Tensor a_host{a_vec.data(), a_shape, Device{Device::Type::kMetax}}; + Tensor b_host{b_vec.data(), b_shape, Device{Device::Type::kMetax}}; + Tensor c_host{c_vec.data(), c_shape, Device{Device::Type::kMetax}}; + + Tensor a_device{nullptr, a_host.shape(), a_host.dtype(), a_host.device(), + a_host.strides()}; + Tensor b_device{nullptr, b_host.shape(), b_host.dtype(), a_host.device(), + b_host.strides()}; + Tensor c_device{nullptr, c_host.shape(), c_host.dtype(), a_host.device(), + c_host.strides()}; + + const auto a_size{a_num_elements * a_device.dtype().element_size()}; + const auto b_size{b_num_elements * b_device.dtype().element_size()}; + const auto c_size{c_num_elements * c_device.dtype().element_size()}; + + mcMalloc(&a_device.data(), a_size); + mcMalloc(&b_device.data(), b_size); + mcMalloc(&c_device.data(), c_size); + + mcMemcpy(a_device.data(), a_vec.data(), a_size, mcMemcpyHostToDevice); + mcMemcpy(b_device.data(), b_vec.data(), b_size, mcMemcpyHostToDevice); + mcMemset(c_device.data(), 0, c_size); + + Gemm::call(nullptr, a_device, b_device, c_device); + + mcMemcpy(a_host.data(), a_device.data(), a_size, mcMemcpyDeviceToHost); + mcMemcpy(b_host.data(), b_device.data(), b_size, mcMemcpyDeviceToHost); + mcMemcpy(c_host.data(), c_device.data(), c_size, mcMemcpyDeviceToHost); + + mcFree(a_device.data()); + mcFree(b_device.data()); + mcFree(c_device.data()); + + std::cout << "A: " << a_host.ToString() << "\n"; + std::cout << "B: " << b_host.ToString() << "\n"; + std::cout << "C: " << c_host.ToString() << "\n"; + + return 0; +} diff --git a/src/cuda/gemm/blas.h b/src/cuda/gemm/blas.h new file mode 100644 index 0000000..a26588a --- /dev/null +++ b/src/cuda/gemm/blas.h @@ -0,0 +1,65 @@ +#ifndef INFINI_OPS_CUDA_GEMM_BLAS_H_ +#define INFINI_OPS_CUDA_GEMM_BLAS_H_ + +#include + +#include "base/gemm.h" + +namespace infini::ops { + +template +class Blas : public Gemm { + public: + Blas(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) + : Gemm{a.stride(0) == 1 ? a : b.T(), + a.stride(0) == 1 ? b : a.T(), + alpha, + beta, + trans_a, + trans_b, + a.stride(0) == 1 ? c : c.T()}, + lda_{a_strides_[1]}, + ldb_{b_strides_[1]}, + ldc_{c_strides_[1]} { + // TODO: Check constraints. + } + + Blas(const Tensor a, const Tensor b, Tensor c) + : Blas{a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt, c} {} + + void operator()(void* stream, const Tensor a, const Tensor b, + std::optional alpha, std::optional beta, + std::optional trans_a, std::optional trans_b, + Tensor c) const override { + typename Backend::blasHandle_t handle; + Backend::blasCreate(&handle); + + Backend::blasSetStream(handle, + static_cast(stream)); + + const auto& alpha_value{alpha.value_or(alpha_)}; + const auto& beta_value{beta.value_or(beta_)}; + const auto& trans_a_value{alpha.value_or(trans_a_)}; + const auto& trans_b_value{beta.value_or(trans_b_)}; + + assert(a_type_ == kFloat32 && b_type_ == kFloat32 && c_type_ == kFloat32 && + "`operator()` not implemented for this data type"); + + Backend::blasGemmEx(handle, trans_a_value, trans_b_value, m_, n_, k_, + &alpha_value, b.data(), lda_, a.data(), ldb_, + &beta_value, c.data(), ldc_); + + Backend::blasDestroy(handle); + } + + private: + Tensor::Stride lda_{0}; + Tensor::Stride ldb_{0}; + Tensor::Stride ldc_{0}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/device.h b/src/device.h index 0eb4db1..22e3bc0 100644 --- a/src/device.h +++ b/src/device.h @@ -6,7 +6,19 @@ namespace infini::ops { class Device { public: // TODO: Complete the list. - enum class Type { kCpu, kNvidia, kCount }; + enum class Type { + kCpu = 0, + kNvidia = 1, + kCambricon = 2, + kAscend = 3, + kMetax = 4, + kMoore = 5, + kIluvatar = 6, + kKunlun = 7, + kHygon = 8, + kQy = 9, + kCount + }; Device() = default; diff --git a/src/metax/gemm/mcblas.h b/src/metax/gemm/mcblas.h new file mode 100644 index 0000000..cc1f081 --- /dev/null +++ b/src/metax/gemm/mcblas.h @@ -0,0 +1,45 @@ +#ifndef INFINI_OPS_METAX_GEMM_MCBLAS_H_ +#define INFINI_OPS_METAX_GEMM_MCBLAS_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/gemm/blas.h" + +namespace infini::ops { + +struct MetaxBackend { + using blasHandle_t = mcblasHandle_t; + using stream_t = mcStream_t; + + static void blasCreate(blasHandle_t* handle) { mcblasCreate(handle); } + + static void blasSetStream(blasHandle_t handle, stream_t stream) { + mcblasSetStream(handle, stream); + } + + static void blasGemmEx(blasHandle_t handle, bool transA, bool transB, int m, + int n, int k, const float* alpha, const void* B, + int ldb, const void* A, int lda, const float* beta, + void* C, int ldc) { + mcblasGemmEx(handle, transA ? MCBLAS_OP_T : MCBLAS_OP_N, + transB ? MCBLAS_OP_T : MCBLAS_OP_N, m, n, k, alpha, B, + MACA_R_32F, ldb, A, MACA_R_32F, lda, beta, C, MACA_R_32F, ldc, + MCBLAS_COMPUTE_32F_FAST_TF32, MCBLAS_GEMM_DEFAULT); + } + + static void blasDestroy(blasHandle_t handle) { mcblasDestroy(handle); } +}; + +template <> +class Operator : public Blas { + public: + using Blas::Blas; +}; + +} // namespace infini::ops + +#endif diff --git a/src/nvidia/gemm/cublas.h b/src/nvidia/gemm/cublas.h index aa77772..9f7b8b6 100644 --- a/src/nvidia/gemm/cublas.h +++ b/src/nvidia/gemm/cublas.h @@ -7,66 +7,37 @@ #include "cublas_v2.h" // clang-format on -#include "base/gemm.h" +#include "cuda/gemm/blas.h" namespace infini::ops { -template <> -class Operator : public Gemm { - public: - Operator(const Tensor a, const Tensor b, std::optional alpha, - std::optional beta, std::optional trans_a, - std::optional trans_b, Tensor c) - : Gemm{a.stride(0) == 1 ? a : b.T(), - a.stride(0) == 1 ? b : a.T(), - alpha, - beta, - trans_a, - trans_b, - a.stride(0) == 1 ? c : c.T()}, - lda_{a_strides_[1]}, - ldb_{b_strides_[1]}, - ldc_{c_strides_[1]} { - // TODO: Check constraints. - } - - Operator(const Tensor a, const Tensor b, Tensor c) - : Operator{a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - c} {} - - void operator()(void* stream, const Tensor a, const Tensor b, - std::optional alpha, std::optional beta, - std::optional trans_a, std::optional trans_b, - Tensor c) const override { - cublasHandle_t handle; - cublasCreate(&handle); - - cublasSetStream(handle, static_cast(stream)); +struct CudaBackend { + using blasHandle_t = cublasHandle_t; + using stream_t = cudaStream_t; - const auto& alpha_value{alpha.value_or(alpha_)}; - const auto& beta_value{beta.value_or(beta_)}; - const auto& trans_a_value{alpha.value_or(trans_a_)}; - const auto& trans_b_value{beta.value_or(trans_b_)}; + static void blasCreate(blasHandle_t* handle) { cublasCreate(handle); } - // TODO: Add support for more data types. - assert(a_type_ == kFloat32 && b_type_ == kFloat32 && c_type_ == kFloat32 && - "`operator()` not implemented for this data type"); + static void blasSetStream(blasHandle_t handle, stream_t stream) { + cublasSetStream(handle, stream); + } - cublasGemmEx(handle, trans_a_value ? CUBLAS_OP_T : CUBLAS_OP_N, - trans_b_value ? CUBLAS_OP_T : CUBLAS_OP_N, m_, n_, k_, - &alpha_value, b.data(), CUDA_R_32F, lda_, a.data(), CUDA_R_32F, - ldb_, &beta_value, c.data(), CUDA_R_32F, ldc_, + static void blasGemmEx(blasHandle_t handle, bool transA, bool transB, int m, + int n, int k, const float* alpha, const void* B, + int ldb, const void* A, int lda, const float* beta, + void* C, int ldc) { + cublasGemmEx(handle, transA ? CUBLAS_OP_T : CUBLAS_OP_N, + transB ? CUBLAS_OP_T : CUBLAS_OP_N, m, n, k, alpha, B, + CUDA_R_32F, ldb, A, CUDA_R_32F, lda, beta, C, CUDA_R_32F, ldc, CUBLAS_COMPUTE_32F_FAST_TF32, CUBLAS_GEMM_DEFAULT); - - cublasDestroy(handle); } - private: - Tensor::Stride lda_{0}; - - Tensor::Stride ldb_{0}; + static void blasDestroy(blasHandle_t handle) { cublasDestroy(handle); } +}; - Tensor::Stride ldc_{0}; +template <> +class Operator : public Blas { + public: + using Blas::Blas; }; } // namespace infini::ops diff --git a/src/operator.h b/src/operator.h index 5c68fb2..45e4f15 100644 --- a/src/operator.h +++ b/src/operator.h @@ -26,10 +26,19 @@ class Operator : public OperatorBase { std::unique_ptr op_ptr; switch (tensor.device().type()) { + // TODO(lzm): use dispatcher to conditionally compile and dispatch + // the devices. This is only a temporary solution +#ifdef USE_CUDA case Device::Type::kNvidia: op_ptr = std::make_unique>( tensor, std::forward(args)...); break; +#elif USE_MACA + case Device::Type::kMetax: + op_ptr = std::make_unique>( + tensor, std::forward(args)...); + break; +#endif default: assert(false && "constructor dispatching not implemented for this device"); From 8da1d18fa0e44289eaf732d84958362fe6fd919b Mon Sep 17 00:00:00 2001 From: Ziminli Date: Wed, 11 Feb 2026 12:12:58 +0000 Subject: [PATCH 21/93] feat: add generic dispatcher, compile-time traits/constructs and CPU GEMM implementation - Add `ConstexprMap` and compile-time traits in `common/` for efficient type-to-metadata mapping and relevant operations. - Implement a generic dispatcher to reduce boilerplate for dispatching, especially for data types and devices. - Add the CPU implementation for the GEMM - Update `DataType` definitions and type lists to support wide dispatching. Follow-up: support for fp16 and bf16 kernels is pending. --- examples/metax/gemm/mcblas.cc | 10 +- examples/nvidia/gemm/cublas.cc | 10 +- src/base/gemm.h | 6 +- src/common/constexpr_map.h | 28 ++++++ src/common/traits.h | 108 ++++++++++++++++++++ src/cpu/gemm/gemm.h | 64 ++++++++++++ src/cuda/gemm/blas.h | 7 +- src/data_type.h | 158 +++++++++++++++++++----------- src/device.h | 117 ++++++++++++++++++++-- src/dispatcher.h | 174 +++++++++++++++++++++++++++++++++ src/operator.h | 26 ++--- src/tensor.cc | 22 ++--- src/tensor.h | 8 +- 13 files changed, 628 insertions(+), 110 deletions(-) create mode 100644 src/common/constexpr_map.h create mode 100644 src/common/traits.h create mode 100644 src/cpu/gemm/gemm.h create mode 100644 src/dispatcher.h diff --git a/examples/metax/gemm/mcblas.cc b/examples/metax/gemm/mcblas.cc index 7348fca..af60f4d 100644 --- a/examples/metax/gemm/mcblas.cc +++ b/examples/metax/gemm/mcblas.cc @@ -5,6 +5,10 @@ #include #include +#ifdef USE_CPU +#include "cpu/gemm/gemm.h" +#endif + #include "tensor.h" int main() { @@ -43,9 +47,9 @@ int main() { Tensor c_device{nullptr, c_host.shape(), c_host.dtype(), a_host.device(), c_host.strides()}; - const auto a_size{a_num_elements * a_device.dtype().element_size()}; - const auto b_size{b_num_elements * b_device.dtype().element_size()}; - const auto c_size{c_num_elements * c_device.dtype().element_size()}; + const auto a_size{a_num_elements * kDataTypeToSize.at(a_device.dtype())}; + const auto b_size{b_num_elements * kDataTypeToSize.at(b_device.dtype())}; + const auto c_size{c_num_elements * kDataTypeToSize.at(c_device.dtype())}; mcMalloc(&a_device.data(), a_size); mcMalloc(&b_device.data(), b_size); diff --git a/examples/nvidia/gemm/cublas.cc b/examples/nvidia/gemm/cublas.cc index 7a7723a..7885d3a 100644 --- a/examples/nvidia/gemm/cublas.cc +++ b/examples/nvidia/gemm/cublas.cc @@ -5,6 +5,10 @@ #include #include +#ifdef USE_CPU +#include "cpu/gemm/gemm.h" +#endif + #include "tensor.h" int main() { @@ -43,9 +47,9 @@ int main() { Tensor c_device{nullptr, c_host.shape(), c_host.dtype(), a_host.device(), c_host.strides()}; - const auto a_size{a_num_elements * a_device.dtype().element_size()}; - const auto b_size{b_num_elements * b_device.dtype().element_size()}; - const auto c_size{c_num_elements * c_device.dtype().element_size()}; + const auto a_size{a_num_elements * kDataTypeToSize.at(a_device.dtype())}; + const auto b_size{b_num_elements * kDataTypeToSize.at(b_device.dtype())}; + const auto c_size{c_num_elements * kDataTypeToSize.at(c_device.dtype())}; cudaMalloc(&a_device.data(), a_size); cudaMalloc(&b_device.data(), b_size); diff --git a/src/base/gemm.h b/src/base/gemm.h index 7602f65..ab96c92 100644 --- a/src/base/gemm.h +++ b/src/base/gemm.h @@ -54,11 +54,11 @@ class Gemm : public Operator { Tensor::Size k_{0}; - const DataType& a_type_; + const DataType a_type_; - const DataType& b_type_; + const DataType b_type_; - const DataType& c_type_; + const DataType c_type_; Tensor::Strides a_strides_; diff --git a/src/common/constexpr_map.h b/src/common/constexpr_map.h new file mode 100644 index 0000000..011d974 --- /dev/null +++ b/src/common/constexpr_map.h @@ -0,0 +1,28 @@ +#ifndef INFINI_OPS_COMMON_CONSTEXPR_MAP_H_ +#define INFINI_OPS_COMMON_CONSTEXPR_MAP_H_ + +#include +#include +#include +#include + +namespace infini::ops { + +template +struct ConstexprMap { + std::array, N> data; + + constexpr Value at(Key key) const { + for (const auto &pr : data) { + if (pr.first == key) return pr.second; + } + // TODO(lzm): change to logging + std::cerr << "ConstexprMap's key not found at " << __FILE__ << ":" + << __LINE__ << std::endl; + std::abort(); + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/common/traits.h b/src/common/traits.h new file mode 100644 index 0000000..83de106 --- /dev/null +++ b/src/common/traits.h @@ -0,0 +1,108 @@ +#ifndef INFINI_OPS_COMMON_TRAITS_H_ +#define INFINI_OPS_COMMON_TRAITS_H_ + +#include +#include + +namespace infini::ops { + +// A generic container for a sequence of compile-time values. +template +struct List {}; + +// ----------------------------------------------------------------------------- +// List Queries +// ----------------------------------------------------------------------------- + +// Check at compile-time if a Value exists within a construct (e.g., List<>). +// Example: static_assert(Contains_v); +template +struct Contains; + +template +struct Contains, Value> + : std::disjunction...> {}; + +template +inline constexpr bool Contains_v = Contains::value; + +// Check at compile-time if a type T is present in a variadic list of types Ts. +// Example: static_assert(IsTypeInList); +template +inline constexpr bool IsTypeInList = (std::is_same_v || ...); + +// ----------------------------------------------------------------------------- +// List Operations +// ----------------------------------------------------------------------------- + +// Concatenates two List types into a single List. +// Example: Concat_t, List<3, 4>> is List<1, 2, 3, 4>. +template +struct Concat; + +template +struct Concat, List> { + using type = List; +}; + +template +using Concat_t = typename Concat::type; + +// ----------------------------------------------------------------------------- +// Invocability Detection (SFINAE) +// ----------------------------------------------------------------------------- + +// Checks if a Functor's template operator() can be called with Args. +template +struct IsInvocable : std::false_type {}; + +template +struct IsInvocable< + Functor, Value, + std::void_t().template operator()( + std::declval()...))>, + Args...> : std::true_type {}; + +template +inline constexpr bool IsInvocable_v = + IsInvocable::value; + +// ----------------------------------------------------------------------------- +// Filtering Logic +// ----------------------------------------------------------------------------- + +// Recursive template to filter values based on Functor support at compile-time. +template +struct Filter; + +// Base case: All values processed. +template +struct Filter, List> { + using type = List; +}; + +// Recursive step: Test the 'Head' value and accumulate if supported. +template +struct Filter, List, Head, Tail...> { + using type = typename std::conditional_t< + IsInvocable_v && + !Contains_v, Head>, + Filter, List, Tail...>, + Filter, List, Tail...>>::type; +}; + +// Interface to filter a List type directly. +template +struct FilterList; + +template +struct FilterList, List> { + using type = + typename Filter, List<>, Items...>::type; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/gemm/gemm.h b/src/cpu/gemm/gemm.h new file mode 100644 index 0000000..fbec039 --- /dev/null +++ b/src/cpu/gemm/gemm.h @@ -0,0 +1,64 @@ +#ifndef INFINI_OPS_CPU_GEMM_H_ +#define INFINI_OPS_CPU_GEMM_H_ + +#include + +#include "base/gemm.h" + +namespace infini::ops { + +template <> +class Operator : public Gemm { + public: + Operator(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) + : Gemm{a, b, alpha, beta, trans_a, trans_b, c}, + lda_{a_strides_[1]}, + ldb_{b_strides_[1]}, + ldc_{c_strides_[1]} { + // TODO: Check constraints. + } + + Operator(const Tensor a, const Tensor b, Tensor c) + : Operator{a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + c} {} + + void operator()(void* stream, const Tensor a, const Tensor b, + std::optional alpha, std::optional beta, + std::optional trans_a, std::optional trans_b, + Tensor c) const override { + const float* A = static_cast(a.data()); + const float* B = static_cast(b.data()); + float* C = static_cast(c.data()); + + const auto& alpha_value{alpha.value_or(alpha_)}; + const auto& beta_value{beta.value_or(beta_)}; + const auto& trans_a_value{trans_a.value_or(trans_a_)}; + const auto& trans_b_value{trans_b.value_or(trans_b_)}; + + for (Tensor::Size i = 0; i < m_; ++i) { + for (Tensor::Size j = 0; j < n_; ++j) { + float sum = 0.0f; + + for (Tensor::Size l = 0; l < k_; ++l) { + float a_val = trans_a_value ? A[l * m_ + i] : A[i * k_ + l]; + float b_val = trans_b_value ? B[j * k_ + l] : B[l * n_ + j]; + sum += a_val * b_val; + } + + Tensor::Size idx = i * n_ + j; + C[idx] = alpha_value * sum + beta_value * C[idx]; + } + } + } + + private: + Tensor::Stride lda_{0}; + Tensor::Stride ldb_{0}; + Tensor::Stride ldc_{0}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/gemm/blas.h b/src/cuda/gemm/blas.h index a26588a..c3c89a4 100644 --- a/src/cuda/gemm/blas.h +++ b/src/cuda/gemm/blas.h @@ -41,10 +41,11 @@ class Blas : public Gemm { const auto& alpha_value{alpha.value_or(alpha_)}; const auto& beta_value{beta.value_or(beta_)}; - const auto& trans_a_value{alpha.value_or(trans_a_)}; - const auto& trans_b_value{beta.value_or(trans_b_)}; + const auto& trans_a_value{trans_a.value_or(trans_a_)}; + const auto& trans_b_value{trans_b.value_or(trans_b_)}; - assert(a_type_ == kFloat32 && b_type_ == kFloat32 && c_type_ == kFloat32 && + assert(a_type_ == DataType::kFloat32 && b_type_ == DataType::kFloat32 && + c_type_ == DataType::kFloat32 && "`operator()` not implemented for this data type"); Backend::blasGemmEx(handle, trans_a_value, trans_b_value, m_, n_, k_, diff --git a/src/data_type.h b/src/data_type.h index 7cf8045..0a2a248 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -1,72 +1,112 @@ #ifndef INFINI_OPS_DATA_TYPE_H_ #define INFINI_OPS_DATA_TYPE_H_ -#include -#include #include #include -#include -namespace infini::ops { - -class DataType { - public: - constexpr DataType(int index, std::size_t element_size, const char* name) - : index_{index}, element_size_{element_size}, name_{name} {} - - static const DataType& FromString(const std::string& name); - - constexpr bool operator==(const DataType& other) const { - return index_ == other.index_; - } - - constexpr std::size_t element_size() const { return element_size_; } - - constexpr const char* name() const { return name_; } +#include "common/constexpr_map.h" +#include "common/traits.h" - private: - int index_{0}; - - std::size_t element_size_{0}; +namespace infini::ops { - const char* name_{nullptr}; +enum class DataType : int8_t { + kInt8, + kInt16, + kInt32, + kInt64, + kUInt8, + kUInt16, + kUInt32, + kUInt64, + kFloat16, + kBFloat16, + kFloat32, + kFloat64 }; -constexpr DataType kInt8{0, sizeof(int8_t), "int8"}; - -constexpr DataType kInt16{1, sizeof(int16_t), "int16"}; - -constexpr DataType kInt32{2, sizeof(int32_t), "int32"}; - -constexpr DataType kInt64{3, sizeof(int64_t), "int64"}; - -constexpr DataType kUInt8{4, sizeof(uint8_t), "uint8"}; - -constexpr DataType kUInt16{5, sizeof(uint16_t), "uint16"}; - -constexpr DataType kUInt32{6, sizeof(uint32_t), "uint32"}; - -constexpr DataType kUInt64{7, sizeof(uint64_t), "uint64"}; - -constexpr DataType kFloat16{8, 2, "float16"}; - -constexpr DataType kBFloat16{9, 2, "bfloat16"}; - -constexpr DataType kFloat32{10, sizeof(float), "float32"}; - -constexpr DataType kFloat64{11, sizeof(double), "float64"}; - -inline const DataType& DataType::FromString(const std::string& name) { - static std::unordered_map name_to_dtype{ - {kInt8.name(), kInt8}, {kInt16.name(), kInt16}, - {kInt32.name(), kInt32}, {kInt64.name(), kInt64}, - {kUInt8.name(), kUInt8}, {kUInt16.name(), kUInt16}, - {kUInt32.name(), kUInt32}, {kUInt64.name(), kUInt64}, - {kFloat16.name(), kFloat16}, {kBFloat16.name(), kBFloat16}, - {kFloat32.name(), kFloat32}, {kFloat64.name(), kFloat64}}; - - return name_to_dtype.at(name); -} +constexpr ConstexprMap kDataTypeToSize{{{ + {DataType::kInt8, 1}, + {DataType::kInt16, 2}, + {DataType::kInt32, 4}, + {DataType::kInt64, 8}, + {DataType::kUInt8, 1}, + {DataType::kUInt16, 2}, + {DataType::kUInt32, 4}, + {DataType::kUInt64, 8}, + {DataType::kFloat16, 2}, + {DataType::kBFloat16, 2}, + {DataType::kFloat32, 4}, + {DataType::kFloat64, 8}, +}}}; + +constexpr ConstexprMap kDataTypeToDesc{{{ + {DataType::kInt8, "int8"}, + {DataType::kInt16, "int16"}, + {DataType::kInt32, "int32"}, + {DataType::kInt64, "int64"}, + {DataType::kUInt8, "uint8"}, + {DataType::kUInt16, "uint16"}, + {DataType::kUInt32, "uint32"}, + {DataType::kUInt64, "uint64"}, + {DataType::kFloat16, "float16"}, + {DataType::kBFloat16, "bfloat16"}, + {DataType::kFloat32, "float32"}, + {DataType::kFloat64, "float64"}, +}}}; + +template +struct TypeMap; + +template +using TypeMap_t = typename TypeMap::type; + +template +struct DataTypeMap; + +template +inline constexpr DataType DataTypeMap_v = DataTypeMap::value; + +#define DEFINE_DATA_TYPE_MAPPING(ENUM_VALUE, CPP_TYPE) \ + template <> \ + struct TypeMap { \ + using type = CPP_TYPE; \ + }; \ + template <> \ + struct DataTypeMap { \ + static constexpr DataType value = DataType::ENUM_VALUE; \ + }; + +DEFINE_DATA_TYPE_MAPPING(kUInt8, uint8_t) +DEFINE_DATA_TYPE_MAPPING(kInt8, int8_t) +DEFINE_DATA_TYPE_MAPPING(kUInt16, uint16_t) +DEFINE_DATA_TYPE_MAPPING(kInt16, int16_t) +DEFINE_DATA_TYPE_MAPPING(kUInt32, uint32_t) +DEFINE_DATA_TYPE_MAPPING(kInt32, int32_t) +DEFINE_DATA_TYPE_MAPPING(kUInt64, uint64_t) +DEFINE_DATA_TYPE_MAPPING(kInt64, int64_t) +DEFINE_DATA_TYPE_MAPPING(kFloat32, float) +DEFINE_DATA_TYPE_MAPPING(kFloat64, double) +// TODO(lzm): support fp16 and bf16 + +// Defines the common categories of data types using List +using FloatingTypes = List; +using ReducedFloatingTypes = List; +using SignedIntegralTypes = + List; +using UnsignedIntegralTypes = List; + +using BitTypes8 = List; +using BitTypes16 = List; +using BitTypes32 = + List; +using BitTypes64 = + List; + +using AllFloatingTypes = Concat_t; +using AllIntegralTypes = Concat_t; +using AllTypes = Concat_t; } // namespace infini::ops diff --git a/src/device.h b/src/device.h index 22e3bc0..c07b53c 100644 --- a/src/device.h +++ b/src/device.h @@ -1,11 +1,13 @@ #ifndef INFINI_OPS_DEVICE_H_ #define INFINI_OPS_DEVICE_H_ +#include "common/constexpr_map.h" +#include "common/traits.h" + namespace infini::ops { class Device { public: - // TODO: Complete the list. enum class Type { kCpu = 0, kNvidia = 1, @@ -24,12 +26,12 @@ class Device { Device(const Type& type, const int& index = 0) : type_{type}, index_{index} {} - static const Type& TypeFromString(const std::string& name) { - // TODO: Handle `"cuda"` dispatching. - static std::unordered_map name_to_type{ - {"cpu", Type::kCpu}, {"cuda", Type::kNvidia}}; + static const Type TypeFromString(const std::string& name) { + return kDescToDevice.at(name); + } - return name_to_type.at(name); + static const std::string_view StringFromType(const Type& type) { + return kDeviceToDesc.at(type); } const Type& type() const { return type_; } @@ -39,9 +41,112 @@ class Device { private: Type type_{Type::kCpu}; + static constexpr ConstexprMap + kDeviceToDesc{{{ + {Device::Type::kCpu, "CPU"}, + {Device::Type::kNvidia, "NVIDIA"}, + {Device::Type::kCambricon, "Cambricon"}, + {Device::Type::kAscend, "Ascend"}, + {Device::Type::kMetax, "Metax"}, + {Device::Type::kMoore, "Moore"}, + {Device::Type::kIluvatar, "Iluvatar"}, + {Device::Type::kKunlun, "Kunlun"}, + {Device::Type::kHygon, "Hygon"}, + {Device::Type::kQy, "QY"}, + }}}; + + static constexpr ConstexprMap + kDescToDevice{{{ + {"CPU", Device::Type::kCpu}, + {"NVIDIA", Device::Type::kNvidia}, + {"Cambricon", Device::Type::kCambricon}, + {"Ascend", Device::Type::kAscend}, + {"Metax", Device::Type::kMetax}, + {"Moore", Device::Type::kMoore}, + {"Iluvatar", Device::Type::kIluvatar}, + {"Kunlun", Device::Type::kKunlun}, + {"Hygon", Device::Type::kHygon}, + {"QY", Device::Type::kQy}, + }}}; + int index_{0}; }; +struct EnabledDeviceFilter { + // Each block defines a template operator() specialized for a specific + // Device. If the macro is NOT defined, the specialization is not compiled, + // and FilterList will exclude it from ActiveDevices. + +#ifdef USE_CPU + template = 0> + void operator()() const {} +#endif + +#ifdef USE_NVIDIA + template = 0> + void operator()() const {} +#endif + +#ifdef USE_CAMBRICON + template = 0> + void operator()() const {} +#endif + +#ifdef USE_ASCEND + template = 0> + void operator()() const {} +#endif + +#ifdef USE_METAX + template = 0> + void operator()() const {} +#endif + +#ifdef USE_MOORE + template = 0> + void operator()() const {} +#endif + +#ifdef USE_ILUVATAR + template = 0> + void operator()() const {} +#endif + +#ifdef USE_KUNLUN + template = 0> + void operator()() const {} +#endif + +#ifdef USE_HYGON + template = 0> + void operator()() const {} +#endif + +#ifdef USE_QY + template = 0> + void operator()() const {} +#endif +}; + +// Defines the common categories of devices using List +using AllDeviceTypes = + List; + +using ActiveDevices = + typename infini::ops::FilterList, + AllDeviceTypes>::type; + } // namespace infini::ops #endif diff --git a/src/dispatcher.h b/src/dispatcher.h new file mode 100644 index 0000000..a2a0aa1 --- /dev/null +++ b/src/dispatcher.h @@ -0,0 +1,174 @@ +#ifndef INFINI_OPS_DISPATCHER_H_ +#define INFINI_OPS_DISPATCHER_H_ + +#include +#include +#include + +#include "common/traits.h" +#include "data_type.h" +#include "device.h" + +namespace infini::ops { + +// ----------------------------------------------------------------------------- +// Core Generic Runtime Dispatchers +// ----------------------------------------------------------------------------- + +// (Single Dispatch) Dispatches a runtime value to a compile-time functor. +template +auto DispatchFunc(ValueType value, Functor&& func, + std::string_view context_str = "", Args&&... args) { + using FilteredPack = + typename Filter, List<>, AllValues...>::type; + + return [&](List) { + using ReturnType = + decltype(std::forward(func) + .template operator()(0)>( + std::forward(args)...)); + + bool handled = false; + + if constexpr (std::is_void_v) { + handled = + ((value == static_cast(Pruned) + ? (std::forward(func).template operator()( + std::forward(args)...), + true) + : false) || + ...); + } else { + std::optional result; + handled = + ((value == static_cast(Pruned) + ? (result.emplace( + std::forward(func).template operator()( + std::forward(args)...)), + true) + : false) || + ...); + return *result; + } + if (!handled) { + // TODO(lzm): change to logging + std::cerr << "Dispatch error: Value " << static_cast(value) + << " not supported in the context: " << context_str << "\n"; + std::abort(); + } + }(FilteredPack{}); +} + +// (Multi-Dispatch) Dispatches a vector of runtime values to a compile-time +// functor. +// Base Case: All dimensions resolved +template +auto DispatchFunc(const std::vector& values, size_t index, + Functor&& func, std::string_view context_str, List, + Args&&... args) { + return std::forward(func).template operator()( + std::forward(args)...); +} + +// (Multi-Dispatch) Recursive Case +template +auto DispatchFunc(const std::vector& values, size_t index, + Functor&& func, std::string_view context_str, List, + Args&&... args) { + return [&](List) { + static_assert(sizeof...(Allowed) > 0, + "DispatchFunc dimension list is empty!"); + using EnumType = std::common_type_t; + + return DispatchFunc( + static_cast(values.at(index)), + [&](Args&&... inner_args) { + return DispatchFunc( + values, index + 1, std::forward(func), context_str, + List{}, std::forward(inner_args)...); + }, + context_str, std::forward(args)...); + }(FirstList{}); +} + +// ----------------------------------------------------------------------------- +// High-Level Specialized Dispatchers +// ----------------------------------------------------------------------------- +// These provide cleaner and more convenient APIs for common InfiniOps types. + +// DataType Dispatch +template +auto DispatchFunc(DataType dtype, Functor&& func, + std::string_view context_str = "", Args&&... args) { + return DispatchFunc( + dtype, + [&](Args&&... inner_args) { + using T = TypeMap_t
; + return std::forward(func).template operator()( + std::forward(inner_args)...); + }, + context_str, std::forward(args)...); +} + +// DataType Multi-Dispatch +template +auto DispatchFunc(std::initializer_list dtypes, Functor&& func, + std::string_view context_str = "", Args&&... args) { + std::vector v; + for (auto d : dtypes) v.push_back(static_cast(d)); + + return DispatchFunc( + v, 0, + [&func](Args&&... inner_args) { + return std::forward(func).template + operator()...>(std::forward(inner_args)...); + }, + context_str, List<>{}, std::forward(args)...); +} + +// Device Dispatch +template +auto DispatchFunc(Device::Type device, Functor&& func, + std::string_view context_str = "", Args&&... args) { + return DispatchFunc( + device, + [&](Args&&... inner_args) { + return std::forward(func).template operator()( + std::forward(inner_args)...); + }, + context_str, std::forward(args)...); +} + +// Device Multi-Dispatch +template +auto DispatchFunc(std::initializer_list devices, Functor&& func, + std::string_view context_str = "", Args&&... args) { + std::vector v; + for (auto d : devices) v.push_back(static_cast(d)); + + return DispatchFunc( + v, 0, + [&func](Args&&... inner_args) { + return std::forward(func).template operator()( + std::forward(inner_args)...); + }, + context_str, List<>{}, std::forward(args)...); +} + +// Interface for generic List Aliases, which unpacks a list +template +auto DispatchFunc(ValueType value, Functor&& func, + std::string_view context_str = "", Args&&... args) { + return [&](List) { + return DispatchFunc>(Is)...>( + value, std::forward(func), context_str, + std::forward(args)...); + }(ListType{}); +} + +} // namespace infini::ops + +#endif diff --git a/src/operator.h b/src/operator.h index 45e4f15..33f8a42 100644 --- a/src/operator.h +++ b/src/operator.h @@ -4,6 +4,7 @@ #include #include +#include "dispatcher.h" #include "tensor.h" namespace infini::ops { @@ -25,24 +26,13 @@ class Operator : public OperatorBase { static auto make(const Tensor tensor, Args&&... args) { std::unique_ptr op_ptr; - switch (tensor.device().type()) { - // TODO(lzm): use dispatcher to conditionally compile and dispatch - // the devices. This is only a temporary solution -#ifdef USE_CUDA - case Device::Type::kNvidia: - op_ptr = std::make_unique>( - tensor, std::forward(args)...); - break; -#elif USE_MACA - case Device::Type::kMetax: - op_ptr = std::make_unique>( - tensor, std::forward(args)...); - break; -#endif - default: - assert(false && - "constructor dispatching not implemented for this device"); - } + DispatchFunc( + tensor.device().type(), + [&]() { + op_ptr = std::make_unique>( + tensor, std::forward(args)...); + }, + "Operator make"); return op_ptr; } diff --git a/src/tensor.cc b/src/tensor.cc index b4c4269..3a8d70b 100644 --- a/src/tensor.cc +++ b/src/tensor.cc @@ -2,6 +2,8 @@ #include +#include "dispatcher.h" + namespace infini::ops { Tensor::Tensor(void* data, std::initializer_list shape, @@ -24,7 +26,7 @@ const void* const& Tensor::data() const { return data_; } const Tensor::Shape& Tensor::shape() const { return shape_; } -const DataType& Tensor::dtype() const { return dtype_; } +const DataType Tensor::dtype() const { return dtype_; } const Device& Tensor::device() const { return device_; } @@ -38,7 +40,7 @@ Tensor::Stride Tensor::stride(const Index& index) const { Tensor::Size Tensor::ndim() const { return shape_.size(); } -Tensor::Size Tensor::element_size() const { return dtype_.element_size(); } +Tensor::Size Tensor::element_size() const { return kDataTypeToSize.at(dtype_); } Tensor Tensor::T() const { return {data_, @@ -49,10 +51,11 @@ Tensor Tensor::T() const { } std::string Tensor::ToString() const { - return "tensor(" + ToStringHelper() + ", dtype=" + dtype_.name() + ")"; + return "tensor(" + ToStringHelper() + + ", dtype=" + std::string(kDataTypeToDesc.at(dtype_)) + ")"; } -const DataType& Tensor::DefaultDataType() { return kFloat32; } +const DataType Tensor::DefaultDataType() { return DataType::kFloat32; } Device Tensor::DefaultDevice() { return Device{Device::Type::kCpu}; } @@ -74,13 +77,10 @@ Tensor::Strides Tensor::DefaultStrides(const Shape& shape) { std::string Tensor::ToStringHelper() const { if (ndim() == 0) { - if (dtype_ == kFloat32) { - return std::to_string(*static_cast(data_)); - } - - // TODO: Handle more data types here. - - assert(false && "string conversion not implemented for this data type"); + return DispatchFunc( + dtype_, + [&]() { return std::to_string(*static_cast(data_)); }, + "ToStringHelper"); } std::string result{"["}; diff --git a/src/tensor.h b/src/tensor.h index 66ca60b..743f908 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -31,7 +31,7 @@ class Tensor { strides_{DefaultStrides(shape)} {} template - Tensor(void* data, const Shape& shape, const DataType& dtype) + Tensor(void* data, const Shape& shape, const DataType dtype) : data_{data}, shape_{shape}, dtype_{dtype}, @@ -73,7 +73,7 @@ class Tensor { const void* const& data() const; - const DataType& dtype() const; + const DataType dtype() const; const Device& device() const; @@ -94,7 +94,7 @@ class Tensor { std::string ToString() const; private: - static const DataType& DefaultDataType(); + static const DataType DefaultDataType(); static Device DefaultDevice(); @@ -106,7 +106,7 @@ class Tensor { Shape shape_; - const DataType& dtype_; + const DataType dtype_; Device device_; From b0293588cf6720b8af361badf901664563a545e0 Mon Sep 17 00:00:00 2001 From: Ziminli Date: Thu, 12 Feb 2026 06:52:31 +0000 Subject: [PATCH 22/93] fix: fix dispatcher default to kCpu issue, various naming issues and further abstract `blas.h` - further abstract `blas.h`, backends now only do name change - fix various naming issues and small issues - combined the `gemm` example programs across the platforms, now only one program for all platforms --- .../{metax/gemm/mcblas.cc => gemm/gemm.cc} | 69 ++++++++------- examples/nvidia/gemm/cublas.cc | 77 ----------------- examples/runtime_api.h | 26 ++++++ src/base/gemm.h | 11 ++- src/common/constexpr_map.h | 14 ++-- src/common/traits.h | 14 ++-- src/cpu/gemm/gemm.h | 16 +--- src/cuda/gemm/blas.h | 19 +++-- src/data_type.h | 42 +++++++--- src/device.h | 84 ++++++++++--------- src/dispatcher.h | 64 +++++++++----- src/metax/gemm/mcblas.h | 33 ++++---- src/nvidia/gemm/cublas.h | 39 +++++---- src/operator.h | 2 +- src/tensor.cc | 4 +- src/tensor.h | 4 +- 16 files changed, 265 insertions(+), 253 deletions(-) rename examples/{metax/gemm/mcblas.cc => gemm/gemm.cc} (51%) delete mode 100644 examples/nvidia/gemm/cublas.cc create mode 100644 examples/runtime_api.h diff --git a/examples/metax/gemm/mcblas.cc b/examples/gemm/gemm.cc similarity index 51% rename from examples/metax/gemm/mcblas.cc rename to examples/gemm/gemm.cc index af60f4d..57a7a37 100644 --- a/examples/metax/gemm/mcblas.cc +++ b/examples/gemm/gemm.cc @@ -1,14 +1,16 @@ -#include "metax/gemm/mcblas.h" - -#include - #include #include +#include #ifdef USE_CPU #include "cpu/gemm/gemm.h" +#elif USE_NVIDIA +#include "nvidia/gemm/cublas.h" +#elif USE_METAX +#include "metax/gemm/mcblas.h" #endif +#include "../runtime_api.h" #include "tensor.h" int main() { @@ -36,38 +38,47 @@ int main() { std::iota(a_vec.begin(), a_vec.end(), 0); std::iota(b_vec.begin(), b_vec.end(), 0); - Tensor a_host{a_vec.data(), a_shape, Device{Device::Type::kMetax}}; - Tensor b_host{b_vec.data(), b_shape, Device{Device::Type::kMetax}}; - Tensor c_host{c_vec.data(), c_shape, Device{Device::Type::kMetax}}; + Device dev{DEFAULT_DEVICE_TYPE}; - Tensor a_device{nullptr, a_host.shape(), a_host.dtype(), a_host.device(), - a_host.strides()}; - Tensor b_device{nullptr, b_host.shape(), b_host.dtype(), a_host.device(), - b_host.strides()}; - Tensor c_device{nullptr, c_host.shape(), c_host.dtype(), a_host.device(), - c_host.strides()}; + Tensor a_host{a_vec.data(), a_shape, dev}; + Tensor b_host{b_vec.data(), b_shape, dev}; + Tensor c_host{c_vec.data(), c_shape, dev}; - const auto a_size{a_num_elements * kDataTypeToSize.at(a_device.dtype())}; - const auto b_size{b_num_elements * kDataTypeToSize.at(b_device.dtype())}; - const auto c_size{c_num_elements * kDataTypeToSize.at(c_device.dtype())}; + const auto a_size{a_num_elements * kDataTypeToSize.at(a_host.dtype())}; + const auto b_size{b_num_elements * kDataTypeToSize.at(b_host.dtype())}; + const auto c_size{c_num_elements * kDataTypeToSize.at(c_host.dtype())}; - mcMalloc(&a_device.data(), a_size); - mcMalloc(&b_device.data(), b_size); - mcMalloc(&c_device.data(), c_size); + void *a_ptr, *b_ptr, *c_ptr; - mcMemcpy(a_device.data(), a_vec.data(), a_size, mcMemcpyHostToDevice); - mcMemcpy(b_device.data(), b_vec.data(), b_size, mcMemcpyHostToDevice); - mcMemset(c_device.data(), 0, c_size); +#ifdef USE_CPU + a_ptr = a_vec.data(); + b_ptr = b_vec.data(); + c_ptr = c_vec.data(); +#else + DEVICE_MALLOC(&a_ptr, a_size); + DEVICE_MALLOC(&b_ptr, b_size); + DEVICE_MALLOC(&c_ptr, c_size); + + DEVICE_MEMCPY(a_ptr, a_vec.data(), a_size, DEVICE_MEMCPY_HOST_TO_DEVICE); + DEVICE_MEMCPY(b_ptr, b_vec.data(), b_size, DEVICE_MEMCPY_HOST_TO_DEVICE); + DEVICE_MEMSET(c_ptr, 0, c_size); +#endif - Gemm::call(nullptr, a_device, b_device, c_device); + Tensor a_device{a_ptr, a_host.shape(), a_host.dtype(), a_host.device(), + a_host.strides()}; + Tensor b_device{b_ptr, b_host.shape(), b_host.dtype(), a_host.device(), + b_host.strides()}; + Tensor c_device{c_ptr, c_host.shape(), c_host.dtype(), a_host.device(), + c_host.strides()}; - mcMemcpy(a_host.data(), a_device.data(), a_size, mcMemcpyDeviceToHost); - mcMemcpy(b_host.data(), b_device.data(), b_size, mcMemcpyDeviceToHost); - mcMemcpy(c_host.data(), c_device.data(), c_size, mcMemcpyDeviceToHost); + Gemm::call(nullptr, a_device, b_device, c_device); - mcFree(a_device.data()); - mcFree(b_device.data()); - mcFree(c_device.data()); +#ifndef USE_CPU + DEVICE_MEMCPY(c_vec.data(), c_ptr, c_size, DEVICE_MEMCPY_DEVICE_TO_HOST); + DEVICE_FREE(a_ptr); + DEVICE_FREE(b_ptr); + DEVICE_FREE(c_ptr); +#endif std::cout << "A: " << a_host.ToString() << "\n"; std::cout << "B: " << b_host.ToString() << "\n"; diff --git a/examples/nvidia/gemm/cublas.cc b/examples/nvidia/gemm/cublas.cc deleted file mode 100644 index 7885d3a..0000000 --- a/examples/nvidia/gemm/cublas.cc +++ /dev/null @@ -1,77 +0,0 @@ -#include "nvidia/gemm/cublas.h" - -#include - -#include -#include - -#ifdef USE_CPU -#include "cpu/gemm/gemm.h" -#endif - -#include "tensor.h" - -int main() { - using namespace infini::ops; - - constexpr auto m{2}; - constexpr auto k{3}; - constexpr auto n{4}; - - std::vector a_shape{m, k}; - std::vector b_shape{k, n}; - std::vector c_shape{m, n}; - - const auto a_num_elements{std::accumulate(a_shape.cbegin(), a_shape.cend(), 1, - std::multiplies())}; - const auto b_num_elements{std::accumulate(b_shape.cbegin(), b_shape.cend(), 1, - std::multiplies())}; - const auto c_num_elements{std::accumulate(c_shape.cbegin(), c_shape.cend(), 1, - std::multiplies())}; - - std::vector a_vec(a_num_elements); - std::vector b_vec(b_num_elements); - std::vector c_vec(c_num_elements); - - std::iota(a_vec.begin(), a_vec.end(), 0); - std::iota(b_vec.begin(), b_vec.end(), 0); - - Tensor a_host{a_vec.data(), a_shape, Device{Device::Type::kNvidia}}; - Tensor b_host{b_vec.data(), b_shape, Device{Device::Type::kNvidia}}; - Tensor c_host{c_vec.data(), c_shape, Device{Device::Type::kNvidia}}; - - Tensor a_device{nullptr, a_host.shape(), a_host.dtype(), a_host.device(), - a_host.strides()}; - Tensor b_device{nullptr, b_host.shape(), b_host.dtype(), a_host.device(), - b_host.strides()}; - Tensor c_device{nullptr, c_host.shape(), c_host.dtype(), a_host.device(), - c_host.strides()}; - - const auto a_size{a_num_elements * kDataTypeToSize.at(a_device.dtype())}; - const auto b_size{b_num_elements * kDataTypeToSize.at(b_device.dtype())}; - const auto c_size{c_num_elements * kDataTypeToSize.at(c_device.dtype())}; - - cudaMalloc(&a_device.data(), a_size); - cudaMalloc(&b_device.data(), b_size); - cudaMalloc(&c_device.data(), c_size); - - cudaMemcpy(a_device.data(), a_vec.data(), a_size, cudaMemcpyHostToDevice); - cudaMemcpy(b_device.data(), b_vec.data(), b_size, cudaMemcpyHostToDevice); - cudaMemset(c_device.data(), 0, c_size); - - Gemm::call(nullptr, a_device, b_device, c_device); - - cudaMemcpy(a_host.data(), a_device.data(), a_size, cudaMemcpyDeviceToHost); - cudaMemcpy(b_host.data(), b_device.data(), b_size, cudaMemcpyDeviceToHost); - cudaMemcpy(c_host.data(), c_device.data(), c_size, cudaMemcpyDeviceToHost); - - cudaFree(a_device.data()); - cudaFree(b_device.data()); - cudaFree(c_device.data()); - - std::cout << "A: " << a_host.ToString() << "\n"; - std::cout << "B: " << b_host.ToString() << "\n"; - std::cout << "C: " << c_host.ToString() << "\n"; - - return 0; -} diff --git a/examples/runtime_api.h b/examples/runtime_api.h new file mode 100644 index 0000000..4e1a8cf --- /dev/null +++ b/examples/runtime_api.h @@ -0,0 +1,26 @@ +#ifndef INFINI_OPS_EXAMPLES_RUNTIME_API_H_ +#define INFINI_OPS_EXAMPLES_RUNTIME_API_H_ + +#ifdef USE_NVIDIA +#include +#define DEVICE_MALLOC cudaMalloc +#define DEVICE_FREE cudaFree +#define DEVICE_MEMCPY cudaMemcpy +#define DEVICE_MEMSET cudaSet +#define DEVICE_MEMCPY_HOST_TO_DEVICE cudaMemcpyHostToDevice +#define DEVICE_MEMCPY_DEVICE_TO_HOST cudaMemcpyDeviceToHost +#define DEFAULT_DEVICE_TYPE Device::Type::kNvidia +#elif USE_METAX +#include +#define DEVICE_MALLOC mcMalloc +#define DEVICE_FREE mcFree +#define DEVICE_MEMCPY mcMemcpy +#define DEVICE_MEMSET mcMemset +#define DEVICE_MEMCPY_HOST_TO_DEVICE mcMemcpyHostToDevice +#define DEVICE_MEMCPY_DEVICE_TO_HOST mcMemcpyDeviceToHost +#define DEFAULT_DEVICE_TYPE Device::Type::kMetax +#elif USE_CPU +#define DEFAULT_DEVICE_TYPE Device::Type::kCpu +#endif + +#endif diff --git a/src/base/gemm.h b/src/base/gemm.h index ab96c92..e6ba2e5 100644 --- a/src/base/gemm.h +++ b/src/base/gemm.h @@ -24,7 +24,10 @@ class Gemm : public Operator { c_type_{c.dtype()}, a_strides_{a.strides()}, b_strides_{b.strides()}, - c_strides_{c.strides()} { + c_strides_{c.strides()}, + lda_{a_strides_[1]}, + ldb_{b_strides_[1]}, + ldc_{c_strides_[1]} { // TODO: Check constraints. } @@ -65,6 +68,12 @@ class Gemm : public Operator { Tensor::Strides b_strides_; Tensor::Strides c_strides_; + + Tensor::Stride lda_{0}; + + Tensor::Stride ldb_{0}; + + Tensor::Stride ldc_{0}; }; } // namespace infini::ops diff --git a/src/common/constexpr_map.h b/src/common/constexpr_map.h index 011d974..921744a 100644 --- a/src/common/constexpr_map.h +++ b/src/common/constexpr_map.h @@ -2,25 +2,29 @@ #define INFINI_OPS_COMMON_CONSTEXPR_MAP_H_ #include +#include #include -#include #include namespace infini::ops { template struct ConstexprMap { - std::array, N> data; + constexpr ConstexprMap(std::array, N> data) + : data_(data) {} constexpr Value at(Key key) const { - for (const auto &pr : data) { + for (const auto &pr : data_) { if (pr.first == key) return pr.second; } // TODO(lzm): change to logging - std::cerr << "ConstexprMap's key not found at " << __FILE__ << ":" - << __LINE__ << std::endl; + assert("ConstexprMap's key is not found!"); + // Unreachable, provided to satisfy the compiler's requirement std::abort(); } + + private: + std::array, N> data_; }; } // namespace infini::ops diff --git a/src/common/traits.h b/src/common/traits.h index 83de106..6f75e9f 100644 --- a/src/common/traits.h +++ b/src/common/traits.h @@ -15,7 +15,7 @@ struct List {}; // ----------------------------------------------------------------------------- // Check at compile-time if a Value exists within a construct (e.g., List<>). -// Example: static_assert(Contains_v); +// Example: static_assert(ContainsValue); template struct Contains; @@ -24,7 +24,7 @@ struct Contains, Value> : std::disjunction...> {}; template -inline constexpr bool Contains_v = Contains::value; +inline constexpr bool ContainsValue = Contains::value; // Check at compile-time if a type T is present in a variadic list of types Ts. // Example: static_assert(IsTypeInList); @@ -36,7 +36,7 @@ inline constexpr bool IsTypeInList = (std::is_same_v || ...); // ----------------------------------------------------------------------------- // Concatenates two List types into a single List. -// Example: Concat_t, List<3, 4>> is List<1, 2, 3, 4>. +// Example: ConcatType, List<3, 4>> is List<1, 2, 3, 4>. template struct Concat; @@ -46,7 +46,7 @@ struct Concat, List> { }; template -using Concat_t = typename Concat::type; +using ConcatType = typename Concat::type; // ----------------------------------------------------------------------------- // Invocability Detection (SFINAE) @@ -64,7 +64,7 @@ struct IsInvocable< Args...> : std::true_type {}; template -inline constexpr bool IsInvocable_v = +inline constexpr bool IsInvocableValue = IsInvocable::value; // ----------------------------------------------------------------------------- @@ -87,8 +87,8 @@ template struct Filter, List, Head, Tail...> { using type = typename std::conditional_t< - IsInvocable_v && - !Contains_v, Head>, + IsInvocableValue && + !ContainsValue, Head>, Filter, List, Tail...>, Filter, List, Tail...>>::type; }; diff --git a/src/cpu/gemm/gemm.h b/src/cpu/gemm/gemm.h index fbec039..32c123f 100644 --- a/src/cpu/gemm/gemm.h +++ b/src/cpu/gemm/gemm.h @@ -13,10 +13,7 @@ class Operator : public Gemm { Operator(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) - : Gemm{a, b, alpha, beta, trans_a, trans_b, c}, - lda_{a_strides_[1]}, - ldb_{b_strides_[1]}, - ldc_{c_strides_[1]} { + : Gemm{a, b, alpha, beta, trans_a, trans_b, c} { // TODO: Check constraints. } @@ -28,9 +25,9 @@ class Operator : public Gemm { std::optional alpha, std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) const override { - const float* A = static_cast(a.data()); - const float* B = static_cast(b.data()); - float* C = static_cast(c.data()); + const auto* A = static_cast(a.data()); + const auto* B = static_cast(b.data()); + auto* C = static_cast(c.data()); const auto& alpha_value{alpha.value_or(alpha_)}; const auto& beta_value{beta.value_or(beta_)}; @@ -52,11 +49,6 @@ class Operator : public Gemm { } } } - - private: - Tensor::Stride lda_{0}; - Tensor::Stride ldb_{0}; - Tensor::Stride ldc_{0}; }; } // namespace infini::ops diff --git a/src/cuda/gemm/blas.h b/src/cuda/gemm/blas.h index c3c89a4..7d8e229 100644 --- a/src/cuda/gemm/blas.h +++ b/src/cuda/gemm/blas.h @@ -23,6 +23,7 @@ class Blas : public Gemm { lda_{a_strides_[1]}, ldb_{b_strides_[1]}, ldc_{c_strides_[1]} { + Backend::blasCreate(&handle); // TODO: Check constraints. } @@ -33,9 +34,6 @@ class Blas : public Gemm { std::optional alpha, std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) const override { - typename Backend::blasHandle_t handle; - Backend::blasCreate(&handle); - Backend::blasSetStream(handle, static_cast(stream)); @@ -48,9 +46,16 @@ class Blas : public Gemm { c_type_ == DataType::kFloat32 && "`operator()` not implemented for this data type"); - Backend::blasGemmEx(handle, trans_a_value, trans_b_value, m_, n_, k_, - &alpha_value, b.data(), lda_, a.data(), ldb_, - &beta_value, c.data(), ldc_); + auto op_a = static_cast( + trans_a_value ? Backend::BLAS_OP_T : Backend::BLAS_OP_N); + auto op_b = static_cast( + trans_b_value ? Backend::BLAS_OP_T : Backend::BLAS_OP_N); + + Backend::blasGemmEx(handle, op_a, op_b, m_, n_, k_, &alpha_value, b.data(), + Backend::R_32F, lda_, a.data(), Backend::R_32F, ldb_, + &beta_value, c.data(), Backend::R_32F, ldc_, + Backend::BLAS_COMPUTE_32F_FAST_TF32, + Backend::BLAS_GEMM_DEFAULT); Backend::blasDestroy(handle); } @@ -59,6 +64,8 @@ class Blas : public Gemm { Tensor::Stride lda_{0}; Tensor::Stride ldb_{0}; Tensor::Stride ldc_{0}; + + typename Backend::blasHandle_t handle; }; } // namespace infini::ops diff --git a/src/data_type.h b/src/data_type.h index 0a2a248..1ccfce3 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -9,7 +9,7 @@ namespace infini::ops { -enum class DataType : int8_t { +enum class DataType : std::int8_t { kInt8, kInt16, kInt32, @@ -54,23 +54,39 @@ constexpr ConstexprMap kDataTypeToDesc{{{ {DataType::kFloat64, "float64"}, }}}; -template +constexpr ConstexprMap kStringToDataType{{{ + {"int8", DataType::kInt8}, + {"int16", DataType::kInt16}, + {"int32", DataType::kInt32}, + {"int64", DataType::kInt64}, + {"uint8", DataType::kUInt8}, + {"uint16", DataType::kUInt16}, + {"uint32", DataType::kUInt32}, + {"uint64", DataType::kUInt64}, + {"float16", DataType::kFloat16}, + {"bfloat16", DataType::kBFloat16}, + {"float32", DataType::kFloat32}, + {"float64", DataType::kFloat64}, +}}}; + +template struct TypeMap; -template -using TypeMap_t = typename TypeMap::type; +template +using TypeMapType = typename TypeMap::type; template struct DataTypeMap; template -inline constexpr DataType DataTypeMap_v = DataTypeMap::value; +inline constexpr DataType DataTypeMapValue = DataTypeMap::value; #define DEFINE_DATA_TYPE_MAPPING(ENUM_VALUE, CPP_TYPE) \ template <> \ struct TypeMap { \ using type = CPP_TYPE; \ }; \ + \ template <> \ struct DataTypeMap { \ static constexpr DataType value = DataType::ENUM_VALUE; \ @@ -89,12 +105,12 @@ DEFINE_DATA_TYPE_MAPPING(kFloat64, double) // TODO(lzm): support fp16 and bf16 // Defines the common categories of data types using List -using FloatingTypes = List; -using ReducedFloatingTypes = List; -using SignedIntegralTypes = +using FloatTypes = List; +using ReducedFloatTypes = List; +using IntTypes = List; -using UnsignedIntegralTypes = List; +using UIntTypes = List; using BitTypes8 = List; using BitTypes16 = List; -using AllFloatingTypes = Concat_t; -using AllIntegralTypes = Concat_t; -using AllTypes = Concat_t; +using AllFloatingTypes = ConcatType; +using AllIntegralTypes = ConcatType; +using AllTypes = ConcatType; } // namespace infini::ops diff --git a/src/device.h b/src/device.h index c07b53c..83b4faa 100644 --- a/src/device.h +++ b/src/device.h @@ -41,32 +41,34 @@ class Device { private: Type type_{Type::kCpu}; - static constexpr ConstexprMap + static constexpr ConstexprMap(Device::Type::kCount)> kDeviceToDesc{{{ - {Device::Type::kCpu, "CPU"}, - {Device::Type::kNvidia, "NVIDIA"}, - {Device::Type::kCambricon, "Cambricon"}, - {Device::Type::kAscend, "Ascend"}, - {Device::Type::kMetax, "Metax"}, - {Device::Type::kMoore, "Moore"}, - {Device::Type::kIluvatar, "Iluvatar"}, - {Device::Type::kKunlun, "Kunlun"}, - {Device::Type::kHygon, "Hygon"}, - {Device::Type::kQy, "QY"}, + {Type::kCpu, "CPU"}, + {Type::kNvidia, "NVIDIA"}, + {Type::kCambricon, "Cambricon"}, + {Type::kAscend, "Ascend"}, + {Type::kMetax, "Metax"}, + {Type::kMoore, "Moore"}, + {Type::kIluvatar, "Iluvatar"}, + {Type::kKunlun, "Kunlun"}, + {Type::kHygon, "Hygon"}, + {Type::kQy, "QY"}, }}}; - static constexpr ConstexprMap + static constexpr ConstexprMap(Device::Type::kCount)> kDescToDevice{{{ - {"CPU", Device::Type::kCpu}, - {"NVIDIA", Device::Type::kNvidia}, - {"Cambricon", Device::Type::kCambricon}, - {"Ascend", Device::Type::kAscend}, - {"Metax", Device::Type::kMetax}, - {"Moore", Device::Type::kMoore}, - {"Iluvatar", Device::Type::kIluvatar}, - {"Kunlun", Device::Type::kKunlun}, - {"Hygon", Device::Type::kHygon}, - {"QY", Device::Type::kQy}, + {"CPU", Type::kCpu}, + {"NVIDIA", Type::kNvidia}, + {"Cambricon", Type::kCambricon}, + {"Ascend", Type::kAscend}, + {"Metax", Type::kMetax}, + {"Moore", Type::kMoore}, + {"Iluvatar", Type::kIluvatar}, + {"Kunlun", Type::kKunlun}, + {"Hygon", Type::kHygon}, + {"QY", Type::kQy}, }}}; int index_{0}; @@ -78,60 +80,62 @@ struct EnabledDeviceFilter { // and FilterList will exclude it from ActiveDevices. #ifdef USE_CPU - template = 0> + template = 0> void operator()() const {} #endif #ifdef USE_NVIDIA - template = 0> + template = 0> void operator()() const {} #endif #ifdef USE_CAMBRICON - template = 0> + template = 0> void operator()() const {} #endif #ifdef USE_ASCEND - template = 0> + template = 0> void operator()() const {} #endif #ifdef USE_METAX - template = 0> + template = 0> void operator()() const {} #endif #ifdef USE_MOORE - template = 0> + template = 0> void operator()() const {} #endif #ifdef USE_ILUVATAR - template = 0> + template = 0> void operator()() const {} #endif #ifdef USE_KUNLUN - template = 0> + template = 0> void operator()() const {} #endif #ifdef USE_HYGON - template = 0> + template = 0> void operator()() const {} #endif #ifdef USE_QY - template = 0> + template = 0> void operator()() const {} #endif }; diff --git a/src/dispatcher.h b/src/dispatcher.h index a2a0aa1..e95ec47 100644 --- a/src/dispatcher.h +++ b/src/dispatcher.h @@ -1,6 +1,7 @@ #ifndef INFINI_OPS_DISPATCHER_H_ #define INFINI_OPS_DISPATCHER_H_ +#include #include #include #include @@ -23,39 +24,60 @@ auto DispatchFunc(ValueType value, Functor&& func, using FilteredPack = typename Filter, List<>, AllValues...>::type; - return [&](List) { + return [&](List) { using ReturnType = decltype(std::forward(func) - .template operator()(0)>( + .template operator()(Head)>( std::forward(args)...)); - bool handled = false; - + // Path for Void Functions if constexpr (std::is_void_v) { - handled = - ((value == static_cast(Pruned) - ? (std::forward(func).template operator()( + bool handled = + ((value == static_cast(Tail) + ? (std::forward(func).template operator()( std::forward(args)...), true) : false) || - ...); - } else { + ... || + (value == static_cast(Head) + ? (std::forward(func).template operator()( + std::forward(args)...), + true) + : false)); + + if (!handled) { + std::cerr << "Dispatch error (void): Value " << static_cast(value) + << " not supported in context: " << context_str << "\n"; + std::abort(); + } + } + // Path for Non-Void Functions + else { std::optional result; - handled = - ((value == static_cast(Pruned) + bool handled = + ((value == static_cast(Tail) ? (result.emplace( - std::forward(func).template operator()( + std::forward(func).template operator()( std::forward(args)...)), true) : false) || - ...); - return *result; - } - if (!handled) { + ... || + (value == static_cast(Head) + ? (result.emplace( + std::forward(func).template operator()( + std::forward(args)...)), + true) + : false)); + + if (handled) { + return *result; + } // TODO(lzm): change to logging - std::cerr << "Dispatch error: Value " << static_cast(value) - << " not supported in the context: " << context_str << "\n"; + std::cerr << "Dispatch error (non-void): Value " + << static_cast(value) + << " not supported in context: " << context_str << "\n"; std::abort(); + return ReturnType{}; } }(FilteredPack{}); } @@ -79,7 +101,7 @@ auto DispatchFunc(const std::vector& values, size_t index, Args&&... args) { return [&](List) { static_assert(sizeof...(Allowed) > 0, - "DispatchFunc dimension list is empty!"); + "`DispatchFunc` dimension list is empty"); using EnumType = std::common_type_t; return DispatchFunc( @@ -105,7 +127,7 @@ auto DispatchFunc(DataType dtype, Functor&& func, return DispatchFunc( dtype, [&](Args&&... inner_args) { - using T = TypeMap_t
; + using T = TypeMapType
; return std::forward(func).template operator()( std::forward(inner_args)...); }, @@ -123,7 +145,7 @@ auto DispatchFunc(std::initializer_list dtypes, Functor&& func, v, 0, [&func](Args&&... inner_args) { return std::forward(func).template - operator()...>(std::forward(inner_args)...); + operator()...>(std::forward(inner_args)...); }, context_str, List<>{}, std::forward(args)...); } diff --git a/src/metax/gemm/mcblas.h b/src/metax/gemm/mcblas.h index cc1f081..9c6bb3e 100644 --- a/src/metax/gemm/mcblas.h +++ b/src/metax/gemm/mcblas.h @@ -15,23 +15,22 @@ struct MetaxBackend { using blasHandle_t = mcblasHandle_t; using stream_t = mcStream_t; - static void blasCreate(blasHandle_t* handle) { mcblasCreate(handle); } - - static void blasSetStream(blasHandle_t handle, stream_t stream) { - mcblasSetStream(handle, stream); - } - - static void blasGemmEx(blasHandle_t handle, bool transA, bool transB, int m, - int n, int k, const float* alpha, const void* B, - int ldb, const void* A, int lda, const float* beta, - void* C, int ldc) { - mcblasGemmEx(handle, transA ? MCBLAS_OP_T : MCBLAS_OP_N, - transB ? MCBLAS_OP_T : MCBLAS_OP_N, m, n, k, alpha, B, - MACA_R_32F, ldb, A, MACA_R_32F, lda, beta, C, MACA_R_32F, ldc, - MCBLAS_COMPUTE_32F_FAST_TF32, MCBLAS_GEMM_DEFAULT); - } - - static void blasDestroy(blasHandle_t handle) { mcblasDestroy(handle); } + static constexpr auto BLAS_OP_N = MCBLAS_OP_N; + static constexpr auto BLAS_OP_T = MCBLAS_OP_T; + static constexpr auto R_32F = MACA_R_32F; + static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = + MCBLAS_COMPUTE_32F_FAST_TF32; + static constexpr auto BLAS_GEMM_DEFAULT = MCBLAS_GEMM_DEFAULT; + + static constexpr auto blasCreate = mcblasCreate; + static constexpr auto blasSetStream = mcblasSetStream; + static constexpr auto blasDestroy = mcblasDestroy; + + static constexpr mcblasStatus_t (*blasGemmEx)( + mcblasHandle_t, mcblasOperation_t, mcblasOperation_t, int, int, int, + const void*, const void*, macaDataType_t, int, const void*, + macaDataType_t, int, const void*, void*, macaDataType_t, int, + mcblasComputeType_t, mcblasGemmAlgo_t) = mcblasGemmEx; }; template <> diff --git a/src/nvidia/gemm/cublas.h b/src/nvidia/gemm/cublas.h index 9f7b8b6..6b4b016 100644 --- a/src/nvidia/gemm/cublas.h +++ b/src/nvidia/gemm/cublas.h @@ -11,33 +11,32 @@ namespace infini::ops { -struct CudaBackend { +struct NvidiaBackend { using blasHandle_t = cublasHandle_t; using stream_t = cudaStream_t; - static void blasCreate(blasHandle_t* handle) { cublasCreate(handle); } - - static void blasSetStream(blasHandle_t handle, stream_t stream) { - cublasSetStream(handle, stream); - } - - static void blasGemmEx(blasHandle_t handle, bool transA, bool transB, int m, - int n, int k, const float* alpha, const void* B, - int ldb, const void* A, int lda, const float* beta, - void* C, int ldc) { - cublasGemmEx(handle, transA ? CUBLAS_OP_T : CUBLAS_OP_N, - transB ? CUBLAS_OP_T : CUBLAS_OP_N, m, n, k, alpha, B, - CUDA_R_32F, ldb, A, CUDA_R_32F, lda, beta, C, CUDA_R_32F, ldc, - CUBLAS_COMPUTE_32F_FAST_TF32, CUBLAS_GEMM_DEFAULT); - } - - static void blasDestroy(blasHandle_t handle) { cublasDestroy(handle); } + static constexpr auto BLAS_OP_N = CUBLAS_OP_N; + static constexpr auto BLAS_OP_T = CUBLAS_OP_T; + static constexpr auto R_32F = CUDA_R_32F; + static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = + CUBLAS_COMPUTE_32F_FAST_TF32; + static constexpr auto BLAS_GEMM_DEFAULT = CUBLAS_GEMM_DEFAULT; + + static constexpr auto blasCreate = cublasCreate; + static constexpr auto blasSetStream = cublasSetStream; + static constexpr auto blasDestroy = cublasDestroy; + + static constexpr cublasStatus_t (*blasGemmEx)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const void*, const void*, cudaDataType_t, int, const void*, + cudaDataType_t, int, const void*, void*, cudaDataType_t, int, + cublasComputeType_t, cublasGemmAlgo_t) = cublasGemmEx; }; template <> -class Operator : public Blas { +class Operator : public Blas { public: - using Blas::Blas; + using Blas::Blas; }; } // namespace infini::ops diff --git a/src/operator.h b/src/operator.h index 33f8a42..8377351 100644 --- a/src/operator.h +++ b/src/operator.h @@ -32,7 +32,7 @@ class Operator : public OperatorBase { op_ptr = std::make_unique>( tensor, std::forward(args)...); }, - "Operator make"); + "Operator::make"); return op_ptr; } diff --git a/src/tensor.cc b/src/tensor.cc index 3a8d70b..9f2161e 100644 --- a/src/tensor.cc +++ b/src/tensor.cc @@ -26,7 +26,7 @@ const void* const& Tensor::data() const { return data_; } const Tensor::Shape& Tensor::shape() const { return shape_; } -const DataType Tensor::dtype() const { return dtype_; } +const DataType& Tensor::dtype() const { return dtype_; } const Device& Tensor::device() const { return device_; } @@ -77,7 +77,7 @@ Tensor::Strides Tensor::DefaultStrides(const Shape& shape) { std::string Tensor::ToStringHelper() const { if (ndim() == 0) { - return DispatchFunc( + return DispatchFunc( dtype_, [&]() { return std::to_string(*static_cast(data_)); }, "ToStringHelper"); diff --git a/src/tensor.h b/src/tensor.h index 743f908..0feb4c4 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -31,7 +31,7 @@ class Tensor { strides_{DefaultStrides(shape)} {} template - Tensor(void* data, const Shape& shape, const DataType dtype) + Tensor(void* data, const Shape& shape, const DataType& dtype) : data_{data}, shape_{shape}, dtype_{dtype}, @@ -73,7 +73,7 @@ class Tensor { const void* const& data() const; - const DataType dtype() const; + const DataType& dtype() const; const Device& device() const; From 3ab24bd499f73536942371d27861c9fdc4aa9b4f Mon Sep 17 00:00:00 2001 From: Ziminli Date: Thu, 12 Feb 2026 07:31:27 +0000 Subject: [PATCH 23/93] refactor: further simplify `blasGemmEx()`, unify comment formatting and example header file inclusion --- examples/gemm/gemm.cc | 2 +- src/common/constexpr_map.h | 4 ++-- src/data_type.h | 10 +++++----- src/device.h | 2 +- src/dispatcher.h | 10 +++++----- src/metax/gemm/mcblas.h | 8 +++----- src/nvidia/gemm/cublas.h | 8 +++----- 7 files changed, 20 insertions(+), 24 deletions(-) diff --git a/examples/gemm/gemm.cc b/examples/gemm/gemm.cc index 57a7a37..e051a8f 100644 --- a/examples/gemm/gemm.cc +++ b/examples/gemm/gemm.cc @@ -10,7 +10,7 @@ #include "metax/gemm/mcblas.h" #endif -#include "../runtime_api.h" +#include "runtime_api.h" #include "tensor.h" int main() { diff --git a/src/common/constexpr_map.h b/src/common/constexpr_map.h index 921744a..3db1275 100644 --- a/src/common/constexpr_map.h +++ b/src/common/constexpr_map.h @@ -17,9 +17,9 @@ struct ConstexprMap { for (const auto &pr : data_) { if (pr.first == key) return pr.second; } - // TODO(lzm): change to logging + // TODO(lzm): change to logging. assert("ConstexprMap's key is not found!"); - // Unreachable, provided to satisfy the compiler's requirement + // Unreachable, provided to satisfy the compiler's requirement. std::abort(); } diff --git a/src/data_type.h b/src/data_type.h index 1ccfce3..b9e2fea 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -102,9 +102,9 @@ DEFINE_DATA_TYPE_MAPPING(kUInt64, uint64_t) DEFINE_DATA_TYPE_MAPPING(kInt64, int64_t) DEFINE_DATA_TYPE_MAPPING(kFloat32, float) DEFINE_DATA_TYPE_MAPPING(kFloat64, double) -// TODO(lzm): support fp16 and bf16 +// TODO(lzm): Support fp16 and bf16. -// Defines the common categories of data types using List +// Defines the common categories of data types using List. using FloatTypes = List; using ReducedFloatTypes = List; using IntTypes = @@ -120,9 +120,9 @@ using BitTypes32 = using BitTypes64 = List; -using AllFloatingTypes = ConcatType; -using AllIntegralTypes = ConcatType; -using AllTypes = ConcatType; +using AllFloatTypes = ConcatType; +using AllIntTypes = ConcatType; +using AllTypes = ConcatType; } // namespace infini::ops diff --git a/src/device.h b/src/device.h index 83b4faa..cc5ab31 100644 --- a/src/device.h +++ b/src/device.h @@ -140,7 +140,7 @@ struct EnabledDeviceFilter { #endif }; -// Defines the common categories of devices using List +// Defines the common categories of devices using List. using AllDeviceTypes = List(value) << " not supported in context: " << context_str << "\n"; @@ -84,7 +84,7 @@ auto DispatchFunc(ValueType value, Functor&& func, // (Multi-Dispatch) Dispatches a vector of runtime values to a compile-time // functor. -// Base Case: All dimensions resolved +// Base Case: All dimensions resolved. template auto DispatchFunc(const std::vector& values, size_t index, Functor&& func, std::string_view context_str, List, @@ -121,10 +121,10 @@ auto DispatchFunc(const std::vector& values, size_t index, // These provide cleaner and more convenient APIs for common InfiniOps types. // DataType Dispatch -template +template auto DispatchFunc(DataType dtype, Functor&& func, std::string_view context_str = "", Args&&... args) { - return DispatchFunc( + return DispatchFunc( dtype, [&](Args&&... inner_args) { using T = TypeMapType
; @@ -179,7 +179,7 @@ auto DispatchFunc(std::initializer_list devices, Functor&& func, context_str, List<>{}, std::forward(args)...); } -// Interface for generic List Aliases, which unpacks a list +// Interface for generic List Aliases, which unpacks a list. template auto DispatchFunc(ValueType value, Functor&& func, diff --git a/src/metax/gemm/mcblas.h b/src/metax/gemm/mcblas.h index 9c6bb3e..659d5fc 100644 --- a/src/metax/gemm/mcblas.h +++ b/src/metax/gemm/mcblas.h @@ -26,11 +26,9 @@ struct MetaxBackend { static constexpr auto blasSetStream = mcblasSetStream; static constexpr auto blasDestroy = mcblasDestroy; - static constexpr mcblasStatus_t (*blasGemmEx)( - mcblasHandle_t, mcblasOperation_t, mcblasOperation_t, int, int, int, - const void*, const void*, macaDataType_t, int, const void*, - macaDataType_t, int, const void*, void*, macaDataType_t, int, - mcblasComputeType_t, mcblasGemmAlgo_t) = mcblasGemmEx; + static constexpr auto blasGemmEx = [](auto&&... args) { + return mcblasGemmEx(std::forward(args)...); + }; }; template <> diff --git a/src/nvidia/gemm/cublas.h b/src/nvidia/gemm/cublas.h index 6b4b016..03a585c 100644 --- a/src/nvidia/gemm/cublas.h +++ b/src/nvidia/gemm/cublas.h @@ -26,11 +26,9 @@ struct NvidiaBackend { static constexpr auto blasSetStream = cublasSetStream; static constexpr auto blasDestroy = cublasDestroy; - static constexpr cublasStatus_t (*blasGemmEx)( - cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, - const void*, const void*, cudaDataType_t, int, const void*, - cudaDataType_t, int, const void*, void*, cudaDataType_t, int, - cublasComputeType_t, cublasGemmAlgo_t) = cublasGemmEx; + static constexpr auto blasGemmEx = [](auto&&... args) { + return cublasGemmEx(std::forward(args)...); + }; }; template <> From 64ce1845af51d93949fa24ae44b8fec9f8f375bd Mon Sep 17 00:00:00 2001 From: Ziminli Date: Thu, 12 Feb 2026 07:52:22 +0000 Subject: [PATCH 24/93] fix: fix the typo for `cudaMemset()` in `runtime_api.h` --- examples/runtime_api.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/runtime_api.h b/examples/runtime_api.h index 4e1a8cf..e71386d 100644 --- a/examples/runtime_api.h +++ b/examples/runtime_api.h @@ -6,7 +6,7 @@ #define DEVICE_MALLOC cudaMalloc #define DEVICE_FREE cudaFree #define DEVICE_MEMCPY cudaMemcpy -#define DEVICE_MEMSET cudaSet +#define DEVICE_MEMSET cudaMemset #define DEVICE_MEMCPY_HOST_TO_DEVICE cudaMemcpyHostToDevice #define DEVICE_MEMCPY_DEVICE_TO_HOST cudaMemcpyDeviceToHost #define DEFAULT_DEVICE_TYPE Device::Type::kNvidia From 3df4832498858069d8dc8eac655ce626367ed29e Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 12 Feb 2026 16:14:38 +0800 Subject: [PATCH 25/93] feat: add `Device::ToString` --- src/device.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/device.h b/src/device.h index cc5ab31..70146bc 100644 --- a/src/device.h +++ b/src/device.h @@ -38,6 +38,10 @@ class Device { const int& index() const { return index_; } + std::string ToString() const { + return std::string{StringFromType(type_)} + ":" + std::to_string(index_); + } + private: Type type_{Type::kCpu}; From ecf030e02f55709de65bdb13bf81502e61bdbc6f Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 12 Feb 2026 16:15:17 +0800 Subject: [PATCH 26/93] feat: use `Device::ToString` in `Tensor::ToString` --- src/tensor.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tensor.cc b/src/tensor.cc index 9f2161e..2c471fb 100644 --- a/src/tensor.cc +++ b/src/tensor.cc @@ -52,7 +52,8 @@ Tensor Tensor::T() const { std::string Tensor::ToString() const { return "tensor(" + ToStringHelper() + - ", dtype=" + std::string(kDataTypeToDesc.at(dtype_)) + ")"; + ", dtype=" + std::string(kDataTypeToDesc.at(dtype_)) + ", device='" + + device_.ToString() + "')"; } const DataType Tensor::DefaultDataType() { return DataType::kFloat32; } From 1f871cbebf1006c0f553d39796c451fbbab3e18c Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 12 Feb 2026 16:18:10 +0800 Subject: [PATCH 27/93] feat: use lowercase words in `Device::kDeviceToDesc` and `kDescToDevice` --- src/device.h | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/device.h b/src/device.h index 70146bc..cceab1e 100644 --- a/src/device.h +++ b/src/device.h @@ -48,31 +48,31 @@ class Device { static constexpr ConstexprMap(Device::Type::kCount)> kDeviceToDesc{{{ - {Type::kCpu, "CPU"}, - {Type::kNvidia, "NVIDIA"}, - {Type::kCambricon, "Cambricon"}, - {Type::kAscend, "Ascend"}, - {Type::kMetax, "Metax"}, - {Type::kMoore, "Moore"}, - {Type::kIluvatar, "Iluvatar"}, - {Type::kKunlun, "Kunlun"}, - {Type::kHygon, "Hygon"}, - {Type::kQy, "QY"}, + {Type::kCpu, "cpu"}, + {Type::kNvidia, "nvidia"}, + {Type::kCambricon, "cambricon"}, + {Type::kAscend, "ascend"}, + {Type::kMetax, "metax"}, + {Type::kMoore, "moore"}, + {Type::kIluvatar, "iluvatar"}, + {Type::kKunlun, "kunlun"}, + {Type::kHygon, "hygon"}, + {Type::kQy, "qy"}, }}}; static constexpr ConstexprMap(Device::Type::kCount)> kDescToDevice{{{ - {"CPU", Type::kCpu}, - {"NVIDIA", Type::kNvidia}, - {"Cambricon", Type::kCambricon}, - {"Ascend", Type::kAscend}, - {"Metax", Type::kMetax}, - {"Moore", Type::kMoore}, - {"Iluvatar", Type::kIluvatar}, - {"Kunlun", Type::kKunlun}, - {"Hygon", Type::kHygon}, - {"QY", Type::kQy}, + {"cpu", Type::kCpu}, + {"nvidia", Type::kNvidia}, + {"cambricon", Type::kCambricon}, + {"ascend", Type::kAscend}, + {"metax", Type::kMetax}, + {"moore", Type::kMoore}, + {"iluvatar", Type::kIluvatar}, + {"kunlun", Type::kKunlun}, + {"hygon", Type::kHygon}, + {"qy", Type::kQy}, }}}; int index_{0}; From 9de33b38c4beeedde6d1527de721f43eb130f9d0 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 12 Feb 2026 17:05:05 +0800 Subject: [PATCH 28/93] fix: update `scripts/generate_wrappers.py` to adapt to the latest changes --- scripts/generate_wrappers.py | 41 +++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 9fe5cc9..ad01181 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -71,7 +71,7 @@ def _generate_arguments(node): ) def _generate_tensor_caster(name): - return f'Tensor{{reinterpret_cast({name}.attr("data_ptr")().cast()), {name}.attr("shape").cast(), DataType::FromString(py::str({name}.attr("dtype")).attr("split")(".").attr("__getitem__")(-1).cast()), Device{{Device::TypeFromString({name}.attr("device").attr("type").cast()), {name}.attr("device").attr("index").is_none() ? 0 : {name}.attr("device").attr("index").cast()}}, {name}.attr("stride")().cast()}}' + return f'Tensor{{reinterpret_cast({name}.attr("data_ptr")().cast()), {name}.attr("shape").cast(), DataTypeFromString(py::str({name}.attr("dtype")).attr("split")(".").attr("__getitem__")(-1).cast()), Device{{DeviceTypeFromString({name}.attr("device").attr("type").cast()), {name}.attr("device").attr("index").is_none() ? 0 : {name}.attr("device").attr("index").cast()}}, {name}.attr("stride")().cast()}}' op_name = operator.name @@ -104,12 +104,51 @@ def _generate_call(call, method=True): #include #include +#include + #include "base/{op_name.lower()}.h" namespace py = pybind11; namespace infini::ops {{ +inline DataType DataTypeFromString(const std::string& name) {{ + return kStringToDataType.at(name); +}} + +inline Device::Type DeviceTypeFromString(const std::string& name) {{ + static const std::unordered_map kTorchNameToTypes{{ + {{"cpu", Device::Type::kCpu}}, +#ifdef USE_NVIDIA + {{"cuda", Device::Type::kNvidia}}, +#endif +#ifdef USE_METAX + {{"cuda", Device::Type::kMetax}}, +#endif +#ifdef USE_ILUVATAR + {{"cuda", Device::Type::kIluvatar}}, +#endif +#ifdef USE_KUNLUN + {{"cuda", Device::Type::kKunlun}}, +#endif +#ifdef USE_HYGON + {{"cuda", Device::Type::kHygon}}, +#endif +#ifdef USE_QY + {{"cuda", Device::Type::kQy}}, +#endif + {{"mlu", Device::Type::kCambricon}}, {{"npu", Device::Type::kAscend}}, + {{"musa", Device::Type::kMoore}}}}; + + auto it{{kTorchNameToTypes.find(name)}}; + + if (it != kTorchNameToTypes.cend()) {{ + return it->second; + }} + + return Device::TypeFromString(name); +}} + void Bind{op_name}(py::module& m) {{ using Self = {op_name}; From 41af1dc100b4d88f732f0771bbf2ea640b92475f Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 12 Feb 2026 18:23:57 +0800 Subject: [PATCH 29/93] feat: add support for legacy c code generation --- scripts/generate_wrappers.py | 192 ++++++++++++++++++++++++++++++++++- src/base/gemm.h | 9 ++ src/cuda/gemm/blas.h | 4 + 3 files changed, 203 insertions(+), 2 deletions(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index ad01181..ca5dddf 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -9,6 +9,10 @@ _BINDINGS_DIR = _GENERATION_DIR / "bindings" +_SRC_DIR = _GENERATION_DIR / "src" + +_INCLUDE_DIR = _GENERATION_DIR / "include" + _INDENTATION = " " @@ -165,8 +169,184 @@ def _generate_call(call, method=True): """ +def _generate_legacy_c(operator, paths): + def _generate_source(operator): + impl_includes = "\n".join( + f'#include "{path.removeprefix("src/")}"' for path in paths + ) + + return f"""#include "../../handle.h" +#include "../../tensor.h" +#include "infiniop/ops/{operator.name.lower()}.h" +{impl_includes} + +static infini::ops::DataType DataTypeFromInfiniDType( + const infiniDtype_t& dtype) {{ + static constexpr infini::ops::ConstexprMap + kInfiniDTypeToDataType{{ + {{{{{{INFINI_DTYPE_I8, infini::ops::DataType::kInt8}}, + {{INFINI_DTYPE_I16, infini::ops::DataType::kInt16}}, + {{INFINI_DTYPE_I32, infini::ops::DataType::kInt32}}, + {{INFINI_DTYPE_I64, infini::ops::DataType::kInt64}}, + {{INFINI_DTYPE_U8, infini::ops::DataType::kUInt8}}, + {{INFINI_DTYPE_U16, infini::ops::DataType::kUInt16}}, + {{INFINI_DTYPE_U32, infini::ops::DataType::kUInt32}}, + {{INFINI_DTYPE_U64, infini::ops::DataType::kUInt64}}, + {{INFINI_DTYPE_F16, infini::ops::DataType::kFloat16}}, + {{INFINI_DTYPE_BF16, infini::ops::DataType::kBFloat16}}, + {{INFINI_DTYPE_F32, infini::ops::DataType::kFloat32}}, + {{INFINI_DTYPE_F64, infini::ops::DataType::kFloat64}}}}}}}}; + + return kInfiniDTypeToDataType.at(dtype); +}} + +static infini::ops::Device::Type DeviceTypeFromInfiniDevice( + const infiniDevice_t& device) {{ + static constexpr infini::ops::ConstexprMap< + infiniDevice_t, infini::ops::Device::Type, + static_cast(INFINI_DEVICE_TYPE_COUNT)> + kInfiniDeviceToDeviceType{{ + {{{{{{INFINI_DEVICE_CPU, infini::ops::Device::Type::kCpu}}, + {{INFINI_DEVICE_NVIDIA, infini::ops::Device::Type::kNvidia}}, + {{INFINI_DEVICE_CAMBRICON, infini::ops::Device::Type::kCambricon}}, + {{INFINI_DEVICE_ASCEND, infini::ops::Device::Type::kAscend}}, + {{INFINI_DEVICE_METAX, infini::ops::Device::Type::kMetax}}, + {{INFINI_DEVICE_MOORE, infini::ops::Device::Type::kMoore}}, + {{INFINI_DEVICE_ILUVATAR, infini::ops::Device::Type::kIluvatar}}, + {{INFINI_DEVICE_KUNLUN, infini::ops::Device::Type::kKunlun}}, + {{INFINI_DEVICE_HYGON, infini::ops::Device::Type::kHygon}}, + {{INFINI_DEVICE_QY, infini::ops::Device::Type::kQy}}}}}}}}; + + return kInfiniDeviceToDeviceType.at(device); +}} + +__C {_generate_create_func_def(operator)} + +__C {_generate_get_workspace_size_func_def(operator)} + +__C {_generate_call_func_def(operator)} + +__C {_generate_destroy_func_def(operator)} +""" + + def _generate_header(operator): + return f"""#ifndef __INFINIOP_{operator.name.upper()}_API_H__ +#define __INFINIOP_{operator.name.upper()}_API_H__ + +#include "base/{operator.name.lower()}.h" + +typedef struct infini::ops::Operator *infiniop{operator.name}Descriptor_t; + +__C __export {_generate_create_func_decl(operator)}; + +__C __export {_generate_get_workspace_size_func_decl(operator)}; + +__C __export {_generate_call_func_decl(operator)}; + +__C __export {_generate_destroy_func_decl(operator)}; + +#endif +""" + + def _generate_create_func_def(operator): + name = operator.name + constructor = operator.constructors[-1] + + return f"""{_generate_create_func_decl(operator)} {{ + *desc_ptr = infini::ops::Operator::make({_generate_arguments(constructor)}).release(); + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_get_workspace_size_func_def(operator): + return f"""{_generate_get_workspace_size_func_decl(operator)} {{ + *size = 0; // desc->workspace_size(); + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_call_func_def(operator): + call = operator.calls[-1] + + return f"""{_generate_call_func_decl(operator)} {{ + (*desc)(stream, {_generate_arguments(call, is_data=True)}); + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_destroy_func_def(operator): + return f"""{_generate_destroy_func_decl(operator)} {{ + delete desc; + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_create_func_decl(operator): + name = operator.name + constructor = operator.constructors[-1] + params = _generate_params(constructor) + + return f"infiniStatus_t infiniopCreate{name}Descriptor(infiniopHandle_t handle, infiniop{name}Descriptor_t *desc_ptr, {params})" + + def _generate_get_workspace_size_func_decl(operator): + name = operator.name + + return f"infiniStatus_t infiniopGet{name}WorkspaceSize(infiniop{name}Descriptor_t desc, size_t *size)" + + def _generate_call_func_decl(operator): + name = operator.name + call = operator.calls[-1] + params = _generate_params(call, call=True) + params = params.replace("void * stream, ", "") + + return f"infiniStatus_t infiniop{name}(infiniop{name}Descriptor_t desc, void *workspace, size_t workspace_size, {params}, void *stream)" + + def _generate_destroy_func_decl(operator): + name = operator.name + + return f"infiniStatus_t infiniopDestroy{name}Descriptor(infiniop{name}Descriptor_t desc)" + + def _generate_params(node, call=False): + arguments = tuple(node.get_arguments()) + + arguments = (arguments[-1], *arguments[:-1]) + + def _handle_tensor(spelling): + if call: + return spelling.replace("Tensor", "void *") + return spelling.replace("Tensor", "infiniopTensorDescriptor_t") + + def _handle_std_optional(spelling): + return spelling.replace("std::optional<", "").replace(">", "") + + return ", ".join( + f"{_handle_std_optional(_handle_tensor(arg.type.spelling))} {arg.spelling}" + for arg in arguments + ) + + def _generate_arguments(node, is_data=False): + return ", ".join( + _generate_tensor_caster(arg.spelling, is_data=is_data) + if "Tensor" in arg.type.spelling + else arg.spelling + for arg in node.get_arguments() + if arg.spelling != "handle" and arg.spelling != "stream" + ) + + def _generate_tensor_caster(name, is_data=False): + if is_data: + return f"infini::ops::Tensor(const_cast({name}), infini::ops::Tensor::Shape{{}})" + + return f"infini::ops::Tensor{{nullptr, {name}->shape(), DataTypeFromInfiniDType({name}->dtype()), infini::ops::Device{{DeviceTypeFromInfiniDevice(handle->device), handle->device_id}}, {name}->strides()}}" + + return _generate_source(operator), _generate_header(operator) + + if __name__ == "__main__": _BINDINGS_DIR.mkdir(parents=True, exist_ok=True) + _SRC_DIR.mkdir(parents=True, exist_ok=True) + _INCLUDE_DIR.mkdir(parents=True, exist_ok=True) with open("ops.json") as f: ops = json.load(f) @@ -174,20 +354,28 @@ def _generate_call(call, method=True): header_paths = [] bind_func_names = [] - for op_name in ops: + for op_name, impl_paths in ops.items(): extractor = _OperatorExtractor() operator = extractor(op_name) + source_path = _SRC_DIR / op_name.lower() header_name = f"{op_name.lower()}.h" bind_func_name = f"Bind{op_name}" (_BINDINGS_DIR / header_name).write_text(_generate_pybind11(operator)) + legacy_c_source, legacy_c_header = _generate_legacy_c(operator, impl_paths) + source_path.mkdir(exist_ok=True) + (_SRC_DIR / op_name.lower() / "operator.cc").write_text(legacy_c_source) + (_INCLUDE_DIR / header_name).write_text(legacy_c_header) + header_paths.append(header_name) bind_func_names.append(bind_func_name) impl_includes = "\n".join( - f'#include "{header_path}"' for header_path in ops.values() + f'#include "{impl_path}"' + for impl_paths in ops.values() + for impl_path in impl_paths ) op_includes = "\n".join(f'#include "{header_path}"' for header_path in header_paths) bind_func_calls = "\n".join( diff --git a/src/base/gemm.h b/src/base/gemm.h index e6ba2e5..1258844 100644 --- a/src/base/gemm.h +++ b/src/base/gemm.h @@ -31,6 +31,9 @@ class Gemm : public Operator { // TODO: Check constraints. } + Gemm(const Tensor a, const Tensor b, Tensor c) + : Gemm{a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt, c} {} + virtual void operator()(void* stream, const Tensor a, const Tensor b, std::optional alpha, std::optional beta, std::optional trans_a, @@ -42,6 +45,12 @@ class Gemm : public Operator { std::nullopt, c); } + virtual void operator()(void* stream, const Tensor a, const Tensor b, + std::optional alpha, std::optional beta, + Tensor c) const { + return operator()(stream, a, b, alpha, beta, std::nullopt, std::nullopt, c); + } + protected: float alpha_{1.0}; diff --git a/src/cuda/gemm/blas.h b/src/cuda/gemm/blas.h index 7d8e229..b15e689 100644 --- a/src/cuda/gemm/blas.h +++ b/src/cuda/gemm/blas.h @@ -27,6 +27,10 @@ class Blas : public Gemm { // TODO: Check constraints. } + Blas(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, Tensor c) + : Blas{a, b, alpha, beta, std::nullopt, std::nullopt, c} {} + Blas(const Tensor a, const Tensor b, Tensor c) : Blas{a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt, c} {} From 632dea2471a9c070e026c99f25836de79aa4e247 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 12 Feb 2026 18:26:01 +0800 Subject: [PATCH 30/93] fix: remove unintended white space in `DeviceTypeFromString` --- scripts/generate_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index ca5dddf..71ca0f6 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -143,7 +143,7 @@ def _generate_call(call, method=True): #endif {{"mlu", Device::Type::kCambricon}}, {{"npu", Device::Type::kAscend}}, {{"musa", Device::Type::kMoore}}}}; - + auto it{{kTorchNameToTypes.find(name)}}; if (it != kTorchNameToTypes.cend()) {{ From 94a9cf213c89c663da8e42464af78beb33c4134b Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 12 Feb 2026 18:28:20 +0800 Subject: [PATCH 31/93] fix: use `op_name.lower()` in `_generate_call` --- scripts/generate_wrappers.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 71ca0f6..e5ed541 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -86,11 +86,11 @@ def _generate_init(constructor): return std::unique_ptr{{static_cast(Self::make({_generate_arguments(constructor)}).release())}}; }}))""" - def _generate_call(call, method=True): + def _generate_call(op_name, call, method=True): call_params = _generate_params(call) if not method: - return f""" m.def("gemm", []({call_params}) {{ return Self::call({_generate_arguments(call)}); }});""" + return f""" m.def("{op_name.lower()}", []({call_params}) {{ return Self::call({_generate_arguments(call)}); }});""" return f""" .def("__call__", [](const Self& self, {call_params}) {{ return static_cast&>(self)({_generate_arguments(call)}); @@ -99,8 +99,10 @@ def _generate_call(call, method=True): inits = "\n".join( _generate_init(constructor) for constructor in operator.constructors ) - calls = "\n".join(_generate_call(call) for call in operator.calls) - callers = "\n".join(_generate_call(call, method=False) for call in operator.calls) + calls = "\n".join(_generate_call(operator.name, call) for call in operator.calls) + callers = "\n".join( + _generate_call(operator.name, call, method=False) for call in operator.calls + ) return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ #define INFINI_OPS_BINDINGS_{op_name.upper()}_H_ From 186071aa14bef4b5d36c983b780a084e77ef2caf Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 12 Feb 2026 19:17:18 +0800 Subject: [PATCH 32/93] fix: add a constructor to `Operator` to support all `operator()` --- src/cpu/gemm/gemm.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/cpu/gemm/gemm.h b/src/cpu/gemm/gemm.h index 32c123f..e61a174 100644 --- a/src/cpu/gemm/gemm.h +++ b/src/cpu/gemm/gemm.h @@ -21,6 +21,10 @@ class Operator : public Gemm { : Operator{a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt, c} {} + Operator(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, Tensor c) + : Operator{a, b, alpha, beta, std::nullopt, std::nullopt, c} {} + void operator()(void* stream, const Tensor a, const Tensor b, std::optional alpha, std::optional beta, std::optional trans_a, std::optional trans_b, From d6b725be9556e0ba45e253eda9a7a3fb984cc221 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 12 Feb 2026 19:19:13 +0800 Subject: [PATCH 33/93] feat: add operator searching to `scripts/generate_wrappers.py` --- scripts/generate_wrappers.py | 58 +++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index e5ed541..66d7efd 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -1,3 +1,4 @@ +import argparse import json import pathlib import textwrap @@ -5,11 +6,15 @@ import clang.cindex from clang.cindex import CursorKind +_SRC_DIR = pathlib.Path("src") + +_BASE_DIR = _SRC_DIR / "base" + _GENERATION_DIR = pathlib.Path("generated") _BINDINGS_DIR = _GENERATION_DIR / "bindings" -_SRC_DIR = _GENERATION_DIR / "src" +_GENERATED_SRC_DIR = _GENERATION_DIR / "src" _INCLUDE_DIR = _GENERATION_DIR / "include" @@ -174,7 +179,7 @@ def _generate_call(op_name, call, method=True): def _generate_legacy_c(operator, paths): def _generate_source(operator): impl_includes = "\n".join( - f'#include "{path.removeprefix("src/")}"' for path in paths + f'#include "{str(path).removeprefix("src/")}"' for path in paths ) return f"""#include "../../handle.h" @@ -345,13 +350,50 @@ def _generate_tensor_caster(name, is_data=False): return _generate_source(operator), _generate_header(operator) +def _get_all_ops(devices): + ops = {} + + for file_path in _BASE_DIR.iterdir(): + if not file_path.is_file(): + continue + + op_name = "".join(word.capitalize() for word in file_path.stem.split("_")) + + ops[op_name] = [] + + for file_path in _SRC_DIR.rglob("*"): + if not file_path.is_file() or file_path.parent.parent.name not in devices: + continue + + if f"class Operator<{op_name}" in file_path.read_text(): + ops[op_name].append(file_path) + + return ops + + if __name__ == "__main__": + parser = argparse.ArgumentParser(description="An automatic wrapper generator.") + + parser.add_argument( + "--devices", + nargs="+", + default="cpu", + type=str, + help="Devices to use. Please pick from cpu, nvidia, cambricon, ascend, metax, moore, iluvatar, kunlun, hygon, and qy. (default: cpu)", + ) + + args = parser.parse_args() + _BINDINGS_DIR.mkdir(parents=True, exist_ok=True) - _SRC_DIR.mkdir(parents=True, exist_ok=True) + _GENERATED_SRC_DIR.mkdir(parents=True, exist_ok=True) _INCLUDE_DIR.mkdir(parents=True, exist_ok=True) - with open("ops.json") as f: - ops = json.load(f) + ops_json = pathlib.Path("ops.json") + + if ops_json.exists(): + ops = json.loads(ops_json.read_text()) + else: + ops = _get_all_ops(args.devices) header_paths = [] bind_func_names = [] @@ -360,7 +402,7 @@ def _generate_tensor_caster(name, is_data=False): extractor = _OperatorExtractor() operator = extractor(op_name) - source_path = _SRC_DIR / op_name.lower() + source_path = _GENERATED_SRC_DIR / op_name.lower() header_name = f"{op_name.lower()}.h" bind_func_name = f"Bind{op_name}" @@ -368,7 +410,9 @@ def _generate_tensor_caster(name, is_data=False): legacy_c_source, legacy_c_header = _generate_legacy_c(operator, impl_paths) source_path.mkdir(exist_ok=True) - (_SRC_DIR / op_name.lower() / "operator.cc").write_text(legacy_c_source) + (_GENERATED_SRC_DIR / op_name.lower() / "operator.cc").write_text( + legacy_c_source + ) (_INCLUDE_DIR / header_name).write_text(legacy_c_header) header_paths.append(header_name) From 87de39780056330a205f7efa0cc56f3f132def8b Mon Sep 17 00:00:00 2001 From: Ziminli <70735843+Ziminli@users.noreply.github.com> Date: Fri, 13 Feb 2026 20:57:47 +0800 Subject: [PATCH 34/93] build: add CMake build system and README (#2) * build: add CMake build system and `README.md` * fix: add cudart and cuda driver's linking * fix: update `README.md` to remove unnecessary information and fix some formatting issues * fix: move `build/` to an earlier position in `.gitignore` --------- Co-authored-by: Jiacheng Huang --- .gitignore | 1 + CMakeLists.txt | 40 +++++++++++++++++++++++++ README.md | 64 +++++++++++++++++++++++++++++++++++++++ examples/CMakeLists.txt | 16 ++++++++++ examples/data_type.cc | 11 ++++--- src/CMakeLists.txt | 66 +++++++++++++++++++++++++++++++++++++++++ 6 files changed, 194 insertions(+), 4 deletions(-) create mode 100644 CMakeLists.txt create mode 100644 README.md create mode 100644 examples/CMakeLists.txt create mode 100644 src/CMakeLists.txt diff --git a/.gitignore b/.gitignore index 80b21fd..99bca7d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ # Generated files +build/ generated/ # Prerequisites diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..5a264e1 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,40 @@ +cmake_minimum_required(VERSION 3.18) +project(InfiniOps LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Options for backends. +option(USE_CPU "Enable CPU backend" OFF) +option(USE_NVIDIA "Enable CUDA backend" OFF) +option(USE_METAX "Enable MetaX backend" OFF) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) + +if(USE_NVIDIA) + add_compile_definitions(USE_NVIDIA=1) + enable_language(CUDA) + find_package(CUDAToolkit REQUIRED) +elseif(USE_METAX) + add_compile_definitions(USE_METAX=1) + + # Normally can be found at: `/opt/maca/`. + set(MACA_PATH $ENV{MACA_PATH}) + set(CMAKE_C_COMPILER ${MACA_PATH}/mxgpu_llvm/bin/mxcc) + set(CMAKE_CXX_COMPILER ${MACA_PATH}/mxgpu_llvm/bin/mxcc) + + include_directories("${MACA_PATH}/include") + link_directories("${MACA_PATH}/lib") + + # Libraries: mcruntime / mcdnn / mcblas. + find_library(MACA_RUNTIME_LIB NAMES mcruntime HINTS "${MACA_PATH}/lib" REQUIRED) + find_library(MACA_DNN_LIB NAMES mcdnn HINTS "${MACA_PATH}/lib" REQUIRED) + find_library(MACA_BLAS_LIB NAMES mcblas HINTS "${MACA_PATH}/lib" REQUIRED) +# If all other platforms are not enabled, CPU is enabled by default. +else() + add_compile_definitions(USE_CPU=1) +endif() + +add_subdirectory(src) + +add_subdirectory(examples) diff --git a/README.md b/README.md new file mode 100644 index 0000000..0c8eb10 --- /dev/null +++ b/README.md @@ -0,0 +1,64 @@ +# InfiniOps + +InfiniOps is a high-performance, hardware-agnostic operator library. + +## 🛠️ Prerequisites + +Ensure your environment meets the following requirements based on your target backend: + + - C++17 compatible compiler + - CMake 3.18+ + - Hardware-specific SDKs (e.g., CUDA Toolkit) + +--- + +## ⚙️ Installation & Building + +InfiniOps uses CMake to manage backends. + +### 1. Setup Environment + +Ensure you have the corresponding SDK installed and environment variables set up for the platform/accelerator you are working on. + +### 2. Configure and Build + +Using these commands at the root directory of this project: + +```bash +mkdir build && cd build + +cmake .. + +make -j$(nproc) +``` + +For the ``: + +| Option | Functionality | 默认值 +|-----------------------------|-----------------------------------|:-: +| `-DUSE_CPU=[ON\|OFF]` | Compile the CPU implementation | n* +| `-DUSE_NVIDIA=[ON\|OFF]` | Compile the NVIDIA implementation | n +| `-DUSE_METAX=[ON\|OFF]` | Compile the MetaX implementation | n + +*Note: If no accelerator options are provided, `USE_CPU` is enabled by default.* + +## 🚀 Running Examples +After a successful build, the executables are located in the `build/examples` directory. + +Run the GEMM example: + +```bash +./examples/gemm +``` + +Run the data_type example: + +```bash +./examples/data_type +``` + +Run the tensor example: + +```bash +./examples/tensor +``` diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt new file mode 100644 index 0000000..68ebc1b --- /dev/null +++ b/examples/CMakeLists.txt @@ -0,0 +1,16 @@ +file(GLOB_RECURSE EXAMPLE_SOURCES CONFIGURE_DEPENDS "*.cc") + +# Iterate through each file and create an executable. +foreach(source_file ${EXAMPLE_SOURCES}) + get_filename_component(example_name ${source_file} NAME_WE) + + add_executable(${example_name} ${source_file}) + + target_link_libraries(${example_name} PRIVATE infiniops) + + target_include_directories(${example_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + + get_filename_component(example_dir ${source_file} DIRECTORY) + + target_include_directories(${example_name} PRIVATE ${example_dir}) +endforeach() diff --git a/examples/data_type.cc b/examples/data_type.cc index 0b9d010..f937123 100644 --- a/examples/data_type.cc +++ b/examples/data_type.cc @@ -10,15 +10,18 @@ int main() { using namespace infini::ops; static const std::vector kDataTypes{ - kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, - kUInt32, kUInt64, kFloat16, kBFloat16, kFloat32, kFloat64}; + DataType::kInt8, DataType::kInt16, DataType::kInt32, + DataType::kInt64, DataType::kUInt8, DataType::kUInt16, + DataType::kUInt32, DataType::kUInt64, DataType::kFloat16, + DataType::kBFloat16, DataType::kFloat32, DataType::kFloat64}; std::cout << std::left << std::setw(10) << "Name" << std::left << std::setw(10) << "Element Size\n"; for (const auto& dtype : kDataTypes) { - std::cout << std::left << std::setw(10) << dtype.name() << std::left - << std::setw(10) << dtype.element_size() << '\n'; + std::cout << std::left << std::setw(10) << kDataTypeToDesc.at(dtype) + << std::left << std::setw(10) << kDataTypeToSize.at(dtype) + << '\n'; } return 0; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000..777059f --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,66 @@ +add_library(infiniops SHARED) + +file(GLOB BASE_SRCS CONFIGURE_DEPENDS "*.cc") +target_sources(infiniops PRIVATE ${BASE_SRCS}) + +if(USE_CPU) + set(CPU_PATTERNS + "cpu/*.cc" + "cpu/*.cpp" + ) + + file(GLOB_RECURSE CPU_SOURCES CONFIGURE_DEPENDS ${CPU_PATTERNS}) + list(APPEND CORE_SOURCES ${CPU_SOURCES}) + + target_compile_definitions(infiniops PRIVATE USE_CPU=1) + + # Reserve for OpenMP. + # find_package(OpenMP REQUIRED) + # target_link_libraries(infiniops PRIVATE OpenMP::OpenMP_CXX) +endif() + +if(USE_NVIDIA) + set(NVIDIA_PATTERNS + "cuda/*.cc" + "cuda/*.cpp" + "cuda/*.cu" + "nvidia/*.cc" + "nvidia/*.cpp" + "nvidia/*.cu" + ) + + file(GLOB_RECURSE NVIDIA_SOURCES CONFIGURE_DEPENDS ${NVIDIA_PATTERNS}) + + enable_language(CUDA) + + target_compile_definitions(infiniops PRIVATE USE_NVIDIA=1) + target_sources(infiniops PRIVATE ${NVIDIA_SOURCES}) + + find_package(CUDAToolkit REQUIRED) + target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver) +endif() + +if(USE_METAX) + set(METAX_PATTERNS + "cuda/*.cc" + "cuda/*.cpp" + "metax/*.cc" + "metax/*.maca" + ) + + file(GLOB_RECURSE METAX_SOURCES CONFIGURE_DEPENDS ${METAX_PATTERNS}) + + set_source_files_properties(${METAX_SOURCES} PROPERTIES LANGUAGE CXX) + + target_compile_definitions(infiniops PRIVATE USE_METAX=1) + target_sources(infiniops PRIVATE ${METAX_SOURCES}) + + target_include_directories(infiniops PUBLIC "${MACA_PATH}/include") + target_link_libraries(infiniops PUBLIC + ${MACA_RUNTIME_LIB} + ${MACA_DNN_LIB} + ${MACA_BLAS_LIB} + ) +endif() + +target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) From 2a5ab4fbce33d32f2641849cb18b482d02b48ed6 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 13 Feb 2026 21:01:58 +0800 Subject: [PATCH 35/93] fix: remove the `*` after `n` --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0c8eb10..cbba3ab 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ For the ``: | Option | Functionality | 默认值 |-----------------------------|-----------------------------------|:-: -| `-DUSE_CPU=[ON\|OFF]` | Compile the CPU implementation | n* +| `-DUSE_CPU=[ON\|OFF]` | Compile the CPU implementation | n | `-DUSE_NVIDIA=[ON\|OFF]` | Compile the NVIDIA implementation | n | `-DUSE_METAX=[ON\|OFF]` | Compile the MetaX implementation | n From b5b613643f4af899385f52aa2dca1dbeb3d7c529 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 25 Feb 2026 11:01:05 +0800 Subject: [PATCH 36/93] build: rename `USE_` options to `WITH_` for backend selection --- CMakeLists.txt | 16 ++++++++-------- README.md | 8 ++++---- examples/gemm/gemm.cc | 10 +++++----- examples/runtime_api.h | 6 +++--- scripts/generate_wrappers.py | 12 ++++++------ src/CMakeLists.txt | 12 ++++++------ src/device.h | 20 ++++++++++---------- 7 files changed, 42 insertions(+), 42 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5a264e1..283bd6f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,18 +5,18 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) # Options for backends. -option(USE_CPU "Enable CPU backend" OFF) -option(USE_NVIDIA "Enable CUDA backend" OFF) -option(USE_METAX "Enable MetaX backend" OFF) +option(WITH_CPU "Enable CPU backend" OFF) +option(WITH_NVIDIA "Enable CUDA backend" OFF) +option(WITH_METAX "Enable MetaX backend" OFF) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) -if(USE_NVIDIA) - add_compile_definitions(USE_NVIDIA=1) +if(WITH_NVIDIA) + add_compile_definitions(WITH_NVIDIA=1) enable_language(CUDA) find_package(CUDAToolkit REQUIRED) -elseif(USE_METAX) - add_compile_definitions(USE_METAX=1) +elseif(WITH_METAX) + add_compile_definitions(WITH_METAX=1) # Normally can be found at: `/opt/maca/`. set(MACA_PATH $ENV{MACA_PATH}) @@ -32,7 +32,7 @@ elseif(USE_METAX) find_library(MACA_BLAS_LIB NAMES mcblas HINTS "${MACA_PATH}/lib" REQUIRED) # If all other platforms are not enabled, CPU is enabled by default. else() - add_compile_definitions(USE_CPU=1) + add_compile_definitions(WITH_CPU=1) endif() add_subdirectory(src) diff --git a/README.md b/README.md index cbba3ab..72803d0 100644 --- a/README.md +++ b/README.md @@ -36,11 +36,11 @@ For the ``: | Option | Functionality | 默认值 |-----------------------------|-----------------------------------|:-: -| `-DUSE_CPU=[ON\|OFF]` | Compile the CPU implementation | n -| `-DUSE_NVIDIA=[ON\|OFF]` | Compile the NVIDIA implementation | n -| `-DUSE_METAX=[ON\|OFF]` | Compile the MetaX implementation | n +| `-DWITH_CPU=[ON\|OFF]` | Compile the CPU implementation | n +| `-DWITH_NVIDIA=[ON\|OFF]` | Compile the NVIDIA implementation | n +| `-DWITH_METAX=[ON\|OFF]` | Compile the MetaX implementation | n -*Note: If no accelerator options are provided, `USE_CPU` is enabled by default.* +*Note: If no accelerator options are provided, `WITH_CPU` is enabled by default.* ## 🚀 Running Examples After a successful build, the executables are located in the `build/examples` directory. diff --git a/examples/gemm/gemm.cc b/examples/gemm/gemm.cc index e051a8f..7622afa 100644 --- a/examples/gemm/gemm.cc +++ b/examples/gemm/gemm.cc @@ -2,11 +2,11 @@ #include #include -#ifdef USE_CPU +#ifdef WITH_CPU #include "cpu/gemm/gemm.h" -#elif USE_NVIDIA +#elif WITH_NVIDIA #include "nvidia/gemm/cublas.h" -#elif USE_METAX +#elif WITH_METAX #include "metax/gemm/mcblas.h" #endif @@ -50,7 +50,7 @@ int main() { void *a_ptr, *b_ptr, *c_ptr; -#ifdef USE_CPU +#ifdef WITH_CPU a_ptr = a_vec.data(); b_ptr = b_vec.data(); c_ptr = c_vec.data(); @@ -73,7 +73,7 @@ int main() { Gemm::call(nullptr, a_device, b_device, c_device); -#ifndef USE_CPU +#ifndef WITH_CPU DEVICE_MEMCPY(c_vec.data(), c_ptr, c_size, DEVICE_MEMCPY_DEVICE_TO_HOST); DEVICE_FREE(a_ptr); DEVICE_FREE(b_ptr); diff --git a/examples/runtime_api.h b/examples/runtime_api.h index e71386d..021f24c 100644 --- a/examples/runtime_api.h +++ b/examples/runtime_api.h @@ -1,7 +1,7 @@ #ifndef INFINI_OPS_EXAMPLES_RUNTIME_API_H_ #define INFINI_OPS_EXAMPLES_RUNTIME_API_H_ -#ifdef USE_NVIDIA +#ifdef WITH_NVIDIA #include #define DEVICE_MALLOC cudaMalloc #define DEVICE_FREE cudaFree @@ -10,7 +10,7 @@ #define DEVICE_MEMCPY_HOST_TO_DEVICE cudaMemcpyHostToDevice #define DEVICE_MEMCPY_DEVICE_TO_HOST cudaMemcpyDeviceToHost #define DEFAULT_DEVICE_TYPE Device::Type::kNvidia -#elif USE_METAX +#elif WITH_METAX #include #define DEVICE_MALLOC mcMalloc #define DEVICE_FREE mcFree @@ -19,7 +19,7 @@ #define DEVICE_MEMCPY_HOST_TO_DEVICE mcMemcpyHostToDevice #define DEVICE_MEMCPY_DEVICE_TO_HOST mcMemcpyDeviceToHost #define DEFAULT_DEVICE_TYPE Device::Type::kMetax -#elif USE_CPU +#elif WITH_CPU #define DEFAULT_DEVICE_TYPE Device::Type::kCpu #endif diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 66d7efd..5755df1 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -130,22 +130,22 @@ def _generate_call(op_name, call, method=True): inline Device::Type DeviceTypeFromString(const std::string& name) {{ static const std::unordered_map kTorchNameToTypes{{ {{"cpu", Device::Type::kCpu}}, -#ifdef USE_NVIDIA +#ifdef WITH_NVIDIA {{"cuda", Device::Type::kNvidia}}, #endif -#ifdef USE_METAX +#ifdef WITH_METAX {{"cuda", Device::Type::kMetax}}, #endif -#ifdef USE_ILUVATAR +#ifdef WITH_ILUVATAR {{"cuda", Device::Type::kIluvatar}}, #endif -#ifdef USE_KUNLUN +#ifdef WITH_KUNLUN {{"cuda", Device::Type::kKunlun}}, #endif -#ifdef USE_HYGON +#ifdef WITH_HYGON {{"cuda", Device::Type::kHygon}}, #endif -#ifdef USE_QY +#ifdef WITH_QY {{"cuda", Device::Type::kQy}}, #endif {{"mlu", Device::Type::kCambricon}}, {{"npu", Device::Type::kAscend}}, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 777059f..a4937eb 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,7 +3,7 @@ add_library(infiniops SHARED) file(GLOB BASE_SRCS CONFIGURE_DEPENDS "*.cc") target_sources(infiniops PRIVATE ${BASE_SRCS}) -if(USE_CPU) +if(WITH_CPU) set(CPU_PATTERNS "cpu/*.cc" "cpu/*.cpp" @@ -12,14 +12,14 @@ if(USE_CPU) file(GLOB_RECURSE CPU_SOURCES CONFIGURE_DEPENDS ${CPU_PATTERNS}) list(APPEND CORE_SOURCES ${CPU_SOURCES}) - target_compile_definitions(infiniops PRIVATE USE_CPU=1) + target_compile_definitions(infiniops PRIVATE WITH_CPU=1) # Reserve for OpenMP. # find_package(OpenMP REQUIRED) # target_link_libraries(infiniops PRIVATE OpenMP::OpenMP_CXX) endif() -if(USE_NVIDIA) +if(WITH_NVIDIA) set(NVIDIA_PATTERNS "cuda/*.cc" "cuda/*.cpp" @@ -33,14 +33,14 @@ if(USE_NVIDIA) enable_language(CUDA) - target_compile_definitions(infiniops PRIVATE USE_NVIDIA=1) + target_compile_definitions(infiniops PRIVATE WITH_NVIDIA=1) target_sources(infiniops PRIVATE ${NVIDIA_SOURCES}) find_package(CUDAToolkit REQUIRED) target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver) endif() -if(USE_METAX) +if(WITH_METAX) set(METAX_PATTERNS "cuda/*.cc" "cuda/*.cpp" @@ -52,7 +52,7 @@ if(USE_METAX) set_source_files_properties(${METAX_SOURCES} PROPERTIES LANGUAGE CXX) - target_compile_definitions(infiniops PRIVATE USE_METAX=1) + target_compile_definitions(infiniops PRIVATE WITH_METAX=1) target_sources(infiniops PRIVATE ${METAX_SOURCES}) target_include_directories(infiniops PUBLIC "${MACA_PATH}/include") diff --git a/src/device.h b/src/device.h index cceab1e..61a8d7e 100644 --- a/src/device.h +++ b/src/device.h @@ -83,61 +83,61 @@ struct EnabledDeviceFilter { // Device. If the macro is NOT defined, the specialization is not compiled, // and FilterList will exclude it from ActiveDevices. -#ifdef USE_CPU +#ifdef WITH_CPU template = 0> void operator()() const {} #endif -#ifdef USE_NVIDIA +#ifdef WITH_NVIDIA template = 0> void operator()() const {} #endif -#ifdef USE_CAMBRICON +#ifdef WITH_CAMBRICON template = 0> void operator()() const {} #endif -#ifdef USE_ASCEND +#ifdef WITH_ASCEND template = 0> void operator()() const {} #endif -#ifdef USE_METAX +#ifdef WITH_METAX template = 0> void operator()() const {} #endif -#ifdef USE_MOORE +#ifdef WITH_MOORE template = 0> void operator()() const {} #endif -#ifdef USE_ILUVATAR +#ifdef WITH_ILUVATAR template = 0> void operator()() const {} #endif -#ifdef USE_KUNLUN +#ifdef WITH_KUNLUN template = 0> void operator()() const {} #endif -#ifdef USE_HYGON +#ifdef WITH_HYGON template = 0> void operator()() const {} #endif -#ifdef USE_QY +#ifdef WITH_QY template = 0> void operator()() const {} From e5b3aeaa07ad18e49c61e2b069ead2d272f624de Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 25 Feb 2026 14:53:40 +0800 Subject: [PATCH 37/93] build: add pybind11 support to generate python bindings --- examples/gemm/gemm.cc | 6 ++++-- src/CMakeLists.txt | 34 ++++++++++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/examples/gemm/gemm.cc b/examples/gemm/gemm.cc index 7622afa..71891ef 100644 --- a/examples/gemm/gemm.cc +++ b/examples/gemm/gemm.cc @@ -4,9 +4,11 @@ #ifdef WITH_CPU #include "cpu/gemm/gemm.h" -#elif WITH_NVIDIA +#endif +#if WITH_NVIDIA #include "nvidia/gemm/cublas.h" -#elif WITH_METAX +#endif +#if WITH_METAX #include "metax/gemm/mcblas.h" #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a4937eb..620c506 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,6 +3,8 @@ add_library(infiniops SHARED) file(GLOB BASE_SRCS CONFIGURE_DEPENDS "*.cc") target_sources(infiniops PRIVATE ${BASE_SRCS}) +set(DEVICE_LIST "") + if(WITH_CPU) set(CPU_PATTERNS "cpu/*.cc" @@ -12,11 +14,13 @@ if(WITH_CPU) file(GLOB_RECURSE CPU_SOURCES CONFIGURE_DEPENDS ${CPU_PATTERNS}) list(APPEND CORE_SOURCES ${CPU_SOURCES}) - target_compile_definitions(infiniops PRIVATE WITH_CPU=1) + target_compile_definitions(infiniops PUBLIC WITH_CPU=1) # Reserve for OpenMP. # find_package(OpenMP REQUIRED) # target_link_libraries(infiniops PRIVATE OpenMP::OpenMP_CXX) + + list(APPEND DEVICE_LIST "cpu") endif() if(WITH_NVIDIA) @@ -33,11 +37,13 @@ if(WITH_NVIDIA) enable_language(CUDA) - target_compile_definitions(infiniops PRIVATE WITH_NVIDIA=1) + target_compile_definitions(infiniops PUBLIC WITH_NVIDIA=1) target_sources(infiniops PRIVATE ${NVIDIA_SOURCES}) find_package(CUDAToolkit REQUIRED) target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver) + + list(APPEND DEVICE_LIST "nvidia") endif() if(WITH_METAX) @@ -61,6 +67,30 @@ if(WITH_METAX) ${MACA_DNN_LIB} ${MACA_BLAS_LIB} ) + + list(APPEND DEVICE_LIST "metax") endif() target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +if(GENERATE_PYTHON_BINDINGS) + execute_process( + COMMAND python ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py --devices ${DEVICE_LIST} + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + RESULT_VARIABLE script_result + ) + + if(NOT script_result EQUAL 0) + message(FATAL_ERROR "Generating wrappers - failed") + else() + message(STATUS "Generating wrappers - done") + endif() + + find_package(Python COMPONENTS Interpreter Development) + find_package(pybind11 CONFIG) + + pybind11_add_module(ops "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc") + + target_include_directories(ops PRIVATE ${PROJECT_SOURCE_DIR}) + target_link_libraries(ops PRIVATE infiniops) +endif() From eea1bdbd160412ee4ed772dfb5d3adafdb95f3e2 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 25 Feb 2026 15:11:11 +0800 Subject: [PATCH 38/93] build: add `GENERATE_PYTHON_BINDINGS` option to `CMakeLists.txt` --- CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 283bd6f..cc671b9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,6 +9,8 @@ option(WITH_CPU "Enable CPU backend" OFF) option(WITH_NVIDIA "Enable CUDA backend" OFF) option(WITH_METAX "Enable MetaX backend" OFF) +option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) if(WITH_NVIDIA) From ee7999a6a21f28528929aadb6c6eb53a4d2dbbbd Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 25 Feb 2026 15:12:52 +0800 Subject: [PATCH 39/93] docs: document `GENERATE_PYTHON_BINDINGS` in `README.md` --- README.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 72803d0..e0befd6 100644 --- a/README.md +++ b/README.md @@ -34,11 +34,12 @@ make -j$(nproc) For the ``: -| Option | Functionality | 默认值 -|-----------------------------|-----------------------------------|:-: -| `-DWITH_CPU=[ON\|OFF]` | Compile the CPU implementation | n -| `-DWITH_NVIDIA=[ON\|OFF]` | Compile the NVIDIA implementation | n -| `-DWITH_METAX=[ON\|OFF]` | Compile the MetaX implementation | n +| Option | Functionality | Default +|----------------------------------------|------------------------------------|:-: +| `-DWITH_CPU=[ON\|OFF]` | Compile the CPU implementation | n +| `-DWITH_NVIDIA=[ON\|OFF]` | Compile the NVIDIA implementation | n +| `-DWITH_METAX=[ON\|OFF]` | Compile the MetaX implementation | n +| `-DGENERATE_PYTHON_BINDINGS=[ON\|OFF]` | Generate Python bindings | n *Note: If no accelerator options are provided, `WITH_CPU` is enabled by default.* From 1d07c8dd0564f40a2df2b168aa00b2d54f7e38f2 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 25 Feb 2026 15:57:32 +0800 Subject: [PATCH 40/93] feat: unify runtime API for CPU backend in `examples/runtime_api.h` --- examples/gemm/gemm.cc | 8 -------- examples/runtime_api.h | 8 ++++++++ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/gemm/gemm.cc b/examples/gemm/gemm.cc index 71891ef..2bdbe6d 100644 --- a/examples/gemm/gemm.cc +++ b/examples/gemm/gemm.cc @@ -52,11 +52,6 @@ int main() { void *a_ptr, *b_ptr, *c_ptr; -#ifdef WITH_CPU - a_ptr = a_vec.data(); - b_ptr = b_vec.data(); - c_ptr = c_vec.data(); -#else DEVICE_MALLOC(&a_ptr, a_size); DEVICE_MALLOC(&b_ptr, b_size); DEVICE_MALLOC(&c_ptr, c_size); @@ -64,7 +59,6 @@ int main() { DEVICE_MEMCPY(a_ptr, a_vec.data(), a_size, DEVICE_MEMCPY_HOST_TO_DEVICE); DEVICE_MEMCPY(b_ptr, b_vec.data(), b_size, DEVICE_MEMCPY_HOST_TO_DEVICE); DEVICE_MEMSET(c_ptr, 0, c_size); -#endif Tensor a_device{a_ptr, a_host.shape(), a_host.dtype(), a_host.device(), a_host.strides()}; @@ -75,12 +69,10 @@ int main() { Gemm::call(nullptr, a_device, b_device, c_device); -#ifndef WITH_CPU DEVICE_MEMCPY(c_vec.data(), c_ptr, c_size, DEVICE_MEMCPY_DEVICE_TO_HOST); DEVICE_FREE(a_ptr); DEVICE_FREE(b_ptr); DEVICE_FREE(c_ptr); -#endif std::cout << "A: " << a_host.ToString() << "\n"; std::cout << "B: " << b_host.ToString() << "\n"; diff --git a/examples/runtime_api.h b/examples/runtime_api.h index 021f24c..63ff95e 100644 --- a/examples/runtime_api.h +++ b/examples/runtime_api.h @@ -20,6 +20,14 @@ #define DEVICE_MEMCPY_DEVICE_TO_HOST mcMemcpyDeviceToHost #define DEFAULT_DEVICE_TYPE Device::Type::kMetax #elif WITH_CPU +#include +#include +#define DEVICE_MALLOC(ptr, size) (*(ptr) = std::malloc(size)) +#define DEVICE_FREE std::free +#define DEVICE_MEMCPY(dst, src, size, kind) std::memcpy(dst, src, size) +#define DEVICE_MEMSET std::memset +#define DEVICE_MEMCPY_HOST_TO_DEVICE 0 +#define DEVICE_MEMCPY_DEVICE_TO_HOST 1 #define DEFAULT_DEVICE_TYPE Device::Type::kCpu #endif From 8651326618d99d8818a44fb37e8873b01b209f70 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 25 Feb 2026 16:08:21 +0800 Subject: [PATCH 41/93] test: add an example for Python binding generation --- README.md | 6 ++++++ examples/gemm.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+) create mode 100644 examples/gemm.py diff --git a/README.md b/README.md index e0befd6..875a936 100644 --- a/README.md +++ b/README.md @@ -63,3 +63,9 @@ Run the tensor example: ```bash ./examples/tensor ``` + +Run the pybind11 example: + +```bash +PYTHONPATH=src python ../examples/gemm.py +``` diff --git a/examples/gemm.py b/examples/gemm.py new file mode 100644 index 0000000..9c2d1a4 --- /dev/null +++ b/examples/gemm.py @@ -0,0 +1,15 @@ +import ops +import torch + +m, n, k = 2, 3, 4 + +x = torch.randn(m, k, device="cpu") +y = torch.randn(k, n, device="cpu") +z = torch.empty(m, n, device="cpu") + +ops.gemm(x, y, z) + +print(x) +print(y) +print(z) +print(torch.mm(x, y)) From e3718675ef1fcb3e5a22ab9ff05e3bb463af1a6b Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 25 Feb 2026 19:08:25 +0800 Subject: [PATCH 42/93] fix: return `strides_[index]` instead of `shape_[index]` in `stride` --- src/tensor.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensor.cc b/src/tensor.cc index 2c471fb..7d03ac9 100644 --- a/src/tensor.cc +++ b/src/tensor.cc @@ -35,7 +35,7 @@ const Tensor::Strides& Tensor::strides() const { return strides_; } Tensor::Size Tensor::size(const Index& index) const { return shape_[index]; } Tensor::Stride Tensor::stride(const Index& index) const { - return shape_[index]; + return strides_[index]; } Tensor::Size Tensor::ndim() const { return shape_.size(); } From 61cfe665f0d485f135a990f55f79a59ab7a8d606 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 25 Feb 2026 19:57:31 +0800 Subject: [PATCH 43/93] refactor: improve GEMM stride handling --- src/base/gemm.h | 7 +++-- src/cuda/gemm/blas.h | 72 ++++++++++++++++++++++++-------------------- 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/src/base/gemm.h b/src/base/gemm.h index 1258844..c4965cf 100644 --- a/src/base/gemm.h +++ b/src/base/gemm.h @@ -1,6 +1,7 @@ #ifndef INFINI_OPS_BASE_GEMM_H_ #define INFINI_OPS_BASE_GEMM_H_ +#include #include #include "operator.h" @@ -25,9 +26,9 @@ class Gemm : public Operator { a_strides_{a.strides()}, b_strides_{b.strides()}, c_strides_{c.strides()}, - lda_{a_strides_[1]}, - ldb_{b_strides_[1]}, - ldc_{c_strides_[1]} { + lda_{std::max(a_strides_[0], a_strides_[1])}, + ldb_{std::max(b_strides_[0], b_strides_[1])}, + ldc_{std::max(c_strides_[0], c_strides_[1])} { // TODO: Check constraints. } diff --git a/src/cuda/gemm/blas.h b/src/cuda/gemm/blas.h index b15e689..15edb47 100644 --- a/src/cuda/gemm/blas.h +++ b/src/cuda/gemm/blas.h @@ -13,20 +13,16 @@ class Blas : public Gemm { Blas(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) - : Gemm{a.stride(0) == 1 ? a : b.T(), - a.stride(0) == 1 ? b : a.T(), - alpha, - beta, - trans_a, - trans_b, - a.stride(0) == 1 ? c : c.T()}, - lda_{a_strides_[1]}, - ldb_{b_strides_[1]}, - ldc_{c_strides_[1]} { - Backend::blasCreate(&handle); + : Gemm{a, b, alpha, beta, trans_a, trans_b, c}, + swapped_a_and_b_{c_strides_[1] == 1}, + op_a_{InitOpA()}, + op_b_{InitOpB()} { + Backend::blasCreate(&handle_); // TODO: Check constraints. } + ~Blas() { Backend::blasDestroy(handle_); } + Blas(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, Tensor c) : Blas{a, b, alpha, beta, std::nullopt, std::nullopt, c} {} @@ -38,38 +34,50 @@ class Blas : public Gemm { std::optional alpha, std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) const override { - Backend::blasSetStream(handle, + Backend::blasSetStream(handle_, static_cast(stream)); const auto& alpha_value{alpha.value_or(alpha_)}; const auto& beta_value{beta.value_or(beta_)}; - const auto& trans_a_value{trans_a.value_or(trans_a_)}; - const auto& trans_b_value{trans_b.value_or(trans_b_)}; - assert(a_type_ == DataType::kFloat32 && b_type_ == DataType::kFloat32 && - c_type_ == DataType::kFloat32 && - "`operator()` not implemented for this data type"); + Backend::blasGemmEx( + handle_, op_a_, op_b_, swapped_a_and_b_ ? n_ : m_, + swapped_a_and_b_ ? m_ : n_, k_, &alpha_value, + swapped_a_and_b_ ? b.data() : a.data(), Backend::R_32F, + swapped_a_and_b_ ? ldb_ : lda_, swapped_a_and_b_ ? a.data() : b.data(), + Backend::R_32F, swapped_a_and_b_ ? lda_ : ldb_, &beta_value, c.data(), + Backend::R_32F, ldc_, Backend::BLAS_COMPUTE_32F_FAST_TF32, + Backend::BLAS_GEMM_DEFAULT); + } - auto op_a = static_cast( - trans_a_value ? Backend::BLAS_OP_T : Backend::BLAS_OP_N); - auto op_b = static_cast( - trans_b_value ? Backend::BLAS_OP_T : Backend::BLAS_OP_N); + private: + auto InitOpA() const { + if (swapped_a_and_b_) { + return ((b_strides_[1] == 1) == trans_b_) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; + } + + return ((a_strides_[1] == 1) != trans_a_) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; + } - Backend::blasGemmEx(handle, op_a, op_b, m_, n_, k_, &alpha_value, b.data(), - Backend::R_32F, lda_, a.data(), Backend::R_32F, ldb_, - &beta_value, c.data(), Backend::R_32F, ldc_, - Backend::BLAS_COMPUTE_32F_FAST_TF32, - Backend::BLAS_GEMM_DEFAULT); + auto InitOpB() const { + if (swapped_a_and_b_) { + return ((a_strides_[1] == 1) == trans_a_) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; + } - Backend::blasDestroy(handle); + return ((b_strides_[1] == 1) != trans_b_) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; } - private: - Tensor::Stride lda_{0}; - Tensor::Stride ldb_{0}; - Tensor::Stride ldc_{0}; + bool swapped_a_and_b_{false}; + + decltype(Backend::BLAS_OP_T) op_a_; + + decltype(Backend::BLAS_OP_T) op_b_; - typename Backend::blasHandle_t handle; + typename Backend::blasHandle_t handle_; }; } // namespace infini::ops From 783373c479a0089ce71f0f7eabc8242380701a3d Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 26 Feb 2026 09:28:53 +0800 Subject: [PATCH 44/93] fix: update `Blas` to use `trans_a` and `trans_b` parameters --- src/cuda/gemm/blas.h | 38 +++++++++++++++++--------------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/src/cuda/gemm/blas.h b/src/cuda/gemm/blas.h index 15edb47..d9085c5 100644 --- a/src/cuda/gemm/blas.h +++ b/src/cuda/gemm/blas.h @@ -14,9 +14,7 @@ class Blas : public Gemm { std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) : Gemm{a, b, alpha, beta, trans_a, trans_b, c}, - swapped_a_and_b_{c_strides_[1] == 1}, - op_a_{InitOpA()}, - op_b_{InitOpB()} { + swapped_a_and_b_{c_strides_[1] == 1} { Backend::blasCreate(&handle_); // TODO: Check constraints. } @@ -40,8 +38,13 @@ class Blas : public Gemm { const auto& alpha_value{alpha.value_or(alpha_)}; const auto& beta_value{beta.value_or(beta_)}; + const auto& trans_a_value{trans_a.value_or(trans_a_)}; + const auto& trans_b_value{trans_b.value_or(trans_b_)}; + auto op_a{GetOpA(trans_a_value, trans_b_value)}; + auto op_b{GetOpB(trans_a_value, trans_b_value)}; + Backend::blasGemmEx( - handle_, op_a_, op_b_, swapped_a_and_b_ ? n_ : m_, + handle_, op_a, op_b, swapped_a_and_b_ ? n_ : m_, swapped_a_and_b_ ? m_ : n_, k_, &alpha_value, swapped_a_and_b_ ? b.data() : a.data(), Backend::R_32F, swapped_a_and_b_ ? ldb_ : lda_, swapped_a_and_b_ ? a.data() : b.data(), @@ -51,32 +54,25 @@ class Blas : public Gemm { } private: - auto InitOpA() const { + auto GetOpA(int trans_a, int trans_b) const { if (swapped_a_and_b_) { - return ((b_strides_[1] == 1) == trans_b_) ? Backend::BLAS_OP_T - : Backend::BLAS_OP_N; + return ((b_strides_[1] == 1) == trans_b) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; } - - return ((a_strides_[1] == 1) != trans_a_) ? Backend::BLAS_OP_T - : Backend::BLAS_OP_N; + return ((a_strides_[1] == 1) != trans_a) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; } - auto InitOpB() const { + auto GetOpB(int trans_a, int trans_b) const { if (swapped_a_and_b_) { - return ((a_strides_[1] == 1) == trans_a_) ? Backend::BLAS_OP_T - : Backend::BLAS_OP_N; + return ((a_strides_[1] == 1) == trans_a) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; } - - return ((b_strides_[1] == 1) != trans_b_) ? Backend::BLAS_OP_T - : Backend::BLAS_OP_N; + return ((b_strides_[1] == 1) != trans_b) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; } bool swapped_a_and_b_{false}; - - decltype(Backend::BLAS_OP_T) op_a_; - - decltype(Backend::BLAS_OP_T) op_b_; - typename Backend::blasHandle_t handle_; }; From ec8a99aafa64a26ebdc60665e813ae7e5b81fc26 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 26 Feb 2026 10:01:21 +0800 Subject: [PATCH 45/93] feat: add negative indexing support to `Tensor` --- src/tensor.cc | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/tensor.cc b/src/tensor.cc index 7d03ac9..5203ac8 100644 --- a/src/tensor.cc +++ b/src/tensor.cc @@ -6,6 +6,10 @@ namespace infini::ops { +static Tensor::Index GetEffectiveIndex(Tensor::Index index, Tensor::Size size) { + return index < 0 ? index + size : index; +} + Tensor::Tensor(void* data, std::initializer_list shape, const DataType& dtype, const Device& device, std::initializer_list strides) @@ -13,11 +17,12 @@ Tensor::Tensor(void* data, std::initializer_list shape, decltype(strides_){strides}} {} Tensor Tensor::operator[](const Index& index) const { - return {reinterpret_cast( - reinterpret_cast(data_) + - index * strides_[0] * element_size()), - Shape{shape_.cbegin() + 1, shape_.cend()}, dtype_, device_, - Strides{strides_.cbegin() + 1, strides_.cend()}}; + return { + reinterpret_cast( + reinterpret_cast(data_) + + GetEffectiveIndex(index, shape_[0]) * strides_[0] * element_size()), + Shape{shape_.cbegin() + 1, shape_.cend()}, dtype_, device_, + Strides{strides_.cbegin() + 1, strides_.cend()}}; } void*& Tensor::data() { return data_; } @@ -32,10 +37,12 @@ const Device& Tensor::device() const { return device_; } const Tensor::Strides& Tensor::strides() const { return strides_; } -Tensor::Size Tensor::size(const Index& index) const { return shape_[index]; } +Tensor::Size Tensor::size(const Index& index) const { + return shape_[GetEffectiveIndex(index, shape_.size())]; +} Tensor::Stride Tensor::stride(const Index& index) const { - return strides_[index]; + return strides_[GetEffectiveIndex(index, strides_.size())]; } Tensor::Size Tensor::ndim() const { return shape_.size(); } From f936f7182ef9628ac321248399738c9fcefa8f4f Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 26 Feb 2026 10:40:56 +0800 Subject: [PATCH 46/93] feat: add batched GEMM support --- src/base/gemm.h | 24 ++++++++++++++++++------ src/cuda/gemm/blas.h | 38 ++++++++++++++++++++++++-------------- src/metax/gemm/mcblas.h | 4 ++-- src/nvidia/gemm/cublas.h | 4 ++-- 4 files changed, 46 insertions(+), 24 deletions(-) diff --git a/src/base/gemm.h b/src/base/gemm.h index c4965cf..2918312 100644 --- a/src/base/gemm.h +++ b/src/base/gemm.h @@ -17,18 +17,22 @@ class Gemm : public Operator { beta_{beta.value_or(1.0)}, trans_a_{static_cast(trans_a.value_or(false))}, trans_b_{static_cast(trans_b.value_or(false))}, - m_{c.size(0)}, - n_{c.size(1)}, - k_{trans_a_ ? a.size(0) : a.size(1)}, + m_{c.size(-2)}, + n_{c.size(-1)}, + k_{trans_a_ ? a.size(-2) : a.size(-1)}, a_type_{a.dtype()}, b_type_{b.dtype()}, c_type_{c.dtype()}, a_strides_{a.strides()}, b_strides_{b.strides()}, c_strides_{c.strides()}, - lda_{std::max(a_strides_[0], a_strides_[1])}, - ldb_{std::max(b_strides_[0], b_strides_[1])}, - ldc_{std::max(c_strides_[0], c_strides_[1])} { + lda_{std::max(a.stride(-2), a.stride(-1))}, + ldb_{std::max(b.stride(-2), b.stride(-1))}, + ldc_{std::max(c.stride(-2), c.stride(-1))}, + batch_count_{c.strides().size() > 2 ? c.size(-3) : 1}, + batch_stride_a_{a.strides().size() > 2 ? a.stride(-3) : 0}, + batch_stride_b_{b.strides().size() > 2 ? b.stride(-3) : 0}, + batch_stride_c_{c.strides().size() > 2 ? c.stride(-3) : 0} { // TODO: Check constraints. } @@ -84,6 +88,14 @@ class Gemm : public Operator { Tensor::Stride ldb_{0}; Tensor::Stride ldc_{0}; + + Tensor::Size batch_count_{1}; + + Tensor::Stride batch_stride_a_{0}; + + Tensor::Stride batch_stride_b_{0}; + + Tensor::Stride batch_stride_c_{0}; }; } // namespace infini::ops diff --git a/src/cuda/gemm/blas.h b/src/cuda/gemm/blas.h index d9085c5..92072d7 100644 --- a/src/cuda/gemm/blas.h +++ b/src/cuda/gemm/blas.h @@ -14,7 +14,9 @@ class Blas : public Gemm { std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) : Gemm{a, b, alpha, beta, trans_a, trans_b, c}, - swapped_a_and_b_{c_strides_[1] == 1} { + a_is_col_major_{a.stride(-1) == 1}, + b_is_col_major_{b.stride(-1) == 1}, + swapped_a_and_b_{c.stride(-1) == 1} { Backend::blasCreate(&handle_); // TODO: Check constraints. } @@ -43,36 +45,44 @@ class Blas : public Gemm { auto op_a{GetOpA(trans_a_value, trans_b_value)}; auto op_b{GetOpB(trans_a_value, trans_b_value)}; - Backend::blasGemmEx( + Backend::blasGemmStridedBatchedEx( handle_, op_a, op_b, swapped_a_and_b_ ? n_ : m_, swapped_a_and_b_ ? m_ : n_, k_, &alpha_value, swapped_a_and_b_ ? b.data() : a.data(), Backend::R_32F, - swapped_a_and_b_ ? ldb_ : lda_, swapped_a_and_b_ ? a.data() : b.data(), - Backend::R_32F, swapped_a_and_b_ ? lda_ : ldb_, &beta_value, c.data(), - Backend::R_32F, ldc_, Backend::BLAS_COMPUTE_32F_FAST_TF32, - Backend::BLAS_GEMM_DEFAULT); + swapped_a_and_b_ ? ldb_ : lda_, + swapped_a_and_b_ ? batch_stride_b_ : batch_stride_a_, + swapped_a_and_b_ ? a.data() : b.data(), Backend::R_32F, + swapped_a_and_b_ ? lda_ : ldb_, + swapped_a_and_b_ ? batch_stride_a_ : batch_stride_b_, &beta_value, + c.data(), Backend::R_32F, ldc_, batch_stride_c_, batch_count_, + Backend::BLAS_COMPUTE_32F_FAST_TF32, Backend::BLAS_GEMM_DEFAULT); } private: auto GetOpA(int trans_a, int trans_b) const { if (swapped_a_and_b_) { - return ((b_strides_[1] == 1) == trans_b) ? Backend::BLAS_OP_T - : Backend::BLAS_OP_N; + return (b_is_col_major_ == trans_b) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; } - return ((a_strides_[1] == 1) != trans_a) ? Backend::BLAS_OP_T - : Backend::BLAS_OP_N; + return (a_is_col_major_ != trans_a) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; } auto GetOpB(int trans_a, int trans_b) const { if (swapped_a_and_b_) { - return ((a_strides_[1] == 1) == trans_a) ? Backend::BLAS_OP_T - : Backend::BLAS_OP_N; + return (a_is_col_major_ == trans_a) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; } - return ((b_strides_[1] == 1) != trans_b) ? Backend::BLAS_OP_T - : Backend::BLAS_OP_N; + return (b_is_col_major_ != trans_b) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; } + bool a_is_col_major_{false}; + + bool b_is_col_major_{false}; + bool swapped_a_and_b_{false}; + typename Backend::blasHandle_t handle_; }; diff --git a/src/metax/gemm/mcblas.h b/src/metax/gemm/mcblas.h index 659d5fc..10bef3b 100644 --- a/src/metax/gemm/mcblas.h +++ b/src/metax/gemm/mcblas.h @@ -26,8 +26,8 @@ struct MetaxBackend { static constexpr auto blasSetStream = mcblasSetStream; static constexpr auto blasDestroy = mcblasDestroy; - static constexpr auto blasGemmEx = [](auto&&... args) { - return mcblasGemmEx(std::forward(args)...); + static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) { + return mcblasGemmStridedBatchedEx(std::forward(args)...); }; }; diff --git a/src/nvidia/gemm/cublas.h b/src/nvidia/gemm/cublas.h index 03a585c..d4e4b78 100644 --- a/src/nvidia/gemm/cublas.h +++ b/src/nvidia/gemm/cublas.h @@ -26,8 +26,8 @@ struct NvidiaBackend { static constexpr auto blasSetStream = cublasSetStream; static constexpr auto blasDestroy = cublasDestroy; - static constexpr auto blasGemmEx = [](auto&&... args) { - return cublasGemmEx(std::forward(args)...); + static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) { + return cublasGemmStridedBatchedEx(std::forward(args)...); }; }; From d73a0b7d9533f5f97215891ade5027155bf14178 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 26 Feb 2026 10:41:49 +0800 Subject: [PATCH 47/93] fix: rename `swapped_a_and_b_` to `swap_a_and_b_` --- src/cuda/gemm/blas.h | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/cuda/gemm/blas.h b/src/cuda/gemm/blas.h index 92072d7..47030cf 100644 --- a/src/cuda/gemm/blas.h +++ b/src/cuda/gemm/blas.h @@ -16,7 +16,7 @@ class Blas : public Gemm { : Gemm{a, b, alpha, beta, trans_a, trans_b, c}, a_is_col_major_{a.stride(-1) == 1}, b_is_col_major_{b.stride(-1) == 1}, - swapped_a_and_b_{c.stride(-1) == 1} { + swap_a_and_b_{c.stride(-1) == 1} { Backend::blasCreate(&handle_); // TODO: Check constraints. } @@ -46,21 +46,20 @@ class Blas : public Gemm { auto op_b{GetOpB(trans_a_value, trans_b_value)}; Backend::blasGemmStridedBatchedEx( - handle_, op_a, op_b, swapped_a_and_b_ ? n_ : m_, - swapped_a_and_b_ ? m_ : n_, k_, &alpha_value, - swapped_a_and_b_ ? b.data() : a.data(), Backend::R_32F, - swapped_a_and_b_ ? ldb_ : lda_, - swapped_a_and_b_ ? batch_stride_b_ : batch_stride_a_, - swapped_a_and_b_ ? a.data() : b.data(), Backend::R_32F, - swapped_a_and_b_ ? lda_ : ldb_, - swapped_a_and_b_ ? batch_stride_a_ : batch_stride_b_, &beta_value, + handle_, op_a, op_b, swap_a_and_b_ ? n_ : m_, swap_a_and_b_ ? m_ : n_, + k_, &alpha_value, swap_a_and_b_ ? b.data() : a.data(), Backend::R_32F, + swap_a_and_b_ ? ldb_ : lda_, + swap_a_and_b_ ? batch_stride_b_ : batch_stride_a_, + swap_a_and_b_ ? a.data() : b.data(), Backend::R_32F, + swap_a_and_b_ ? lda_ : ldb_, + swap_a_and_b_ ? batch_stride_a_ : batch_stride_b_, &beta_value, c.data(), Backend::R_32F, ldc_, batch_stride_c_, batch_count_, Backend::BLAS_COMPUTE_32F_FAST_TF32, Backend::BLAS_GEMM_DEFAULT); } private: auto GetOpA(int trans_a, int trans_b) const { - if (swapped_a_and_b_) { + if (swap_a_and_b_) { return (b_is_col_major_ == trans_b) ? Backend::BLAS_OP_T : Backend::BLAS_OP_N; } @@ -69,7 +68,7 @@ class Blas : public Gemm { } auto GetOpB(int trans_a, int trans_b) const { - if (swapped_a_and_b_) { + if (swap_a_and_b_) { return (a_is_col_major_ == trans_a) ? Backend::BLAS_OP_T : Backend::BLAS_OP_N; } @@ -81,7 +80,7 @@ class Blas : public Gemm { bool b_is_col_major_{false}; - bool swapped_a_and_b_{false}; + bool swap_a_and_b_{false}; typename Backend::blasHandle_t handle_; }; From ef07165ba5d2fb45a2d9c1ad4331705a916f77bb Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 26 Feb 2026 14:56:41 +0800 Subject: [PATCH 48/93] test: add basic testing infrastructure --- .gitignore | 217 ++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 4 + tests/conftest.py | 41 +++++++++ tests/utils.py | 26 ++++++ 4 files changed, 288 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/utils.py diff --git a/.gitignore b/.gitignore index 99bca7d..2effaff 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,220 @@ generated/ # debug information files *.dwo + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +# Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +# poetry.lock +# poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +# pdm.lock +# pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +# pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# Redis +*.rdb +*.aof +*.pid + +# RabbitMQ +mnesia/ +rabbitmq/ +rabbitmq-data/ + +# ActiveMQ +activemq-data/ + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +# .idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# Streamlit +.streamlit/secrets.toml diff --git a/requirements.txt b/requirements.txt index add49ef..811e92f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,5 @@ libclang +pytest +pytest-cov +ruff +torch diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..83b575a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,41 @@ +import hashlib +import random + +import pytest +import torch + + +def pytest_configure(): + torch.backends.fp32_precision = "tf32" + + +def pytest_collectstart(collector): + if isinstance(collector, pytest.Module): + _set_random_seed(_hash(collector.name)) + + +@pytest.fixture(scope="module", autouse=True) +def set_seed_per_module(request): + _set_random_seed(_hash(_module_path_from_request(request))) + + +@pytest.fixture(autouse=True) +def set_seed_per_test(request): + _set_random_seed(_hash(_test_case_path_from_request(request))) + + +def _set_random_seed(seed): + random.seed(seed) + torch.manual_seed(seed) + + +def _test_case_path_from_request(request): + return f"{_module_path_from_request(request)}::{request.node.name}" + + +def _module_path_from_request(request): + return f"{request.module.__name__.replace('.', '/')}.py" + + +def _hash(string): + return int(hashlib.sha256(string.encode("utf-8")).hexdigest(), 16) % 2**32 diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..3a6ca17 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,26 @@ +import contextlib + +import torch + + +def get_available_devices(): + devices = [] + + if torch.cuda.is_available(): + devices.append("cuda") + + if hasattr(torch, "mlu") and torch.mlu.is_available(): + devices.append("mlu") + + return tuple(devices) + + +with contextlib.suppress(ImportError, ModuleNotFoundError): + import torch_mlu # noqa: F401 + + +def empty_strided(shape, strides, *, dtype=None, device=None): + if strides is None: + return torch.empty(shape, dtype=dtype, device=device) + + return torch.empty_strided(shape, strides, dtype=dtype, device=device) From f42ca61ae3a9e284b97f6e6186ed580c576b8e9e Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 26 Feb 2026 14:58:02 +0800 Subject: [PATCH 49/93] test: add test cases for `ops.gemm` --- tests/test_gemm.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 tests/test_gemm.py diff --git a/tests/test_gemm.py b/tests/test_gemm.py new file mode 100644 index 0000000..41fabbf --- /dev/null +++ b/tests/test_gemm.py @@ -0,0 +1,75 @@ +import ops +import pytest +import torch + +from tests.utils import empty_strided, get_available_devices + + +@pytest.mark.parametrize("device", get_available_devices()) +# TODO: Add support for more data types. +@pytest.mark.parametrize("dtype, rtol, atol", ((torch.float32, 1e-3, 1e-3),)) +@pytest.mark.parametrize("trans_b", (False, True)) +@pytest.mark.parametrize("trans_a", (False, True)) +@pytest.mark.parametrize("beta", (-1, -0.5, 0, 0.5, 1)) +@pytest.mark.parametrize("alpha", (-1, -0.5, 0, 0.5, 1)) +@pytest.mark.parametrize( + "a_shape, b_shape, c_shape, a_strides, b_strides, c_strides", + ( + ((1, 2048), (2048, 2048), (1, 2048), None, None, None), + ((2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None), + ((1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)), + ((6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)), + ((4, 48, 64), (4, 64, 6), (4, 48, 6), None, None, None), + ), +) +def test_gemm( + a_shape, + b_shape, + c_shape, + a_strides, + b_strides, + c_strides, + alpha, + beta, + trans_a, + trans_b, + dtype, + device, + rtol, + atol, +): + a = empty_strided(a_shape, a_strides, dtype=dtype, device=device) + b = empty_strided(b_shape, b_strides, dtype=dtype, device=device) + + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + output = empty_strided(c_shape, c_strides, dtype=dtype, device=device) + expected = output.clone() + + a.normal_() + b.normal_() + + # TODO: Add keyword argument support. + ops.gemm(a, b, alpha, beta, trans_a, trans_b, output) + _torch_gemm( + a, b, alpha=alpha, beta=beta, trans_a=trans_a, trans_b=trans_b, c=expected + ) + + assert torch.allclose(output, expected, rtol=rtol, atol=atol) + + +def _torch_gemm(a, b, *, alpha=1.0, beta=1.0, trans_a=False, trans_b=False, c=None): + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + if a.ndim == 2: + return torch.addmm(c, a, b, beta=beta, alpha=alpha, out=c) + + return torch.baddbmm(c, a, b, beta=beta, alpha=alpha, out=c) From 5367b7a49909ea6636c5487a49d4dd1cd4c14f8d Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 26 Feb 2026 20:49:14 +0800 Subject: [PATCH 50/93] test: add `tests/__init__.py` --- tests/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/__init__.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 From 761912e748a00c1d60192da0c4070842f64c2796 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 26 Feb 2026 20:56:42 +0800 Subject: [PATCH 51/93] build: configure Python packaging --- pyproject.toml | 13 +++++++++++++ src/CMakeLists.txt | 8 ++++++++ tests/test_gemm.py | 4 ++-- 3 files changed, 23 insertions(+), 2 deletions(-) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d3a2811 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,13 @@ +[build-system] +requires = ["scikit-build-core", "pybind11", "libclang"] +build-backend = "scikit_build_core.build" + +[project] +name = "InfiniOps" +version = "0.1.0" + +[tool.scikit-build.wheel] +install-dir = "infini" + +[tool.scikit-build.cmake.define] +GENERATE_PYTHON_BINDINGS = "ON" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 620c506..16059c7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -93,4 +93,12 @@ if(GENERATE_PYTHON_BINDINGS) target_include_directories(ops PRIVATE ${PROJECT_SOURCE_DIR}) target_link_libraries(ops PRIVATE infiniops) + + set_target_properties(infiniops PROPERTIES INSTALL_RPATH "$ORIGIN") + set_target_properties(ops PROPERTIES INSTALL_RPATH "$ORIGIN") + + install(TARGETS infiniops ops DESTINATION .) + + file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/__init__.py" "") + install(FILES "${CMAKE_CURRENT_BINARY_DIR}/__init__.py" DESTINATION .) endif() diff --git a/tests/test_gemm.py b/tests/test_gemm.py index 41fabbf..a169dcb 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -1,4 +1,4 @@ -import ops +import infini.ops import pytest import torch @@ -54,7 +54,7 @@ def test_gemm( b.normal_() # TODO: Add keyword argument support. - ops.gemm(a, b, alpha, beta, trans_a, trans_b, output) + infini.ops.gemm(a, b, alpha, beta, trans_a, trans_b, output) _torch_gemm( a, b, alpha=alpha, beta=beta, trans_a=trans_a, trans_b=trans_b, c=expected ) From 03c372b8e44ffd3e28a65b7035ac778a90e55775 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 26 Feb 2026 21:08:55 +0800 Subject: [PATCH 52/93] build: move dependencies to `pyproject.toml` --- pyproject.toml | 3 +++ requirements.txt | 5 ----- 2 files changed, 3 insertions(+), 5 deletions(-) delete mode 100644 requirements.txt diff --git a/pyproject.toml b/pyproject.toml index d3a2811..3f50a62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,9 @@ build-backend = "scikit_build_core.build" name = "InfiniOps" version = "0.1.0" +[project.optional-dependencies] +dev = ["pytest", "pytest-cov", "ruff", "torch"] + [tool.scikit-build.wheel] install-dir = "infini" diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 811e92f..0000000 --- a/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -libclang -pytest -pytest-cov -ruff -torch From 1af292c98bdcdeb34f083ee86032b3550e9b6681 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 26 Feb 2026 22:28:30 +0800 Subject: [PATCH 53/93] build: add support for automatically detecting available devices --- CMakeLists.txt | 21 +++++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 22 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index cc671b9..a35cbd8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,8 +9,29 @@ option(WITH_CPU "Enable CPU backend" OFF) option(WITH_NVIDIA "Enable CUDA backend" OFF) option(WITH_METAX "Enable MetaX backend" OFF) +option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF) +if(AUTO_DETECT_DEVICES) + message(STATUS "Auto-detecting available devices...") + + set(WITH_CPU ON) + + include(CheckLanguage) + check_language(CUDA) + + if(CMAKE_CUDA_COMPILER) + set(WITH_NVIDIA ON) + message(STATUS "Auto-detected NVIDIA environment.") + endif() + + # TODO: Please test and uncomment/update the auto-detection for MetaX. + # if(DEFINED ENV{MACA_PATH}) + # set(WITH_METAX ON) + # message(STATUS "Auto-detected MetaX environment.") + # endif() +endif() + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) if(WITH_NVIDIA) diff --git a/pyproject.toml b/pyproject.toml index 3f50a62..96a3d61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,4 +13,5 @@ dev = ["pytest", "pytest-cov", "ruff", "torch"] install-dir = "infini" [tool.scikit-build.cmake.define] +AUTO_DETECT_DEVICES = "ON" GENERATE_PYTHON_BINDINGS = "ON" From c96ad60666fc45d0ce265fdb51b264640689338f Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 27 Feb 2026 03:18:36 +0000 Subject: [PATCH 54/93] feat: auto-detect system include paths in `scripts/generate_wrappers.py` --- scripts/generate_wrappers.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 5755df1..c0a7a31 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -1,6 +1,7 @@ import argparse import json import pathlib +import subprocess import textwrap import clang.cindex @@ -23,8 +24,24 @@ class _OperatorExtractor: def __call__(self, op_name): + def _get_system_include_flags(): + system_include_flags = [] + + for line in subprocess.getoutput( + "clang++ -E -x c++ -v /dev/null" + ).splitlines(): + if not line.startswith(" "): + continue + + system_include_flags.append("-isystem") + system_include_flags.append(line.strip()) + + return system_include_flags + + system_include_flags = _get_system_include_flags() + index = clang.cindex.Index.create() - args = ("-std=c++17", "-x", "c++", "-I", "src") + args = ("-std=c++17", "-x", "c++", "-I", "src") + tuple(system_include_flags) translation_unit = index.parse(f"src/base/{op_name.lower()}.h", args=args) nodes = tuple(type(self)._find(translation_unit.cursor, op_name)) From 60b47f156279d74025d58936940656a6f7ea5a91 Mon Sep 17 00:00:00 2001 From: Ziminli <70735843+Ziminli@users.noreply.github.com> Date: Fri, 27 Feb 2026 19:29:59 +0800 Subject: [PATCH 55/93] feat: add the implementation of `Add` operator on CPU, NVIDIA, and MetaX (#4) * feat: add the `Add` operator with related common utils - Add the CPU and CUDA (NVIDIA, MetaX) implementation of `Add` - Add common CUDA-compatiable utils under `common/cuda/` - Add common host utils in `common/generic_utils.h` - Modify `src/CMakeLists.txt` to correctly interpret device syntax in related files and enable OpenMPI requirement - Add some `Tensor` methods * feat: support fp16 and bf16 on NVIDIA and MetaX and partially support it on CPU * refactor: move pybind11 helpers to `utils.h` * fix: wrap backends in operator-specific namespaces * fix: wrap `cudaMalloc` in a lambda to fix function resolution * build: set `LANGUAGE CUDA` for pybind11 sources when `WITH_NVIDIA` is enabled * refactor: change the variable names for the `Add` operator * test: add test cases for `infini.ops.add` * test: add `randn_strided` helper and use in `test_add.py` and `test_gemm.py` * fix: update variable name change to the CPU implementation of `Add` * fix: format `Add` and `Gemm` kernel files --------- Co-authored-by: Jiacheng Huang --- scripts/generate_wrappers.py | 92 ++++++++++++--------- src/CMakeLists.txt | 15 +++- src/base/add.h | 71 ++++++++++++++++ src/common/cuda/kernel_commons.h | 25 ++++++ src/common/generic_utils.h | 26 ++++++ src/cpu/add/add.h | 55 +++++++++++++ src/cuda/add/kernel.h | 136 +++++++++++++++++++++++++++++++ src/data_type.h | 30 ++++++- src/metax/add/kernel.h | 38 +++++++++ src/metax/gemm/mcblas.h | 15 +++- src/nvidia/add/kernel.h | 41 ++++++++++ src/nvidia/gemm/cublas.h | 15 +++- src/tensor.cc | 48 ++++++++++- src/tensor.h | 12 ++- tests/test_add.py | 45 ++++++++++ tests/test_gemm.py | 9 +- tests/utils.py | 10 +++ 17 files changed, 625 insertions(+), 58 deletions(-) create mode 100644 src/base/add.h create mode 100644 src/common/cuda/kernel_commons.h create mode 100644 src/common/generic_utils.h create mode 100644 src/cpu/add/add.h create mode 100644 src/cuda/add/kernel.h create mode 100644 src/metax/add/kernel.h create mode 100644 src/nvidia/add/kernel.h create mode 100644 tests/test_add.py diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index c0a7a31..2a18752 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -21,6 +21,56 @@ _INDENTATION = " " +_UTILS_H_CONTENT = """#ifndef INFINI_OPS_BINDINGS_UTILS_H_ +#define INFINI_OPS_BINDINGS_UTILS_H_ + +#include +#include + +namespace infini::ops { + +inline DataType DataTypeFromString(const std::string& name) { + return kStringToDataType.at(name); +} + +inline Device::Type DeviceTypeFromString(const std::string& name) { + static const std::unordered_map kTorchNameToTypes{ + {"cpu", Device::Type::kCpu}, +#ifdef WITH_NVIDIA + {"cuda", Device::Type::kNvidia}, +#endif +#ifdef WITH_METAX + {"cuda", Device::Type::kMetax}, +#endif +#ifdef WITH_ILUVATAR + {"cuda", Device::Type::kIluvatar}, +#endif +#ifdef WITH_KUNLUN + {"cuda", Device::Type::kKunlun}, +#endif +#ifdef WITH_HYGON + {"cuda", Device::Type::kHygon}, +#endif +#ifdef WITH_QY + {"cuda", Device::Type::kQy}, +#endif + {"mlu", Device::Type::kCambricon}, {"npu", Device::Type::kAscend}, + {"musa", Device::Type::kMoore}}; + + auto it{kTorchNameToTypes.find(name)}; + + if (it != kTorchNameToTypes.cend()) { + return it->second; + } + + return Device::TypeFromString(name); +} + +} // namespace infini::ops + +#endif +""" + class _OperatorExtractor: def __call__(self, op_name): @@ -132,51 +182,13 @@ def _generate_call(op_name, call, method=True): #include #include -#include - #include "base/{op_name.lower()}.h" +#include "utils.h" namespace py = pybind11; namespace infini::ops {{ -inline DataType DataTypeFromString(const std::string& name) {{ - return kStringToDataType.at(name); -}} - -inline Device::Type DeviceTypeFromString(const std::string& name) {{ - static const std::unordered_map kTorchNameToTypes{{ - {{"cpu", Device::Type::kCpu}}, -#ifdef WITH_NVIDIA - {{"cuda", Device::Type::kNvidia}}, -#endif -#ifdef WITH_METAX - {{"cuda", Device::Type::kMetax}}, -#endif -#ifdef WITH_ILUVATAR - {{"cuda", Device::Type::kIluvatar}}, -#endif -#ifdef WITH_KUNLUN - {{"cuda", Device::Type::kKunlun}}, -#endif -#ifdef WITH_HYGON - {{"cuda", Device::Type::kHygon}}, -#endif -#ifdef WITH_QY - {{"cuda", Device::Type::kQy}}, -#endif - {{"mlu", Device::Type::kCambricon}}, {{"npu", Device::Type::kAscend}}, - {{"musa", Device::Type::kMoore}}}}; - - auto it{{kTorchNameToTypes.find(name)}}; - - if (it != kTorchNameToTypes.cend()) {{ - return it->second; - }} - - return Device::TypeFromString(name); -}} - void Bind{op_name}(py::module& m) {{ using Self = {op_name}; @@ -415,6 +427,8 @@ def _get_all_ops(devices): header_paths = [] bind_func_names = [] + (_BINDINGS_DIR / "utils.h").write_text(_UTILS_H_CONTENT) + for op_name, impl_paths in ops.items(): extractor = _OperatorExtractor() operator = extractor(op_name) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 16059c7..2f1f5cc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -16,9 +16,8 @@ if(WITH_CPU) target_compile_definitions(infiniops PUBLIC WITH_CPU=1) - # Reserve for OpenMP. - # find_package(OpenMP REQUIRED) - # target_link_libraries(infiniops PRIVATE OpenMP::OpenMP_CXX) + find_package(OpenMP REQUIRED) + target_link_libraries(infiniops PRIVATE OpenMP::OpenMP_CXX) list(APPEND DEVICE_LIST "cpu") endif() @@ -59,6 +58,7 @@ if(WITH_METAX) set_source_files_properties(${METAX_SOURCES} PROPERTIES LANGUAGE CXX) target_compile_definitions(infiniops PRIVATE WITH_METAX=1) + target_compile_options(infiniops PUBLIC "-x" "maca") target_sources(infiniops PRIVATE ${METAX_SOURCES}) target_include_directories(infiniops PUBLIC "${MACA_PATH}/include") @@ -86,10 +86,17 @@ if(GENERATE_PYTHON_BINDINGS) message(STATUS "Generating wrappers - done") endif() + set(PYBIND11_SOURCES "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc") + + # TODO: There might be a better solution. + if(WITH_NVIDIA) + set_source_files_properties(${PYBIND11_SOURCES} PROPERTIES LANGUAGE CUDA) + endif() + find_package(Python COMPONENTS Interpreter Development) find_package(pybind11 CONFIG) - pybind11_add_module(ops "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc") + pybind11_add_module(ops ${PYBIND11_SOURCES}) target_include_directories(ops PRIVATE ${PROJECT_SOURCE_DIR}) target_link_libraries(ops PRIVATE infiniops) diff --git a/src/base/add.h b/src/base/add.h new file mode 100644 index 0000000..a2da9ef --- /dev/null +++ b/src/base/add.h @@ -0,0 +1,71 @@ +#ifndef INFINI_OPS_BASE_ADD_H_ +#define INFINI_OPS_BASE_ADD_H_ + +#include + +#include "operator.h" + +namespace infini::ops { + +class Add : public Operator { + public: + Add(const Tensor input, const Tensor other, Tensor out) + : ndim_{out.ndim()}, + output_size_{out.numel()}, + input_type_{input.dtype()}, + other_type_{other.dtype()}, + out_type_{out.dtype()}, + input_shape_{input.shape()}, + other_shape_{other.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + other_strides_{other.strides()}, + out_strides_{out.strides()}, + is_input_contiguous_{input.IsContiguous()}, + is_other_contiguous_{other.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(!out.HasBroadcastDim() && + "The output of `Add` should NOT have broadcasted dim!"); + // TODO(lzm): support mix-precision later using the generic elementwise + // framework. + assert(input_type_ == other_type_ && other_type_ == out_type_ && + "Operator `Add` requires all input and output Tensors to have the " + "same dtype"); + } + + virtual void operator()(void* stream, const Tensor input, const Tensor other, + Tensor out) const = 0; + + protected: + Tensor::Size ndim_{0}; + + Tensor::Size output_size_{0}; + + const DataType input_type_; + + const DataType other_type_; + + const DataType out_type_; + + Tensor::Shape input_shape_; + + Tensor::Shape other_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides other_strides_; + + Tensor::Strides out_strides_; + + bool is_input_contiguous_{false}; + + bool is_other_contiguous_{false}; + + bool is_out_contiguous_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/common/cuda/kernel_commons.h b/src/common/cuda/kernel_commons.h new file mode 100644 index 0000000..4d92e00 --- /dev/null +++ b/src/common/cuda/kernel_commons.h @@ -0,0 +1,25 @@ +#ifndef INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_ +#define INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_ + +#ifdef WITH_NVIDIA +#include +#elif WITH_METAX +#include +#endif + +namespace infini::ops { + +__forceinline__ __device__ __host__ size_t +indexToOffset(size_t flat_index, size_t ndim, const size_t *shape, + const ptrdiff_t *strides) { + size_t res = 0; + for (size_t i = ndim; i-- > 0;) { + res += (flat_index % shape[i]) * strides[i]; + flat_index /= shape[i]; + } + return res; +} + +} // namespace infini::ops + +#endif diff --git a/src/common/generic_utils.h b/src/common/generic_utils.h new file mode 100644 index 0000000..6c82f49 --- /dev/null +++ b/src/common/generic_utils.h @@ -0,0 +1,26 @@ +#ifndef INFINI_OPS_COMMON_GENERIC_UTILS_H_ +#define INFINI_OPS_COMMON_GENERIC_UTILS_H_ + +#include + +namespace infini::ops::utils { + +std::size_t indexToOffset(std::size_t flat_index, std::size_t ndim, + const std::size_t* shape, + const std::ptrdiff_t* strides) { + std::size_t res = 0; + for (std::size_t i = ndim; i-- > 0;) { + res += (flat_index % shape[i]) * strides[i]; + flat_index /= shape[i]; + } + return res; +} + +template +constexpr auto CeilDiv(const Tx& x, const Ty& y) { + return (x + y - 1) / y; +} + +} // namespace infini::ops::utils + +#endif diff --git a/src/cpu/add/add.h b/src/cpu/add/add.h new file mode 100644 index 0000000..d9a456a --- /dev/null +++ b/src/cpu/add/add.h @@ -0,0 +1,55 @@ +#ifndef INFINI_OPS_CPU_ADD_ADD_H_ +#define INFINI_OPS_CPU_ADD_ADD_H_ + +#include + +#include "base/add.h" +#include "common/generic_utils.h" + +namespace infini::ops { + +template <> +class Operator : public Add { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Add{input, other, out} { + // TODO: Check constraints. + } + + void operator()(void* stream, const Tensor input, const Tensor other, + Tensor out) const override { + DispatchFunc>( + out_type_, [&]() { compute(stream, input, other, out); }, + "Operator::operator()"); + } + + private: + template + void compute(void* stream, const Tensor input, const Tensor other, + Tensor out) const { + const auto* input_ptr = static_cast(input.data()); + const auto* other_ptr = static_cast(other.data()); + auto* out_ptr = static_cast(out.data()); + + auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, + const auto* strides) { + return is_contig ? i : utils::indexToOffset(i, ndim_, shape, strides); + }; + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size_; ++i) { + auto input_idx = get_idx(i, is_input_contiguous_, input_shape_.data(), + input_strides_.data()); + auto other_idx = get_idx(i, is_other_contiguous_, other_shape_.data(), + other_strides_.data()); + auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), + out_strides_.data()); + + out_ptr[out_idx] = input_ptr[input_idx] + other_ptr[other_idx]; + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/add/kernel.h b/src/cuda/add/kernel.h new file mode 100644 index 0000000..664bcee --- /dev/null +++ b/src/cuda/add/kernel.h @@ -0,0 +1,136 @@ +#ifndef INFINI_OPS_CUDA_ADD_KERNEL_H_ +#define INFINI_OPS_CUDA_ADD_KERNEL_H_ + +#include + +#include "base/add.h" +#include "common/cuda/kernel_commons.h" +#include "common/generic_utils.h" + +namespace infini::ops { + +typedef struct AddOp { + public: + static constexpr std::size_t num_inputs = 2; + template + __device__ __forceinline__ T operator()(const T& input, + const T& other) const { + if constexpr (std::is_same_v) { + return __hadd2(input, other); + } else if constexpr (std::is_same_v || + std::is_same_v>) { + return __hadd(input, other); + } else if constexpr (std::is_same_v) { + return __fadd_rn(input, other); + } else { + return input + other; + } + } +} AddOp; + +template +__global__ void AddKernel( + T* out, const T* input, const T* other, const Tensor::Size* out_shape, + const Tensor::Size* input_shape, const Tensor::Size* other_shape, + const Tensor::Stride* out_strides, const Tensor::Stride* input_strides, + const Tensor::Stride* other_strides, size_t output_size, size_t ndim, + size_t offset, bool out_contiguous, bool input_contiguous, + bool other_contiguous) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + + if (idx < output_size) { + Tensor::Size out_idx = + out_contiguous ? idx : indexToOffset(idx, ndim, out_shape, out_strides); + Tensor::Size input_idx = + input_contiguous ? idx + : indexToOffset(idx, ndim, input_shape, input_strides); + Tensor::Size other_idx = + other_contiguous ? idx + : indexToOffset(idx, ndim, other_shape, other_strides); + + out[out_idx] = AddOp{}(input[input_idx], other[other_idx]); + } +} + +template +class CudaAdd : public Add { + public: + CudaAdd(const Tensor input, const Tensor other, Tensor out) + : Add{input, other, out} { + size_t shape_size = ndim_ * sizeof(*d_input_shape_); + size_t strides_size = ndim_ * sizeof(*d_input_strides_); + + Backend::malloc((void**)&d_input_shape_, shape_size); + Backend::malloc((void**)&d_other_shape_, shape_size); + Backend::malloc((void**)&d_out_shape_, shape_size); + Backend::malloc((void**)&d_input_strides_, strides_size); + Backend::malloc((void**)&d_other_strides_, strides_size); + Backend::malloc((void**)&d_out_strides_, strides_size); + + Backend::memcpy(d_input_shape_, input_shape_.data(), shape_size, + Backend::memcpyH2D); + Backend::memcpy(d_other_shape_, other_shape_.data(), shape_size, + Backend::memcpyH2D); + Backend::memcpy(d_out_shape_, out_shape_.data(), shape_size, + Backend::memcpyH2D); + Backend::memcpy(d_input_strides_, input_strides_.data(), strides_size, + Backend::memcpyH2D); + Backend::memcpy(d_other_strides_, other_strides_.data(), strides_size, + Backend::memcpyH2D); + Backend::memcpy(d_out_strides_, out_strides_.data(), strides_size, + Backend::memcpyH2D); + } + + ~CudaAdd() { + Backend::free(d_input_shape_); + Backend::free(d_other_shape_); + Backend::free(d_out_shape_); + Backend::free(d_input_strides_); + Backend::free(d_other_strides_); + Backend::free(d_out_strides_); + } + + void operator()(void* stream, const Tensor input, const Tensor other, + Tensor out) const override { + DispatchFunc( + out_type_, + [&]() { + // TODO(lzm): currently hard-code block_size to be 256. + dim3 blockDims( + std::min(static_cast(256), output_size_)); + dim3 gridDims(utils::CeilDiv(output_size_, blockDims.x)); + size_t step = gridDims.x * blockDims.x; + + T* d_out = reinterpret_cast(out.data()); + const T* d_input = reinterpret_cast(input.data()); + const T* d_other = reinterpret_cast(other.data()); + + for (size_t i = 0; i < output_size_; i += step) { + AddKernel<<(stream)>>>( + d_out, d_input, d_other, d_out_shape_, d_input_shape_, + d_other_shape_, d_out_strides_, d_input_strides_, + d_other_strides_, output_size_, ndim_, i, is_out_contiguous_, + is_input_contiguous_, is_other_contiguous_); + } + }, + "CudaAdd::operator()"); + } + + private: + Tensor::Size* d_input_shape_{nullptr}; + + Tensor::Size* d_other_shape_{nullptr}; + + Tensor::Size* d_out_shape_{nullptr}; + + Tensor::Stride* d_input_strides_{nullptr}; + + Tensor::Stride* d_other_strides_{nullptr}; + + Tensor::Stride* d_out_strides_{nullptr}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/data_type.h b/src/data_type.h index b9e2fea..567a076 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -4,6 +4,14 @@ #include #include +#ifdef WITH_NVIDIA +#include +#include +#elif WITH_METAX +#include +#include +#endif + #include "common/constexpr_map.h" #include "common/traits.h" @@ -102,7 +110,27 @@ DEFINE_DATA_TYPE_MAPPING(kUInt64, uint64_t) DEFINE_DATA_TYPE_MAPPING(kInt64, int64_t) DEFINE_DATA_TYPE_MAPPING(kFloat32, float) DEFINE_DATA_TYPE_MAPPING(kFloat64, double) -// TODO(lzm): Support fp16 and bf16. + +#ifdef WITH_NVIDIA +DEFINE_DATA_TYPE_MAPPING(kFloat16, half) +DEFINE_DATA_TYPE_MAPPING(kBFloat16, __nv_bfloat16) +#elif WITH_METAX +DEFINE_DATA_TYPE_MAPPING(kFloat16, __half) +DEFINE_DATA_TYPE_MAPPING(kBFloat16, __maca_bfloat16) +#else +// TODO(lzm): currently there's an ambiguity of uint16_t mapping to both kUInt16 +// and kFloat16/kBFloat16 for CPU. When CPU custom bfloat16/float16 types are +// defined, this should be replaced. +template <> +struct TypeMap { + using type = uint16_t; +}; +template <> +struct TypeMap { + using type = uint16_t; +}; +#endif +#undef DEFINE_DATA_TYPE_MAPPING // Defines the common categories of data types using List. using FloatTypes = List; diff --git a/src/metax/add/kernel.h b/src/metax/add/kernel.h new file mode 100644 index 0000000..ce9ec01 --- /dev/null +++ b/src/metax/add/kernel.h @@ -0,0 +1,38 @@ +#ifndef INFINI_OPS_METAX_ADD_KERNEL_H_ +#define INFINI_OPS_METAX_ADD_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/add/kernel.h" + +namespace infini::ops { + +namespace add { + +struct MetaxBackend { + using stream_t = mcStream_t; + + static constexpr auto malloc = mcMalloc; + + static constexpr auto memcpy = mcMemcpy; + + static constexpr auto free = mcFree; + + static constexpr auto memcpyH2D = mcMemcpyHostToDevice; +}; + +} // namespace add + +template <> +class Operator : public CudaAdd { + public: + using CudaAdd::CudaAdd; +}; + +} // namespace infini::ops + +#endif diff --git a/src/metax/gemm/mcblas.h b/src/metax/gemm/mcblas.h index 10bef3b..1fd6f2f 100644 --- a/src/metax/gemm/mcblas.h +++ b/src/metax/gemm/mcblas.h @@ -11,19 +11,28 @@ namespace infini::ops { +namespace gemm { + struct MetaxBackend { using blasHandle_t = mcblasHandle_t; + using stream_t = mcStream_t; static constexpr auto BLAS_OP_N = MCBLAS_OP_N; + static constexpr auto BLAS_OP_T = MCBLAS_OP_T; + static constexpr auto R_32F = MACA_R_32F; + static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = MCBLAS_COMPUTE_32F_FAST_TF32; + static constexpr auto BLAS_GEMM_DEFAULT = MCBLAS_GEMM_DEFAULT; static constexpr auto blasCreate = mcblasCreate; + static constexpr auto blasSetStream = mcblasSetStream; + static constexpr auto blasDestroy = mcblasDestroy; static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) { @@ -31,10 +40,12 @@ struct MetaxBackend { }; }; +} // namespace gemm + template <> -class Operator : public Blas { +class Operator : public Blas { public: - using Blas::Blas; + using Blas::Blas; }; } // namespace infini::ops diff --git a/src/nvidia/add/kernel.h b/src/nvidia/add/kernel.h new file mode 100644 index 0000000..7e6c3e5 --- /dev/null +++ b/src/nvidia/add/kernel.h @@ -0,0 +1,41 @@ +#ifndef INFINI_OPS_NVIDIA_ADD_KERNEL_H_ +#define INFINI_OPS_NVIDIA_ADD_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/add/kernel.h" + +namespace infini::ops { + +namespace add { + +struct NvidiaBackend { + using stream_t = cudaStream_t; + + static constexpr auto malloc = [](auto&&... args) { + return cudaMalloc(std::forward(args)...); + }; + + static constexpr auto memcpy = cudaMemcpy; + + static constexpr auto free = cudaFree; + + static constexpr auto memcpyH2D = cudaMemcpyHostToDevice; +}; + +} // namespace add + +template <> +class Operator + : public CudaAdd { + public: + using CudaAdd::CudaAdd; +}; + +} // namespace infini::ops + +#endif diff --git a/src/nvidia/gemm/cublas.h b/src/nvidia/gemm/cublas.h index d4e4b78..16c1b7a 100644 --- a/src/nvidia/gemm/cublas.h +++ b/src/nvidia/gemm/cublas.h @@ -11,19 +11,28 @@ namespace infini::ops { +namespace gemm { + struct NvidiaBackend { using blasHandle_t = cublasHandle_t; + using stream_t = cudaStream_t; static constexpr auto BLAS_OP_N = CUBLAS_OP_N; + static constexpr auto BLAS_OP_T = CUBLAS_OP_T; + static constexpr auto R_32F = CUDA_R_32F; + static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = CUBLAS_COMPUTE_32F_FAST_TF32; + static constexpr auto BLAS_GEMM_DEFAULT = CUBLAS_GEMM_DEFAULT; static constexpr auto blasCreate = cublasCreate; + static constexpr auto blasSetStream = cublasSetStream; + static constexpr auto blasDestroy = cublasDestroy; static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) { @@ -31,10 +40,12 @@ struct NvidiaBackend { }; }; +} // namespace gemm + template <> -class Operator : public Blas { +class Operator : public Blas { public: - using Blas::Blas; + using Blas::Blas; }; } // namespace infini::ops diff --git a/src/tensor.cc b/src/tensor.cc index 5203ac8..8746c60 100644 --- a/src/tensor.cc +++ b/src/tensor.cc @@ -1,6 +1,8 @@ #include "tensor.h" +#include #include +#include #include "dispatcher.h" @@ -49,6 +51,12 @@ Tensor::Size Tensor::ndim() const { return shape_.size(); } Tensor::Size Tensor::element_size() const { return kDataTypeToSize.at(dtype_); } +Tensor::Size Tensor::numel() const { + return std::accumulate(shape_.begin(), shape_.end(), + static_cast(1), + [](Tensor::Size a, Tensor::Size b) { return a * b; }); +} + Tensor Tensor::T() const { return {data_, {shape_[1], shape_[0]}, @@ -63,6 +71,25 @@ std::string Tensor::ToString() const { device_.ToString() + "')"; } +bool Tensor::HasBroadcastDim() const { + return std::any_of(shape_.begin(), shape_.end(), + [&, i = 0](const auto&) mutable { + return shape_[i] != 1 && strides_[i++] == 0; + }); +} + +bool Tensor::IsContiguous() const { + if (ndim() == 0) { + return true; + } + + if (!IsMergeable(0, ndim() - 1)) { + return false; + } + + return stride(ndim() - 1) == 1; +} + const DataType Tensor::DefaultDataType() { return DataType::kFloat32; } Device Tensor::DefaultDevice() { return Device{Device::Type::kCpu}; } @@ -85,10 +112,10 @@ Tensor::Strides Tensor::DefaultStrides(const Shape& shape) { std::string Tensor::ToStringHelper() const { if (ndim() == 0) { - return DispatchFunc( + return DispatchFunc>( dtype_, [&]() { return std::to_string(*static_cast(data_)); }, - "ToStringHelper"); + "Tensor::ToStringHelper()"); } std::string result{"["}; @@ -103,4 +130,21 @@ std::string Tensor::ToStringHelper() const { return result; } +bool Tensor::IsMergeable(Tensor::Size dim_start, Tensor::Size dim_end) const { + if (dim_start == dim_end) { + return true; + } + + for (Tensor::Size i = dim_start; i < dim_end; ++i) { + if (size(i) == 1 && stride(i) == 0) { + return false; + } + if (stride(i) != size(i + 1) * stride(i + 1)) { + return false; + } + } + + return true; +} + } // namespace infini::ops diff --git a/src/tensor.h b/src/tensor.h index 0feb4c4..39d4f98 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -12,9 +12,9 @@ namespace infini::ops { class Tensor { public: - using Size = std::uint64_t; + using Size = std::size_t; - using Stride = std::int64_t; + using Stride = std::ptrdiff_t; using Index = Stride; @@ -89,10 +89,16 @@ class Tensor { Size element_size() const; + Size numel() const; + Tensor T() const; std::string ToString() const; + bool HasBroadcastDim() const; + + bool IsContiguous() const; + private: static const DataType DefaultDataType(); @@ -102,6 +108,8 @@ class Tensor { std::string ToStringHelper() const; + bool IsMergeable(Size dim_start, Size dim_end) const; + void* data_{nullptr}; Shape shape_; diff --git a/tests/test_add.py b/tests/test_add.py new file mode 100644 index 0000000..7a1351f --- /dev/null +++ b/tests/test_add.py @@ -0,0 +1,45 @@ +import infini.ops +import pytest +import torch + +from tests.utils import empty_strided, get_available_devices, randn_strided + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize( + "dtype, rtol, atol", + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-3, 1e-3), + ), +) +@pytest.mark.parametrize( + "shape, a_strides, b_strides, c_strides", + ( + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), + ((13, 4, 4), None, None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (0, 4, 1), None), + ((16, 5632), None, None, None), + ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), + ((13, 16, 2), (128, 4, 1), (0, 2, 1), (64, 4, 1)), + ((13, 16, 2), (128, 4, 1), (2, 0, 1), (64, 4, 1)), + ((4, 4, 5632), None, None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), + ), +) +def test_add(shape, a_strides, b_strides, c_strides, dtype, device, rtol, atol): + a = randn_strided(shape, a_strides, dtype=dtype, device=device) + b = randn_strided(shape, b_strides, dtype=dtype, device=device) + + output = empty_strided(shape, c_strides, dtype=dtype, device=device) + expected = output.clone() + + # TODO: Add keyword argument support. + infini.ops.add(a, b, output) + torch.add(a, b, out=expected) + + assert torch.allclose(output, expected, rtol=rtol, atol=atol) diff --git a/tests/test_gemm.py b/tests/test_gemm.py index a169dcb..da97f53 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import empty_strided, get_available_devices +from tests.utils import empty_strided, get_available_devices, randn_strided @pytest.mark.parametrize("device", get_available_devices()) @@ -38,8 +38,8 @@ def test_gemm( rtol, atol, ): - a = empty_strided(a_shape, a_strides, dtype=dtype, device=device) - b = empty_strided(b_shape, b_strides, dtype=dtype, device=device) + a = randn_strided(a_shape, a_strides, dtype=dtype, device=device) + b = randn_strided(b_shape, b_strides, dtype=dtype, device=device) if trans_a: a = a.transpose(-2, -1) @@ -50,9 +50,6 @@ def test_gemm( output = empty_strided(c_shape, c_strides, dtype=dtype, device=device) expected = output.clone() - a.normal_() - b.normal_() - # TODO: Add keyword argument support. infini.ops.gemm(a, b, alpha, beta, trans_a, trans_b, output) _torch_gemm( diff --git a/tests/utils.py b/tests/utils.py index 3a6ca17..c9b8006 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,3 +24,13 @@ def empty_strided(shape, strides, *, dtype=None, device=None): return torch.empty(shape, dtype=dtype, device=device) return torch.empty_strided(shape, strides, dtype=dtype, device=device) + + +def randn_strided(shape, strides, *, dtype=None, device=None): + output = empty_strided(shape, strides, dtype=dtype, device=device) + + output.as_strided( + (output.untyped_storage().size() // output.element_size(),), (1,) + ).normal_() + + return output From 5fcd6453a08720e3dd6c1bd0a2e19370d4d320a3 Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Sat, 28 Feb 2026 15:34:18 +0800 Subject: [PATCH 56/93] feat(gemm-iluvatar): add Iluvatar GEMM backend support (#3) * feat(gemm): add Iluvatar backend * feat(example): delete cudaSetDevice(0) in gemm example * fix(comment): delete useless comments * fix: Close AUTO_DETECT_DEVICES in pyproject.toml and fix import in examples/gemm.py * fix(comment): update comments for `BLAS_COMPUTE_32F_FAST_TF32` and `BLAS_GEMM_DEFAULT` in `src/iluvatar/gemm/cublas.h` * fix: format `src/iluvatar/gemm/cublas.h` * fix: format `src/nvidia/gemm/cublas.h` * refactor: `import infini.ops` instead of `from infini import ops` * build: auto-detect devices via `/dev` file globbing * build: turn on `AUTO_DETECT_DEVICES` in `pyproject.toml` * build: add support for automatically detecting Iluvatar devices * fix: wrap `IluvatarBackend` in `gemm` namespace * fix: prevent compilation failures for missing device specializations * chore: add `pytest-xdist` to `dev` dependencies --------- Co-authored-by: Jiacheng Huang --- CMakeLists.txt | 31 +++++++++++++++++---- examples/gemm.py | 4 +-- examples/gemm/gemm.cc | 3 ++ examples/runtime_api.h | 9 ++++++ pyproject.toml | 2 +- src/CMakeLists.txt | 23 +++++++++++++++ src/iluvatar/gemm/cublas.h | 57 ++++++++++++++++++++++++++++++++++++++ src/operator.h | 10 +++++-- src/tensor.cc | 2 +- src/tensor.h | 2 +- 10 files changed, 130 insertions(+), 13 deletions(-) create mode 100644 src/iluvatar/gemm/cublas.h diff --git a/CMakeLists.txt b/CMakeLists.txt index a35cbd8..053a37b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) # Options for backends. option(WITH_CPU "Enable CPU backend" OFF) option(WITH_NVIDIA "Enable CUDA backend" OFF) +option(WITH_ILUVATAR "Enable Iluvatar GPU backend" OFF) option(WITH_METAX "Enable MetaX backend" OFF) option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) @@ -17,14 +18,20 @@ if(AUTO_DETECT_DEVICES) set(WITH_CPU ON) - include(CheckLanguage) - check_language(CUDA) + file(GLOB NVIDIA_DEV_FILES "/dev/nvidia*") - if(CMAKE_CUDA_COMPILER) + if(NVIDIA_DEV_FILES) set(WITH_NVIDIA ON) message(STATUS "Auto-detected NVIDIA environment.") endif() + file(GLOB ILUVATAR_DEV_FILES "/dev/iluvatar*") + + if(ILUVATAR_DEV_FILES) + set(WITH_ILUVATAR ON) + message(STATUS "Auto-detected Iluvatar environment.") + endif() + # TODO: Please test and uncomment/update the auto-detection for MetaX. # if(DEFINED ENV{MACA_PATH}) # set(WITH_METAX ON) @@ -38,7 +45,17 @@ if(WITH_NVIDIA) add_compile_definitions(WITH_NVIDIA=1) enable_language(CUDA) find_package(CUDAToolkit REQUIRED) -elseif(WITH_METAX) +endif() + +if(WITH_ILUVATAR) + add_compile_definitions(WITH_ILUVATAR=1) + if(NOT WITH_NVIDIA) + enable_language(CUDA) + find_package(CUDAToolkit REQUIRED) + endif() +endif() + +if(WITH_METAX) add_compile_definitions(WITH_METAX=1) # Normally can be found at: `/opt/maca/`. @@ -53,8 +70,10 @@ elseif(WITH_METAX) find_library(MACA_RUNTIME_LIB NAMES mcruntime HINTS "${MACA_PATH}/lib" REQUIRED) find_library(MACA_DNN_LIB NAMES mcdnn HINTS "${MACA_PATH}/lib" REQUIRED) find_library(MACA_BLAS_LIB NAMES mcblas HINTS "${MACA_PATH}/lib" REQUIRED) -# If all other platforms are not enabled, CPU is enabled by default. -else() +endif() + +# If no GPU platform is enabled, CPU is enabled by default. +if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX) add_compile_definitions(WITH_CPU=1) endif() diff --git a/examples/gemm.py b/examples/gemm.py index 9c2d1a4..cd707c1 100644 --- a/examples/gemm.py +++ b/examples/gemm.py @@ -1,4 +1,4 @@ -import ops +import infini.ops import torch m, n, k = 2, 3, 4 @@ -7,7 +7,7 @@ y = torch.randn(k, n, device="cpu") z = torch.empty(m, n, device="cpu") -ops.gemm(x, y, z) +infini.ops.gemm(x, y, z) print(x) print(y) diff --git a/examples/gemm/gemm.cc b/examples/gemm/gemm.cc index 2bdbe6d..c611264 100644 --- a/examples/gemm/gemm.cc +++ b/examples/gemm/gemm.cc @@ -8,6 +8,9 @@ #if WITH_NVIDIA #include "nvidia/gemm/cublas.h" #endif +#if WITH_ILUVATAR +#include "iluvatar/gemm/cublas.h" +#endif #if WITH_METAX #include "metax/gemm/mcblas.h" #endif diff --git a/examples/runtime_api.h b/examples/runtime_api.h index 63ff95e..896af64 100644 --- a/examples/runtime_api.h +++ b/examples/runtime_api.h @@ -10,6 +10,15 @@ #define DEVICE_MEMCPY_HOST_TO_DEVICE cudaMemcpyHostToDevice #define DEVICE_MEMCPY_DEVICE_TO_HOST cudaMemcpyDeviceToHost #define DEFAULT_DEVICE_TYPE Device::Type::kNvidia +#elif WITH_ILUVATAR +#include +#define DEVICE_MALLOC cudaMalloc +#define DEVICE_FREE cudaFree +#define DEVICE_MEMCPY cudaMemcpy +#define DEVICE_MEMSET cudaMemset +#define DEVICE_MEMCPY_HOST_TO_DEVICE cudaMemcpyHostToDevice +#define DEVICE_MEMCPY_DEVICE_TO_HOST cudaMemcpyDeviceToHost +#define DEFAULT_DEVICE_TYPE Device::Type::kIluvatar #elif WITH_METAX #include #define DEVICE_MALLOC mcMalloc diff --git a/pyproject.toml b/pyproject.toml index 96a3d61..b5d2cdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "InfiniOps" version = "0.1.0" [project.optional-dependencies] -dev = ["pytest", "pytest-cov", "ruff", "torch"] +dev = ["pytest", "pytest-cov", "pytest-xdist", "ruff", "torch"] [tool.scikit-build.wheel] install-dir = "infini" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2f1f5cc..2ae2b70 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -45,6 +45,29 @@ if(WITH_NVIDIA) list(APPEND DEVICE_LIST "nvidia") endif() +if(WITH_ILUVATAR) + set(ILUVATAR_PATTERNS + "cuda/*.cc" + "cuda/*.cpp" + "cuda/*.cu" + "iluvatar/*.cc" + "iluvatar/*.cpp" + "iluvatar/*.cu" + ) + + file(GLOB_RECURSE ILUVATAR_SOURCES CONFIGURE_DEPENDS ${ILUVATAR_PATTERNS}) + + enable_language(CUDA) + + target_compile_definitions(infiniops PUBLIC WITH_ILUVATAR=1) + target_sources(infiniops PRIVATE ${ILUVATAR_SOURCES}) + + find_package(CUDAToolkit REQUIRED) + target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver) + + list(APPEND DEVICE_LIST "iluvatar") +endif() + if(WITH_METAX) set(METAX_PATTERNS "cuda/*.cc" diff --git a/src/iluvatar/gemm/cublas.h b/src/iluvatar/gemm/cublas.h new file mode 100644 index 0000000..969df69 --- /dev/null +++ b/src/iluvatar/gemm/cublas.h @@ -0,0 +1,57 @@ +#ifndef INFINI_OPS_ILUVATAR_GEMM_CUBLAS_H_ +#define INFINI_OPS_ILUVATAR_GEMM_CUBLAS_H_ + +#include + +// clang-format off +#include "cublas_v2.h" +// clang-format on + +#include "cuda/gemm/blas.h" + +namespace infini::ops { + +namespace gemm { + +struct IluvatarBackend { + using blasHandle_t = cublasHandle_t; + + using stream_t = cudaStream_t; + + static constexpr auto BLAS_OP_N = CUBLAS_OP_N; + + static constexpr auto BLAS_OP_T = CUBLAS_OP_T; + + static constexpr auto R_32F = CUDA_R_32F; + + // Iluvatar uses `cudaDataType` for `computeType`, so we need to use + // `CUDA_R_32F` instead of `CUBLAS_COMPUTE_32F_FAST_TF32`. + static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = CUDA_R_32F; + + // Iluvatar uses `CUBLAS_GEMM_DEFAULT_TENSOR_OP` instead of + // `CUBLAS_GEMM_DEFAULT`. + static constexpr auto BLAS_GEMM_DEFAULT = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + + static constexpr auto blasCreate = cublasCreate; + + static constexpr auto blasSetStream = cublasSetStream; + + static constexpr auto blasDestroy = cublasDestroy; + + static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) { + return cublasGemmStridedBatchedEx(std::forward(args)...); + }; +}; + +} // namespace gemm + +template <> +class Operator + : public Blas { + public: + using Blas::Blas; +}; + +} // namespace infini::ops + +#endif diff --git a/src/operator.h b/src/operator.h index 8377351..e3c7be8 100644 --- a/src/operator.h +++ b/src/operator.h @@ -3,6 +3,7 @@ #include #include +#include #include "dispatcher.h" #include "tensor.h" @@ -29,8 +30,13 @@ class Operator : public OperatorBase { DispatchFunc( tensor.device().type(), [&]() { - op_ptr = std::make_unique>( - tensor, std::forward(args)...); + if constexpr (std::is_constructible_v, + const Tensor&, Args...>) { + op_ptr = std::make_unique>( + tensor, std::forward(args)...); + } else { + assert("operator is not implemented for this device"); + } }, "Operator::make"); diff --git a/src/tensor.cc b/src/tensor.cc index 8746c60..fe6905d 100644 --- a/src/tensor.cc +++ b/src/tensor.cc @@ -29,7 +29,7 @@ Tensor Tensor::operator[](const Index& index) const { void*& Tensor::data() { return data_; } -const void* const& Tensor::data() const { return data_; } +const void* Tensor::data() const { return data_; } const Tensor::Shape& Tensor::shape() const { return shape_; } diff --git a/src/tensor.h b/src/tensor.h index 39d4f98..306be70 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -71,7 +71,7 @@ class Tensor { void*& data(); - const void* const& data() const; + const void* data() const; const DataType& dtype() const; From 4cc0f007537854b6a06b014a794867ba5c32724b Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 28 Feb 2026 11:05:32 +0000 Subject: [PATCH 57/93] test: centralize Act/Assert logic --- tests/conftest.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++- tests/utils.py | 17 ++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 83b575a..f4fd945 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,9 +5,14 @@ import torch -def pytest_configure(): +def pytest_configure(config): torch.backends.fp32_precision = "tf32" + config.addinivalue_line( + "markers", + "auto_act_and_assert: automatically perform Act and Assert phases using the return values", + ) + def pytest_collectstart(collector): if isinstance(collector, pytest.Module): @@ -29,6 +34,34 @@ def _set_random_seed(seed): torch.manual_seed(seed) +@pytest.hookimpl(tryfirst=True) +def pytest_pyfunc_call(pyfuncitem): + if pyfuncitem.get_closest_marker("auto_act_and_assert"): + func_kwargs = { + arg: pyfuncitem.funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames + } + + payload = pyfuncitem.obj(**func_kwargs) + + func = payload.func + ref = payload.ref + args = payload.args + kwargs = payload.kwargs + + ref_args = _clone(args) + ref_kwargs = _clone(kwargs) + + output = func(*args, **kwargs) + expected = ref(*ref_args, **ref_kwargs) + + rtol = payload.rtol + atol = payload.atol + + assert torch.allclose(output, expected, rtol=rtol, atol=atol) + + return True + + def _test_case_path_from_request(request): return f"{_module_path_from_request(request)}::{request.node.name}" @@ -39,3 +72,19 @@ def _module_path_from_request(request): def _hash(string): return int(hashlib.sha256(string.encode("utf-8")).hexdigest(), 16) % 2**32 + + +def _clone(obj): + if isinstance(obj, torch.Tensor): + return obj.clone() + + if isinstance(obj, tuple): + return tuple(_clone(a) for a in obj) + + if isinstance(obj, list): + return [_clone(a) for a in obj] + + if isinstance(obj, dict): + return {key: _clone(value) for key, value in obj.items()} + + return obj diff --git a/tests/utils.py b/tests/utils.py index c9b8006..50faec1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,25 @@ import contextlib +import dataclasses +from collections.abc import Callable import torch +@dataclasses.dataclass +class Payload: + func: Callable + + ref: Callable + + args: tuple + + kwargs: dict + + rtol: float = 1e-5 + + atol: float = 1e-8 + + def get_available_devices(): devices = [] From fd708940429eb882b8c217cb61ff85e92802859e Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 28 Feb 2026 11:10:14 +0000 Subject: [PATCH 58/93] test: use `pytest.mark.auto_act_and_assert` in `tests/test_add.py` and `tests/test_gemm.py` --- tests/test_add.py | 19 ++++++++++++------- tests/test_gemm.py | 25 ++++++++++++++++--------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/tests/test_add.py b/tests/test_add.py index 7a1351f..a3bf6c4 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -2,9 +2,10 @@ import pytest import torch -from tests.utils import empty_strided, get_available_devices, randn_strided +from tests.utils import Payload, empty_strided, get_available_devices, randn_strided +@pytest.mark.auto_act_and_assert @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize( "dtype, rtol, atol", @@ -34,12 +35,16 @@ def test_add(shape, a_strides, b_strides, c_strides, dtype, device, rtol, atol): a = randn_strided(shape, a_strides, dtype=dtype, device=device) b = randn_strided(shape, b_strides, dtype=dtype, device=device) + c = empty_strided(shape, c_strides, dtype=dtype, device=device) - output = empty_strided(shape, c_strides, dtype=dtype, device=device) - expected = output.clone() + return Payload(_add, _torch_add, (a, b, c), {}, rtol=rtol, atol=atol) - # TODO: Add keyword argument support. - infini.ops.add(a, b, output) - torch.add(a, b, out=expected) - assert torch.allclose(output, expected, rtol=rtol, atol=atol) +def _add(a, b, c): + infini.ops.add(a, b, c) + + return c + + +def _torch_add(a, b, c): + return torch.add(a, b, out=c) diff --git a/tests/test_gemm.py b/tests/test_gemm.py index da97f53..21879cc 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -2,9 +2,10 @@ import pytest import torch -from tests.utils import empty_strided, get_available_devices, randn_strided +from tests.utils import Payload, empty_strided, get_available_devices, randn_strided +@pytest.mark.auto_act_and_assert @pytest.mark.parametrize("device", get_available_devices()) # TODO: Add support for more data types. @pytest.mark.parametrize("dtype, rtol, atol", ((torch.float32, 1e-3, 1e-3),)) @@ -47,19 +48,25 @@ def test_gemm( if trans_b: b = b.transpose(-2, -1) - output = empty_strided(c_shape, c_strides, dtype=dtype, device=device) - expected = output.clone() + c = empty_strided(c_shape, c_strides, dtype=dtype, device=device) - # TODO: Add keyword argument support. - infini.ops.gemm(a, b, alpha, beta, trans_a, trans_b, output) - _torch_gemm( - a, b, alpha=alpha, beta=beta, trans_a=trans_a, trans_b=trans_b, c=expected + return Payload( + _gemm, + _torch_gemm, + (a, b, alpha, beta, trans_a, trans_b, c), + {}, + rtol=rtol, + atol=atol, ) - assert torch.allclose(output, expected, rtol=rtol, atol=atol) +def _gemm(a, b, alpha, beta, trans_a, trans_b, c): + infini.ops.gemm(a, b, alpha, beta, trans_a, trans_b, c) -def _torch_gemm(a, b, *, alpha=1.0, beta=1.0, trans_a=False, trans_b=False, c=None): + return c + + +def _torch_gemm(a, b, alpha=1.0, beta=1.0, trans_a=False, trans_b=False, c=None): if trans_a: a = a.transpose(-2, -1) From fd800cb68308341025435928517c6ee1c0dd4126 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 2 Mar 2026 02:30:05 +0000 Subject: [PATCH 59/93] test: centralize `dtype` and `device` parametrization --- tests/conftest.py | 33 +++++++++++++++++++++++++++++++++ tests/test_add.py | 11 +---------- tests/test_gemm.py | 3 +-- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f4fd945..7ea7cb8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,8 @@ import pytest import torch +from tests.utils import get_available_devices + def pytest_configure(config): torch.backends.fp32_precision = "tf32" @@ -34,6 +36,23 @@ def _set_random_seed(seed): torch.manual_seed(seed) +def pytest_generate_tests(metafunc): + already_parametrized = _get_parametrized_args(metafunc) + + if "dtype" in metafunc.fixturenames and "dtype" not in already_parametrized: + metafunc.parametrize( + "dtype, rtol, atol", + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-3, 1e-3), + ), + ) + + if "device" in metafunc.fixturenames and "device" not in already_parametrized: + metafunc.parametrize("device", get_available_devices()) + + @pytest.hookimpl(tryfirst=True) def pytest_pyfunc_call(pyfuncitem): if pyfuncitem.get_closest_marker("auto_act_and_assert"): @@ -62,6 +81,20 @@ def pytest_pyfunc_call(pyfuncitem): return True +def _get_parametrized_args(metafunc): + parametrized_args = set() + + for marker in metafunc.definition.iter_markers(name="parametrize"): + args = marker.args[0] + + if isinstance(args, str): + parametrized_args.update(x.strip() for x in args.split(",")) + elif isinstance(args, (list, tuple)): + parametrized_args.update(args) + + return parametrized_args + + def _test_case_path_from_request(request): return f"{_module_path_from_request(request)}::{request.node.name}" diff --git a/tests/test_add.py b/tests/test_add.py index a3bf6c4..badff9b 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -2,19 +2,10 @@ import pytest import torch -from tests.utils import Payload, empty_strided, get_available_devices, randn_strided +from tests.utils import Payload, empty_strided, randn_strided @pytest.mark.auto_act_and_assert -@pytest.mark.parametrize("device", get_available_devices()) -@pytest.mark.parametrize( - "dtype, rtol, atol", - ( - (torch.float32, 1e-7, 1e-7), - (torch.float16, 1e-3, 1e-3), - (torch.bfloat16, 1e-3, 1e-3), - ), -) @pytest.mark.parametrize( "shape, a_strides, b_strides, c_strides", ( diff --git a/tests/test_gemm.py b/tests/test_gemm.py index 21879cc..bb08975 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -2,11 +2,10 @@ import pytest import torch -from tests.utils import Payload, empty_strided, get_available_devices, randn_strided +from tests.utils import Payload, empty_strided, randn_strided @pytest.mark.auto_act_and_assert -@pytest.mark.parametrize("device", get_available_devices()) # TODO: Add support for more data types. @pytest.mark.parametrize("dtype, rtol, atol", ((torch.float32, 1e-3, 1e-3),)) @pytest.mark.parametrize("trans_b", (False, True)) From 54458fbe9fd8938a53a0229ba58848db6a17bd94 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 2 Mar 2026 02:32:55 +0000 Subject: [PATCH 60/93] test: reorder `pytest.mark.parametrize` decorators in `tests/test_gemm.py` --- tests/test_gemm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_gemm.py b/tests/test_gemm.py index bb08975..c091136 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -6,12 +6,6 @@ @pytest.mark.auto_act_and_assert -# TODO: Add support for more data types. -@pytest.mark.parametrize("dtype, rtol, atol", ((torch.float32, 1e-3, 1e-3),)) -@pytest.mark.parametrize("trans_b", (False, True)) -@pytest.mark.parametrize("trans_a", (False, True)) -@pytest.mark.parametrize("beta", (-1, -0.5, 0, 0.5, 1)) -@pytest.mark.parametrize("alpha", (-1, -0.5, 0, 0.5, 1)) @pytest.mark.parametrize( "a_shape, b_shape, c_shape, a_strides, b_strides, c_strides", ( @@ -22,6 +16,12 @@ ((4, 48, 64), (4, 64, 6), (4, 48, 6), None, None, None), ), ) +@pytest.mark.parametrize("alpha", (-1, -0.5, 0, 0.5, 1)) +@pytest.mark.parametrize("beta", (-1, -0.5, 0, 0.5, 1)) +@pytest.mark.parametrize("trans_a", (False, True)) +@pytest.mark.parametrize("trans_b", (False, True)) +# TODO: Add support for more data types. +@pytest.mark.parametrize(("dtype", "rtol", "atol"), ((torch.float32, 1e-3, 1e-3),)) def test_gemm( a_shape, b_shape, From 544825968993a061acd159a4233d7785801aeda0 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 2 Mar 2026 02:36:32 +0000 Subject: [PATCH 61/93] test: rename operands to `input`, `other`, and `out` in `tests/test_add.py` --- tests/test_add.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/test_add.py b/tests/test_add.py index badff9b..f5f84b4 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -7,7 +7,7 @@ @pytest.mark.auto_act_and_assert @pytest.mark.parametrize( - "shape, a_strides, b_strides, c_strides", + "shape, input_strides, other_strides, out_strides", ( ((13, 4), None, None, None), ((13, 4), (10, 1), (10, 1), (10, 1)), @@ -23,19 +23,21 @@ ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ), ) -def test_add(shape, a_strides, b_strides, c_strides, dtype, device, rtol, atol): - a = randn_strided(shape, a_strides, dtype=dtype, device=device) - b = randn_strided(shape, b_strides, dtype=dtype, device=device) - c = empty_strided(shape, c_strides, dtype=dtype, device=device) +def test_add( + shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol +): + input = randn_strided(shape, input_strides, dtype=dtype, device=device) + other = randn_strided(shape, other_strides, dtype=dtype, device=device) + out = empty_strided(shape, out_strides, dtype=dtype, device=device) - return Payload(_add, _torch_add, (a, b, c), {}, rtol=rtol, atol=atol) + return Payload(_add, _torch_add, (input, other, out), {}, rtol=rtol, atol=atol) -def _add(a, b, c): - infini.ops.add(a, b, c) +def _add(input, other, out): + infini.ops.add(input, other, out) - return c + return out -def _torch_add(a, b, c): - return torch.add(a, b, out=c) +def _torch_add(input, other, out): + return torch.add(input, other, out=out) From 92bbe8aa8fc6f3924f7d03c43acc6cced0cf9549 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 2 Mar 2026 06:01:14 +0000 Subject: [PATCH 62/93] test: add benchmarking support --- tests/conftest.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 7ea7cb8..21d54f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,10 +3,17 @@ import pytest import torch +import torch.utils.benchmark as benchmark from tests.utils import get_available_devices +def pytest_addoption(parser): + parser.addoption( + "--benchmark", action="store_true", help="Run performance benchmarks." + ) + + def pytest_configure(config): torch.backends.fp32_precision = "tf32" @@ -73,6 +80,28 @@ def pytest_pyfunc_call(pyfuncitem): output = func(*args, **kwargs) expected = ref(*ref_args, **ref_kwargs) + if pyfuncitem.config.getoption("--benchmark"): + stmt = "func(*args, **kwargs)" + + func_timer = benchmark.Timer( + stmt=stmt, + globals={"func": func, "args": args, "kwargs": kwargs}, + label=func.__name__, + description="InfiniOps", + ) + + ref_timer = benchmark.Timer( + stmt=stmt, + globals={"func": ref, "args": ref_args, "kwargs": ref_kwargs}, + label=func.__name__, + description="Reference", + ) + + func_measurement = func_timer.blocked_autorange() + ref_measurement = ref_timer.blocked_autorange() + + benchmark.Compare((func_measurement, ref_measurement)).print() + rtol = payload.rtol atol = payload.atol From a3f6101b45e04effde46d3b041db4719f8a5b16e Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 2 Mar 2026 07:56:24 +0000 Subject: [PATCH 63/93] perf: cache `Operator` instances in `Operator::call` --- src/device.h | 14 ++++++++++++++ src/hash.h | 12 ++++++++++++ src/operator.h | 17 ++++++++++++++--- src/tensor.h | 22 ++++++++++++++++++++++ 4 files changed, 62 insertions(+), 3 deletions(-) create mode 100644 src/hash.h diff --git a/src/device.h b/src/device.h index 61a8d7e..daf31a6 100644 --- a/src/device.h +++ b/src/device.h @@ -3,6 +3,7 @@ #include "common/constexpr_map.h" #include "common/traits.h" +#include "hash.h" namespace infini::ops { @@ -157,4 +158,17 @@ using ActiveDevices = } // namespace infini::ops +template <> +struct std::hash { + std::size_t operator()(const infini::ops::Device& device) const { + std::size_t seed{0}; + + hash_combine(seed, device.type()); + + hash_combine(seed, device.index()); + + return seed; + } +}; + #endif diff --git a/src/hash.h b/src/hash.h new file mode 100644 index 0000000..aced9cf --- /dev/null +++ b/src/hash.h @@ -0,0 +1,12 @@ +#ifndef INFINI_OPS_HASH_H_ +#define INFINI_OPS_HASH_H_ + +#include + +template +inline void hash_combine(std::size_t& seed, const T& v) { + std::hash> hasher; + seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +#endif diff --git a/src/operator.h b/src/operator.h index e3c7be8..97fcebc 100644 --- a/src/operator.h +++ b/src/operator.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "dispatcher.h" #include "tensor.h" @@ -45,9 +46,19 @@ class Operator : public OperatorBase { template static auto call(void* stream, Args&&... args) { - // TODO: Cache the created `Operator`. - return (*make(std::forward(args)...))(stream, - std::forward(args)...); + static std::unordered_map> cache; + + std::size_t hash{0}; + + (hash_combine(hash, args), ...); + + auto it{cache.find(hash)}; + + if (it == cache.end()) { + it = cache.emplace(hash, make(std::forward(args)...)).first; + } + + return (*it->second)(stream, std::forward(args)...); } template diff --git a/src/tensor.h b/src/tensor.h index 306be70..cc2bf9e 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -7,6 +7,7 @@ #include "data_type.h" #include "device.h" +#include "hash.h" namespace infini::ops { @@ -123,4 +124,25 @@ class Tensor { } // namespace infini::ops +template <> +struct std::hash { + std::size_t operator()(const infini::ops::Tensor& tensor) const { + std::size_t seed{0}; + + for (const auto& size : tensor.shape()) { + hash_combine(seed, size); + } + + hash_combine(seed, tensor.dtype()); + + hash_combine(seed, tensor.device()); + + for (const auto& stride : tensor.strides()) { + hash_combine(seed, stride); + } + + return seed; + } +}; + #endif From 2e5cbc4c596021ac60583d1abc26173993d00e47 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 2 Mar 2026 07:58:01 +0000 Subject: [PATCH 64/93] fix: fix invalid string literal assertion in `Operator::make` --- src/operator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator.h b/src/operator.h index 97fcebc..0bbbc0c 100644 --- a/src/operator.h +++ b/src/operator.h @@ -36,7 +36,7 @@ class Operator : public OperatorBase { op_ptr = std::make_unique>( tensor, std::forward(args)...); } else { - assert("operator is not implemented for this device"); + assert(false && "operator is not implemented for this device"); } }, "Operator::make"); From a7f447c0b638c3298cd003d6085bb0855e7ba063 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 2 Mar 2026 08:33:42 +0000 Subject: [PATCH 65/93] test: use `clone_strided` in `_clone` to preserve tensor layout --- tests/conftest.py | 4 ++-- tests/utils.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 21d54f9..44654c3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import torch import torch.utils.benchmark as benchmark -from tests.utils import get_available_devices +from tests.utils import clone_strided, get_available_devices def pytest_addoption(parser): @@ -138,7 +138,7 @@ def _hash(string): def _clone(obj): if isinstance(obj, torch.Tensor): - return obj.clone() + return clone_strided(obj) if isinstance(obj, tuple): return tuple(_clone(a) for a in obj) diff --git a/tests/utils.py b/tests/utils.py index 50faec1..11afcdf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -51,3 +51,15 @@ def randn_strided(shape, strides, *, dtype=None, device=None): ).normal_() return output + + +def clone_strided(input): + output = empty_strided( + input.size(), input.stride(), dtype=input.dtype, device=input.device + ) + + as_strided_args = (output.untyped_storage().size() // output.element_size(),), (1,) + + output.as_strided(*as_strided_args).copy_(input.as_strided(*as_strided_args)) + + return output From 45d8b9d31d9959792569bc14708b032d3e694529 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 2 Mar 2026 08:45:22 +0000 Subject: [PATCH 66/93] refactor: move pybind11 utilities from `scripts/generate_wrappers.py` to a static header `src/pybind11_utils.h` --- scripts/generate_wrappers.py | 54 +----------------------------------- src/pybind11_utils.h | 51 ++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 53 deletions(-) create mode 100644 src/pybind11_utils.h diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 2a18752..80ef582 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -21,56 +21,6 @@ _INDENTATION = " " -_UTILS_H_CONTENT = """#ifndef INFINI_OPS_BINDINGS_UTILS_H_ -#define INFINI_OPS_BINDINGS_UTILS_H_ - -#include -#include - -namespace infini::ops { - -inline DataType DataTypeFromString(const std::string& name) { - return kStringToDataType.at(name); -} - -inline Device::Type DeviceTypeFromString(const std::string& name) { - static const std::unordered_map kTorchNameToTypes{ - {"cpu", Device::Type::kCpu}, -#ifdef WITH_NVIDIA - {"cuda", Device::Type::kNvidia}, -#endif -#ifdef WITH_METAX - {"cuda", Device::Type::kMetax}, -#endif -#ifdef WITH_ILUVATAR - {"cuda", Device::Type::kIluvatar}, -#endif -#ifdef WITH_KUNLUN - {"cuda", Device::Type::kKunlun}, -#endif -#ifdef WITH_HYGON - {"cuda", Device::Type::kHygon}, -#endif -#ifdef WITH_QY - {"cuda", Device::Type::kQy}, -#endif - {"mlu", Device::Type::kCambricon}, {"npu", Device::Type::kAscend}, - {"musa", Device::Type::kMoore}}; - - auto it{kTorchNameToTypes.find(name)}; - - if (it != kTorchNameToTypes.cend()) { - return it->second; - } - - return Device::TypeFromString(name); -} - -} // namespace infini::ops - -#endif -""" - class _OperatorExtractor: def __call__(self, op_name): @@ -183,7 +133,7 @@ def _generate_call(op_name, call, method=True): #include #include "base/{op_name.lower()}.h" -#include "utils.h" +#include "pybind11_utils.h" namespace py = pybind11; @@ -427,8 +377,6 @@ def _get_all_ops(devices): header_paths = [] bind_func_names = [] - (_BINDINGS_DIR / "utils.h").write_text(_UTILS_H_CONTENT) - for op_name, impl_paths in ops.items(): extractor = _OperatorExtractor() operator = extractor(op_name) diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h new file mode 100644 index 0000000..93c1493 --- /dev/null +++ b/src/pybind11_utils.h @@ -0,0 +1,51 @@ +#ifndef INFINI_OPS_PYBIND11_UTILS_H_ +#define INFINI_OPS_PYBIND11_UTILS_H_ + +#include +#include + +#include "data_type.h" +#include "device.h" + +namespace infini::ops { + +inline DataType DataTypeFromString(const std::string& name) { + return kStringToDataType.at(name); +} + +inline Device::Type DeviceTypeFromString(const std::string& name) { + static const std::unordered_map kTorchNameToTypes{ + {"cpu", Device::Type::kCpu}, +#ifdef WITH_NVIDIA + {"cuda", Device::Type::kNvidia}, +#endif +#ifdef WITH_METAX + {"cuda", Device::Type::kMetax}, +#endif +#ifdef WITH_ILUVATAR + {"cuda", Device::Type::kIluvatar}, +#endif +#ifdef WITH_KUNLUN + {"cuda", Device::Type::kKunlun}, +#endif +#ifdef WITH_HYGON + {"cuda", Device::Type::kHygon}, +#endif +#ifdef WITH_QY + {"cuda", Device::Type::kQy}, +#endif + {"mlu", Device::Type::kCambricon}, {"npu", Device::Type::kAscend}, + {"musa", Device::Type::kMoore}}; + + auto it{kTorchNameToTypes.find(name)}; + + if (it != kTorchNameToTypes.cend()) { + return it->second; + } + + return Device::TypeFromString(name); +} + +} // namespace infini::ops + +#endif From 23950e1854657f6eef928893257f6f93e468f6bc Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 2 Mar 2026 11:02:27 +0000 Subject: [PATCH 67/93] refactor: move tensor conversion logic from `scripts/generate_wrappers.py` to `src/pybind11_utils.h` --- scripts/generate_wrappers.py | 5 +---- src/pybind11_utils.h | 29 +++++++++++++++++++++++++++-- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 80ef582..3f22702 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -89,16 +89,13 @@ def _generate_params(node): def _generate_arguments(node): return ", ".join( - _generate_tensor_caster(arg.spelling) + f"TensorFromPybind11Handle({arg.spelling})" if "Tensor" in arg.type.spelling else arg.spelling for arg in node.get_arguments() if arg.spelling != "stream" ) - def _generate_tensor_caster(name): - return f'Tensor{{reinterpret_cast({name}.attr("data_ptr")().cast()), {name}.attr("shape").cast(), DataTypeFromString(py::str({name}.attr("dtype")).attr("split")(".").attr("__getitem__")(-1).cast()), Device{{DeviceTypeFromString({name}.attr("device").attr("type").cast()), {name}.attr("device").attr("index").is_none() ? 0 : {name}.attr("device").attr("index").cast()}}, {name}.attr("stride")().cast()}}' - op_name = operator.name def _generate_init(constructor): diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h index 93c1493..674b022 100644 --- a/src/pybind11_utils.h +++ b/src/pybind11_utils.h @@ -1,12 +1,14 @@ #ifndef INFINI_OPS_PYBIND11_UTILS_H_ #define INFINI_OPS_PYBIND11_UTILS_H_ -#include -#include +#include +#include #include "data_type.h" #include "device.h" +namespace py = pybind11; + namespace infini::ops { inline DataType DataTypeFromString(const std::string& name) { @@ -46,6 +48,29 @@ inline Device::Type DeviceTypeFromString(const std::string& name) { return Device::TypeFromString(name); } +inline Tensor TensorFromPybind11Handle(py::handle obj) { + auto data{ + reinterpret_cast(obj.attr("data_ptr")().cast())}; + + auto shape{obj.attr("shape").cast()}; + + auto dtype_str{py::str(obj.attr("dtype")).cast()}; + auto pos{dtype_str.find_last_of('.')}; + auto dtype{DataTypeFromString( + pos == std::string::npos ? dtype_str : dtype_str.substr(pos + 1))}; + + auto device_obj{obj.attr("device")}; + auto device_type_str{device_obj.attr("type").cast()}; + auto device_index_obj{device_obj.attr("index")}; + auto device_index{device_index_obj.is_none() ? 0 + : device_index_obj.cast()}; + Device device{DeviceTypeFromString(device_type_str), device_index}; + + auto strides{obj.attr("stride")().cast()}; + + return Tensor{data, std::move(shape), dtype, device, std::move(strides)}; +} + } // namespace infini::ops #endif From 7cf4f85c55fa3146a7768535aaf31bb60473b155 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 3 Mar 2026 08:50:16 +0000 Subject: [PATCH 68/93] fix: use `g++` instead of `clang++` in `_get_system_include_flags` --- scripts/generate_wrappers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 3f22702..db2f14b 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -27,9 +27,7 @@ def __call__(self, op_name): def _get_system_include_flags(): system_include_flags = [] - for line in subprocess.getoutput( - "clang++ -E -x c++ -v /dev/null" - ).splitlines(): + for line in subprocess.getoutput("g++ -E -x c++ -v /dev/null").splitlines(): if not line.startswith(" "): continue From 049de3a36321ad2264eabf07e4e69129e51d4c3d Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Wed, 4 Mar 2026 16:30:58 +0800 Subject: [PATCH 69/93] feat(ops): add `RmsNorm` with Iluvatar, NVIDIA, CPU backends and fp16/bf16 support (#6) * feat(ops): add RmsNorm with Iluvatar, NVIDIA, CPU backends and fp16/bf16 support * refactor(test): align test_rms_norm with Payload pattern and add stride cases * refactor(ops): align RmsNorm with Add pattern, header-only CUDA backend * refactor(ops): restruct rmsnorm IluvatarBackend * refactor(gemm): update constructors to use std::optional for alpha, beta, and trans parameters * refactor(ops): update RmsNorm constructors and parameter naming for consistency * refactor(ops): remove std::optional from Gemm and Blas constructors for alpha, beta, and trans parameters * refactor(ops): standardize parameter naming in rms_norm tests and related functions * refactor(ops): remove unused lambda functions from IluvatarBackend in rms_norm * refactor(ops): clean up formatting and improve readability in rms_norm and related files * refactor: imporve PR changes * chore: revert changes in `src/common/traits.h` and `src/dispatcher.h` * refactor: use `snake_case` `op_name` in `scripts/generate_wrappers.py` --------- Co-authored-by: Jiacheng Huang --- CMakeLists.txt | 21 +++++++++-- pyproject.toml | 3 ++ scripts/generate_wrappers.py | 62 ++++++++++++++++++++---------- src/CMakeLists.txt | 11 +++++- src/base/rms_norm.h | 58 ++++++++++++++++++++++++++++ src/common/constexpr_map.h | 6 +-- src/common/cuda/kernel_commons.h | 4 +- src/common/generic_utils.h | 4 +- src/cpu/rms_norm/rms_norm.h | 61 ++++++++++++++++++++++++++++++ src/cuda/rms_norm/kernel.cuh | 59 +++++++++++++++++++++++++++++ src/cuda/rms_norm/kernel.h | 63 +++++++++++++++++++++++++++++++ src/data_type.h | 9 +++-- src/iluvatar/rms_norm/kernel.h | 31 +++++++++++++++ src/nvidia/rms_norm/kernel.h | 31 +++++++++++++++ tests/test_rms_norm.py | 65 ++++++++++++++++++++++++++++++++ 15 files changed, 455 insertions(+), 33 deletions(-) create mode 100644 src/base/rms_norm.h create mode 100644 src/cpu/rms_norm/rms_norm.h create mode 100644 src/cuda/rms_norm/kernel.cuh create mode 100644 src/cuda/rms_norm/kernel.h create mode 100644 src/iluvatar/rms_norm/kernel.h create mode 100644 src/nvidia/rms_norm/kernel.h create mode 100644 tests/test_rms_norm.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 053a37b..36aa295 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,18 +41,33 @@ endif() include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) +# NVIDIA and Iluvatar are parallel backends; only one GPU backend at a time. +if(WITH_NVIDIA AND WITH_ILUVATAR) + message(FATAL_ERROR "`WITH_NVIDIA` and `WITH_ILUVATAR` cannot both be `ON`. Build one GPU backend at a time.") +endif() + if(WITH_NVIDIA) add_compile_definitions(WITH_NVIDIA=1) enable_language(CUDA) find_package(CUDAToolkit REQUIRED) endif() +# Iluvatar: CUDA-compatible device, uses `clang++` with `-x ivcore` (not `nvcc`). +# Reference: `InfiniCore` `xmake/iluvatar.lua`. if(WITH_ILUVATAR) add_compile_definitions(WITH_ILUVATAR=1) - if(NOT WITH_NVIDIA) - enable_language(CUDA) - find_package(CUDAToolkit REQUIRED) + set(ILUVATAR_ARCH "ivcore20" CACHE STRING "Iluvatar GPU architecture") + find_program(CLANGXX NAMES clang++) + if(CLANGXX) + set(CMAKE_CUDA_COMPILER "${CLANGXX}" CACHE STRING "Iluvatar CUDA compiler (clang++)") + else() + set(CMAKE_CUDA_COMPILER "clang++" CACHE STRING "Iluvatar CUDA compiler (clang++)") endif() + set(CMAKE_CUDA_FLAGS "-x ivcore -std=c++17 --cuda-gpu-arch=${ILUVATAR_ARCH} -fPIC -Wno-error=unused-variable -Wno-error=unused-private-field -Wno-unused-variable" CACHE STRING "Iluvatar CUDA flags") + set(CMAKE_CUDA_SEPARABLE_COMPILATION OFF CACHE BOOL "Disable RDC for Iluvatar") + message(STATUS "Iluvatar: CUDA compiler ${CMAKE_CUDA_COMPILER}, arch ${ILUVATAR_ARCH}") + enable_language(CUDA) + find_package(CUDAToolkit REQUIRED) endif() if(WITH_METAX) diff --git a/pyproject.toml b/pyproject.toml index b5d2cdb..765b90a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,3 +15,6 @@ install-dir = "infini" [tool.scikit-build.cmake.define] AUTO_DETECT_DEVICES = "ON" GENERATE_PYTHON_BINDINGS = "ON" + +[tool.pytest.ini_options] +testpaths = ["tests"] diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index db2f14b..0458a39 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -1,6 +1,7 @@ import argparse import json import pathlib +import shutil import subprocess import textwrap @@ -25,14 +26,26 @@ class _OperatorExtractor: def __call__(self, op_name): def _get_system_include_flags(): + def _get_compilers(): + compilers = [] + + for compiler in ("clang++", "g++"): + if shutil.which(compiler) is not None: + compilers.append(compiler) + + return compilers + system_include_flags = [] - for line in subprocess.getoutput("g++ -E -x c++ -v /dev/null").splitlines(): - if not line.startswith(" "): - continue + for compiler in _get_compilers(): + for line in subprocess.getoutput( + f"{compiler} -E -x c++ -v /dev/null" + ).splitlines(): + if not line.startswith(" "): + continue - system_include_flags.append("-isystem") - system_include_flags.append(line.strip()) + system_include_flags.append("-isystem") + system_include_flags.append(line.strip()) return system_include_flags @@ -40,7 +53,7 @@ def _get_system_include_flags(): index = clang.cindex.Index.create() args = ("-std=c++17", "-x", "c++", "-I", "src") + tuple(system_include_flags) - translation_unit = index.parse(f"src/base/{op_name.lower()}.h", args=args) + translation_unit = index.parse(f"src/base/{op_name}.h", args=args) nodes = tuple(type(self)._find(translation_unit.cursor, op_name)) @@ -57,7 +70,12 @@ def _get_system_include_flags(): @staticmethod def _find(node, op_name): - if node.semantic_parent and node.semantic_parent.spelling == op_name: + pascal_case_op_name = _snake_to_pascal(op_name) + + if ( + node.semantic_parent + and node.semantic_parent.spelling == pascal_case_op_name + ): yield node for child in node.get_children(): @@ -107,7 +125,7 @@ def _generate_call(op_name, call, method=True): call_params = _generate_params(call) if not method: - return f""" m.def("{op_name.lower()}", []({call_params}) {{ return Self::call({_generate_arguments(call)}); }});""" + return f""" m.def("{op_name}", []({call_params}) {{ return Self::call({_generate_arguments(call)}); }});""" return f""" .def("__call__", [](const Self& self, {call_params}) {{ return static_cast&>(self)({_generate_arguments(call)}); @@ -121,23 +139,25 @@ def _generate_call(op_name, call, method=True): _generate_call(operator.name, call, method=False) for call in operator.calls ) + pascal_case_op_name = _snake_to_pascal(op_name) + return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ #define INFINI_OPS_BINDINGS_{op_name.upper()}_H_ #include #include -#include "base/{op_name.lower()}.h" +#include "base/{op_name}.h" #include "pybind11_utils.h" namespace py = pybind11; namespace infini::ops {{ -void Bind{op_name}(py::module& m) {{ - using Self = {op_name}; +void Bind{pascal_case_op_name}(py::module& m) {{ + using Self = {pascal_case_op_name}; - py::class_(m, "{op_name}") + py::class_(m, "{pascal_case_op_name}") {inits} {calls}; @@ -324,6 +344,10 @@ def _generate_tensor_caster(name, is_data=False): return _generate_source(operator), _generate_header(operator) +def _snake_to_pascal(snake_str): + return "".join(word.capitalize() for word in snake_str.split("_")) + + def _get_all_ops(devices): ops = {} @@ -331,7 +355,7 @@ def _get_all_ops(devices): if not file_path.is_file(): continue - op_name = "".join(word.capitalize() for word in file_path.stem.split("_")) + op_name = file_path.stem ops[op_name] = [] @@ -339,7 +363,7 @@ def _get_all_ops(devices): if not file_path.is_file() or file_path.parent.parent.name not in devices: continue - if f"class Operator<{op_name}" in file_path.read_text(): + if f"class Operator<{_snake_to_pascal(op_name)}" in file_path.read_text(): ops[op_name].append(file_path) return ops @@ -376,17 +400,15 @@ def _get_all_ops(devices): extractor = _OperatorExtractor() operator = extractor(op_name) - source_path = _GENERATED_SRC_DIR / op_name.lower() - header_name = f"{op_name.lower()}.h" - bind_func_name = f"Bind{op_name}" + source_path = _GENERATED_SRC_DIR / op_name + header_name = f"{op_name}.h" + bind_func_name = f"Bind{_snake_to_pascal(op_name)}" (_BINDINGS_DIR / header_name).write_text(_generate_pybind11(operator)) legacy_c_source, legacy_c_header = _generate_legacy_c(operator, impl_paths) source_path.mkdir(exist_ok=True) - (_GENERATED_SRC_DIR / op_name.lower() / "operator.cc").write_text( - legacy_c_source - ) + (_GENERATED_SRC_DIR / op_name / "operator.cc").write_text(legacy_c_source) (_INCLUDE_DIR / header_name).write_text(legacy_c_header) header_paths.append(header_name) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2ae2b70..97cc0e3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -43,6 +43,10 @@ if(WITH_NVIDIA) target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver) list(APPEND DEVICE_LIST "nvidia") + set_target_properties(infiniops PROPERTIES + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + ) endif() if(WITH_ILUVATAR) @@ -65,6 +69,11 @@ if(WITH_ILUVATAR) find_package(CUDAToolkit REQUIRED) target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver) + set_target_properties(infiniops PROPERTIES + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + ) + list(APPEND DEVICE_LIST "iluvatar") endif() @@ -112,7 +121,7 @@ if(GENERATE_PYTHON_BINDINGS) set(PYBIND11_SOURCES "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc") # TODO: There might be a better solution. - if(WITH_NVIDIA) + if(WITH_NVIDIA OR WITH_ILUVATAR) set_source_files_properties(${PYBIND11_SOURCES} PROPERTIES LANGUAGE CUDA) endif() diff --git a/src/base/rms_norm.h b/src/base/rms_norm.h new file mode 100644 index 0000000..db9041c --- /dev/null +++ b/src/base/rms_norm.h @@ -0,0 +1,58 @@ +#ifndef INFINI_OPS_BASE_RMS_NORM_H_ +#define INFINI_OPS_BASE_RMS_NORM_H_ + +#include +#include + +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class RmsNorm : public Operator { + public: + RmsNorm(const Tensor input, const Tensor weight, float eps, Tensor out) + : eps_{eps}, + out_shape_{out.shape()}, + input_shape_{input.shape()}, + out_strides_{out.strides()}, + input_strides_{input.strides()}, + dim_{out.size(-1)}, + ndim_{out.ndim()}, + batch_size_{ndim_ == 2 ? out.size(-2) : out.size(-3)}, + nhead_{ndim_ == 2 ? 1 : out.size(-2)} {} + + RmsNorm(const Tensor input, const Tensor weight, Tensor out) + : RmsNorm{input, weight, 1e-6f, out} {} + + virtual void operator()(void* stream, const Tensor input, const Tensor weight, + float eps, Tensor out) const = 0; + + virtual void operator()(void* stream, const Tensor input, const Tensor weight, + Tensor out) const { + return operator()(stream, input, weight, eps_, out); + } + + protected: + Tensor::Shape input_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides out_strides_; + + float eps_{1e-6f}; + + Tensor::Size dim_{0}; + + Tensor::Size ndim_{0}; + + Tensor::Size batch_size_{0}; + + Tensor::Size nhead_{1}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/common/constexpr_map.h b/src/common/constexpr_map.h index 3db1275..0d142eb 100644 --- a/src/common/constexpr_map.h +++ b/src/common/constexpr_map.h @@ -8,9 +8,9 @@ namespace infini::ops { -template +template struct ConstexprMap { - constexpr ConstexprMap(std::array, N> data) + constexpr ConstexprMap(std::array, size> data) : data_(data) {} constexpr Value at(Key key) const { @@ -24,7 +24,7 @@ struct ConstexprMap { } private: - std::array, N> data_; + std::array, size> data_; }; } // namespace infini::ops diff --git a/src/common/cuda/kernel_commons.h b/src/common/cuda/kernel_commons.h index 4d92e00..98b9f48 100644 --- a/src/common/cuda/kernel_commons.h +++ b/src/common/cuda/kernel_commons.h @@ -3,7 +3,9 @@ #ifdef WITH_NVIDIA #include -#elif WITH_METAX +#elif defined(WITH_ILUVATAR) +#include +#elif WITH_METAX // TODO: Use `defined`. #include #endif diff --git a/src/common/generic_utils.h b/src/common/generic_utils.h index 6c82f49..36df934 100644 --- a/src/common/generic_utils.h +++ b/src/common/generic_utils.h @@ -16,8 +16,8 @@ std::size_t indexToOffset(std::size_t flat_index, std::size_t ndim, return res; } -template -constexpr auto CeilDiv(const Tx& x, const Ty& y) { +template +constexpr auto CeilDiv(const X& x, const Y& y) { return (x + y - 1) / y; } diff --git a/src/cpu/rms_norm/rms_norm.h b/src/cpu/rms_norm/rms_norm.h new file mode 100644 index 0000000..4656ba7 --- /dev/null +++ b/src/cpu/rms_norm/rms_norm.h @@ -0,0 +1,61 @@ +#ifndef INFINI_OPS_CPU_RMS_NORM_H_ +#define INFINI_OPS_CPU_RMS_NORM_H_ + +#include + +#include "base/rms_norm.h" +#include "data_type.h" +#include "tensor.h" + +namespace infini::ops { + +template <> +class Operator : public RmsNorm { + public: + using RmsNorm::RmsNorm; + + void operator()(void* stream, const Tensor input, const Tensor weight, + float eps, Tensor out) const override { + // CPU backend supports fp32 only; fp16/bf16 use GPU backends. + if (out.dtype() != DataType::kFloat32 || + input.dtype() != DataType::kFloat32 || + weight.dtype() != DataType::kFloat32) { + abort(); + } + + auto* out_ptr = static_cast(out.data()); + const auto* input_ptr = static_cast(input.data()); + const auto* weight_ptr = static_cast(weight.data()); + + auto stride_input_batch = input_strides_.size() > 1 ? input_strides_[0] : 0; + auto stride_input_nhead = + input_strides_.size() > 1 ? input_strides_[1] : input_strides_[0]; + auto stride_out_batch = out_strides_.size() > 1 ? out_strides_[0] : 0; + auto stride_out_nhead = + out_strides_.size() > 1 ? out_strides_[1] : out_strides_[0]; + + for (Tensor::Size bi = 0; bi < batch_size_; ++bi) { + for (Tensor::Size hi = 0; hi < nhead_; ++hi) { + const float* input_row = + input_ptr + bi * stride_input_batch + hi * stride_input_nhead; + float* out_row = + out_ptr + bi * stride_out_batch + hi * stride_out_nhead; + + float ss = 0; + for (Tensor::Size k = 0; k < dim_; ++k) { + float v = input_row[k]; + ss += v * v; + } + float rms = 1.f / std::sqrt(ss / static_cast(dim_) + eps_); + + for (Tensor::Size k = 0; k < dim_; ++k) { + out_row[k] = input_row[k] * weight_ptr[k] * rms; + } + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/rms_norm/kernel.cuh b/src/cuda/rms_norm/kernel.cuh new file mode 100644 index 0000000..09f20a8 --- /dev/null +++ b/src/cuda/rms_norm/kernel.cuh @@ -0,0 +1,59 @@ +#ifndef INFINI_OPS_CUDA_RMS_NORM_KERNEL_CUH_ +#define INFINI_OPS_CUDA_RMS_NORM_KERNEL_CUH_ + +#include +#include + +#include +#include +#include + +namespace infini::ops { + +namespace { + +template +__device__ __forceinline__ Compute SumSquared(const Data* data_ptr, + size_t count) { + Compute ss = 0; + for (size_t i = threadIdx.x; i < count; i += block_size) { + ss += Compute(data_ptr[i]) * Compute(data_ptr[i]); + } + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + return BlockReduce(temp_storage).Sum(ss); +} + +} // namespace + +template +__global__ void RmsNormKernel(Data* __restrict__ y, int64_t stride_y_batch, + int64_t stride_y_nhead, + const Data* __restrict__ x, + int64_t stride_x_batch, int64_t stride_x_nhead, + const Weight* __restrict__ w, size_t nhead, + size_t dim, float epsilon) { + size_t batch_idx = blockIdx.x / nhead; + size_t head_idx = blockIdx.x % nhead; + + auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead; + auto x_ptr = x + batch_idx * stride_x_batch + head_idx * stride_x_nhead; + auto w_ptr = w; + + Compute ss = SumSquared(x_ptr, dim); + + __shared__ Compute rms; + if (threadIdx.x == 0) { + rms = Compute(rsqrtf(ss / Compute(dim) + epsilon)); + } + __syncthreads(); + + for (size_t i = threadIdx.x; i < dim; i += block_size) { + y_ptr[i] = Data(Compute(x_ptr[i]) * Compute(w_ptr[i]) * rms); + } +} + +} // namespace infini::ops + +#endif diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h new file mode 100644 index 0000000..dc0b2a7 --- /dev/null +++ b/src/cuda/rms_norm/kernel.h @@ -0,0 +1,63 @@ +#ifndef INFINI_OPS_CUDA_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_CUDA_RMS_NORM_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "base/rms_norm.h" +#include "cuda/rms_norm/kernel.cuh" +#include "data_type.h" +#include "dispatcher.h" + +namespace infini::ops { + +namespace { + +constexpr unsigned int kBlockSize = 256; + +} // namespace + +template +class CudaRmsNorm : public RmsNorm { + public: + using RmsNorm::RmsNorm; + + void operator()(void* stream, const Tensor input, const Tensor weight, + float eps, Tensor out) const override { + auto cuda_stream = + static_cast(stream ? stream : 0); + + auto stride_input_batch = input_strides_.size() > 1 ? input_strides_[0] : 0; + auto stride_input_nhead = + input_strides_.size() > 1 ? input_strides_[1] : input_strides_[0]; + auto stride_out_batch = out_strides_.size() > 1 ? out_strides_[0] : 0; + auto stride_out_nhead = + out_strides_.size() > 1 ? out_strides_[1] : out_strides_[0]; + + uint32_t num_blocks = static_cast(batch_size_ * nhead_); + + if (out.dtype() != input.dtype() || out.dtype() != weight.dtype()) { + std::abort(); + } + + DispatchFunc( + out.dtype(), + [&]() { + RmsNormKernel + <<>>( + reinterpret_cast(out.data()), stride_out_batch, + stride_out_nhead, reinterpret_cast(input.data()), + stride_input_batch, stride_input_nhead, + reinterpret_cast(weight.data()), nhead_, dim_, + eps_); + }, + "CudaRmsNorm::operator()"); + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/data_type.h b/src/data_type.h index 567a076..850dda2 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -7,7 +7,10 @@ #ifdef WITH_NVIDIA #include #include -#elif WITH_METAX +#elif defined(WITH_ILUVATAR) +#include +#include +#elif defined(WITH_METAX) #include #include #endif @@ -111,10 +114,10 @@ DEFINE_DATA_TYPE_MAPPING(kInt64, int64_t) DEFINE_DATA_TYPE_MAPPING(kFloat32, float) DEFINE_DATA_TYPE_MAPPING(kFloat64, double) -#ifdef WITH_NVIDIA +#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) DEFINE_DATA_TYPE_MAPPING(kFloat16, half) DEFINE_DATA_TYPE_MAPPING(kBFloat16, __nv_bfloat16) -#elif WITH_METAX +#elif defined(WITH_METAX) DEFINE_DATA_TYPE_MAPPING(kFloat16, __half) DEFINE_DATA_TYPE_MAPPING(kBFloat16, __maca_bfloat16) #else diff --git a/src/iluvatar/rms_norm/kernel.h b/src/iluvatar/rms_norm/kernel.h new file mode 100644 index 0000000..3971c3a --- /dev/null +++ b/src/iluvatar/rms_norm/kernel.h @@ -0,0 +1,31 @@ +#ifndef INFINI_OPS_ILUVATAR_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_ILUVATAR_RMS_NORM_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/rms_norm/kernel.h" + +namespace infini::ops { + +namespace rms_norm { + +struct IluvatarBackend { + using stream_t = cudaStream_t; +}; + +} // namespace rms_norm + +template <> +class Operator + : public CudaRmsNorm { + public: + using CudaRmsNorm::CudaRmsNorm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/nvidia/rms_norm/kernel.h b/src/nvidia/rms_norm/kernel.h new file mode 100644 index 0000000..496bddd --- /dev/null +++ b/src/nvidia/rms_norm/kernel.h @@ -0,0 +1,31 @@ +#ifndef INFINI_OPS_NVIDIA_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_NVIDIA_RMS_NORM_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/rms_norm/kernel.h" + +namespace infini::ops { + +namespace rms_norm { + +struct NvidiaBackend { + using stream_t = cudaStream_t; +}; + +} // namespace rms_norm + +template <> +class Operator + : public CudaRmsNorm { + public: + using CudaRmsNorm::CudaRmsNorm; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py new file mode 100644 index 0000000..12ec7ee --- /dev/null +++ b/tests/test_rms_norm.py @@ -0,0 +1,65 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "input_shape, weight_shape, input_strides, weight_strides, out_strides", + ( + ((1, 64), (64,), None, None, None), + ((2, 128), (128,), None, None, None), + ((4, 48, 64), (64,), None, None, None), + ((2, 4, 2048), (2048,), None, None, None), + ((1, 64), (64,), (64, 1), (1,), (64, 1)), + ((4, 48, 64), (64,), (3072, 64, 1), (1,), (3072, 64, 1)), + ), +) +@pytest.mark.parametrize("eps", (1e-6, 1e-5)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-4, 1e-4), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), + ), +) +def test_rms_norm( + input_shape, + weight_shape, + input_strides, + weight_strides, + out_strides, + eps, + dtype, + device, + rtol, + atol, +): + if device == "cpu" and dtype in (torch.float16, torch.bfloat16): + pytest.skip("CPU backend does not support fp16/bf16") + + input = randn_strided(input_shape, input_strides, dtype=dtype, device=device) + weight = randn_strided(weight_shape, weight_strides, dtype=dtype, device=device) + out = empty_strided(input_shape, out_strides, dtype=dtype, device=device) + + return Payload( + _rms_norm, + _torch_rms_norm, + (input, weight), + {"eps": eps, "out": out}, + rtol=rtol, + atol=atol, + ) + + +def _rms_norm(input, weight, *, eps=1e-6, out=None): + infini.ops.rms_norm(input, weight, eps, out) + + return out + + +def _torch_rms_norm(input, weight, *, eps=1e-6, out=None): + return torch.nn.functional.rms_norm(input, input.shape[-1:], weight=weight, eps=eps) From 59031f7894196224d4b09a5b953c9013f41c4564 Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Wed, 4 Mar 2026 17:05:14 +0800 Subject: [PATCH 70/93] feat(ops): add Iluvatar GPU backend for `Add` (#8) --- src/iluvatar/add/kernel.h | 41 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 src/iluvatar/add/kernel.h diff --git a/src/iluvatar/add/kernel.h b/src/iluvatar/add/kernel.h new file mode 100644 index 0000000..551544f --- /dev/null +++ b/src/iluvatar/add/kernel.h @@ -0,0 +1,41 @@ +#ifndef INFINI_OPS_ILUVATAR_ADD_KERNEL_H_ +#define INFINI_OPS_ILUVATAR_ADD_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/add/kernel.h" + +namespace infini::ops { + +namespace add { + +struct IluvatarBackend { + using stream_t = cudaStream_t; + + static constexpr auto malloc = [](auto&&... args) { + return cudaMalloc(std::forward(args)...); + }; + + static constexpr auto memcpy = cudaMemcpy; + + static constexpr auto free = cudaFree; + + static constexpr auto memcpyH2D = cudaMemcpyHostToDevice; +}; + +} // namespace add + +template <> +class Operator + : public CudaAdd { + public: + using CudaAdd::CudaAdd; +}; + +} // namespace infini::ops + +#endif From a6d915b0813fefb954e3e296befbd84cd5ef6613 Mon Sep 17 00:00:00 2001 From: Ziminli <70735843+Ziminli@users.noreply.github.com> Date: Thu, 5 Mar 2026 10:43:57 +0800 Subject: [PATCH 71/93] refactor: adapt dispatcher for full C++17 compatibility and support `pip install` on MetaX (#5) * refactor: adapt the dispatcher to be C++17-compatiable - dispatcher now does not depend on C++20 features - udpate the current dispatcher use cases - add some relevant constexpr traits in common/traits.h - add `PYBIND_ENABLE_EXTRAS` internal cmake variable for controlling the flags introduced by pybind * style: format some comments in common/traits.h * fix: support mxcc to use pytest by using `scripts/mxcc_wrapper.sh` * build: add auto-detection for MetaX * style: change the naming for types and variables in `common/traits.h`, `common/constexpr_map.h` and `dispatcher.h` * style: fix the method and context string naming in `src/add/add.h` * refactor: change the anonymous namespaces in `dispatcher.h` to namespace `detail` to comply with the styling rules * style: fix comment styling issues * fix: update `DispatchFunc` usage in `src/cuda/rms_norm/kernel.h` --------- Co-authored-by: Jiacheng Huang --- CMakeLists.txt | 36 +++- scripts/mxcc_wrapper.sh | 24 +++ src/CMakeLists.txt | 6 +- src/base/add.h | 4 +- src/common/constexpr_map.h | 2 +- src/common/traits.h | 148 +++++++++----- src/cpu/add/add.h | 10 +- src/cuda/add/kernel.h | 3 +- src/cuda/rms_norm/kernel.h | 3 +- src/device.h | 40 +--- src/dispatcher.h | 383 ++++++++++++++++++++++++------------- src/operator.h | 3 +- src/tensor.cc | 5 +- 13 files changed, 440 insertions(+), 227 deletions(-) create mode 100755 scripts/mxcc_wrapper.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 36aa295..3a41236 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,6 +4,9 @@ project(InfiniOps LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) +# Internal variable to control pybind11's automatic optimization flags (like `-flto`). +set(PYBIND11_ENABLE_EXTRAS ON) + # Options for backends. option(WITH_CPU "Enable CPU backend" OFF) option(WITH_NVIDIA "Enable CUDA backend" OFF) @@ -32,11 +35,26 @@ if(AUTO_DETECT_DEVICES) message(STATUS "Auto-detected Iluvatar environment.") endif() - # TODO: Please test and uncomment/update the auto-detection for MetaX. - # if(DEFINED ENV{MACA_PATH}) - # set(WITH_METAX ON) - # message(STATUS "Auto-detected MetaX environment.") - # endif() + if(DEFINED ENV{MACA_PATH}) + set(WITH_METAX ON) + message(STATUS "Auto-detected MetaX environment from MACA_PATH") + else() + execute_process( + COMMAND sh -c "grep -h 9999 /sys/bus/pci/devices/*/vendor 2>/dev/null" + OUTPUT_VARIABLE _pci_vendor_output + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + + string(FIND "${_pci_vendor_output}" "9999" _found_pos) + + if(_found_pos GREATER -1) + set(WITH_METAX ON) + message(STATUS "Detected MetaX GPU from PCI vendor ID 0x9999") + else() + set(WITH_METAX OFF) + message(STATUS "No MetaX GPU detected") + endif() + endif() endif() include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) @@ -75,8 +93,8 @@ if(WITH_METAX) # Normally can be found at: `/opt/maca/`. set(MACA_PATH $ENV{MACA_PATH}) - set(CMAKE_C_COMPILER ${MACA_PATH}/mxgpu_llvm/bin/mxcc) - set(CMAKE_CXX_COMPILER ${MACA_PATH}/mxgpu_llvm/bin/mxcc) + set(CMAKE_C_COMPILER ${CMAKE_CURRENT_SOURCE_DIR}/scripts/mxcc_wrapper.sh) + set(CMAKE_CXX_COMPILER ${CMAKE_CURRENT_SOURCE_DIR}/scripts/mxcc_wrapper.sh) include_directories("${MACA_PATH}/include") link_directories("${MACA_PATH}/lib") @@ -92,6 +110,10 @@ if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX) add_compile_definitions(WITH_CPU=1) endif() +if(WITH_METAX) + set(PYBIND11_ENABLE_EXTRAS OFF) +endif() + add_subdirectory(src) add_subdirectory(examples) diff --git a/scripts/mxcc_wrapper.sh b/scripts/mxcc_wrapper.sh new file mode 100755 index 0000000..0010617 --- /dev/null +++ b/scripts/mxcc_wrapper.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Filter out flags unsupported by `mxcc`. +ARGS=() +skip_next=0 +for arg in "$@"; do + if [ $skip_next -eq 1 ]; then + skip_next=0 + continue + fi + case "$arg" in + -pthread) + ;; + -B) + skip_next=1 + ;; + -B*) + ;; + *) + ARGS+=("$arg") + ;; + esac +done + +exec ${MACA_PATH}/mxgpu_llvm/bin/mxcc "${ARGS[@]}" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 97cc0e3..02fbc4d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -128,7 +128,11 @@ if(GENERATE_PYTHON_BINDINGS) find_package(Python COMPONENTS Interpreter Development) find_package(pybind11 CONFIG) - pybind11_add_module(ops ${PYBIND11_SOURCES}) + if(PYBIND11_ENABLE_EXTRAS) + pybind11_add_module(ops ${PYBIND11_SOURCES}) + else() + pybind11_add_module(ops NO_EXTRAS ${PYBIND11_SOURCES}) + endif() target_include_directories(ops PRIVATE ${PROJECT_SOURCE_DIR}) target_link_libraries(ops PRIVATE infiniops) diff --git a/src/base/add.h b/src/base/add.h index a2da9ef..49ae1ae 100644 --- a/src/base/add.h +++ b/src/base/add.h @@ -25,11 +25,11 @@ class Add : public Operator { is_other_contiguous_{other.IsContiguous()}, is_out_contiguous_{out.IsContiguous()} { assert(!out.HasBroadcastDim() && - "The output of `Add` should NOT have broadcasted dim!"); + "the output of `Add` should NOT have broadcasted dim!"); // TODO(lzm): support mix-precision later using the generic elementwise // framework. assert(input_type_ == other_type_ && other_type_ == out_type_ && - "Operator `Add` requires all input and output Tensors to have the " + "operator `Add` requires all input and output Tensors to have the " "same dtype"); } diff --git a/src/common/constexpr_map.h b/src/common/constexpr_map.h index 0d142eb..27fdc67 100644 --- a/src/common/constexpr_map.h +++ b/src/common/constexpr_map.h @@ -18,7 +18,7 @@ struct ConstexprMap { if (pr.first == key) return pr.second; } // TODO(lzm): change to logging. - assert("ConstexprMap's key is not found!"); + assert("the key is not found in the `ConstexprMap`"); // Unreachable, provided to satisfy the compiler's requirement. std::abort(); } diff --git a/src/common/traits.h b/src/common/traits.h index 6f75e9f..c746f4c 100644 --- a/src/common/traits.h +++ b/src/common/traits.h @@ -6,101 +6,163 @@ namespace infini::ops { +// --------------------- List and TypePack --------------------- // A generic container for a sequence of compile-time values. -template +template struct List {}; +// `ListGet(List{})` extracts the `i`th value from a `List` +// tag. +template +constexpr auto ListGetImpl(List) { + if constexpr (index == 0) + return head; + else + return ListGetImpl(List{}); +} + +template +constexpr auto ListGet(List list) { + return ListGetImpl(list); +} + +template +struct TypePack {}; + +// ----------------------------------------------------------------------------- +// Tags +// ----------------------------------------------------------------------------- +// Tags are passed as regular function arguments to user functors instead of +// template parameters. This lets users write plain C++17 `[](auto tag)` lambdas +// rather than C++20 template lambdas (`[]()`). + +// `TypeTag`: carries a C++ type. Recover with `typename +// decltype(tag)::type`. +template +struct TypeTag { + using type = T; +}; + +// `ValueTag`: carries a compile-time value. Recover with +// `decltype(tag)::value`. +template +struct ValueTag { + using value_type = decltype(v); + static constexpr auto value = v; +}; + // ----------------------------------------------------------------------------- // List Queries // ----------------------------------------------------------------------------- -// Check at compile-time if a Value exists within a construct (e.g., List<>). -// Example: static_assert(ContainsValue); -template +// Check at compile-time if a value exists within a construct (e.g., `List<>`). +// Example: `static_assert(ContainsValue)`; +template struct Contains; -template -struct Contains, Value> - : std::disjunction...> {}; +template +struct Contains, value> + : std::disjunction...> {}; -template -inline constexpr bool ContainsValue = Contains::value; +template +inline constexpr bool ContainsValue = Contains::value; -// Check at compile-time if a type T is present in a variadic list of types Ts. -// Example: static_assert(IsTypeInList); +// Check at compile-time if a type `T` is present in a variadic list of types +// `Ts`. +// Example: `static_assert(IsTypeInList)`; template inline constexpr bool IsTypeInList = (std::is_same_v || ...); +// Trait to detect whether `T` is a `List<...>` specialization. +template +struct IsListType : std::false_type {}; + +template +struct IsListType> : std::true_type {}; + // ----------------------------------------------------------------------------- // List Operations // ----------------------------------------------------------------------------- -// Concatenates two List types into a single List. -// Example: ConcatType, List<3, 4>> is List<1, 2, 3, 4>. +// Concatenates two List types into a single `List`. +// Example: `ConcatType, List<3, 4>>` is `List<1, 2, 3, 4>`. template struct Concat; -template -struct Concat, List> { - using type = List; +template +struct Concat, List> { + using type = List; }; template using ConcatType = typename Concat::type; +template +struct Flatten; + +template +struct Flatten> { + using type = List; +}; + +template +struct Flatten { + using type = typename Flatten, Rest...>::type; +}; + // ----------------------------------------------------------------------------- // Invocability Detection (SFINAE) // ----------------------------------------------------------------------------- -// Checks if a Functor's template operator() can be called with Args. -template +// Checks if a `Functor` can be called with a `ValueTag` and `Args...`. +template struct IsInvocable : std::false_type {}; -template -struct IsInvocable< - Functor, Value, - std::void_t().template operator()( - std::declval()...))>, - Args...> : std::true_type {}; +template +struct IsInvocable()( + ValueTag{}, std::declval()...))>, + Args...> : std::true_type {}; -template +template inline constexpr bool IsInvocableValue = - IsInvocable::value; + IsInvocable::value; // ----------------------------------------------------------------------------- // Filtering Logic // ----------------------------------------------------------------------------- -// Recursive template to filter values based on Functor support at compile-time. +// Recursive template to filter values based on `Functor` support at +// compile-time. template + auto... remaining> struct Filter; // Base case: All values processed. -template -struct Filter, List> { - using type = List; +template +struct Filter, List> { + using type = List; }; -// Recursive step: Test the 'Head' value and accumulate if supported. -template -struct Filter, List, Head, Tail...> { +// Recursive step: Test the `head` value and accumulate if supported. +template +struct Filter, List, head, tail...> { using type = typename std::conditional_t< - IsInvocableValue && - !ContainsValue, Head>, - Filter, List, Tail...>, - Filter, List, Tail...>>::type; + IsInvocableValue && + !ContainsValue, head>, + Filter, List, tail...>, + Filter, List, tail...>>::type; }; -// Interface to filter a List type directly. +// Interface to filter a `List` type directly. template struct FilterList; -template -struct FilterList, List> { +template +struct FilterList, List> { using type = - typename Filter, List<>, Items...>::type; + typename Filter, List<>, items...>::type; }; } // namespace infini::ops diff --git a/src/cpu/add/add.h b/src/cpu/add/add.h index d9a456a..171c1ad 100644 --- a/src/cpu/add/add.h +++ b/src/cpu/add/add.h @@ -19,13 +19,17 @@ class Operator : public Add { void operator()(void* stream, const Tensor input, const Tensor other, Tensor out) const override { DispatchFunc>( - out_type_, [&]() { compute(stream, input, other, out); }, - "Operator::operator()"); + out_type_, + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(stream, input, other, out); + }, + "`Operator::operator()`"); } private: template - void compute(void* stream, const Tensor input, const Tensor other, + void Compute(void* stream, const Tensor input, const Tensor other, Tensor out) const { const auto* input_ptr = static_cast(input.data()); const auto* other_ptr = static_cast(other.data()); diff --git a/src/cuda/add/kernel.h b/src/cuda/add/kernel.h index 664bcee..d73a769 100644 --- a/src/cuda/add/kernel.h +++ b/src/cuda/add/kernel.h @@ -94,7 +94,8 @@ class CudaAdd : public Add { Tensor out) const override { DispatchFunc( out_type_, - [&]() { + [&](auto tag) { + using T = typename decltype(tag)::type; // TODO(lzm): currently hard-code block_size to be 256. dim3 blockDims( std::min(static_cast(256), output_size_)); diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h index dc0b2a7..6057be1 100644 --- a/src/cuda/rms_norm/kernel.h +++ b/src/cuda/rms_norm/kernel.h @@ -45,7 +45,8 @@ class CudaRmsNorm : public RmsNorm { DispatchFunc( out.dtype(), - [&]() { + [&](auto tag) { + using T = typename decltype(tag)::type; RmsNormKernel <<>>( reinterpret_cast(out.data()), stride_out_batch, diff --git a/src/device.h b/src/device.h index daf31a6..90fae55 100644 --- a/src/device.h +++ b/src/device.h @@ -85,63 +85,43 @@ struct EnabledDeviceFilter { // and FilterList will exclude it from ActiveDevices. #ifdef WITH_CPU - template = 0> - void operator()() const {} + void operator()(ValueTag) const {} #endif #ifdef WITH_NVIDIA - template = 0> - void operator()() const {} + void operator()(ValueTag) const {} #endif #ifdef WITH_CAMBRICON - template = 0> - void operator()() const {} + void operator()(ValueTag) const {} #endif #ifdef WITH_ASCEND - template = 0> - void operator()() const {} + void operator()(ValueTag) const {} #endif #ifdef WITH_METAX - template = 0> - void operator()() const {} + void operator()(ValueTag) const {} #endif #ifdef WITH_MOORE - template = 0> - void operator()() const {} + void operator()(ValueTag) const {} #endif #ifdef WITH_ILUVATAR - template = 0> - void operator()() const {} + void operator()(ValueTag) const {} #endif #ifdef WITH_KUNLUN - template = 0> - void operator()() const {} + void operator()(ValueTag) const {} #endif #ifdef WITH_HYGON - template = 0> - void operator()() const {} + void operator()(ValueTag) const {} #endif #ifdef WITH_QY - template = 0> - void operator()() const {} + void operator()(ValueTag) const {} #endif }; diff --git a/src/dispatcher.h b/src/dispatcher.h index 6b70da5..83b282c 100644 --- a/src/dispatcher.h +++ b/src/dispatcher.h @@ -16,103 +16,174 @@ namespace infini::ops { // Core Generic Runtime Dispatchers // ----------------------------------------------------------------------------- -// (Single Dispatch) Dispatches a runtime value to a compile-time functor. -template -auto DispatchFunc(ValueType value, Functor&& func, - std::string_view context_str = "", Args&&... args) { - using FilteredPack = - typename Filter, List<>, AllValues...>::type; - - return [&](List) { - using ReturnType = - decltype(std::forward(func) - .template operator()(Head)>( - std::forward(args)...)); - - // Path for Void Functions - if constexpr (std::is_void_v) { - bool handled = - ((value == static_cast(Tail) - ? (std::forward(func).template operator()( - std::forward(args)...), - true) - : false) || - ... || - (value == static_cast(Head) - ? (std::forward(func).template operator()( - std::forward(args)...), - true) - : false)); - - if (!handled) { - std::cerr << "Dispatch error (void): Value " << static_cast(value) - << " not supported in context: " << context_str << "\n"; - std::abort(); - } - } - // Path for Non-Void Functions - else { - std::optional result; - bool handled = - ((value == static_cast(Tail) - ? (result.emplace( - std::forward(func).template operator()( - std::forward(args)...)), - true) - : false) || - ... || - (value == static_cast(Head) - ? (result.emplace( - std::forward(func).template operator()( - std::forward(args)...)), - true) - : false)); - - if (handled) { - return *result; - } +namespace detail { + +// Implements the dispatch body over a resolved `List`. +template +auto DispatchFuncImpl(ValueType value, Functor &&func, + std::string_view context_str, List, + Args &&...args) { + using ReturnType = decltype(std::forward(func)( + ValueTag(head)>{}, std::forward(args)...)); + + // Path for void functions. + if constexpr (std::is_void_v) { + bool handled = ((value == static_cast(tail) + ? (std::forward(func)( + ValueTag{}, std::forward(args)...), + true) + : false) || + ... || + (value == static_cast(head) + ? (std::forward(func)( + ValueTag{}, std::forward(args)...), + true) + : false)); + + if (!handled) { // TODO(lzm): change to logging. - std::cerr << "Dispatch error (non-void): Value " - << static_cast(value) - << " not supported in context: " << context_str << "\n"; + std::cerr << "dispatch error (void): value " << static_cast(value) + << " not supported in the context: " << context_str << "\n"; std::abort(); - return ReturnType{}; } - }(FilteredPack{}); + } + // Path for non-void functions. + else { + std::optional result; + bool handled = ((value == static_cast(tail) + ? (result.emplace(std::forward(func)( + ValueTag{}, std::forward(args)...)), + true) + : false) || + ... || + (value == static_cast(head) + ? (result.emplace(std::forward(func)( + ValueTag{}, std::forward(args)...)), + true) + : false)); + + if (handled) { + return *result; + } + // TODO(lzm): change to logging. + std::cerr << "dispatch error (non-void): value " << static_cast(value) + << " not supported in the context: " << context_str << "\n"; + std::abort(); + return ReturnType{}; + } +} + +// Deduces `head`/`tail` from a `List` type via partial specialization, +// then forwards to `DispatchFuncImpl`. +template +struct DispatchFuncUnwrap; + +template +struct DispatchFuncUnwrap, + std::tuple> { + static auto call(ValueType value, Functor &&func, + std::string_view context_str, Args &&...args) { + return DispatchFuncImpl(value, std::forward(func), context_str, + List{}, std::forward(args)...); + } +}; + +// Empty-list specialization +template +struct DispatchFuncUnwrap, std::tuple> { + static auto call(ValueType value, Functor &&, std::string_view context_str, + Args &&...) { + // TODO(lzm): change to logging. + std::cerr << "dispatch error: no allowed values registered for value " + << static_cast(value) + << " in the context: " << context_str << "\n"; + std::abort(); + } +}; + +} // namespace detail + +// (Single Dispatch) Dispatches a runtime value to a compile-time functor. +template +auto DispatchFunc(ValueType value, Functor &&func, + std::string_view context_str = "", Args &&...args) { + using FilteredPack = typename Filter, List<>, + all_values...>::type; + + return detail::DispatchFuncUnwrap< + ValueType, Functor, FilteredPack, + std::tuple>::call(value, std::forward(func), + context_str, std::forward(args)...); } // (Multi-Dispatch) Dispatches a vector of runtime values to a compile-time // functor. -// Base Case: All dimensions resolved. -template -auto DispatchFunc(const std::vector& values, size_t index, - Functor&& func, std::string_view context_str, List, - Args&&... args) { - return std::forward(func).template operator()( +// Base Case: All Dimensions Resolved +template +auto DispatchFunc(const std::vector &values, size_t /*index*/, + Functor &&func, std::string_view /*context_str*/, + List, Args &&...args) { + return std::forward(func)(List{}, + std::forward(args)...); +} + +// Forward declaration of the recursive multi-dispatch overload. +template +auto DispatchFunc(const std::vector &values, size_t index, + Functor &&func, std::string_view context_str, List, + Args &&...args); + +// Adapter used in the recursive multi-dispatch case: given a resolved value +// `val` recurse into the next dimension. +template +struct MultiDispatchRecurseAdapter; + +template +struct MultiDispatchRecurseAdapter, Functor, items...> { + const std::vector &values; + size_t next_index; + Functor &func; + std::string_view context_str; + + template + auto operator()(ValueTag, Args &&...args) const { + return DispatchFunc(values, next_index, func, context_str, + List{}, + std::forward(args)...); + } +}; + +template +auto MultiDispatchFirstDim(const std::vector &values, size_t index, + Functor &func, std::string_view context_str, + List, List, Args &&...args) { + static_assert(sizeof...(allowed) > 0, + "`DispatchFunc` dimension list is empty"); + using EnumType = std::common_type_t; + + MultiDispatchRecurseAdapter adapter{ + values, index + 1, func, context_str}; + + return DispatchFunc( + static_cast(values.at(index)), adapter, context_str, std::forward(args)...); } // (Multi-Dispatch) Recursive Case template -auto DispatchFunc(const std::vector& values, size_t index, - Functor&& func, std::string_view context_str, List, - Args&&... args) { - return [&](List) { - static_assert(sizeof...(Allowed) > 0, - "`DispatchFunc` dimension list is empty"); - using EnumType = std::common_type_t; - - return DispatchFunc( - static_cast(values.at(index)), - [&](Args&&... inner_args) { - return DispatchFunc( - values, index + 1, std::forward(func), context_str, - List{}, std::forward(inner_args)...); - }, - context_str, std::forward(args)...); - }(FirstList{}); + typename... Args, auto... items> +auto DispatchFunc(const std::vector &values, size_t index, + Functor &&func, std::string_view context_str, List, + Args &&...args) { + return MultiDispatchFirstDim>( + values, index, func, context_str, List{}, FirstList{}, + std::forward(args)...); } // ----------------------------------------------------------------------------- @@ -120,75 +191,115 @@ auto DispatchFunc(const std::vector& values, size_t index, // ----------------------------------------------------------------------------- // These provide cleaner and more convenient APIs for common InfiniOps types. -// DataType Dispatch +namespace detail { + +// Bridges the generic value dispatch layer to the `DataType`-specific type +// dispatch layer. +template +struct DataTypeAdapter { + Functor &func; + + template + auto operator()(ValueTag, Args &&...args) const { + using T = TypeMapType(dtype)>; + return func(TypeTag{}, std::forward(args)...); + } +}; + +template +struct DataTypeMultiAdapter { + Functor &func; + + template + auto operator()(List, Args &&...args) const { + return func(TypeTag(dtypes)>>{}..., + std::forward(args)...); + } +}; + +template +struct DeviceAdapter { + Functor &func; + + template + auto operator()(ValueTag, Args &&...args) const { + return func(ValueTag{}, std::forward(args)...); + } +}; + +template +struct DeviceMultiAdapter { + Functor &func; + + template + auto operator()(List, Args &&...args) const { + return func(ValueTag{}..., std::forward(args)...); + } +}; + +} // namespace detail + +// `DataType` Dispatch template -auto DispatchFunc(DataType dtype, Functor&& func, - std::string_view context_str = "", Args&&... args) { - return DispatchFunc( - dtype, - [&](Args&&... inner_args) { - using T = TypeMapType
; - return std::forward(func).template operator()( - std::forward(inner_args)...); - }, - context_str, std::forward(args)...); +auto DispatchFunc(DataType dtype, Functor &&func, + std::string_view context_str = "", Args &&...args) { + detail::DataTypeAdapter> adapter{func}; + return DispatchFunc(dtype, adapter, context_str, + std::forward(args)...); } -// DataType Multi-Dispatch +// `DataType` Multi-Dispatch template -auto DispatchFunc(std::initializer_list dtypes, Functor&& func, - std::string_view context_str = "", Args&&... args) { +auto DispatchFunc(std::initializer_list dtypes, Functor &&func, + std::string_view context_str = "", Args &&...args) { std::vector v; for (auto d : dtypes) v.push_back(static_cast(d)); - return DispatchFunc( - v, 0, - [&func](Args&&... inner_args) { - return std::forward(func).template - operator()...>(std::forward(inner_args)...); - }, - context_str, List<>{}, std::forward(args)...); + detail::DataTypeMultiAdapter> adapter{func}; + return DispatchFunc(v, 0, adapter, context_str, List<>{}, + std::forward(args)...); } -// Device Dispatch -template -auto DispatchFunc(Device::Type device, Functor&& func, - std::string_view context_str = "", Args&&... args) { - return DispatchFunc( - device, - [&](Args&&... inner_args) { - return std::forward(func).template operator()( - std::forward(inner_args)...); - }, - context_str, std::forward(args)...); +// `Device` Dispatch +template +auto DispatchFunc(Device::Type device, Functor &&func, + std::string_view context_str = "", Args &&...args) { + detail::DeviceAdapter> adapter{func}; + return DispatchFunc(allowed_devices)...>( + device, adapter, context_str, std::forward(args)...); } -// Device Multi-Dispatch +// `Device` Multi-Dispatch template -auto DispatchFunc(std::initializer_list devices, Functor&& func, - std::string_view context_str = "", Args&&... args) { +auto DispatchFunc(std::initializer_list devices, Functor &&func, + std::string_view context_str = "", Args &&...args) { std::vector v; for (auto d : devices) v.push_back(static_cast(d)); - return DispatchFunc( - v, 0, - [&func](Args&&... inner_args) { - return std::forward(func).template operator()( - std::forward(inner_args)...); - }, - context_str, List<>{}, std::forward(args)...); + detail::DeviceMultiAdapter> adapter{func}; + return DispatchFunc(v, 0, adapter, context_str, List<>{}, + std::forward(args)...); } -// Interface for generic List Aliases, which unpacks a list. +template +auto DispatchFuncListAliasImpl(ValueType value, Functor &&func, + std::string_view context_str, List, + Args &&...args) { + return DispatchFunc>(items)...>( + value, std::forward(func), context_str, + std::forward(args)...); +} + +// Interface for Generic `List` Aliases template -auto DispatchFunc(ValueType value, Functor&& func, - std::string_view context_str = "", Args&&... args) { - return [&](List) { - return DispatchFunc>(Is)...>( - value, std::forward(func), context_str, - std::forward(args)...); - }(ListType{}); + typename... Args, + typename = std::enable_if_t::value>> +auto DispatchFunc(ValueType value, Functor &&func, + std::string_view context_str = "", Args &&...args) { + return DispatchFuncListAliasImpl(value, std::forward(func), + context_str, ListType{}, + std::forward(args)...); } } // namespace infini::ops diff --git a/src/operator.h b/src/operator.h index 0bbbc0c..dccc327 100644 --- a/src/operator.h +++ b/src/operator.h @@ -30,7 +30,8 @@ class Operator : public OperatorBase { DispatchFunc( tensor.device().type(), - [&]() { + [&](auto tag) { + constexpr Device::Type dev = decltype(tag)::value; if constexpr (std::is_constructible_v, const Tensor&, Args...>) { op_ptr = std::make_unique>( diff --git a/src/tensor.cc b/src/tensor.cc index fe6905d..b4806a2 100644 --- a/src/tensor.cc +++ b/src/tensor.cc @@ -114,7 +114,10 @@ std::string Tensor::ToStringHelper() const { if (ndim() == 0) { return DispatchFunc>( dtype_, - [&]() { return std::to_string(*static_cast(data_)); }, + [&](auto tag) { + using T = typename decltype(tag)::type; + return std::to_string(*static_cast(data_)); + }, "Tensor::ToStringHelper()"); } From 0256d48b84335cf3f4d32b6431a4b67ec8417d26 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang <45955067+voltjia@users.noreply.github.com> Date: Thu, 5 Mar 2026 20:05:32 +0800 Subject: [PATCH 72/93] refactor: introduce handle and workspace (#13) * refactor: update `Operator::call` to accept `handle`, `stream`, `workspace`, and `workspace_size_in_bytes` * feat: add `workspace_size_in_bytes` virtual method to `OperatorBase` --- examples/gemm/gemm.cc | 2 +- scripts/generate_wrappers.py | 1 - src/base/add.h | 2 +- src/base/gemm.h | 11 ++++---- src/base/rms_norm.h | 8 +++--- src/cpu/add/add.h | 7 +++-- src/cpu/gemm/gemm.h | 7 +++-- src/cpu/rms_norm/rms_norm.h | 4 +-- src/cuda/add/kernel.h | 4 +-- src/cuda/gemm/blas.h | 9 +++---- src/cuda/rms_norm/kernel.h | 6 ++--- src/handle.h | 24 ++++++++++++++--- src/operator.h | 51 +++++++++++++++++++++++++++--------- 13 files changed, 86 insertions(+), 50 deletions(-) diff --git a/examples/gemm/gemm.cc b/examples/gemm/gemm.cc index c611264..62779e6 100644 --- a/examples/gemm/gemm.cc +++ b/examples/gemm/gemm.cc @@ -70,7 +70,7 @@ int main() { Tensor c_device{c_ptr, c_host.shape(), c_host.dtype(), a_host.device(), c_host.strides()}; - Gemm::call(nullptr, a_device, b_device, c_device); + Gemm::call(a_device, b_device, c_device); DEVICE_MEMCPY(c_vec.data(), c_ptr, c_size, DEVICE_MEMCPY_DEVICE_TO_HOST); DEVICE_FREE(a_ptr); diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 0458a39..edde67c 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -435,7 +435,6 @@ def _get_all_ops(devices): namespace infini::ops {{ PYBIND11_MODULE(ops, m) {{ -{_INDENTATION}m.def("set_stream", [](std::uintptr_t stream) {{ OperatorBase::set_stream(reinterpret_cast(stream)); }}); {textwrap.indent(bind_func_calls, _INDENTATION)} }} diff --git a/src/base/add.h b/src/base/add.h index 49ae1ae..06bfa4c 100644 --- a/src/base/add.h +++ b/src/base/add.h @@ -33,7 +33,7 @@ class Add : public Operator { "same dtype"); } - virtual void operator()(void* stream, const Tensor input, const Tensor other, + virtual void operator()(const Tensor input, const Tensor other, Tensor out) const = 0; protected: diff --git a/src/base/gemm.h b/src/base/gemm.h index 2918312..0bb3502 100644 --- a/src/base/gemm.h +++ b/src/base/gemm.h @@ -39,21 +39,20 @@ class Gemm : public Operator { Gemm(const Tensor a, const Tensor b, Tensor c) : Gemm{a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt, c} {} - virtual void operator()(void* stream, const Tensor a, const Tensor b, + virtual void operator()(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) const = 0; - virtual void operator()(void* stream, const Tensor a, const Tensor b, - Tensor c) const { - return operator()(stream, a, b, std::nullopt, std::nullopt, std::nullopt, + virtual void operator()(const Tensor a, const Tensor b, Tensor c) const { + return operator()(a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt, c); } - virtual void operator()(void* stream, const Tensor a, const Tensor b, + virtual void operator()(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, Tensor c) const { - return operator()(stream, a, b, alpha, beta, std::nullopt, std::nullopt, c); + return operator()(a, b, alpha, beta, std::nullopt, std::nullopt, c); } protected: diff --git a/src/base/rms_norm.h b/src/base/rms_norm.h index db9041c..3b40a1c 100644 --- a/src/base/rms_norm.h +++ b/src/base/rms_norm.h @@ -25,12 +25,12 @@ class RmsNorm : public Operator { RmsNorm(const Tensor input, const Tensor weight, Tensor out) : RmsNorm{input, weight, 1e-6f, out} {} - virtual void operator()(void* stream, const Tensor input, const Tensor weight, - float eps, Tensor out) const = 0; + virtual void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const = 0; - virtual void operator()(void* stream, const Tensor input, const Tensor weight, + virtual void operator()(const Tensor input, const Tensor weight, Tensor out) const { - return operator()(stream, input, weight, eps_, out); + return operator()(input, weight, eps_, out); } protected: diff --git a/src/cpu/add/add.h b/src/cpu/add/add.h index 171c1ad..c76e4da 100644 --- a/src/cpu/add/add.h +++ b/src/cpu/add/add.h @@ -16,21 +16,20 @@ class Operator : public Add { // TODO: Check constraints. } - void operator()(void* stream, const Tensor input, const Tensor other, + void operator()(const Tensor input, const Tensor other, Tensor out) const override { DispatchFunc>( out_type_, [&](auto tag) { using T = typename decltype(tag)::type; - Compute(stream, input, other, out); + Compute(input, other, out); }, "`Operator::operator()`"); } private: template - void Compute(void* stream, const Tensor input, const Tensor other, - Tensor out) const { + void Compute(const Tensor input, const Tensor other, Tensor out) const { const auto* input_ptr = static_cast(input.data()); const auto* other_ptr = static_cast(other.data()); auto* out_ptr = static_cast(out.data()); diff --git a/src/cpu/gemm/gemm.h b/src/cpu/gemm/gemm.h index e61a174..9fe87b6 100644 --- a/src/cpu/gemm/gemm.h +++ b/src/cpu/gemm/gemm.h @@ -25,10 +25,9 @@ class Operator : public Gemm { std::optional beta, Tensor c) : Operator{a, b, alpha, beta, std::nullopt, std::nullopt, c} {} - void operator()(void* stream, const Tensor a, const Tensor b, - std::optional alpha, std::optional beta, - std::optional trans_a, std::optional trans_b, - Tensor c) const override { + void operator()(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) const override { const auto* A = static_cast(a.data()); const auto* B = static_cast(b.data()); auto* C = static_cast(c.data()); diff --git a/src/cpu/rms_norm/rms_norm.h b/src/cpu/rms_norm/rms_norm.h index 4656ba7..f032993 100644 --- a/src/cpu/rms_norm/rms_norm.h +++ b/src/cpu/rms_norm/rms_norm.h @@ -14,8 +14,8 @@ class Operator : public RmsNorm { public: using RmsNorm::RmsNorm; - void operator()(void* stream, const Tensor input, const Tensor weight, - float eps, Tensor out) const override { + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { // CPU backend supports fp32 only; fp16/bf16 use GPU backends. if (out.dtype() != DataType::kFloat32 || input.dtype() != DataType::kFloat32 || diff --git a/src/cuda/add/kernel.h b/src/cuda/add/kernel.h index d73a769..dcbf5f6 100644 --- a/src/cuda/add/kernel.h +++ b/src/cuda/add/kernel.h @@ -90,7 +90,7 @@ class CudaAdd : public Add { Backend::free(d_out_strides_); } - void operator()(void* stream, const Tensor input, const Tensor other, + void operator()(const Tensor input, const Tensor other, Tensor out) const override { DispatchFunc( out_type_, @@ -108,7 +108,7 @@ class CudaAdd : public Add { for (size_t i = 0; i < output_size_; i += step) { AddKernel<<(stream)>>>( + static_cast(stream_)>>>( d_out, d_input, d_other, d_out_shape_, d_input_shape_, d_other_shape_, d_out_strides_, d_input_strides_, d_other_strides_, output_size_, ndim_, i, is_out_contiguous_, diff --git a/src/cuda/gemm/blas.h b/src/cuda/gemm/blas.h index 47030cf..1a8f7a4 100644 --- a/src/cuda/gemm/blas.h +++ b/src/cuda/gemm/blas.h @@ -30,12 +30,11 @@ class Blas : public Gemm { Blas(const Tensor a, const Tensor b, Tensor c) : Blas{a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt, c} {} - void operator()(void* stream, const Tensor a, const Tensor b, - std::optional alpha, std::optional beta, - std::optional trans_a, std::optional trans_b, - Tensor c) const override { + void operator()(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) const override { Backend::blasSetStream(handle_, - static_cast(stream)); + static_cast(stream_)); const auto& alpha_value{alpha.value_or(alpha_)}; const auto& beta_value{beta.value_or(beta_)}; diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h index 6057be1..a2e27f2 100644 --- a/src/cuda/rms_norm/kernel.h +++ b/src/cuda/rms_norm/kernel.h @@ -25,10 +25,10 @@ class CudaRmsNorm : public RmsNorm { public: using RmsNorm::RmsNorm; - void operator()(void* stream, const Tensor input, const Tensor weight, - float eps, Tensor out) const override { + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { auto cuda_stream = - static_cast(stream ? stream : 0); + static_cast(stream_ ? stream_ : 0); auto stride_input_batch = input_strides_.size() > 1 ? input_strides_[0] : 0; auto stride_input_nhead = diff --git a/src/handle.h b/src/handle.h index 37a91b6..4deeb83 100644 --- a/src/handle.h +++ b/src/handle.h @@ -1,18 +1,34 @@ #ifndef INFINI_OPS_HANDLE_H_ #define INFINI_OPS_HANDLE_H_ -#include "device.h" +#include namespace infini::ops { class Handle { public: - Handle(Device device) : device_{device} {} + void* stream() const { return stream_; } - const Device& device() const { return device_; } + void* workspace() const { return workspace_; } + + std::size_t workspace_size_in_bytes() const { + return workspace_size_in_bytes_; + } + + void set_stream(void* stream) { stream_ = stream; } + + void set_workspace(void* workspace) { workspace_ = workspace; } + + void set_workspace_size_in_bytes(std::size_t workspace_size_in_bytes) { + workspace_size_in_bytes_ = workspace_size_in_bytes; + } private: - Device device_; + void* stream_{nullptr}; + + void* workspace_{nullptr}; + + std::size_t workspace_size_in_bytes_{0}; }; } // namespace infini::ops diff --git a/src/operator.h b/src/operator.h index dccc327..f40b976 100644 --- a/src/operator.h +++ b/src/operator.h @@ -7,6 +7,7 @@ #include #include "dispatcher.h" +#include "handle.h" #include "tensor.h" namespace infini::ops { @@ -15,10 +16,26 @@ class OperatorBase { public: virtual ~OperatorBase() = default; - static void set_stream(void* stream) { stream_ = stream; } + virtual std::size_t workspace_size_in_bytes() const { return 0; } + + void set_handle(const Handle& handle) { handle_ = handle; } + + void set_stream(void* stream) { stream_ = stream; } + + void set_workspace(void* workspace) { workspace_ = workspace; } + + void set_workspace_size_in_bytes(std::size_t workspace_size_in_bytes) { + workspace_size_in_bytes_ = workspace_size_in_bytes; + } protected: - inline static thread_local void* stream_{nullptr}; + Handle handle_; + + void* stream_{nullptr}; + + void* workspace_{nullptr}; + + std::size_t workspace_size_in_bytes_{0}; }; template @@ -46,7 +63,8 @@ class Operator : public OperatorBase { } template - static auto call(void* stream, Args&&... args) { + static auto call(const Handle& handle, void* stream, void* workspace, + std::size_t workspace_size_in_bytes, Args&&... args) { static std::unordered_map> cache; std::size_t hash{0}; @@ -59,23 +77,30 @@ class Operator : public OperatorBase { it = cache.emplace(hash, make(std::forward(args)...)).first; } - return (*it->second)(stream, std::forward(args)...); - } + auto& op{it->second}; - template - static auto call(const Tensor tensor, Args&&... args) { - return call(stream_, tensor, std::forward(args)...); + auto resolved_stream{stream ? stream : handle.stream()}; + auto resolved_workspace{workspace ? workspace : handle.workspace()}; + auto resolved_workspace_size{workspace_size_in_bytes + ? workspace_size_in_bytes + : handle.workspace_size_in_bytes()}; + + op->set_handle(handle); + op->set_stream(resolved_stream); + op->set_workspace(resolved_workspace); + op->set_workspace_size_in_bytes(resolved_workspace_size); + + return (*op)(std::forward(args)...); } template - auto operator()(void* stream, Args&&... args) const { - return (*static_cast(this))(stream, - std::forward(args)...); + static auto call(const Tensor tensor, Args&&... args) { + return call({}, nullptr, nullptr, 0, tensor, std::forward(args)...); } template - auto operator()(const Tensor tensor, Args&&... args) const { - return operator()(stream_, tensor, std::forward(args)...); + auto operator()(Args&&... args) const { + return (*static_cast(this))(std::forward(args)...); } }; From 24cc11a69b8baf54a703ccb2cc94ebc0d5f2d3c6 Mon Sep 17 00:00:00 2001 From: zhangyunze <93699316+bitzyz@users.noreply.github.com> Date: Thu, 5 Mar 2026 20:54:29 +0800 Subject: [PATCH 73/93] feat: add the implementation of `Gemm` operator on Cambricon (#7) * feat: add the implementation of operator on Cambricon * chore: format `src/cambricon/gemm/cnblas.h` with `clang-format` * refactor: update `src/cambricon/gemm/cnblas.h` to use latest `operator()` mechanism * refactor: update `src/cambricon/gemm/cnblas.h` to use `workspace_` from `OperatorBase` * chore: resolve PR comments * chore: reverse tensor descriptor destruction order --------- Co-authored-by: Jiacheng Huang --- CMakeLists.txt | 23 +++++- examples/gemm/gemm.cc | 3 + examples/runtime_api.h | 9 ++ src/CMakeLists.txt | 9 ++ src/cambricon/common.h | 25 ++++++ src/cambricon/gemm/cnblas.h | 159 ++++++++++++++++++++++++++++++++++++ tests/test_gemm.py | 4 + 7 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 src/cambricon/common.h create mode 100644 src/cambricon/gemm/cnblas.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 3a41236..a312238 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,6 +12,7 @@ option(WITH_CPU "Enable CPU backend" OFF) option(WITH_NVIDIA "Enable CUDA backend" OFF) option(WITH_ILUVATAR "Enable Iluvatar GPU backend" OFF) option(WITH_METAX "Enable MetaX backend" OFF) +option(WITH_CAMBRICON "Enable Cambricon backend" OFF) option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF) @@ -55,6 +56,11 @@ if(AUTO_DETECT_DEVICES) message(STATUS "No MetaX GPU detected") endif() endif() + + if(DEFINED ENV{NEUWARE_HOME}) + set(WITH_CAMBRICON ON) + message(STATUS "Auto-detected Cambricon environment.") + endif() endif() include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) @@ -105,7 +111,22 @@ if(WITH_METAX) find_library(MACA_BLAS_LIB NAMES mcblas HINTS "${MACA_PATH}/lib" REQUIRED) endif() -# If no GPU platform is enabled, CPU is enabled by default. +if(WITH_CAMBRICON) + add_compile_definitions(WITH_CAMBRICON=1) + set(NEUWARE_HOME $ENV{NEUWARE_HOME}) + + include_directories("${NEUWARE_HOME}/include") + link_directories("${NEUWARE_HOME}/lib") + link_directories("${NEUWARE_HOME}/lib64") + + # Libraries: `cnrt` / `cnnl` / `cnnl_extra` / `cnpapi`. + find_library(CAMBRICON_RUNTIME_LIB NAMES cnrt HINTS "${NEUWARE_HOME}/lib64" REQUIRED) + find_library(CAMBRICON_CNNL_LIB NAMES cnnl HINTS "${NEUWARE_HOME}/lib64" REQUIRED) + find_library(CAMBRICON_CNNL_EXTRA_LIB NAMES cnnl_extra HINTS "${NEUWARE_HOME}/lib64" REQUIRED) + find_library(CAMBRICON_PAPI_LIB NAMES cnpapi HINTS "${NEUWARE_HOME}/lib64" REQUIRED) +endif() + +# If all other platforms are not enabled, CPU is enabled by default. if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX) add_compile_definitions(WITH_CPU=1) endif() diff --git a/examples/gemm/gemm.cc b/examples/gemm/gemm.cc index 62779e6..bb82890 100644 --- a/examples/gemm/gemm.cc +++ b/examples/gemm/gemm.cc @@ -14,6 +14,9 @@ #if WITH_METAX #include "metax/gemm/mcblas.h" #endif +#if WITH_CAMBRICON +#include "cambricon/gemm/cnblas.h" +#endif #include "runtime_api.h" #include "tensor.h" diff --git a/examples/runtime_api.h b/examples/runtime_api.h index 896af64..b56a8fd 100644 --- a/examples/runtime_api.h +++ b/examples/runtime_api.h @@ -28,6 +28,15 @@ #define DEVICE_MEMCPY_HOST_TO_DEVICE mcMemcpyHostToDevice #define DEVICE_MEMCPY_DEVICE_TO_HOST mcMemcpyDeviceToHost #define DEFAULT_DEVICE_TYPE Device::Type::kMetax +#elif WITH_CAMBRICON +#include +#define DEVICE_MALLOC cnrtMalloc +#define DEVICE_FREE cnrtFree +#define DEVICE_MEMCPY cnrtMemcpy +#define DEVICE_MEMSET cnrtMemset +#define DEVICE_MEMCPY_HOST_TO_DEVICE cnrtMemcpyHostToDev +#define DEVICE_MEMCPY_DEVICE_TO_HOST cnrtMemcpyDevToHost +#define DEFAULT_DEVICE_TYPE Device::Type::kCambricon #elif WITH_CPU #include #include diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 02fbc4d..6eef5d3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -103,6 +103,15 @@ if(WITH_METAX) list(APPEND DEVICE_LIST "metax") endif() +if(WITH_CAMBRICON) + target_compile_definitions(infiniops PUBLIC WITH_CAMBRICON=1) + + target_include_directories(infiniops PUBLIC "${NEUWARE_HOME}/include") + target_link_libraries(infiniops PUBLIC ${CAMBRICON_RUNTIME_LIB} ${CAMBRICON_CNNL_LIB} ${CAMBRICON_CNNL_EXTRA_LIB} ${CAMBRICON_PAPI_LIB}) + + list(APPEND DEVICE_LIST "cambricon") +endif() + target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) if(GENERATE_PYTHON_BINDINGS) diff --git a/src/cambricon/common.h b/src/cambricon/common.h new file mode 100644 index 0000000..50775c2 --- /dev/null +++ b/src/cambricon/common.h @@ -0,0 +1,25 @@ +#ifndef INFINI_OPS_CAMBRICON_COMMON_H_ +#define INFINI_OPS_CAMBRICON_COMMON_H_ + +#include + +#include "data_type.h" + +namespace infini::ops::cnnl_utils { + +inline cnnlDataType_t GetDataType(DataType dtype) { + switch (dtype) { + case DataType::kInt32: + return CNNL_DTYPE_INT32; + case DataType::kFloat16: + return CNNL_DTYPE_HALF; + case DataType::kFloat32: + return CNNL_DTYPE_FLOAT; + default: + return CNNL_DTYPE_INVALID; + } +} + +} // namespace infini::ops::cnnl_utils + +#endif diff --git a/src/cambricon/gemm/cnblas.h b/src/cambricon/gemm/cnblas.h new file mode 100644 index 0000000..ac95bd5 --- /dev/null +++ b/src/cambricon/gemm/cnblas.h @@ -0,0 +1,159 @@ +#ifndef INFINI_OPS_CAMBRICON_GEMM_CNBLAS_H_ +#define INFINI_OPS_CAMBRICON_GEMM_CNBLAS_H_ + +#include +#include +#include + +// clang-format off +#include +#include +// clang-format on + +#include "base/gemm.h" +#include "cambricon/common.h" + +namespace infini::ops { + +template <> +class Operator : public Gemm { + public: + Operator(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) + : Gemm{a, b, alpha, beta, trans_a, trans_b, c}, + a_rows_{a.size(-2)}, + a_cols_{a.size(-1)}, + b_rows_{b.size(-2)}, + b_cols_{b.size(-1)}, + c_rows_{c.size(-2)}, + c_cols_{c.size(-1)} { + assert(!trans_a_ && "`trans_a` is not currently supported"); + assert(!trans_b_ && "`trans_b` is not currently supported"); + + cnnlCreate(&cnnl_handle_); + + cnnlCreateTensorDescriptor(&desc_a_); + cnnlCreateTensorDescriptor(&desc_b_); + cnnlCreateTensorDescriptor(&desc_c_); + + cnnlCreateMatMulDescriptor(&matmul_desc_); + cnnlCreateMatMulAlgo(&matmul_algo_); + cnnlCreateMatMulHeuristicResult(&heuristic_result_); + + int32_t use_stride = 1; + cnnlSetMatMulDescAttr(matmul_desc_, CNNL_MATMUL_USE_STRIDE, &use_stride, + sizeof(int32_t)); + + SetupTensorDescriptor(desc_a_, a_strides_, a_type_, a_rows_, a_cols_, + batch_count_, batch_stride_a_); + SetupTensorDescriptor(desc_b_, b_strides_, b_type_, b_rows_, b_cols_, + batch_count_, batch_stride_b_); + SetupTensorDescriptor(desc_c_, c_strides_, c_type_, c_rows_, c_cols_, + batch_count_, batch_stride_c_); + int count = 0; + cnnlGetBatchMatMulExAlgoHeuristic(cnnl_handle_, matmul_desc_, desc_a_, + desc_b_, desc_c_, NULL, 1, + &heuristic_result_, &count); + + cnrtMalloc(&default_workspace_, workspace_size_in_bytes()); + } + + Operator(const Tensor a, const Tensor b, Tensor c) + : Operator{a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + c} {} + + Operator(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, Tensor c) + : Operator{a, b, alpha, beta, std::nullopt, std::nullopt, c} {} + + ~Operator() { + cnrtFree(default_workspace_); + cnnlDestroyTensorDescriptor(desc_c_); + cnnlDestroyTensorDescriptor(desc_b_); + cnnlDestroyTensorDescriptor(desc_a_); + cnnlDestroyMatMulDescriptor(matmul_desc_); + cnnlDestroyMatMulAlgo(matmul_algo_); + cnnlDestroyMatMulHeuristicResult(heuristic_result_); + cnnlDestroy(cnnl_handle_); + } + + void operator()(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) const override { + const auto& alpha_value{alpha.value_or(alpha_)}; + const auto& beta_value{beta.value_or(beta_)}; + + cnnlSetQueue(cnnl_handle_, (cnrtQueue_t)stream_); + + auto workspace{workspace_ ? workspace_ : default_workspace_}; + auto workspace_size{workspace_size_in_bytes_ ? workspace_size_in_bytes_ + : workspace_size_in_bytes()}; + + cnnlBatchMatMulEx(cnnl_handle_, matmul_desc_, matmul_algo_, &alpha_value, + desc_a_, a.data(), desc_b_, b.data(), &beta_value, + desc_c_, c.data(), workspace, workspace_size); + } + + std::size_t workspace_size_in_bytes() const override { + std::size_t size{0}; + + cnnlGetBatchMatMulExHeuristicResult(heuristic_result_, matmul_algo_, &size); + + return size; + } + + private: + void SetupTensorDescriptor(cnnlTensorDescriptor_t desc, + const Tensor::Strides& strides, DataType dtype, + Tensor::Size rows, Tensor::Size cols, + Tensor::Size batch, Tensor::Stride batch_stride) { + cnnlDataType_t cnnl_dtype = cnnl_utils::GetDataType(dtype); + + if (batch > 1) { + std::vector dims = {static_cast(batch), static_cast(rows), + static_cast(cols)}; + std::vector strides_arr = { + static_cast(batch_stride), + static_cast(strides[strides.size() - 2]), + static_cast(strides[strides.size() - 1])}; + cnnlSetTensorDescriptorEx(desc, CNNL_LAYOUT_ARRAY, cnnl_dtype, + dims.size(), dims.data(), strides_arr.data()); + } else { + std::vector dims = {static_cast(rows), static_cast(cols)}; + std::vector strides_arr = { + static_cast(strides[strides.size() - 2]), + static_cast(strides[strides.size() - 1])}; + cnnlSetTensorDescriptorEx(desc, CNNL_LAYOUT_ARRAY, cnnl_dtype, + dims.size(), dims.data(), strides_arr.data()); + } + } + + cnnlHandle_t cnnl_handle_; + + cnnlTensorDescriptor_t desc_a_; + + cnnlTensorDescriptor_t desc_b_; + + cnnlTensorDescriptor_t desc_c_; + + cnnlMatMulDescriptor_t matmul_desc_; + + cnnlMatMulAlgo_t matmul_algo_; + + cnnlMatMulHeuristicResult_t heuristic_result_; + + Tensor::Size a_rows_, a_cols_; + + Tensor::Size b_rows_, b_cols_; + + Tensor::Size c_rows_, c_cols_; + + // TODO: Remove the following member after default workspace mechanism has + // been introduced globally. + void* default_workspace_{nullptr}; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_gemm.py b/tests/test_gemm.py index c091136..faee9d5 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -38,6 +38,10 @@ def test_gemm( rtol, atol, ): + # Skip transposing test cases for MLU platform as transposing is not currently supported. + if device == "mlu" and (trans_a or trans_b): + pytest.skip("transposing is not currently supported on MLU") + a = randn_strided(a_shape, a_strides, dtype=dtype, device=device) b = randn_strided(b_shape, b_strides, dtype=dtype, device=device) From ea78d153957f8c47651ca9afc12c4fefea95500f Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 6 Mar 2026 06:12:00 +0800 Subject: [PATCH 74/93] fix: include `"tensor.h"` instead of `"data_type.h"` and `"device.h"` in `src/pybind11_utils.h` --- src/pybind11_utils.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h index 674b022..8f48bf2 100644 --- a/src/pybind11_utils.h +++ b/src/pybind11_utils.h @@ -4,8 +4,7 @@ #include #include -#include "data_type.h" -#include "device.h" +#include "tensor.h" namespace py = pybind11; From a671e3a7012adca0c2b92371ac58d50390ead14c Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Fri, 6 Mar 2026 13:08:00 +0800 Subject: [PATCH 75/93] feat(ops): implement `CausalSoftmax` operator with CPU and CUDA backends (#12) * feat(ops): implement CausalSoftmax operator with CPU and CUDA backends * refactor(ops): update CausalSoftmax constructor and method signatures for consistency * style: improve assertion messages in CausalSoftmax for clarity and consistency * chore: format files with `clang-format` * test: disable testing skipping for unpresent `infini.ops.causal_softmax` * refactor: update `causal_softmax` to use latest `operator()` mechanism --------- Co-authored-by: Jiacheng Huang --- src/base/causal_softmax.h | 52 +++++++++++ src/cpu/causal_softmax/causal_softmax.h | 71 +++++++++++++++ src/cuda/causal_softmax/kernel.cuh | 116 ++++++++++++++++++++++++ src/cuda/causal_softmax/kernel.h | 61 +++++++++++++ src/iluvatar/causal_softmax/kernel.h | 31 +++++++ src/nvidia/causal_softmax/kernel.h | 31 +++++++ tests/test_causal_softmax.py | 57 ++++++++++++ 7 files changed, 419 insertions(+) create mode 100644 src/base/causal_softmax.h create mode 100644 src/cpu/causal_softmax/causal_softmax.h create mode 100644 src/cuda/causal_softmax/kernel.cuh create mode 100644 src/cuda/causal_softmax/kernel.h create mode 100644 src/iluvatar/causal_softmax/kernel.h create mode 100644 src/nvidia/causal_softmax/kernel.h create mode 100644 tests/test_causal_softmax.py diff --git a/src/base/causal_softmax.h b/src/base/causal_softmax.h new file mode 100644 index 0000000..b8393d8 --- /dev/null +++ b/src/base/causal_softmax.h @@ -0,0 +1,52 @@ +#ifndef INFINI_OPS_BASE_CAUSAL_SOFTMAX_H_ +#define INFINI_OPS_BASE_CAUSAL_SOFTMAX_H_ + +#include +#include + +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class CausalSoftmax : public Operator { + public: + CausalSoftmax(const Tensor input, Tensor out) + : dtype_{input.dtype()}, + ndim_{out.ndim()}, + batch_size_{ndim_ == 2 ? 1 : out.size(-3)}, + seq_len_{out.size(-2)}, + total_seq_len_{out.size(-1)}, + input_strides_{input.strides()}, + out_strides_{out.strides()} { + assert(input.shape() == out.shape() && + "`CausalSoftmax` requires `input` and `out` same shape"); + assert(input.dtype() == out.dtype() && + "`CausalSoftmax` requires `input` and `out` same dtype"); + assert((ndim_ == 2 || ndim_ == 3) && + "`CausalSoftmax` requires 2D or 3D tensor"); + assert(seq_len_ <= total_seq_len_ && + "`CausalSoftmax` requires shape[-2] <= shape[-1]"); + } + + virtual void operator()(const Tensor input, Tensor out) const = 0; + + protected: + const DataType dtype_; + + Tensor::Size ndim_{0}; + + Tensor::Size batch_size_{0}; + + Tensor::Size seq_len_{0}; + + Tensor::Size total_seq_len_{0}; + + Tensor::Strides input_strides_; + + Tensor::Strides out_strides_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/causal_softmax/causal_softmax.h b/src/cpu/causal_softmax/causal_softmax.h new file mode 100644 index 0000000..0005c4f --- /dev/null +++ b/src/cpu/causal_softmax/causal_softmax.h @@ -0,0 +1,71 @@ +#ifndef INFINI_OPS_CPU_CAUSAL_SOFTMAX_H_ +#define INFINI_OPS_CPU_CAUSAL_SOFTMAX_H_ + +#include + +#include "base/causal_softmax.h" +#include "data_type.h" +#include "tensor.h" + +namespace infini::ops { + +template <> +class Operator : public CausalSoftmax { + public: + Operator(const Tensor input, Tensor out) : CausalSoftmax{input, out} {} + + void operator()(const Tensor input, Tensor out) const override { + if (out.dtype() != DataType::kFloat32 || + input.dtype() != DataType::kFloat32) { + std::abort(); + } + + auto* out_ptr = static_cast(out.data()); + const auto* input_ptr = static_cast(input.data()); + + auto out_stride_b = ndim_ == 3 ? out_strides_[0] : 0; + auto out_stride_i = out_strides_[ndim_ - 2]; + auto out_stride_j = out_strides_[ndim_ - 1]; + auto input_stride_b = ndim_ == 3 ? input_strides_[0] : 0; + auto input_stride_i = input_strides_[ndim_ - 2]; + auto input_stride_j = input_strides_[ndim_ - 1]; + + for (Tensor::Size bi = 0; bi < batch_size_; ++bi) { + for (Tensor::Size i = 0; i < seq_len_; ++i) { + ptrdiff_t out_offset = bi * out_stride_b + i * out_stride_i; + ptrdiff_t input_offset = bi * input_stride_b + i * input_stride_i; + float* out_row = out_ptr + out_offset; + const float* input_row = input_ptr + input_offset; + + Tensor::Size valid_len = total_seq_len_ - seq_len_ + i + 1; + + for (Tensor::Size j = valid_len; j < total_seq_len_; ++j) { + out_row[j * out_stride_j] = 0.0f; + } + + float max_val = input_row[0]; + for (Tensor::Size j = 1; j < valid_len; ++j) { + float v = input_row[j * input_stride_j]; + if (v > max_val) { + max_val = v; + } + } + + float sum = 0.0f; + for (Tensor::Size j = 0; j < valid_len; ++j) { + float v = std::exp(input_row[j * input_stride_j] - max_val); + out_row[j * out_stride_j] = v; + sum += v; + } + + for (Tensor::Size j = 0; j < valid_len; ++j) { + out_row[j * out_stride_j] /= sum; + } + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/causal_softmax/kernel.cuh b/src/cuda/causal_softmax/kernel.cuh new file mode 100644 index 0000000..d195237 --- /dev/null +++ b/src/cuda/causal_softmax/kernel.cuh @@ -0,0 +1,116 @@ +#ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_CUH_ +#define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_CUH_ + +#include +#include + +#include +#include +#include + +namespace infini::ops { + +namespace { + +template +__device__ __forceinline__ Data ExpAndCast(Compute x) { + Compute e = std::exp(x); + if constexpr (std::is_same_v) { + return __float2half(static_cast(e)); + } else if constexpr (std::is_same_v) { + return __float2bfloat16(static_cast(e)); + } else { + return static_cast(e); + } +} + +struct BlockMaxOp { + template + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return (a > b) ? a : b; + } +}; + +template +__device__ __forceinline__ Data BlockMax(const Data* data_ptr, size_t count) { + Data thread_max = count > 0 ? data_ptr[0] : Data{}; + for (size_t i = threadIdx.x; i < count; i += block_size) { + Data v = data_ptr[i]; + thread_max = (v > thread_max) ? v : thread_max; + } + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + return BlockReduce(temp_storage).Reduce(thread_max, BlockMaxOp()); +} + +template +__device__ __forceinline__ Compute BlockSum(const Data* data_ptr, + size_t count) { + Compute thread_sum = 0; + for (size_t i = threadIdx.x; i < count; i += block_size) { + thread_sum += Compute(data_ptr[i]); + } + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + return BlockReduce(temp_storage).Sum(thread_sum); +} + +} // namespace + +template +__global__ void CausalSoftmaxKernel( + Data* __restrict__ out_ptr, const Data* __restrict__ input_ptr, + size_t batch_size, size_t seq_len, size_t total_seq_len, + int64_t stride_out_batch, int64_t stride_out_row, + int64_t stride_input_batch, int64_t stride_input_row) { + size_t row_idx = blockIdx.x; + size_t batch_idx = blockIdx.y; + + Data* out_row = + out_ptr + batch_idx * stride_out_batch + row_idx * stride_out_row; + const Data* input_row = + input_ptr + batch_idx * stride_input_batch + row_idx * stride_input_row; + + size_t valid_len = total_seq_len - seq_len + row_idx + 1; + + __shared__ Data max_val; + Data block_max = BlockMax(input_row, valid_len); + if (threadIdx.x == 0) { + max_val = block_max; + } + __syncthreads(); + + for (size_t col = threadIdx.x; col < total_seq_len; col += block_size) { + if (col < valid_len) { + Compute diff = + static_cast(input_row[col]) - static_cast(max_val); + out_row[col] = ExpAndCast(diff); + } else { + out_row[col] = Data(0); + } + } + __syncthreads(); + + __shared__ Compute sum_val; + Compute block_sum = + BlockSum(out_row, total_seq_len); + if (threadIdx.x == 0) { + sum_val = block_sum; + } + __syncthreads(); + + for (size_t col = threadIdx.x; col < total_seq_len; col += block_size) { + Compute quot = static_cast(out_row[col]) / sum_val; + if constexpr (std::is_same_v) { + out_row[col] = __float2half(static_cast(quot)); + } else if constexpr (std::is_same_v) { + out_row[col] = __float2bfloat16(static_cast(quot)); + } else { + out_row[col] = static_cast(quot); + } + } +} + +} // namespace infini::ops + +#endif diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h new file mode 100644 index 0000000..610b042 --- /dev/null +++ b/src/cuda/causal_softmax/kernel.h @@ -0,0 +1,61 @@ +#ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_ +#define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "base/causal_softmax.h" +#include "cuda/causal_softmax/kernel.cuh" +#include "data_type.h" +#include "dispatcher.h" + +namespace infini::ops { + +namespace causal_softmax { + +constexpr unsigned int kBlockSize = 256; + +} // namespace causal_softmax + +template +class CudaCausalSoftmax : public CausalSoftmax { + public: + using CausalSoftmax::CausalSoftmax; + + void operator()(const Tensor input, Tensor out) const override { + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + + auto stride_input_batch = ndim_ == 3 ? input_strides_[0] : 0; + auto stride_input_row = input_strides_[ndim_ - 2]; + auto stride_out_batch = ndim_ == 3 ? out_strides_[0] : 0; + auto stride_out_row = out_strides_[ndim_ - 2]; + + dim3 grid(static_cast(seq_len_), + static_cast(batch_size_)); + + if (out.dtype() != input.dtype()) { + std::abort(); + } + + DispatchFunc( + out.dtype(), + [&](auto tag) { + using T = typename decltype(tag)::type; + CausalSoftmaxKernel + <<>>( + reinterpret_cast(out.data()), + reinterpret_cast(input.data()), batch_size_, + seq_len_, total_seq_len_, stride_out_batch, stride_out_row, + stride_input_batch, stride_input_row); + }, + "CudaCausalSoftmax::operator()"); + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/iluvatar/causal_softmax/kernel.h b/src/iluvatar/causal_softmax/kernel.h new file mode 100644 index 0000000..d216815 --- /dev/null +++ b/src/iluvatar/causal_softmax/kernel.h @@ -0,0 +1,31 @@ +#ifndef INFINI_OPS_ILUVATAR_CAUSAL_SOFTMAX_KERNEL_H_ +#define INFINI_OPS_ILUVATAR_CAUSAL_SOFTMAX_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/causal_softmax/kernel.h" + +namespace infini::ops { + +namespace causal_softmax { + +struct IluvatarBackend { + using stream_t = cudaStream_t; +}; + +} // namespace causal_softmax + +template <> +class Operator + : public CudaCausalSoftmax { + public: + using CudaCausalSoftmax::CudaCausalSoftmax; +}; + +} // namespace infini::ops + +#endif diff --git a/src/nvidia/causal_softmax/kernel.h b/src/nvidia/causal_softmax/kernel.h new file mode 100644 index 0000000..5be316a --- /dev/null +++ b/src/nvidia/causal_softmax/kernel.h @@ -0,0 +1,31 @@ +#ifndef INFINI_OPS_NVIDIA_CAUSAL_SOFTMAX_KERNEL_H_ +#define INFINI_OPS_NVIDIA_CAUSAL_SOFTMAX_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/causal_softmax/kernel.h" + +namespace infini::ops { + +namespace causal_softmax { + +struct NvidiaBackend { + using stream_t = cudaStream_t; +}; + +} // namespace causal_softmax + +template <> +class Operator + : public CudaCausalSoftmax { + public: + using CudaCausalSoftmax::CudaCausalSoftmax; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_causal_softmax.py b/tests/test_causal_softmax.py new file mode 100644 index 0000000..81a64b4 --- /dev/null +++ b/tests/test_causal_softmax.py @@ -0,0 +1,57 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, out_strides", + ( + ((3, 3), None, None), + ((3, 5), None, None), + ((32, 512), None, None), + ((32, 512), (1024, 1), (1024, 1)), + ((4, 20, 512), None, None), + ((4, 20, 512), (20480, 512, 1), None), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-5, 1e-5), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), + ), +) +def test_causal_softmax(shape, input_strides, out_strides, dtype, device, rtol, atol): + if device == "cpu" and dtype in (torch.float16, torch.bfloat16): + pytest.skip("CPU backend does not support fp16/bf16") + + input_tensor = randn_strided(shape, input_strides, dtype=dtype, device=device) + out = empty_strided(shape, out_strides, dtype=dtype, device=device) + + return Payload( + _causal_softmax, + _torch_causal_softmax, + (input_tensor, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _causal_softmax(input, out): + infini.ops.causal_softmax(input, out) + + return out + + +def _torch_causal_softmax(input, out): + mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1]) + masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32)) + result = torch.nn.functional.softmax(masked, dim=-1, dtype=input.dtype) + out.copy_(result) + + return out From 8442eff22e5658da0ebc226b0ddf7fb04b116452 Mon Sep 17 00:00:00 2001 From: Ziminli <70735843+Ziminli@users.noreply.github.com> Date: Fri, 6 Mar 2026 13:16:58 +0800 Subject: [PATCH 76/93] feat: support casting and CPU bfloat16 and float16 (#11) * feat: add the CPU implementation of float16 and bfloat16 and the CPU `Cast()` function - add the CPU implementation of float16 and bfloat16 as `float16_t` and `bfloat16_t` - add the CPU `Cast()` function that support conversion between any two CPU supported types, including the custom `float16_t` and `bfloat16_t` * style: change `indexToOffset()` to `IndexToOffset()` to comply with the styling requirement * feat: add the CUDA `Cast()` function * refactor: refactor CUDA `Cast` utility with SFINAE-based hardware dispatch and move them into `common/cuda/cast.h` * style: change the naming of some types in `common/cast.h` and `common/cuda/cast.h` to better comply with the naming rules * chore: remove unused header `data_type.h` in `common/cuda/kernel_commons.h` * style: adjust comments for styling rule compliance * style: change `float_t` and `bfloat16_t` to `Float16` and `BFloat16` and fix various styling issues. --- src/common/cast.h | 57 +++++++++++++++ src/common/cuda/cast.h | 103 +++++++++++++++++++++++++++ src/common/cuda/kernel_commons.h | 6 +- src/common/generic_utils.h | 2 +- src/cpu/add/add.h | 2 +- src/cuda/add/kernel.h | 6 +- src/data_type.h | 117 ++++++++++++++++++++++++++----- 7 files changed, 267 insertions(+), 26 deletions(-) create mode 100644 src/common/cast.h create mode 100644 src/common/cuda/cast.h diff --git a/src/common/cast.h b/src/common/cast.h new file mode 100644 index 0000000..4129941 --- /dev/null +++ b/src/common/cast.h @@ -0,0 +1,57 @@ +#ifndef INFINI_OPS_COMMON_CAST_H_ +#define INFINI_OPS_COMMON_CAST_H_ + +#include "data_type.h" + +namespace infini::ops { + +namespace detail { + +template +constexpr float ToFloatHelper(T &&x) { + using PureSrc = std::remove_cv_t>; + if constexpr (IsBFloat16 || IsFP16) { + return std::forward(x).ToFloat(); + } else { + return static_cast(std::forward(x)); + } +} + +template +constexpr Dst FromFloatHelper(float f) { + using PureDst = std::remove_cv_t>; + if constexpr (IsBFloat16 || IsFP16) { + return PureDst::FromFloat(f); + } else { + return static_cast(f); + } +} + +} // namespace detail + +template +Dst Cast(Src &&x) { + static_assert(!std::is_reference_v, + "`Cast` cannot return reference types"); + + using PureDst = std::remove_cv_t>; + using PureSrc = std::remove_cv_t>; + + if constexpr (std::is_same_v) { + return std::forward(x); + } + + constexpr bool src_is_custom = IsBFloat16 || IsFP16; + constexpr bool dst_is_custom = IsBFloat16 || IsFP16; + + if constexpr (!src_is_custom && !dst_is_custom) { + return static_cast(std::forward(x)); + } else { + return detail::FromFloatHelper( + detail::ToFloatHelper(std::forward(x))); + } +} + +} // namespace infini::ops + +#endif diff --git a/src/common/cuda/cast.h b/src/common/cuda/cast.h new file mode 100644 index 0000000..c89982b --- /dev/null +++ b/src/common/cuda/cast.h @@ -0,0 +1,103 @@ +#ifndef INFINI_OPS_COMMON_CUDA_CAST_H_ +#define INFINI_OPS_COMMON_CUDA_CAST_H_ + +#ifdef WITH_NVIDIA +#include +#elif WITH_METAX +#include +#endif + +#include "data_type.h" + +namespace infini::ops { + +namespace detail { + +template +using PureType = std::remove_cv_t>; + +template +__host__ __device__ constexpr float ToFloatHelper(T&& x) { + using PureSrc = PureType; + if constexpr (IsBFloat16) { + return __bfloat162float(x); + } else if constexpr (IsFP16) { + return __half2float(x); + } else { + return static_cast(std::forward(x)); + } +} + +template +__host__ __device__ constexpr Dst FromFloatHelper(float f) { + using PureDst = PureType; + if constexpr (IsBFloat16) { + return __float2bfloat16(f); + } else if constexpr (IsFP16) { + return __float2half(f); + } else { + return static_cast(f); + } +} + +// Priority tags for overload resolution. +struct PriorityLow {}; + +struct PriorityHigh : PriorityLow {}; + +// Fallback: lowest priority. This always matches if nothing else does. +template +__host__ __device__ constexpr Dst HardwareCast(Src&& x, PriorityLow) { + return FromFloatHelper(ToFloatHelper(std::forward(x))); +} + +// Usage: `DEFINE_DIRECT_CAST(INTRINSIC, CONDITION)`. +#define DEFINE_DIRECT_CAST(INTRINSIC, ...) \ + template \ + __host__ __device__ auto HardwareCast(Src x, PriorityHigh) \ + ->std::enable_if_t<(__VA_ARGS__), \ + decltype(INTRINSIC(std::declval()))> { \ + return INTRINSIC(x); \ + } + +DEFINE_DIRECT_CAST( + __bfloat162int_rn, + std::is_same_v, int>&& IsBFloat16>) +DEFINE_DIRECT_CAST( + __bfloat162short_rn, + std::is_same_v, short>&& IsBFloat16>) +DEFINE_DIRECT_CAST( + __int2bfloat16_rn, + IsBFloat16>&& std::is_same_v, int>) +DEFINE_DIRECT_CAST(__int2half_rn, + IsFP16>&& std::is_same_v, int>) +DEFINE_DIRECT_CAST( + __double2bfloat16, + IsBFloat16>&& std::is_same_v, double>) +DEFINE_DIRECT_CAST( + __double2half, + IsFP16>&& std::is_same_v, double>) +DEFINE_DIRECT_CAST(__half, IsFP16>&& IsBFloat16>) +#undef DEFINE_DIRECT_CAST + +} // namespace detail + +template +__host__ __device__ Dst Cast(Src&& x) { + static_assert(!std::is_reference_v, + "`Cast` cannot return reference types"); + + using PureSrc = std::remove_cv_t>; + using PureDst = std::remove_cv_t>; + + if constexpr (std::is_same_v) { + return std::forward(x); + } else { + return detail::HardwareCast(std::forward(x), + detail::PriorityHigh{}); + } +} + +} // namespace infini::ops + +#endif diff --git a/src/common/cuda/kernel_commons.h b/src/common/cuda/kernel_commons.h index 98b9f48..c5deb35 100644 --- a/src/common/cuda/kernel_commons.h +++ b/src/common/cuda/kernel_commons.h @@ -9,11 +9,13 @@ #include #endif +#include "cast.h" + namespace infini::ops { __forceinline__ __device__ __host__ size_t -indexToOffset(size_t flat_index, size_t ndim, const size_t *shape, - const ptrdiff_t *strides) { +IndexToOffset(size_t flat_index, size_t ndim, const size_t* shape, + const ptrdiff_t* strides) { size_t res = 0; for (size_t i = ndim; i-- > 0;) { res += (flat_index % shape[i]) * strides[i]; diff --git a/src/common/generic_utils.h b/src/common/generic_utils.h index 36df934..795f2fb 100644 --- a/src/common/generic_utils.h +++ b/src/common/generic_utils.h @@ -5,7 +5,7 @@ namespace infini::ops::utils { -std::size_t indexToOffset(std::size_t flat_index, std::size_t ndim, +std::size_t IndexToOffset(std::size_t flat_index, std::size_t ndim, const std::size_t* shape, const std::ptrdiff_t* strides) { std::size_t res = 0; diff --git a/src/cpu/add/add.h b/src/cpu/add/add.h index c76e4da..a3e9aab 100644 --- a/src/cpu/add/add.h +++ b/src/cpu/add/add.h @@ -36,7 +36,7 @@ class Operator : public Add { auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, const auto* strides) { - return is_contig ? i : utils::indexToOffset(i, ndim_, shape, strides); + return is_contig ? i : utils::IndexToOffset(i, ndim_, shape, strides); }; #pragma omp parallel for diff --git a/src/cuda/add/kernel.h b/src/cuda/add/kernel.h index dcbf5f6..b481255 100644 --- a/src/cuda/add/kernel.h +++ b/src/cuda/add/kernel.h @@ -40,13 +40,13 @@ __global__ void AddKernel( if (idx < output_size) { Tensor::Size out_idx = - out_contiguous ? idx : indexToOffset(idx, ndim, out_shape, out_strides); + out_contiguous ? idx : IndexToOffset(idx, ndim, out_shape, out_strides); Tensor::Size input_idx = input_contiguous ? idx - : indexToOffset(idx, ndim, input_shape, input_strides); + : IndexToOffset(idx, ndim, input_shape, input_strides); Tensor::Size other_idx = other_contiguous ? idx - : indexToOffset(idx, ndim, other_shape, other_strides); + : IndexToOffset(idx, ndim, other_shape, other_strides); out[out_idx] = AddOp{}(input[input_idx], other[other_idx]); } diff --git a/src/data_type.h b/src/data_type.h index 850dda2..8a3e544 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -2,6 +2,7 @@ #define INFINI_OPS_DATA_TYPE_H_ #include +#include #include #ifdef WITH_NVIDIA @@ -80,6 +81,86 @@ constexpr ConstexprMap kStringToDataType{{{ {"float64", DataType::kFloat64}, }}}; +struct Float16 { + std::uint16_t bits; + + static inline Float16 FromFloat(float val) { + std::uint32_t f32; + std::memcpy(&f32, &val, sizeof(f32)); + std::uint16_t sign = (f32 >> 16) & 0x8000; + std::int32_t exponent = ((f32 >> 23) & 0xFF) - 127; + std::uint32_t mantissa = f32 & 0x7FFFFF; + + if (exponent >= 16) { + // NaN + if (exponent == 128 && mantissa != 0) { + return {static_cast(sign | 0x7E00)}; + } + // Inf + return {static_cast(sign | 0x7C00)}; + } else if (exponent >= -14) { + return {static_cast(sign | ((exponent + 15) << 10) | + (mantissa >> 13))}; + } else if (exponent >= -24) { + mantissa |= 0x800000; + mantissa >>= (-14 - exponent); + return {static_cast(sign | (mantissa >> 13))}; + } + // Too small for subnormal: return signed zero. + return {sign}; + } + + inline float ToFloat() const { + std::uint32_t sign = (bits & 0x8000) << 16; + std::int32_t exponent = (bits >> 10) & 0x1F; + std::uint32_t mantissa = bits & 0x3FF; + std::uint32_t f32_bits; + + if (exponent == 31) { + f32_bits = sign | 0x7F800000 | (mantissa << 13); + } else if (exponent == 0) { + if (mantissa == 0) { + f32_bits = sign; + } else { + exponent = -14; + while ((mantissa & 0x400) == 0) { + mantissa <<= 1; + exponent--; + } + mantissa &= 0x3FF; + f32_bits = sign | ((exponent + 127) << 23) | (mantissa << 13); + } + } else { + f32_bits = sign | ((exponent + 127 - 15) << 23) | (mantissa << 13); + } + + float result; + std::memcpy(&result, &f32_bits, sizeof(result)); + return result; + } +}; + +struct BFloat16 { + std::uint16_t bits; + + static inline BFloat16 FromFloat(float val) { + std::uint32_t bits32; + std::memcpy(&bits32, &val, sizeof(bits32)); + + const std::uint32_t rounding_bias = 0x00007FFF + ((bits32 >> 16) & 1); + std::uint16_t bf16_bits = + static_cast((bits32 + rounding_bias) >> 16); + return {bf16_bits}; + } + + inline float ToFloat() const { + std::uint32_t bits32 = static_cast(bits) << 16; + float result; + std::memcpy(&result, &bits32, sizeof(result)); + return result; + } +}; + template struct TypeMap; @@ -103,14 +184,14 @@ inline constexpr DataType DataTypeMapValue = DataTypeMap::value; static constexpr DataType value = DataType::ENUM_VALUE; \ }; -DEFINE_DATA_TYPE_MAPPING(kUInt8, uint8_t) -DEFINE_DATA_TYPE_MAPPING(kInt8, int8_t) -DEFINE_DATA_TYPE_MAPPING(kUInt16, uint16_t) -DEFINE_DATA_TYPE_MAPPING(kInt16, int16_t) -DEFINE_DATA_TYPE_MAPPING(kUInt32, uint32_t) -DEFINE_DATA_TYPE_MAPPING(kInt32, int32_t) -DEFINE_DATA_TYPE_MAPPING(kUInt64, uint64_t) -DEFINE_DATA_TYPE_MAPPING(kInt64, int64_t) +DEFINE_DATA_TYPE_MAPPING(kUInt8, std::uint8_t) +DEFINE_DATA_TYPE_MAPPING(kInt8, std::int8_t) +DEFINE_DATA_TYPE_MAPPING(kUInt16, std::uint16_t) +DEFINE_DATA_TYPE_MAPPING(kInt16, std::int16_t) +DEFINE_DATA_TYPE_MAPPING(kUInt32, std::uint32_t) +DEFINE_DATA_TYPE_MAPPING(kInt32, std::int32_t) +DEFINE_DATA_TYPE_MAPPING(kUInt64, std::uint64_t) +DEFINE_DATA_TYPE_MAPPING(kInt64, std::int64_t) DEFINE_DATA_TYPE_MAPPING(kFloat32, float) DEFINE_DATA_TYPE_MAPPING(kFloat64, double) @@ -121,20 +202,18 @@ DEFINE_DATA_TYPE_MAPPING(kBFloat16, __nv_bfloat16) DEFINE_DATA_TYPE_MAPPING(kFloat16, __half) DEFINE_DATA_TYPE_MAPPING(kBFloat16, __maca_bfloat16) #else -// TODO(lzm): currently there's an ambiguity of uint16_t mapping to both kUInt16 -// and kFloat16/kBFloat16 for CPU. When CPU custom bfloat16/float16 types are -// defined, this should be replaced. -template <> -struct TypeMap { - using type = uint16_t; -}; -template <> -struct TypeMap { - using type = uint16_t; -}; +DEFINE_DATA_TYPE_MAPPING(kFloat16, Float16) +DEFINE_DATA_TYPE_MAPPING(kBFloat16, BFloat16) #endif #undef DEFINE_DATA_TYPE_MAPPING +// Define the traits to check whether a type is bfloat16 or float16. +template +inline constexpr bool IsBFloat16 = (DataTypeMapValue == DataType::kBFloat16); + +template +inline constexpr bool IsFP16 = (DataTypeMapValue == DataType::kFloat16); + // Defines the common categories of data types using List. using FloatTypes = List; using ReducedFloatTypes = List; From 42f1e205a96c565be34c8b25dd36f1cb50c5c94d Mon Sep 17 00:00:00 2001 From: zhangyunze <93699316+bitzyz@users.noreply.github.com> Date: Fri, 6 Mar 2026 16:40:12 +0800 Subject: [PATCH 77/93] feat: add `swiglu` op with NVIDIA and CPU backends (#10) * feat: add op with NVIDIA and CPU backends * fix: fix code as pr comment * chore: format `tests/test_swiglu.py` and `tests/utils.py` --------- Co-authored-by: Jiacheng Huang --- src/base/swiglu.h | 68 +++++++++++++++++++ src/common/cuda/kernel_commons.h | 41 +++++++++++ src/cpu/swiglu/swiglu.h | 58 ++++++++++++++++ src/cuda/swiglu/kernel.cuh | 104 ++++++++++++++++++++++++++++ src/cuda/swiglu/kernel.h | 113 +++++++++++++++++++++++++++++++ src/nvidia/swiglu/kernel.h | 41 +++++++++++ tests/test_swiglu.py | 52 ++++++++++++++ tests/utils.py | 10 +++ 8 files changed, 487 insertions(+) create mode 100644 src/base/swiglu.h create mode 100644 src/cpu/swiglu/swiglu.h create mode 100644 src/cuda/swiglu/kernel.cuh create mode 100644 src/cuda/swiglu/kernel.h create mode 100644 src/nvidia/swiglu/kernel.h create mode 100644 tests/test_swiglu.py diff --git a/src/base/swiglu.h b/src/base/swiglu.h new file mode 100644 index 0000000..023b14a --- /dev/null +++ b/src/base/swiglu.h @@ -0,0 +1,68 @@ +#ifndef INFINI_OPS_BASE_SWIGLU_H_ +#define INFINI_OPS_BASE_SWIGLU_H_ + +#include + +#include "operator.h" + +namespace infini::ops { + +class Swiglu : public Operator { + public: + Swiglu(const Tensor input, const Tensor gate, Tensor out) + : ndim_{out.ndim()}, + output_size_{out.numel()}, + input_type_{input.dtype()}, + gate_type_{gate.dtype()}, + out_type_{out.dtype()}, + input_shape_{input.shape()}, + gate_shape_{gate.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + gate_strides_{gate.strides()}, + out_strides_{out.strides()}, + is_input_contiguous_{input.IsContiguous()}, + is_gate_contiguous_{gate.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert( + input_type_ == gate_type_ && gate_type_ == out_type_ && + "operator `Swiglu` requires all input and output tensors to have the " + "same dtype"); + } + + virtual void operator()(const Tensor input, const Tensor gate, + Tensor out) const = 0; + + protected: + Tensor::Size ndim_{0}; + + Tensor::Size output_size_{0}; + + const DataType input_type_; + + const DataType gate_type_; + + const DataType out_type_; + + Tensor::Shape input_shape_; + + Tensor::Shape gate_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides gate_strides_; + + Tensor::Strides out_strides_; + + bool is_input_contiguous_{false}; + + bool is_gate_contiguous_{false}; + + bool is_out_contiguous_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/common/cuda/kernel_commons.h b/src/common/cuda/kernel_commons.h index c5deb35..43cce8f 100644 --- a/src/common/cuda/kernel_commons.h +++ b/src/common/cuda/kernel_commons.h @@ -2,17 +2,58 @@ #define INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_ #ifdef WITH_NVIDIA +#include +#include #include +using cuda_bfloat16 = nv_bfloat16; +using cuda_bfloat162 = nv_bfloat162; #elif defined(WITH_ILUVATAR) #include #elif WITH_METAX // TODO: Use `defined`. #include +using cuda_bfloat16 = maca_bfloat16; +using cuda_bfloat162 = maca_bfloat162; #endif #include "cast.h" namespace infini::ops { +constexpr int CUDA_BLOCK_SIZE_128 = 128; +constexpr int CUDA_BLOCK_SIZE_256 = 256; +constexpr int CUDA_BLOCK_SIZE_512 = 512; +constexpr int CUDA_BLOCK_SIZE_1024 = 1024; + +// Query the maximum threads per block for the current CUDA device. +inline int QueryMaxThreadsPerBlock() { +#ifdef WITH_NVIDIA + int device = 0; + cudaGetDevice(&device); + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, device); + return prop.maxThreadsPerBlock; +#elif WITH_METAX + // TODO: Add MCR device properties query for Metax. + return CUDA_BLOCK_SIZE_256; +#endif +} + +// Get optimal block size based on GPU hardware architecture. +inline int GetOptimalBlockSize() { + int max_threads = QueryMaxThreadsPerBlock(); + + // Select the largest supported block size for better performance. + if (max_threads >= CUDA_BLOCK_SIZE_1024) { + return CUDA_BLOCK_SIZE_1024; + } else if (max_threads >= CUDA_BLOCK_SIZE_512) { + return CUDA_BLOCK_SIZE_512; + } else if (max_threads >= CUDA_BLOCK_SIZE_256) { + return CUDA_BLOCK_SIZE_256; + } else { + return CUDA_BLOCK_SIZE_128; + } +} + __forceinline__ __device__ __host__ size_t IndexToOffset(size_t flat_index, size_t ndim, const size_t* shape, const ptrdiff_t* strides) { diff --git a/src/cpu/swiglu/swiglu.h b/src/cpu/swiglu/swiglu.h new file mode 100644 index 0000000..a01a4b6 --- /dev/null +++ b/src/cpu/swiglu/swiglu.h @@ -0,0 +1,58 @@ +#ifndef INFINI_OPS_CPU_SWIGLU_SWIGLU_H_ +#define INFINI_OPS_CPU_SWIGLU_SWIGLU_H_ + +#include + +#include "base/swiglu.h" +#include "common/generic_utils.h" + +namespace infini::ops { + +template <> +class Operator : public Swiglu { + public: + using Swiglu::Swiglu; + + void operator()(const Tensor input, const Tensor gate, + Tensor out) const override { + DispatchFunc( + out_type_, + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(input, gate, out); + }, + "Operator::operator()"); + } + + private: + template + void Compute(const Tensor input, const Tensor gate, Tensor out) const { + const auto* input_ptr = static_cast(input.data()); + const auto* gate_ptr = static_cast(gate.data()); + auto* out_ptr = static_cast(out.data()); + + auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, + const auto* strides) { + return is_contig ? i : utils::IndexToOffset(i, ndim_, shape, strides); + }; + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size_; ++i) { + auto input_idx = get_idx(i, is_input_contiguous_, input_shape_.data(), + input_strides_.data()); + auto gate_idx = get_idx(i, is_gate_contiguous_, gate_shape_.data(), + gate_strides_.data()); + auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), + out_strides_.data()); + const T x = input_ptr[input_idx]; + const T sigmoid_x = + static_cast(1.0 / (1.0 + std::exp(-static_cast(x)))); + const T swish_x = x * sigmoid_x; + out_ptr[out_idx] = swish_x * gate_ptr[gate_idx]; + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/swiglu/kernel.cuh b/src/cuda/swiglu/kernel.cuh new file mode 100644 index 0000000..f404450 --- /dev/null +++ b/src/cuda/swiglu/kernel.cuh @@ -0,0 +1,104 @@ +#ifndef INFINI_OPS_CUDA_SWIGLU_KERNEL_CUH_ +#define INFINI_OPS_CUDA_SWIGLU_KERNEL_CUH_ + +#include + +#include "common/cuda/kernel_commons.h" + +namespace infini::ops { + +// Optimized sigmoid function with support for vectorized types. +template +__device__ __forceinline__ T Sigmoid(const T& x) { + if constexpr (std::is_same_v) { + return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x)))); + } else if constexpr (std::is_same_v) { + return hrcp( + __hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x)))))); + } else if constexpr (std::is_same_v) { + float x0 = __bfloat162float(__low2bfloat16(x)); + float x1 = __bfloat162float(__high2bfloat16(x)); + float sig0 = __frcp_rn(__fadd_rn(1.0f, __expf(-x0))); + float sig1 = __frcp_rn(__fadd_rn(1.0f, __expf(-x1))); + return __floats2bfloat162_rn(sig0, sig1); + } else if constexpr (std::is_same_v) { + float xf = __bfloat162float(x); + return __float2bfloat16_rn(__frcp_rn(__fadd_rn(1.0f, __expf(-xf)))); + } else if constexpr (std::is_same_v) { + return __frcp_rn(__fadd_rn(1.0f, __expf(-x))); + } else { + return 1.0f / (1.0f + expf(-x)); + } +} + +// SwiGLU(x, gate) = Swish(x) * gate = (x * sigmoid(x)) * gate. +template +__global__ void SwigluKernel(T* out, const T* a, const T* b, + const size_t* out_shape, const size_t* input_shape, + const size_t* gate_shape, + const ptrdiff_t* out_strides, + const ptrdiff_t* input_strides, + const ptrdiff_t* gate_strides, size_t output_size, + size_t ndim, size_t offset, bool out_contiguous, + bool input_contiguous, bool gate_contiguous) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + + if (idx < output_size) { + size_t out_idx, input_idx, gate_idx; + + if (out_contiguous) { + out_idx = idx; + } else { + out_idx = IndexToOffset(idx, ndim, out_shape, out_strides); + } + + if (input_contiguous) { + input_idx = idx; + } else { + input_idx = IndexToOffset(idx, ndim, input_shape, input_strides); + } + + if (gate_contiguous) { + gate_idx = idx; + } else { + gate_idx = IndexToOffset(idx, ndim, gate_shape, gate_strides); + } + + T up = a[input_idx]; + T gate = b[gate_idx]; + + if constexpr (std::is_same_v) { + // Vectorized `half2` computation for better performance. + out[out_idx] = __hmul2(__hmul2(gate, Sigmoid(gate)), up); + } else if constexpr (std::is_same_v) { + // Optimized `half` precision computation. + out[out_idx] = __hmul(__hmul(gate, Sigmoid(gate)), up); + } else if constexpr (std::is_same_v) { + cuda_bfloat162 sig = Sigmoid(gate); + float gate0 = __bfloat162float(__low2bfloat16(gate)); + float gate1 = __bfloat162float(__high2bfloat16(gate)); + float sig0 = __bfloat162float(__low2bfloat16(sig)); + float sig1 = __bfloat162float(__high2bfloat16(sig)); + float up0 = __bfloat162float(__low2bfloat16(up)); + float up1 = __bfloat162float(__high2bfloat16(up)); + float res0 = __fmul_rn(__fmul_rn(gate0, sig0), up0); + float res1 = __fmul_rn(__fmul_rn(gate1, sig1), up1); + out[out_idx] = __floats2bfloat162_rn(res0, res1); + } else if constexpr (std::is_same_v) { + cuda_bfloat16 sig = Sigmoid(gate); + float gatef = __bfloat162float(gate); + float sigf = __bfloat162float(sig); + float upf = __bfloat162float(up); + out[out_idx] = + __float2bfloat16_rn(__fmul_rn(__fmul_rn(gatef, sigf), upf)); + } else if constexpr (std::is_same_v) { + out[out_idx] = __fmul_rn(__fmul_rn(gate, Sigmoid(gate)), up); + } else { + out[out_idx] = gate * Sigmoid(gate) * up; + } + } +} + +} // namespace infini::ops + +#endif diff --git a/src/cuda/swiglu/kernel.h b/src/cuda/swiglu/kernel.h new file mode 100644 index 0000000..7c459a6 --- /dev/null +++ b/src/cuda/swiglu/kernel.h @@ -0,0 +1,113 @@ +#ifndef INFINI_OPS_CUDA_SWIGLU_KERNEL_H_ +#define INFINI_OPS_CUDA_SWIGLU_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "base/swiglu.h" +#include "common/generic_utils.h" +#include "cuda/swiglu/kernel.cuh" + +namespace infini::ops { + +template +class CudaSwiglu : public Swiglu { + public: + CudaSwiglu(const Tensor input, const Tensor gate, Tensor out) + : Swiglu{input, gate, out} { + size_t shape_size = ndim_ * sizeof(*d_input_shape_); + size_t strides_size = ndim_ * sizeof(*d_input_strides_); + + Backend::malloc((void**)&d_input_shape_, shape_size); + Backend::malloc((void**)&d_gate_shape_, shape_size); + Backend::malloc((void**)&d_out_shape_, shape_size); + Backend::malloc((void**)&d_input_strides_, strides_size); + Backend::malloc((void**)&d_gate_strides_, strides_size); + Backend::malloc((void**)&d_out_strides_, strides_size); + + Backend::memcpy(d_input_shape_, input_shape_.data(), shape_size, + Backend::memcpyH2D); + Backend::memcpy(d_gate_shape_, gate_shape_.data(), shape_size, + Backend::memcpyH2D); + Backend::memcpy(d_out_shape_, out_shape_.data(), shape_size, + Backend::memcpyH2D); + Backend::memcpy(d_input_strides_, input_strides_.data(), strides_size, + Backend::memcpyH2D); + Backend::memcpy(d_gate_strides_, gate_strides_.data(), strides_size, + Backend::memcpyH2D); + Backend::memcpy(d_out_strides_, out_strides_.data(), strides_size, + Backend::memcpyH2D); + } + + ~CudaSwiglu() { + Backend::free(d_input_shape_); + Backend::free(d_gate_shape_); + Backend::free(d_out_shape_); + Backend::free(d_input_strides_); + Backend::free(d_gate_strides_); + Backend::free(d_out_strides_); + } + + void operator()(const Tensor input, const Tensor gate, + Tensor out) const override { + DispatchFunc( + out_type_, + [&](auto tag) { + using T = typename decltype(tag)::type; + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + int block_size = GetOptimalBlockSize(); + dim3 blockDims( + std::min(static_cast(block_size), output_size_)); + dim3 gridDims(utils::CeilDiv(output_size_, blockDims.x)); + size_t step = gridDims.x * blockDims.x; + + T* d_out = reinterpret_cast(out.data()); + const T* d_input = reinterpret_cast(input.data()); + const T* d_gate = reinterpret_cast(gate.data()); + +// Launch kernel with appropriate block size based on GPU architecture. +#define LAUNCH_SWIGLU_KERNEL(BLOCK_SIZE) \ + for (size_t i = 0; i < output_size_; i += step) { \ + SwigluKernel<<>>( \ + d_out, d_input, d_gate, d_out_shape_, d_input_shape_, d_gate_shape_, \ + d_out_strides_, d_input_strides_, d_gate_strides_, output_size_, \ + ndim_, i, is_out_contiguous_, is_input_contiguous_, \ + is_gate_contiguous_); \ + } + + if (block_size == CUDA_BLOCK_SIZE_1024) { + LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_1024) + } else if (block_size == CUDA_BLOCK_SIZE_512) { + LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_512) + } else if (block_size == CUDA_BLOCK_SIZE_256) { + LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_256) + } else { + LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_128) + } + +#undef LAUNCH_SWIGLU_KERNEL + }, + "CudaSwiglu::operator()"); + } + + private: + Tensor::Size* d_input_shape_{nullptr}; + + Tensor::Size* d_gate_shape_{nullptr}; + + Tensor::Size* d_out_shape_{nullptr}; + + Tensor::Stride* d_input_strides_{nullptr}; + + Tensor::Stride* d_gate_strides_{nullptr}; + + Tensor::Stride* d_out_strides_{nullptr}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/nvidia/swiglu/kernel.h b/src/nvidia/swiglu/kernel.h new file mode 100644 index 0000000..54644e5 --- /dev/null +++ b/src/nvidia/swiglu/kernel.h @@ -0,0 +1,41 @@ +#ifndef INFINI_OPS_NVIDIA_SWIGLU_KERNEL_H_ +#define INFINI_OPS_NVIDIA_SWIGLU_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/swiglu/kernel.h" + +namespace infini::ops { + +namespace swiglu { + +struct NvidiaBackend { + using stream_t = cudaStream_t; + + static constexpr auto malloc = [](auto&&... args) { + return cudaMalloc(std::forward(args)...); + }; + + static constexpr auto memcpy = cudaMemcpy; + + static constexpr auto free = cudaFree; + + static constexpr auto memcpyH2D = cudaMemcpyHostToDevice; +}; + +} // namespace swiglu + +template <> +class Operator + : public CudaSwiglu { + public: + using CudaSwiglu::CudaSwiglu; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py new file mode 100644 index 0000000..700537a --- /dev/null +++ b/tests/test_swiglu.py @@ -0,0 +1,52 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, rand_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, gate_strides, out_strides", + ( + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4, 4), None, None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((16, 5632), None, None, None), + ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), + ((4, 4, 5632), None, None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_swiglu( + shape, input_strides, gate_strides, out_strides, dtype, device, rtol, atol +): + if device == "cpu" and dtype in (torch.float16, torch.bfloat16): + pytest.skip("CPU backend does not support fp16/bf16") + + input = rand_strided(shape, input_strides, dtype=dtype, device=device) + gate = rand_strided(shape, gate_strides, dtype=dtype, device=device) + out = empty_strided(shape, out_strides, dtype=dtype, device=device) + + return Payload(_swiglu, _torch_swiglu, (input, gate, out), {}, rtol=rtol, atol=atol) + + +def _swiglu(input, gate, out): + infini.ops.swiglu(input, gate, out) + + return out + + +def _torch_swiglu(input, gate, out): + swish_x = gate * torch.sigmoid(gate) + + return torch.mul(input, swish_x, out=out) diff --git a/tests/utils.py b/tests/utils.py index 11afcdf..45b1205 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -53,6 +53,16 @@ def randn_strided(shape, strides, *, dtype=None, device=None): return output +def rand_strided(shape, strides, *, dtype=None, device=None): + output = empty_strided(shape, strides, dtype=dtype, device=device) + + output.as_strided( + (output.untyped_storage().size() // output.element_size(),), (1,) + ).uniform_(0, 1) + + return output + + def clone_strided(input): output = empty_strided( input.size(), input.stride(), dtype=input.dtype, device=input.device From 6e93d39935e5254125b221ba68b9afe30ab9a519 Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Wed, 11 Mar 2026 11:10:44 +0800 Subject: [PATCH 78/93] feat: reorganize casting utilities and enhance CPU support (#16) * feat: reorganize casting utilities and enhance CPU support - Moved casting functions to separate CPU and CUDA headers for better organization. - Introduced a new `Cast()` function in the CPU implementation to handle type conversions, including support for custom types like `float16_t` and `bfloat16_t`. - Updated various operators to utilize the new casting utilities, ensuring consistent type handling across CPU and CUDA backends. - Enhanced test cases to cover additional data types and ensure compatibility with the new casting logic. * fix: update bfloat16 test tolerance in `test_rms_norm.py` - Increased the tolerance for `bfloat16` from `1e-2` to `2e-2` to better accommodate numerical precision in tests. * format: simplify type dispatching in `Add` operator and formating --- src/common/cast.h | 57 ++-------------------- src/common/cpu/cast.h | 57 ++++++++++++++++++++++ src/common/cuda/cast.h | 4 +- src/cpu/add/add.h | 9 +++- src/cpu/causal_softmax/causal_softmax.h | 39 +++++++++------ src/cpu/gemm/gemm.h | 64 ++++++++++++++++++++----- src/cpu/rms_norm/rms_norm.h | 38 +++++++++------ src/cpu/swiglu/swiglu.h | 15 ++++-- src/cuda/gemm/blas.h | 11 +++-- src/iluvatar/gemm/cublas.h | 18 +++++++ src/metax/gemm/mcblas.h | 18 +++++++ src/nvidia/gemm/cublas.h | 18 +++++++ tests/test_add.py | 49 ++++++++++++++++--- tests/test_causal_softmax.py | 3 -- tests/test_gemm.py | 20 ++++++-- tests/test_rms_norm.py | 5 +- tests/test_swiglu.py | 3 -- tests/utils.py | 12 ++++- 18 files changed, 312 insertions(+), 128 deletions(-) create mode 100644 src/common/cpu/cast.h diff --git a/src/common/cast.h b/src/common/cast.h index 4129941..4973764 100644 --- a/src/common/cast.h +++ b/src/common/cast.h @@ -1,57 +1,10 @@ #ifndef INFINI_OPS_COMMON_CAST_H_ #define INFINI_OPS_COMMON_CAST_H_ -#include "data_type.h" - -namespace infini::ops { - -namespace detail { - -template -constexpr float ToFloatHelper(T &&x) { - using PureSrc = std::remove_cv_t>; - if constexpr (IsBFloat16 || IsFP16) { - return std::forward(x).ToFloat(); - } else { - return static_cast(std::forward(x)); - } -} - -template -constexpr Dst FromFloatHelper(float f) { - using PureDst = std::remove_cv_t>; - if constexpr (IsBFloat16 || IsFP16) { - return PureDst::FromFloat(f); - } else { - return static_cast(f); - } -} - -} // namespace detail - -template -Dst Cast(Src &&x) { - static_assert(!std::is_reference_v, - "`Cast` cannot return reference types"); - - using PureDst = std::remove_cv_t>; - using PureSrc = std::remove_cv_t>; - - if constexpr (std::is_same_v) { - return std::forward(x); - } - - constexpr bool src_is_custom = IsBFloat16 || IsFP16; - constexpr bool dst_is_custom = IsBFloat16 || IsFP16; - - if constexpr (!src_is_custom && !dst_is_custom) { - return static_cast(std::forward(x)); - } else { - return detail::FromFloatHelper( - detail::ToFloatHelper(std::forward(x))); - } -} - -} // namespace infini::ops +#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) || defined(WITH_METAX) +#include "common/cuda/cast.h" +#else +#include "common/cpu/cast.h" +#endif #endif diff --git a/src/common/cpu/cast.h b/src/common/cpu/cast.h new file mode 100644 index 0000000..68b95fc --- /dev/null +++ b/src/common/cpu/cast.h @@ -0,0 +1,57 @@ +#ifndef INFINI_OPS_COMMON_CPU_CAST_H_ +#define INFINI_OPS_COMMON_CPU_CAST_H_ + +#include "data_type.h" + +namespace infini::ops { + +namespace detail { + +template +constexpr float ToFloatHelper(T &&x) { + using PureSrc = std::remove_cv_t >; + if constexpr (IsBFloat16 || IsFP16) { + return std::forward(x).ToFloat(); + } else { + return static_cast(std::forward(x)); + } +} + +template +constexpr Dst FromFloatHelper(float f) { + using PureDst = std::remove_cv_t >; + if constexpr (IsBFloat16 || IsFP16) { + return PureDst::FromFloat(f); + } else { + return static_cast(f); + } +} + +} // namespace detail + +template +Dst Cast(Src &&x) { + static_assert(!std::is_reference_v, + "`Cast` cannot return reference types"); + + using PureDst = std::remove_cv_t >; + using PureSrc = std::remove_cv_t >; + + if constexpr (std::is_same_v) { + return std::forward(x); + } + + constexpr bool src_is_custom = IsBFloat16 || IsFP16; + constexpr bool dst_is_custom = IsBFloat16 || IsFP16; + + if constexpr (!src_is_custom && !dst_is_custom) { + return static_cast(std::forward(x)); + } else { + return detail::FromFloatHelper( + detail::ToFloatHelper(std::forward(x))); + } +} + +} // namespace infini::ops + +#endif diff --git a/src/common/cuda/cast.h b/src/common/cuda/cast.h index c89982b..d3dcdb9 100644 --- a/src/common/cuda/cast.h +++ b/src/common/cuda/cast.h @@ -3,7 +3,9 @@ #ifdef WITH_NVIDIA #include -#elif WITH_METAX +#elif defined(WITH_ILUVATAR) +#include +#elif defined(WITH_METAX) #include #endif diff --git a/src/cpu/add/add.h b/src/cpu/add/add.h index a3e9aab..ec605c3 100644 --- a/src/cpu/add/add.h +++ b/src/cpu/add/add.h @@ -4,6 +4,7 @@ #include #include "base/add.h" +#include "common/cast.h" #include "common/generic_utils.h" namespace infini::ops { @@ -18,7 +19,7 @@ class Operator : public Add { void operator()(const Tensor input, const Tensor other, Tensor out) const override { - DispatchFunc>( + DispatchFunc( out_type_, [&](auto tag) { using T = typename decltype(tag)::type; @@ -30,6 +31,9 @@ class Operator : public Add { private: template void Compute(const Tensor input, const Tensor other, Tensor out) const { + using ComputeType = + std::conditional_t || IsFP16, float, T>; + const auto* input_ptr = static_cast(input.data()); const auto* other_ptr = static_cast(other.data()); auto* out_ptr = static_cast(out.data()); @@ -48,7 +52,8 @@ class Operator : public Add { auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), out_strides_.data()); - out_ptr[out_idx] = input_ptr[input_idx] + other_ptr[other_idx]; + out_ptr[out_idx] = Cast(Cast(input_ptr[input_idx]) + + Cast(other_ptr[other_idx])); } } }; diff --git a/src/cpu/causal_softmax/causal_softmax.h b/src/cpu/causal_softmax/causal_softmax.h index 0005c4f..ca207a2 100644 --- a/src/cpu/causal_softmax/causal_softmax.h +++ b/src/cpu/causal_softmax/causal_softmax.h @@ -4,6 +4,8 @@ #include #include "base/causal_softmax.h" +#include "common/cast.h" +#include "common/generic_utils.h" #include "data_type.h" #include "tensor.h" @@ -15,13 +17,20 @@ class Operator : public CausalSoftmax { Operator(const Tensor input, Tensor out) : CausalSoftmax{input, out} {} void operator()(const Tensor input, Tensor out) const override { - if (out.dtype() != DataType::kFloat32 || - input.dtype() != DataType::kFloat32) { - std::abort(); - } + DispatchFunc( + out.dtype(), + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(input, out); + }, + "`Operator::operator()`"); + } - auto* out_ptr = static_cast(out.data()); - const auto* input_ptr = static_cast(input.data()); + private: + template + void Compute(const Tensor input, Tensor out) const { + auto* out_ptr = static_cast(out.data()); + const auto* input_ptr = static_cast(input.data()); auto out_stride_b = ndim_ == 3 ? out_strides_[0] : 0; auto out_stride_i = out_strides_[ndim_ - 2]; @@ -34,18 +43,18 @@ class Operator : public CausalSoftmax { for (Tensor::Size i = 0; i < seq_len_; ++i) { ptrdiff_t out_offset = bi * out_stride_b + i * out_stride_i; ptrdiff_t input_offset = bi * input_stride_b + i * input_stride_i; - float* out_row = out_ptr + out_offset; - const float* input_row = input_ptr + input_offset; + T* out_row = out_ptr + out_offset; + const T* input_row = input_ptr + input_offset; Tensor::Size valid_len = total_seq_len_ - seq_len_ + i + 1; for (Tensor::Size j = valid_len; j < total_seq_len_; ++j) { - out_row[j * out_stride_j] = 0.0f; + out_row[j * out_stride_j] = Cast(0.0f); } - float max_val = input_row[0]; + float max_val = Cast(input_row[0]); for (Tensor::Size j = 1; j < valid_len; ++j) { - float v = input_row[j * input_stride_j]; + float v = Cast(input_row[j * input_stride_j]); if (v > max_val) { max_val = v; } @@ -53,13 +62,15 @@ class Operator : public CausalSoftmax { float sum = 0.0f; for (Tensor::Size j = 0; j < valid_len; ++j) { - float v = std::exp(input_row[j * input_stride_j] - max_val); - out_row[j * out_stride_j] = v; + float v = + std::exp(Cast(input_row[j * input_stride_j]) - max_val); + out_row[j * out_stride_j] = Cast(v); sum += v; } for (Tensor::Size j = 0; j < valid_len; ++j) { - out_row[j * out_stride_j] /= sum; + out_row[j * out_stride_j] = + Cast(Cast(out_row[j * out_stride_j]) / sum); } } } diff --git a/src/cpu/gemm/gemm.h b/src/cpu/gemm/gemm.h index 9fe87b6..685a94a 100644 --- a/src/cpu/gemm/gemm.h +++ b/src/cpu/gemm/gemm.h @@ -4,6 +4,8 @@ #include #include "base/gemm.h" +#include "common/cast.h" +#include "common/generic_utils.h" namespace infini::ops { @@ -28,27 +30,63 @@ class Operator : public Gemm { void operator()(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) const override { - const auto* A = static_cast(a.data()); - const auto* B = static_cast(b.data()); - auto* C = static_cast(c.data()); + DispatchFunc( + c.dtype(), + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(a, b, alpha, beta, trans_a, trans_b, c); + }, + "`Operator::operator()`"); + } + + private: + template + void Compute(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) const { + const auto* A = static_cast(a.data()); + const auto* B = static_cast(b.data()); + auto* C = static_cast(c.data()); const auto& alpha_value{alpha.value_or(alpha_)}; const auto& beta_value{beta.value_or(beta_)}; const auto& trans_a_value{trans_a.value_or(trans_a_)}; const auto& trans_b_value{trans_b.value_or(trans_b_)}; - for (Tensor::Size i = 0; i < m_; ++i) { - for (Tensor::Size j = 0; j < n_; ++j) { - float sum = 0.0f; + Tensor::Stride stride_a_m = trans_a_value + ? a_strides_[a_strides_.size() - 1] + : a_strides_[a_strides_.size() - 2]; + Tensor::Stride stride_a_k = trans_a_value + ? a_strides_[a_strides_.size() - 2] + : a_strides_[a_strides_.size() - 1]; + Tensor::Stride stride_b_k = trans_b_value + ? b_strides_[b_strides_.size() - 1] + : b_strides_[b_strides_.size() - 2]; + Tensor::Stride stride_b_n = trans_b_value + ? b_strides_[b_strides_.size() - 2] + : b_strides_[b_strides_.size() - 1]; + Tensor::Stride stride_c_m = c_strides_[c_strides_.size() - 2]; + Tensor::Stride stride_c_n = c_strides_[c_strides_.size() - 1]; - for (Tensor::Size l = 0; l < k_; ++l) { - float a_val = trans_a_value ? A[l * m_ + i] : A[i * k_ + l]; - float b_val = trans_b_value ? B[j * k_ + l] : B[l * n_ + j]; - sum += a_val * b_val; - } + for (Tensor::Size b = 0; b < batch_count_; ++b) { + const auto* A_batch = A + b * batch_stride_a_; + const auto* B_batch = B + b * batch_stride_b_; + auto* C_batch = C + b * batch_stride_c_; - Tensor::Size idx = i * n_ + j; - C[idx] = alpha_value * sum + beta_value * C[idx]; + for (Tensor::Size i = 0; i < m_; ++i) { + for (Tensor::Size j = 0; j < n_; ++j) { + float sum = 0.0f; + + for (Tensor::Size l = 0; l < k_; ++l) { + float a_val = Cast(A_batch[i * stride_a_m + l * stride_a_k]); + float b_val = Cast(B_batch[l * stride_b_k + j * stride_b_n]); + sum += a_val * b_val; + } + + Tensor::Size idx = i * stride_c_m + j * stride_c_n; + float c_val = beta_value == 0.0f ? 0.0f : Cast(C_batch[idx]); + C_batch[idx] = Cast(alpha_value * sum + beta_value * c_val); + } } } } diff --git a/src/cpu/rms_norm/rms_norm.h b/src/cpu/rms_norm/rms_norm.h index f032993..b3caeb0 100644 --- a/src/cpu/rms_norm/rms_norm.h +++ b/src/cpu/rms_norm/rms_norm.h @@ -4,6 +4,8 @@ #include #include "base/rms_norm.h" +#include "common/cast.h" +#include "common/generic_utils.h" #include "data_type.h" #include "tensor.h" @@ -16,16 +18,22 @@ class Operator : public RmsNorm { void operator()(const Tensor input, const Tensor weight, float eps, Tensor out) const override { - // CPU backend supports fp32 only; fp16/bf16 use GPU backends. - if (out.dtype() != DataType::kFloat32 || - input.dtype() != DataType::kFloat32 || - weight.dtype() != DataType::kFloat32) { - abort(); - } + DispatchFunc( + out.dtype(), + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(input, weight, eps, out); + }, + "`Operator::operator()`"); + } - auto* out_ptr = static_cast(out.data()); - const auto* input_ptr = static_cast(input.data()); - const auto* weight_ptr = static_cast(weight.data()); + private: + template + void Compute(const Tensor input, const Tensor weight, float eps, + Tensor out) const { + auto* out_ptr = static_cast(out.data()); + const auto* input_ptr = static_cast(input.data()); + const auto* weight_ptr = static_cast(weight.data()); auto stride_input_batch = input_strides_.size() > 1 ? input_strides_[0] : 0; auto stride_input_nhead = @@ -36,20 +44,20 @@ class Operator : public RmsNorm { for (Tensor::Size bi = 0; bi < batch_size_; ++bi) { for (Tensor::Size hi = 0; hi < nhead_; ++hi) { - const float* input_row = + const T* input_row = input_ptr + bi * stride_input_batch + hi * stride_input_nhead; - float* out_row = - out_ptr + bi * stride_out_batch + hi * stride_out_nhead; + T* out_row = out_ptr + bi * stride_out_batch + hi * stride_out_nhead; float ss = 0; for (Tensor::Size k = 0; k < dim_; ++k) { - float v = input_row[k]; + float v = Cast(input_row[k]); ss += v * v; } - float rms = 1.f / std::sqrt(ss / static_cast(dim_) + eps_); + float rms = 1.f / std::sqrt(ss / static_cast(dim_) + eps); for (Tensor::Size k = 0; k < dim_; ++k) { - out_row[k] = input_row[k] * weight_ptr[k] * rms; + out_row[k] = Cast(Cast(input_row[k]) * + Cast(weight_ptr[k]) * rms); } } } diff --git a/src/cpu/swiglu/swiglu.h b/src/cpu/swiglu/swiglu.h index a01a4b6..ac2b3b2 100644 --- a/src/cpu/swiglu/swiglu.h +++ b/src/cpu/swiglu/swiglu.h @@ -4,6 +4,7 @@ #include #include "base/swiglu.h" +#include "common/cast.h" #include "common/generic_utils.h" namespace infini::ops { @@ -27,6 +28,9 @@ class Operator : public Swiglu { private: template void Compute(const Tensor input, const Tensor gate, Tensor out) const { + using ComputeType = + std::conditional_t || IsFP16, float, T>; + const auto* input_ptr = static_cast(input.data()); const auto* gate_ptr = static_cast(gate.data()); auto* out_ptr = static_cast(out.data()); @@ -44,11 +48,12 @@ class Operator : public Swiglu { gate_strides_.data()); auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), out_strides_.data()); - const T x = input_ptr[input_idx]; - const T sigmoid_x = - static_cast(1.0 / (1.0 + std::exp(-static_cast(x)))); - const T swish_x = x * sigmoid_x; - out_ptr[out_idx] = swish_x * gate_ptr[gate_idx]; + const ComputeType gate_val = Cast(gate_ptr[gate_idx]); + const ComputeType sigmoid_gate = static_cast( + 1.0 / (1.0 + std::exp(-static_cast(gate_val)))); + const ComputeType swish_gate = gate_val * sigmoid_gate; + out_ptr[out_idx] = + Cast(Cast(input_ptr[input_idx]) * swish_gate); } } }; diff --git a/src/cuda/gemm/blas.h b/src/cuda/gemm/blas.h index 1a8f7a4..5f669ca 100644 --- a/src/cuda/gemm/blas.h +++ b/src/cuda/gemm/blas.h @@ -46,14 +46,17 @@ class Blas : public Gemm { Backend::blasGemmStridedBatchedEx( handle_, op_a, op_b, swap_a_and_b_ ? n_ : m_, swap_a_and_b_ ? m_ : n_, - k_, &alpha_value, swap_a_and_b_ ? b.data() : a.data(), Backend::R_32F, + k_, &alpha_value, swap_a_and_b_ ? b.data() : a.data(), + Backend::GetDataType(swap_a_and_b_ ? b.dtype() : a.dtype()), swap_a_and_b_ ? ldb_ : lda_, swap_a_and_b_ ? batch_stride_b_ : batch_stride_a_, - swap_a_and_b_ ? a.data() : b.data(), Backend::R_32F, + swap_a_and_b_ ? a.data() : b.data(), + Backend::GetDataType(swap_a_and_b_ ? a.dtype() : b.dtype()), swap_a_and_b_ ? lda_ : ldb_, swap_a_and_b_ ? batch_stride_a_ : batch_stride_b_, &beta_value, - c.data(), Backend::R_32F, ldc_, batch_stride_c_, batch_count_, - Backend::BLAS_COMPUTE_32F_FAST_TF32, Backend::BLAS_GEMM_DEFAULT); + c.data(), Backend::GetDataType(c.dtype()), ldc_, batch_stride_c_, + batch_count_, Backend::GetComputeType(c.dtype()), + Backend::BLAS_GEMM_DEFAULT); } private: diff --git a/src/iluvatar/gemm/cublas.h b/src/iluvatar/gemm/cublas.h index 969df69..cbf287a 100644 --- a/src/iluvatar/gemm/cublas.h +++ b/src/iluvatar/gemm/cublas.h @@ -22,10 +22,16 @@ struct IluvatarBackend { static constexpr auto BLAS_OP_T = CUBLAS_OP_T; + static constexpr auto R_16F = CUDA_R_16F; + + static constexpr auto R_16BF = CUDA_R_16BF; + static constexpr auto R_32F = CUDA_R_32F; // Iluvatar uses `cudaDataType` for `computeType`, so we need to use // `CUDA_R_32F` instead of `CUBLAS_COMPUTE_32F_FAST_TF32`. + static constexpr auto BLAS_COMPUTE_32F = CUDA_R_32F; + static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = CUDA_R_32F; // Iluvatar uses `CUBLAS_GEMM_DEFAULT_TENSOR_OP` instead of @@ -41,6 +47,18 @@ struct IluvatarBackend { static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) { return cublasGemmStridedBatchedEx(std::forward(args)...); }; + + static auto GetDataType(DataType dtype) { + if (dtype == DataType::kFloat16) return R_16F; + if (dtype == DataType::kBFloat16) return R_16BF; + return R_32F; + } + + static auto GetComputeType(DataType dtype) { + if (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16) + return BLAS_COMPUTE_32F; + return BLAS_COMPUTE_32F_FAST_TF32; + } }; } // namespace gemm diff --git a/src/metax/gemm/mcblas.h b/src/metax/gemm/mcblas.h index 1fd6f2f..4d5f313 100644 --- a/src/metax/gemm/mcblas.h +++ b/src/metax/gemm/mcblas.h @@ -22,8 +22,14 @@ struct MetaxBackend { static constexpr auto BLAS_OP_T = MCBLAS_OP_T; + static constexpr auto R_16F = MACA_R_16F; + + static constexpr auto R_16BF = MACA_R_16BF; + static constexpr auto R_32F = MACA_R_32F; + static constexpr auto BLAS_COMPUTE_32F = MCBLAS_COMPUTE_32F; + static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = MCBLAS_COMPUTE_32F_FAST_TF32; @@ -38,6 +44,18 @@ struct MetaxBackend { static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) { return mcblasGemmStridedBatchedEx(std::forward(args)...); }; + + static auto GetDataType(DataType dtype) { + if (dtype == DataType::kFloat16) return R_16F; + if (dtype == DataType::kBFloat16) return R_16BF; + return R_32F; + } + + static auto GetComputeType(DataType dtype) { + if (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16) + return BLAS_COMPUTE_32F; + return BLAS_COMPUTE_32F_FAST_TF32; + } }; } // namespace gemm diff --git a/src/nvidia/gemm/cublas.h b/src/nvidia/gemm/cublas.h index 16c1b7a..eaf3b40 100644 --- a/src/nvidia/gemm/cublas.h +++ b/src/nvidia/gemm/cublas.h @@ -22,8 +22,14 @@ struct NvidiaBackend { static constexpr auto BLAS_OP_T = CUBLAS_OP_T; + static constexpr auto R_16F = CUDA_R_16F; + + static constexpr auto R_16BF = CUDA_R_16BF; + static constexpr auto R_32F = CUDA_R_32F; + static constexpr auto BLAS_COMPUTE_32F = CUBLAS_COMPUTE_32F; + static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = CUBLAS_COMPUTE_32F_FAST_TF32; @@ -38,6 +44,18 @@ struct NvidiaBackend { static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) { return cublasGemmStridedBatchedEx(std::forward(args)...); }; + + static auto GetDataType(DataType dtype) { + if (dtype == DataType::kFloat16) return R_16F; + if (dtype == DataType::kBFloat16) return R_16BF; + return R_32F; + } + + static auto GetComputeType(DataType dtype) { + if (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16) + return BLAS_COMPUTE_32F; + return BLAS_COMPUTE_32F_FAST_TF32; + } }; } // namespace gemm diff --git a/tests/test_add.py b/tests/test_add.py index f5f84b4..afbce0d 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -2,7 +2,16 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randn_strided +from tests.utils import Payload, empty_strided, randint_strided, randn_strided + +_INT_DTYPES = ( + torch.int16, + torch.uint16, + torch.int32, + torch.uint32, + torch.int64, + torch.uint64, +) @pytest.mark.auto_act_and_assert @@ -23,11 +32,28 @@ ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ), ) -def test_add( - shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol -): - input = randn_strided(shape, input_strides, dtype=dtype, device=device) - other = randn_strided(shape, other_strides, dtype=dtype, device=device) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + (torch.int16, 0, 0), + (torch.uint16, 0, 0), + (torch.int32, 0, 0), + (torch.uint32, 0, 0), + (torch.int64, 0, 0), + (torch.uint64, 0, 0), + ), +) +def test_add(shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol): + if dtype in _INT_DTYPES: + input = randint_strided(0, 100, shape, input_strides, dtype=dtype, device=device) + other = randint_strided(0, 100, shape, other_strides, dtype=dtype, device=device) + else: + input = randn_strided(shape, input_strides, dtype=dtype, device=device) + other = randn_strided(shape, other_strides, dtype=dtype, device=device) + out = empty_strided(shape, out_strides, dtype=dtype, device=device) return Payload(_add, _torch_add, (input, other, out), {}, rtol=rtol, atol=atol) @@ -40,4 +66,13 @@ def _add(input, other, out): def _torch_add(input, other, out): - return torch.add(input, other, out=out) + if input.dtype in (torch.uint16, torch.uint32, torch.uint64): + input = input.to(torch.int64) + + if other.dtype in (torch.uint16, torch.uint32, torch.uint64): + other = other.to(torch.int64) + + res = torch.add(input, other) + out.copy_(res.to(out.dtype)) + + return out diff --git a/tests/test_causal_softmax.py b/tests/test_causal_softmax.py index 81a64b4..8b35457 100644 --- a/tests/test_causal_softmax.py +++ b/tests/test_causal_softmax.py @@ -26,9 +26,6 @@ ), ) def test_causal_softmax(shape, input_strides, out_strides, dtype, device, rtol, atol): - if device == "cpu" and dtype in (torch.float16, torch.bfloat16): - pytest.skip("CPU backend does not support fp16/bf16") - input_tensor = randn_strided(shape, input_strides, dtype=dtype, device=device) out = empty_strided(shape, out_strides, dtype=dtype, device=device) diff --git a/tests/test_gemm.py b/tests/test_gemm.py index faee9d5..e5d00ed 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randn_strided +from tests.utils import Payload, randn_strided @pytest.mark.auto_act_and_assert @@ -20,8 +20,14 @@ @pytest.mark.parametrize("beta", (-1, -0.5, 0, 0.5, 1)) @pytest.mark.parametrize("trans_a", (False, True)) @pytest.mark.parametrize("trans_b", (False, True)) -# TODO: Add support for more data types. -@pytest.mark.parametrize(("dtype", "rtol", "atol"), ((torch.float32, 1e-3, 1e-3),)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-3, 1e-3), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), + ), +) def test_gemm( a_shape, b_shape, @@ -51,7 +57,7 @@ def test_gemm( if trans_b: b = b.transpose(-2, -1) - c = empty_strided(c_shape, c_strides, dtype=dtype, device=device) + c = randn_strided(c_shape, c_strides, dtype=dtype, device=device) return Payload( _gemm, @@ -76,6 +82,12 @@ def _torch_gemm(a, b, alpha=1.0, beta=1.0, trans_a=False, trans_b=False, c=None) if trans_b: b = b.transpose(-2, -1) + # PyTorch `baddbmm`/`addmm` ignores `beta` when `alpha=0.0`. + if alpha == 0: + c.mul_(beta) + + return c + if a.ndim == 2: return torch.addmm(c, a, b, beta=beta, alpha=alpha, out=c) diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index 12ec7ee..f447091 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -23,7 +23,7 @@ ( (torch.float32, 1e-4, 1e-4), (torch.float16, 1e-2, 1e-2), - (torch.bfloat16, 1e-2, 1e-2), + (torch.bfloat16, 2e-2, 1e-2), ), ) def test_rms_norm( @@ -38,9 +38,6 @@ def test_rms_norm( rtol, atol, ): - if device == "cpu" and dtype in (torch.float16, torch.bfloat16): - pytest.skip("CPU backend does not support fp16/bf16") - input = randn_strided(input_shape, input_strides, dtype=dtype, device=device) weight = randn_strided(weight_shape, weight_strides, dtype=dtype, device=device) out = empty_strided(input_shape, out_strides, dtype=dtype, device=device) diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py index 700537a..89c95f7 100644 --- a/tests/test_swiglu.py +++ b/tests/test_swiglu.py @@ -30,9 +30,6 @@ def test_swiglu( shape, input_strides, gate_strides, out_strides, dtype, device, rtol, atol ): - if device == "cpu" and dtype in (torch.float16, torch.bfloat16): - pytest.skip("CPU backend does not support fp16/bf16") - input = rand_strided(shape, input_strides, dtype=dtype, device=device) gate = rand_strided(shape, gate_strides, dtype=dtype, device=device) out = empty_strided(shape, out_strides, dtype=dtype, device=device) diff --git a/tests/utils.py b/tests/utils.py index 45b1205..78a350f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,7 +21,7 @@ class Payload: def get_available_devices(): - devices = [] + devices = ["cpu"] if torch.cuda.is_available(): devices.append("cuda") @@ -63,6 +63,16 @@ def rand_strided(shape, strides, *, dtype=None, device=None): return output +def randint_strided(low, high, shape, strides, *, dtype=None, device=None): + output = empty_strided(shape, strides, dtype=dtype, device=device) + + output.as_strided( + (output.untyped_storage().size() // output.element_size(),), (1,) + ).random_(low, high) + + return output + + def clone_strided(input): output = empty_strided( input.size(), input.stride(), dtype=input.dtype, device=input.device From 71fc3884b700bdbd58605bf49fc5698ab2304a33 Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Thu, 12 Mar 2026 17:22:02 +0800 Subject: [PATCH 79/93] fix: add equality operators and `CacheKey` `struct` (#18) * fix: add equality operators and CacheKey struct for improved operator caching - Implemented `operator==` and `operator!=` for the `Device` class to facilitate comparison. - Introduced `CacheKey` struct in `operator.h` to enhance caching mechanism with a hash and vector of tensors. - Updated the `Operator::call` method to utilize `CacheKey` for caching operators based on input arguments. - Added `MetaEqual` method in `Tensor` class for tensor comparison based on metadata. * refactor: move CacheKey struct to detail namespace and enhance Tensor comparison - Changed the namespace of `CacheKey` to `infini::ops::detail` for better organization. - Updated the hash and equality operators for `CacheKey` to reflect the new namespace. - Removed the `MetaEqual` method from the `Tensor` class and replaced it with a dedicated `std::equal_to` specialization for `Tensor` to improve comparison logic. * style: remove unnecessary blank line in cublas.h for improved readability --- src/device.h | 6 ++++ src/iluvatar/gemm/cublas.h | 2 +- src/operator.h | 67 ++++++++++++++++++++++++++++++++++---- src/tensor.h | 9 +++++ 4 files changed, 77 insertions(+), 7 deletions(-) diff --git a/src/device.h b/src/device.h index 90fae55..5d9b3ee 100644 --- a/src/device.h +++ b/src/device.h @@ -43,6 +43,12 @@ class Device { return std::string{StringFromType(type_)} + ":" + std::to_string(index_); } + bool operator==(const Device& other) const { + return type_ == other.type_ && index_ == other.index_; + } + + bool operator!=(const Device& other) const { return !(*this == other); } + private: Type type_{Type::kCpu}; diff --git a/src/iluvatar/gemm/cublas.h b/src/iluvatar/gemm/cublas.h index cbf287a..310d888 100644 --- a/src/iluvatar/gemm/cublas.h +++ b/src/iluvatar/gemm/cublas.h @@ -31,7 +31,7 @@ struct IluvatarBackend { // Iluvatar uses `cudaDataType` for `computeType`, so we need to use // `CUDA_R_32F` instead of `CUBLAS_COMPUTE_32F_FAST_TF32`. static constexpr auto BLAS_COMPUTE_32F = CUDA_R_32F; - + static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = CUDA_R_32F; // Iluvatar uses `CUBLAS_GEMM_DEFAULT_TENSOR_OP` instead of diff --git a/src/operator.h b/src/operator.h index f40b976..be6fb51 100644 --- a/src/operator.h +++ b/src/operator.h @@ -5,11 +5,66 @@ #include #include #include +#include #include "dispatcher.h" #include "handle.h" #include "tensor.h" +namespace infini::ops::detail { + +struct CacheKey { + std::size_t hash; + + std::vector tensors; + + std::size_t scalar_hash; + + template + static CacheKey Build(const Args&... args) { + CacheKey key; + key.hash = 0; + key.scalar_hash = 0; + (key.Absorb(args), ...); + return key; + } + + private: + void Absorb(const Tensor& t) { + hash_combine(hash, t); + tensors.push_back(t); + } + + template + void Absorb(const T& v) { + hash_combine(hash, v); + hash_combine(scalar_hash, v); + } +}; + +} // namespace infini::ops::detail + +template <> +struct std::hash { + std::size_t operator()(const infini::ops::detail::CacheKey& key) const { + return key.hash; + } +}; + +template <> +struct std::equal_to { + bool operator()(const infini::ops::detail::CacheKey& a, + const infini::ops::detail::CacheKey& b) const { + if (a.scalar_hash != b.scalar_hash) return false; + if (a.tensors.size() != b.tensors.size()) return false; + std::equal_to eq; + for (std::size_t i = 0; i < a.tensors.size(); ++i) { + if (!eq(a.tensors[i], b.tensors[i])) return false; + } + return true; + } +}; + namespace infini::ops { class OperatorBase { @@ -65,16 +120,16 @@ class Operator : public OperatorBase { template static auto call(const Handle& handle, void* stream, void* workspace, std::size_t workspace_size_in_bytes, Args&&... args) { - static std::unordered_map> cache; - - std::size_t hash{0}; + static std::unordered_map> + cache; - (hash_combine(hash, args), ...); + auto key = detail::CacheKey::Build(args...); - auto it{cache.find(hash)}; + auto it{cache.find(key)}; if (it == cache.end()) { - it = cache.emplace(hash, make(std::forward(args)...)).first; + it = cache.emplace(std::move(key), make(std::forward(args)...)) + .first; } auto& op{it->second}; diff --git a/src/tensor.h b/src/tensor.h index cc2bf9e..bbe72f8 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -145,4 +145,13 @@ struct std::hash { } }; +template <> +struct std::equal_to { + bool operator()(const infini::ops::Tensor& a, + const infini::ops::Tensor& b) const { + return a.dtype() == b.dtype() && a.device() == b.device() && + a.shape() == b.shape() && a.strides() == b.strides(); + } +}; + #endif From d094e1010af77cdae61fa2a5415294b1db892e4b Mon Sep 17 00:00:00 2001 From: gongchensu Date: Tue, 17 Mar 2026 16:09:36 +0800 Subject: [PATCH 80/93] feat(gemm-moore): add Moore GEMM backend support (#14) * feat(gemm-moore): add Moore (MUSA) GEMM backend support. * refactor(gemm-moore): reuse shared BLAS helper and specialize scalars. * build: use detected Python interpreter for wrapper generation. --------- Co-authored-by: zhuyue --- CMakeLists.txt | 31 +++++++++++++- examples/gemm/gemm.cc | 3 ++ examples/runtime_api.h | 9 +++++ src/CMakeLists.txt | 10 ++++- src/cuda/gemm/blas.h | 20 ++++++--- src/moore/gemm/mublas.h | 89 +++++++++++++++++++++++++++++++++++++++++ tests/test_gemm.py | 16 ++++++-- tests/utils.py | 3 ++ 8 files changed, 171 insertions(+), 10 deletions(-) create mode 100644 src/moore/gemm/mublas.h diff --git a/CMakeLists.txt b/CMakeLists.txt index a312238..18e98c9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,7 @@ option(WITH_NVIDIA "Enable CUDA backend" OFF) option(WITH_ILUVATAR "Enable Iluvatar GPU backend" OFF) option(WITH_METAX "Enable MetaX backend" OFF) option(WITH_CAMBRICON "Enable Cambricon backend" OFF) +option(WITH_MOORE "Enable Moore backend" OFF) option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF) @@ -61,6 +62,34 @@ if(AUTO_DETECT_DEVICES) set(WITH_CAMBRICON ON) message(STATUS "Auto-detected Cambricon environment.") endif() + + if(DEFINED ENV{MUSA_ROOT} OR DEFINED ENV{MUSA_HOME} OR DEFINED ENV{MUSA_PATH}) + set(WITH_MOORE ON) + set(WITH_MOORE ON CACHE BOOL "Enable Moore backend" FORCE) + message(STATUS "Auto-detected Moore environment.") + else() + set(WITH_MOORE OFF) + set(WITH_MOORE OFF CACHE BOOL "Enable Moore backend" FORCE) + endif() +endif() + +if(WITH_MOORE) + set(MUSA_ROOT $ENV{MUSA_ROOT} $ENV{MUSA_HOME} $ENV{MUSA_PATH}) + list(FILTER MUSA_ROOT EXCLUDE REGEX "^$") + list(GET MUSA_ROOT 0 MUSA_ROOT) + if(NOT MUSA_ROOT) + message(FATAL_ERROR "`WITH_MOORE` is `ON` but `MUSA_ROOT`/`MUSA_HOME`/`MUSA_PATH` is not set.") + endif() + message(STATUS "Using Moore from `${MUSA_ROOT}`.") + list(PREPEND CMAKE_MODULE_PATH "${MUSA_ROOT}/cmake") + set(MUSA_TOOLKIT_ROOT_DIR "${MUSA_ROOT}" CACHE PATH "Toolkit location." FORCE) + find_package(MUSA REQUIRED) + add_compile_definitions(WITH_MOORE=1) + include_directories("${MUSA_ROOT}/include") + link_directories("${MUSA_ROOT}/lib") + find_library(MUSA_LIB NAMES musa HINTS "${MUSA_ROOT}/lib" REQUIRED) + find_library(MUSART_LIB NAMES musart HINTS "${MUSA_ROOT}/lib" REQUIRED) + find_library(MUBLAS_LIB NAMES mublas HINTS "${MUSA_ROOT}/lib" REQUIRED) endif() include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) @@ -127,7 +156,7 @@ if(WITH_CAMBRICON) endif() # If all other platforms are not enabled, CPU is enabled by default. -if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX) +if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE) add_compile_definitions(WITH_CPU=1) endif() diff --git a/examples/gemm/gemm.cc b/examples/gemm/gemm.cc index bb82890..4664740 100644 --- a/examples/gemm/gemm.cc +++ b/examples/gemm/gemm.cc @@ -17,6 +17,9 @@ #if WITH_CAMBRICON #include "cambricon/gemm/cnblas.h" #endif +#if WITH_MOORE +#include "moore/gemm/mublas.h" +#endif #include "runtime_api.h" #include "tensor.h" diff --git a/examples/runtime_api.h b/examples/runtime_api.h index b56a8fd..c5b7597 100644 --- a/examples/runtime_api.h +++ b/examples/runtime_api.h @@ -37,6 +37,15 @@ #define DEVICE_MEMCPY_HOST_TO_DEVICE cnrtMemcpyHostToDev #define DEVICE_MEMCPY_DEVICE_TO_HOST cnrtMemcpyDevToHost #define DEFAULT_DEVICE_TYPE Device::Type::kCambricon +#elif WITH_MOORE +#include +#define DEVICE_MALLOC musaMalloc +#define DEVICE_FREE musaFree +#define DEVICE_MEMCPY musaMemcpy +#define DEVICE_MEMSET musaMemset +#define DEVICE_MEMCPY_HOST_TO_DEVICE musaMemcpyHostToDevice +#define DEVICE_MEMCPY_DEVICE_TO_HOST musaMemcpyDeviceToHost +#define DEFAULT_DEVICE_TYPE Device::Type::kMoore #elif WITH_CPU #include #include diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6eef5d3..3a04144 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -112,11 +112,19 @@ if(WITH_CAMBRICON) list(APPEND DEVICE_LIST "cambricon") endif() +if(WITH_MOORE) + target_compile_definitions(infiniops PUBLIC WITH_MOORE=1) + target_include_directories(infiniops PUBLIC "${MUSA_ROOT}/include") + target_link_libraries(infiniops PUBLIC ${MUSA_LIB} ${MUSART_LIB} ${MUBLAS_LIB}) + list(APPEND DEVICE_LIST "moore") +endif() + target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) if(GENERATE_PYTHON_BINDINGS) + find_package(Python COMPONENTS Interpreter REQUIRED) execute_process( - COMMAND python ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py --devices ${DEVICE_LIST} + COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py --devices ${DEVICE_LIST} WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} RESULT_VARIABLE script_result ) diff --git a/src/cuda/gemm/blas.h b/src/cuda/gemm/blas.h index 5f669ca..5a7cf2f 100644 --- a/src/cuda/gemm/blas.h +++ b/src/cuda/gemm/blas.h @@ -43,20 +43,30 @@ class Blas : public Gemm { const auto& trans_b_value{trans_b.value_or(trans_b_)}; auto op_a{GetOpA(trans_a_value, trans_b_value)}; auto op_b{GetOpB(trans_a_value, trans_b_value)}; + const void* alpha_ptr{GetAlphaPtr(alpha_value, c.dtype())}; + const void* beta_ptr{GetBetaPtr(beta_value, c.dtype())}; Backend::blasGemmStridedBatchedEx( handle_, op_a, op_b, swap_a_and_b_ ? n_ : m_, swap_a_and_b_ ? m_ : n_, - k_, &alpha_value, swap_a_and_b_ ? b.data() : a.data(), + k_, alpha_ptr, swap_a_and_b_ ? b.data() : a.data(), Backend::GetDataType(swap_a_and_b_ ? b.dtype() : a.dtype()), swap_a_and_b_ ? ldb_ : lda_, swap_a_and_b_ ? batch_stride_b_ : batch_stride_a_, swap_a_and_b_ ? a.data() : b.data(), Backend::GetDataType(swap_a_and_b_ ? a.dtype() : b.dtype()), swap_a_and_b_ ? lda_ : ldb_, - swap_a_and_b_ ? batch_stride_a_ : batch_stride_b_, &beta_value, - c.data(), Backend::GetDataType(c.dtype()), ldc_, batch_stride_c_, - batch_count_, Backend::GetComputeType(c.dtype()), - Backend::BLAS_GEMM_DEFAULT); + swap_a_and_b_ ? batch_stride_a_ : batch_stride_b_, beta_ptr, c.data(), + Backend::GetDataType(c.dtype()), ldc_, batch_stride_c_, batch_count_, + Backend::GetComputeType(c.dtype()), Backend::BLAS_GEMM_DEFAULT); + } + + protected: + virtual const void* GetAlphaPtr(const float& alpha, DataType) const { + return α + } + + virtual const void* GetBetaPtr(const float& beta, DataType) const { + return β } private: diff --git a/src/moore/gemm/mublas.h b/src/moore/gemm/mublas.h new file mode 100644 index 0000000..8ec931f --- /dev/null +++ b/src/moore/gemm/mublas.h @@ -0,0 +1,89 @@ +#ifndef INFINI_OPS_MOORE_GEMM_MUBLAS_H_ +#define INFINI_OPS_MOORE_GEMM_MUBLAS_H_ + +#include +#include + +#include + +#include "cuda/gemm/blas.h" + +namespace infini::ops { + +namespace gemm { + +struct MooreBackend { + using blasHandle_t = mublasHandle_t; + + using stream_t = musaStream_t; + + static constexpr auto BLAS_OP_N = MUBLAS_OP_N; + + static constexpr auto BLAS_OP_T = MUBLAS_OP_T; + + static constexpr auto R_16F = MUSA_R_16F; + + static constexpr auto R_16BF = MUSA_R_16BF; + + static constexpr auto R_32F = MUSA_R_32F; + + static constexpr auto BLAS_GEMM_DEFAULT = MUBLAS_GEMM_DEFAULT; + + static constexpr auto blasCreate = mublasCreate; + + static constexpr auto blasSetStream = mublasSetStream; + + static constexpr auto blasDestroy = mublasDestroy; + + static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) { + return mublasGemmStridedBatchedEx(std::forward(args)...); + }; + + static musaDataType_t GetDataType(DataType dtype) { + if (dtype == DataType::kFloat16) return R_16F; + if (dtype == DataType::kBFloat16) return R_16BF; + return R_32F; + } + + static mublasComputeType_t GetComputeType(DataType dtype) { + if (dtype == DataType::kFloat16) return MUBLAS_COMPUTE_16F; + if (dtype == DataType::kBFloat16) return MUBLAS_COMPUTE_32F; + return MUBLAS_COMPUTE_32F; + } +}; + +} // namespace gemm + +template <> +class Operator : public Blas { + public: + using Blas::Blas; + + protected: + const void* GetAlphaPtr(const float& alpha, DataType dtype) const override { + if (gemm::MooreBackend::GetComputeType(dtype) == MUBLAS_COMPUTE_16F) { + alpha_fp16_ = Float16::FromFloat(alpha); + return &alpha_fp16_; + } + + return α + } + + const void* GetBetaPtr(const float& beta, DataType dtype) const override { + if (gemm::MooreBackend::GetComputeType(dtype) == MUBLAS_COMPUTE_16F) { + beta_fp16_ = Float16::FromFloat(beta); + return &beta_fp16_; + } + + return β + } + + private: + mutable Float16 alpha_fp16_{}; + + mutable Float16 beta_fp16_{}; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_gemm.py b/tests/test_gemm.py index e5d00ed..43a47b6 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -88,7 +88,17 @@ def _torch_gemm(a, b, alpha=1.0, beta=1.0, trans_a=False, trans_b=False, c=None) return c - if a.ndim == 2: - return torch.addmm(c, a, b, beta=beta, alpha=alpha, out=c) + # Some backends (e.g. `torch_musa`) may reject `addmm`/`baddbmm(out=...)` + # for certain strided outputs. Fall back to `matmul` plus fused `alpha`/`beta` + # update to keep reference coverage. + try: + if a.ndim == 2: + return torch.addmm(c, a, b, beta=beta, alpha=alpha, out=c) + + return torch.baddbmm(c, a, b, beta=beta, alpha=alpha, out=c) + except RuntimeError: + c_original = c.clone() + torch.matmul(a, b, out=c) + c.mul_(alpha).add_(c_original, alpha=beta) - return torch.baddbmm(c, a, b, beta=beta, alpha=alpha, out=c) + return c diff --git a/tests/utils.py b/tests/utils.py index 78a350f..aa4ee42 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -29,6 +29,9 @@ def get_available_devices(): if hasattr(torch, "mlu") and torch.mlu.is_available(): devices.append("mlu") + if hasattr(torch, "musa") and torch.musa.is_available(): + devices.append("musa") + return tuple(devices) From 2650cd964715f4d94060b681c7a658383bfa9742 Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Wed, 18 Mar 2026 17:47:53 +0800 Subject: [PATCH 81/93] feat: optimize `BLOCK_SIZE` for CUDA kernels and support Iluvatar `SwigLU` (#17) * refactor: reorganize casting utilities and enhance CUDA kernel support - Moved CPU casting functions to a new file `common/cpu/cast.h` and updated the `Cast` function to utilize these utilities. - Updated CUDA kernel files to include the new casting utilities and improved block size handling in kernel launches. - Enhanced the `Add`, `CausalSoftmax`, `Gemm`, `RmsNorm`, and `Swiglu` operators to utilize the new casting mechanisms for better type handling. - Added support for additional data types in tests and adjusted test cases for consistency across CPU and GPU backends. * refactor: improve formatting * refactor: cache cudaDeviceProp per device via DevicePropertyCache Introduce DevicePropertyCache to query and cache all device properties once at first access, avoiding repeated cudaGetDeviceProperties calls. QueryMaxThreadsPerBlock and GetOptimalBlockSize are simplified to delegate to the cache. Also move block_size out of dispatch lambdas in add and swiglu kernels since it does not depend on the dispatched type. Co-Authored-By: Claude Sonnet 4.6 * perf: reduce redundant ops in Add, SwiGLU, and RmsNorm CUDA kernels - Add __restrict__ to all pointer params in AddKernel and SwigluKernel to enable compiler alias analysis, vectorization, and prefetch - Remove dead for-loop in Add/SwiGLU kernel launch (step >= output_size_ by construction, loop body always executed exactly once); drop offset param - Inline sigmoid in SwiGLU bfloat16/bfloat162 paths to eliminate redundant bf16<->float round-trips (8 -> 4 conversions for bfloat162, 2 -> 1 for bfloat16) - Use a temp variable in RmsNorm SumSquared to guarantee single global load Co-Authored-By: Claude Sonnet 4.6 * fix: abort with diagnostic on out-of-range device_id in GetDeviceProps Returning a default-constructed dummy cudaDeviceProp silently propagated incorrect device properties; now print an explicit error and abort so the bug is immediately visible. Co-Authored-By: Claude Sonnet 4.6 * style: improve code formatting and comments in CUDA kernel files - Updated comments for clarity in `kernel_commons.h`. - Reformatted kernel launch macro in `kernel.h` for better readability. - Enhanced line breaks in `SwigluKernel` implementation for improved code structure. - Adjusted test function formatting in `test_add.py` for consistency. * refactor: remove unnecessary cuda_runtime.h includes in kernel headers - Removed redundant `#include ` from `add`, `causal_softmax`, `rms_norm`, and `swiglu` kernel header files. - Added TODO comments for future removal of the remaining includes to improve code clarity and maintainability. --------- Co-authored-by: Claude Sonnet 4.6 --- src/common/cuda/kernel_commons.h | 62 +++++++++++++++++++------ src/cuda/add/kernel.cuh | 56 ++++++++++++++++++++++ src/cuda/add/kernel.h | 79 ++++++++++---------------------- src/cuda/causal_softmax/kernel.h | 39 ++++++++++------ src/cuda/rms_norm/kernel.cuh | 3 +- src/cuda/rms_norm/kernel.h | 40 ++++++++++------ src/cuda/swiglu/kernel.cuh | 35 +++++++------- src/cuda/swiglu/kernel.h | 23 ++++------ src/iluvatar/swiglu/kernel.h | 41 +++++++++++++++++ tests/test_add.py | 12 +++-- 10 files changed, 261 insertions(+), 129 deletions(-) create mode 100644 src/cuda/add/kernel.cuh create mode 100644 src/iluvatar/swiglu/kernel.h diff --git a/src/common/cuda/kernel_commons.h b/src/common/cuda/kernel_commons.h index 43cce8f..3c85031 100644 --- a/src/common/cuda/kernel_commons.h +++ b/src/common/cuda/kernel_commons.h @@ -8,13 +8,21 @@ using cuda_bfloat16 = nv_bfloat16; using cuda_bfloat162 = nv_bfloat162; #elif defined(WITH_ILUVATAR) +#include +#include #include -#elif WITH_METAX // TODO: Use `defined`. +using cuda_bfloat16 = nv_bfloat16; +using cuda_bfloat162 = nv_bfloat162; +#elif defined(WITH_METAX) #include using cuda_bfloat16 = maca_bfloat16; using cuda_bfloat162 = maca_bfloat162; #endif +#include +#include +#include + #include "cast.h" namespace infini::ops { @@ -23,27 +31,55 @@ constexpr int CUDA_BLOCK_SIZE_128 = 128; constexpr int CUDA_BLOCK_SIZE_256 = 256; constexpr int CUDA_BLOCK_SIZE_512 = 512; constexpr int CUDA_BLOCK_SIZE_1024 = 1024; +constexpr int CUDA_BLOCK_SIZE_2048 = 2048; + +#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) +// Cache `cudaDeviceProp` per device, initialized once at first access. +class DevicePropertyCache { + public: + static const cudaDeviceProp& GetCurrentDeviceProps() { + int device_id = 0; + cudaGetDevice(&device_id); + return GetDeviceProps(device_id); + } + + static const cudaDeviceProp& GetDeviceProps(int device_id) { + static std::vector cache = []() { + int count = 0; + cudaGetDeviceCount(&count); + if (count == 0) return std::vector{}; + std::vector props(count); + for (int i = 0; i < count; ++i) { + cudaGetDeviceProperties(&props[i], i); + } + return props; + }(); + + if (device_id < 0 || device_id >= static_cast(cache.size())) { + std::cerr << "error: `device_id` " << device_id << " is out of range [0, " + << cache.size() << ") in `GetDeviceProps`\n"; + std::abort(); + } + return cache[device_id]; + } +}; -// Query the maximum threads per block for the current CUDA device. inline int QueryMaxThreadsPerBlock() { -#ifdef WITH_NVIDIA - int device = 0; - cudaGetDevice(&device); - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, device); - return prop.maxThreadsPerBlock; -#elif WITH_METAX + return DevicePropertyCache::GetCurrentDeviceProps().maxThreadsPerBlock; +} +#elif defined(WITH_METAX) +inline int QueryMaxThreadsPerBlock() { // TODO: Add MCR device properties query for Metax. return CUDA_BLOCK_SIZE_256; -#endif } +#endif // Get optimal block size based on GPU hardware architecture. inline int GetOptimalBlockSize() { int max_threads = QueryMaxThreadsPerBlock(); - - // Select the largest supported block size for better performance. - if (max_threads >= CUDA_BLOCK_SIZE_1024) { + if (max_threads >= CUDA_BLOCK_SIZE_2048) { + return CUDA_BLOCK_SIZE_2048; + } else if (max_threads >= CUDA_BLOCK_SIZE_1024) { return CUDA_BLOCK_SIZE_1024; } else if (max_threads >= CUDA_BLOCK_SIZE_512) { return CUDA_BLOCK_SIZE_512; diff --git a/src/cuda/add/kernel.cuh b/src/cuda/add/kernel.cuh new file mode 100644 index 0000000..2d58809 --- /dev/null +++ b/src/cuda/add/kernel.cuh @@ -0,0 +1,56 @@ +#ifndef INFINI_OPS_CUDA_ADD_KERNEL_CUH_ +#define INFINI_OPS_CUDA_ADD_KERNEL_CUH_ + +#include "common/cuda/kernel_commons.h" + +namespace infini::ops { + +struct AddOp { + static constexpr std::size_t num_inputs = 2; + + template + __device__ __forceinline__ T operator()(const T& input, + const T& other) const { + if constexpr (std::is_same_v) { + return __hadd2(input, other); + } else if constexpr (std::is_same_v || + std::is_same_v>) { + return __hadd(input, other); + } else if constexpr (std::is_same_v) { + return __fadd_rn(input, other); + } else { + return input + other; + } + } +}; + +template +__global__ void AddKernel(T* __restrict__ out, const T* __restrict__ input, + const T* __restrict__ other, + const size_t* __restrict__ out_shape, + const size_t* __restrict__ input_shape, + const size_t* __restrict__ other_shape, + const ptrdiff_t* __restrict__ out_strides, + const ptrdiff_t* __restrict__ input_strides, + const ptrdiff_t* __restrict__ other_strides, + size_t output_size, size_t ndim, bool out_contiguous, + bool input_contiguous, bool other_contiguous) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < output_size) { + size_t out_idx = + out_contiguous ? idx : IndexToOffset(idx, ndim, out_shape, out_strides); + size_t input_idx = + input_contiguous ? idx + : IndexToOffset(idx, ndim, input_shape, input_strides); + size_t other_idx = + other_contiguous ? idx + : IndexToOffset(idx, ndim, other_shape, other_strides); + + out[out_idx] = AddOp{}(input[input_idx], other[other_idx]); + } +} + +} // namespace infini::ops + +#endif diff --git a/src/cuda/add/kernel.h b/src/cuda/add/kernel.h index b481255..c174afb 100644 --- a/src/cuda/add/kernel.h +++ b/src/cuda/add/kernel.h @@ -1,57 +1,14 @@ #ifndef INFINI_OPS_CUDA_ADD_KERNEL_H_ #define INFINI_OPS_CUDA_ADD_KERNEL_H_ -#include +#include #include "base/add.h" -#include "common/cuda/kernel_commons.h" #include "common/generic_utils.h" +#include "cuda/add/kernel.cuh" namespace infini::ops { -typedef struct AddOp { - public: - static constexpr std::size_t num_inputs = 2; - template - __device__ __forceinline__ T operator()(const T& input, - const T& other) const { - if constexpr (std::is_same_v) { - return __hadd2(input, other); - } else if constexpr (std::is_same_v || - std::is_same_v>) { - return __hadd(input, other); - } else if constexpr (std::is_same_v) { - return __fadd_rn(input, other); - } else { - return input + other; - } - } -} AddOp; - -template -__global__ void AddKernel( - T* out, const T* input, const T* other, const Tensor::Size* out_shape, - const Tensor::Size* input_shape, const Tensor::Size* other_shape, - const Tensor::Stride* out_strides, const Tensor::Stride* input_strides, - const Tensor::Stride* other_strides, size_t output_size, size_t ndim, - size_t offset, bool out_contiguous, bool input_contiguous, - bool other_contiguous) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; - - if (idx < output_size) { - Tensor::Size out_idx = - out_contiguous ? idx : IndexToOffset(idx, ndim, out_shape, out_strides); - Tensor::Size input_idx = - input_contiguous ? idx - : IndexToOffset(idx, ndim, input_shape, input_strides); - Tensor::Size other_idx = - other_contiguous ? idx - : IndexToOffset(idx, ndim, other_shape, other_strides); - - out[out_idx] = AddOp{}(input[input_idx], other[other_idx]); - } -} - template class CudaAdd : public Add { public: @@ -92,28 +49,40 @@ class CudaAdd : public Add { void operator()(const Tensor input, const Tensor other, Tensor out) const override { + int block_size = GetOptimalBlockSize(); DispatchFunc( out_type_, [&](auto tag) { using T = typename decltype(tag)::type; - // TODO(lzm): currently hard-code block_size to be 256. + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); dim3 blockDims( - std::min(static_cast(256), output_size_)); + std::min(static_cast(block_size), output_size_)); dim3 gridDims(utils::CeilDiv(output_size_, blockDims.x)); - size_t step = gridDims.x * blockDims.x; T* d_out = reinterpret_cast(out.data()); const T* d_input = reinterpret_cast(input.data()); const T* d_other = reinterpret_cast(other.data()); - for (size_t i = 0; i < output_size_; i += step) { - AddKernel<<(stream_)>>>( - d_out, d_input, d_other, d_out_shape_, d_input_shape_, - d_other_shape_, d_out_strides_, d_input_strides_, - d_other_strides_, output_size_, ndim_, i, is_out_contiguous_, - is_input_contiguous_, is_other_contiguous_); +#define LAUNCH_ADD_KERNEL(BLOCK_SIZE) \ + AddKernel<<>>( \ + d_out, d_input, d_other, d_out_shape_, d_input_shape_, d_other_shape_, \ + d_out_strides_, d_input_strides_, d_other_strides_, output_size_, ndim_, \ + is_out_contiguous_, is_input_contiguous_, is_other_contiguous_); + + if (block_size == CUDA_BLOCK_SIZE_2048) { + LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_2048) + } else if (block_size == CUDA_BLOCK_SIZE_1024) { + LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_1024) + } else if (block_size == CUDA_BLOCK_SIZE_512) { + LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_512) + } else if (block_size == CUDA_BLOCK_SIZE_256) { + LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_256) + } else { + LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_128) } + +#undef LAUNCH_ADD_KERNEL }, "CudaAdd::operator()"); } diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h index 610b042..23a040a 100644 --- a/src/cuda/causal_softmax/kernel.h +++ b/src/cuda/causal_softmax/kernel.h @@ -4,22 +4,17 @@ #include // clang-format off -#include +#include // TODO: Remove this // clang-format on #include "base/causal_softmax.h" +#include "common/cuda/kernel_commons.h" #include "cuda/causal_softmax/kernel.cuh" #include "data_type.h" #include "dispatcher.h" namespace infini::ops { -namespace causal_softmax { - -constexpr unsigned int kBlockSize = 256; - -} // namespace causal_softmax - template class CudaCausalSoftmax : public CausalSoftmax { public: @@ -41,16 +36,34 @@ class CudaCausalSoftmax : public CausalSoftmax { std::abort(); } + int block_size = GetOptimalBlockSize(); + DispatchFunc( out.dtype(), [&](auto tag) { using T = typename decltype(tag)::type; - CausalSoftmaxKernel - <<>>( - reinterpret_cast(out.data()), - reinterpret_cast(input.data()), batch_size_, - seq_len_, total_seq_len_, stride_out_batch, stride_out_row, - stride_input_batch, stride_input_row); + +#define LAUNCH_CAUSAL_SOFTMAX_KERNEL(BLOCK_SIZE) \ + CausalSoftmaxKernel \ + <<>>( \ + reinterpret_cast(out.data()), \ + reinterpret_cast(input.data()), batch_size_, seq_len_, \ + total_seq_len_, stride_out_batch, stride_out_row, \ + stride_input_batch, stride_input_row); + + if (block_size == CUDA_BLOCK_SIZE_2048) { + LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_2048) + } else if (block_size == CUDA_BLOCK_SIZE_1024) { + LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_1024) + } else if (block_size == CUDA_BLOCK_SIZE_512) { + LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_512) + } else if (block_size == CUDA_BLOCK_SIZE_256) { + LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_256) + } else { + LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_128) + } + +#undef LAUNCH_CAUSAL_SOFTMAX_KERNEL }, "CudaCausalSoftmax::operator()"); } diff --git a/src/cuda/rms_norm/kernel.cuh b/src/cuda/rms_norm/kernel.cuh index 09f20a8..98383f3 100644 --- a/src/cuda/rms_norm/kernel.cuh +++ b/src/cuda/rms_norm/kernel.cuh @@ -17,7 +17,8 @@ __device__ __forceinline__ Compute SumSquared(const Data* data_ptr, size_t count) { Compute ss = 0; for (size_t i = threadIdx.x; i < count; i += block_size) { - ss += Compute(data_ptr[i]) * Compute(data_ptr[i]); + Compute val = Compute(data_ptr[i]); + ss += val * val; } using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h index a2e27f2..f450e91 100644 --- a/src/cuda/rms_norm/kernel.h +++ b/src/cuda/rms_norm/kernel.h @@ -4,22 +4,17 @@ #include // clang-format off -#include +#include // TODO: Remove this // clang-format on #include "base/rms_norm.h" +#include "common/cuda/kernel_commons.h" #include "cuda/rms_norm/kernel.cuh" #include "data_type.h" #include "dispatcher.h" namespace infini::ops { -namespace { - -constexpr unsigned int kBlockSize = 256; - -} // namespace - template class CudaRmsNorm : public RmsNorm { public: @@ -43,17 +38,34 @@ class CudaRmsNorm : public RmsNorm { std::abort(); } + int block_size = GetOptimalBlockSize(); + DispatchFunc( out.dtype(), [&](auto tag) { using T = typename decltype(tag)::type; - RmsNormKernel - <<>>( - reinterpret_cast(out.data()), stride_out_batch, - stride_out_nhead, reinterpret_cast(input.data()), - stride_input_batch, stride_input_nhead, - reinterpret_cast(weight.data()), nhead_, dim_, - eps_); + +#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \ + RmsNormKernel \ + <<>>( \ + reinterpret_cast(out.data()), stride_out_batch, \ + stride_out_nhead, reinterpret_cast(input.data()), \ + stride_input_batch, stride_input_nhead, \ + reinterpret_cast(weight.data()), nhead_, dim_, eps_); + + if (block_size == CUDA_BLOCK_SIZE_2048) { + LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_2048) + } else if (block_size == CUDA_BLOCK_SIZE_1024) { + LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_1024) + } else if (block_size == CUDA_BLOCK_SIZE_512) { + LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_512) + } else if (block_size == CUDA_BLOCK_SIZE_256) { + LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_256) + } else { + LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_128) + } + +#undef LAUNCH_RMS_NORM_KERNEL }, "CudaRmsNorm::operator()"); } diff --git a/src/cuda/swiglu/kernel.cuh b/src/cuda/swiglu/kernel.cuh index f404450..8004b76 100644 --- a/src/cuda/swiglu/kernel.cuh +++ b/src/cuda/swiglu/kernel.cuh @@ -33,15 +33,18 @@ __device__ __forceinline__ T Sigmoid(const T& x) { // SwiGLU(x, gate) = Swish(x) * gate = (x * sigmoid(x)) * gate. template -__global__ void SwigluKernel(T* out, const T* a, const T* b, - const size_t* out_shape, const size_t* input_shape, - const size_t* gate_shape, - const ptrdiff_t* out_strides, - const ptrdiff_t* input_strides, - const ptrdiff_t* gate_strides, size_t output_size, - size_t ndim, size_t offset, bool out_contiguous, - bool input_contiguous, bool gate_contiguous) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; +__global__ void SwigluKernel(T* __restrict__ out, const T* __restrict__ a, + const T* __restrict__ b, + const size_t* __restrict__ out_shape, + const size_t* __restrict__ input_shape, + const size_t* __restrict__ gate_shape, + const ptrdiff_t* __restrict__ out_strides, + const ptrdiff_t* __restrict__ input_strides, + const ptrdiff_t* __restrict__ gate_strides, + size_t output_size, size_t ndim, + bool out_contiguous, bool input_contiguous, + bool gate_contiguous) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < output_size) { size_t out_idx, input_idx, gate_idx; @@ -74,21 +77,19 @@ __global__ void SwigluKernel(T* out, const T* a, const T* b, // Optimized `half` precision computation. out[out_idx] = __hmul(__hmul(gate, Sigmoid(gate)), up); } else if constexpr (std::is_same_v) { - cuda_bfloat162 sig = Sigmoid(gate); float gate0 = __bfloat162float(__low2bfloat16(gate)); float gate1 = __bfloat162float(__high2bfloat16(gate)); - float sig0 = __bfloat162float(__low2bfloat16(sig)); - float sig1 = __bfloat162float(__high2bfloat16(sig)); float up0 = __bfloat162float(__low2bfloat16(up)); float up1 = __bfloat162float(__high2bfloat16(up)); - float res0 = __fmul_rn(__fmul_rn(gate0, sig0), up0); - float res1 = __fmul_rn(__fmul_rn(gate1, sig1), up1); - out[out_idx] = __floats2bfloat162_rn(res0, res1); + float sig0 = __frcp_rn(__fadd_rn(1.0f, __expf(-gate0))); + float sig1 = __frcp_rn(__fadd_rn(1.0f, __expf(-gate1))); + out[out_idx] = + __floats2bfloat162_rn(__fmul_rn(__fmul_rn(gate0, sig0), up0), + __fmul_rn(__fmul_rn(gate1, sig1), up1)); } else if constexpr (std::is_same_v) { - cuda_bfloat16 sig = Sigmoid(gate); float gatef = __bfloat162float(gate); - float sigf = __bfloat162float(sig); float upf = __bfloat162float(up); + float sigf = __frcp_rn(__fadd_rn(1.0f, __expf(-gatef))); out[out_idx] = __float2bfloat16_rn(__fmul_rn(__fmul_rn(gatef, sigf), upf)); } else if constexpr (std::is_same_v) { diff --git a/src/cuda/swiglu/kernel.h b/src/cuda/swiglu/kernel.h index 7c459a6..d05a1af 100644 --- a/src/cuda/swiglu/kernel.h +++ b/src/cuda/swiglu/kernel.h @@ -4,7 +4,7 @@ #include // clang-format off -#include +#include // TODO: Remove this // clang-format on #include "base/swiglu.h" @@ -53,33 +53,30 @@ class CudaSwiglu : public Swiglu { void operator()(const Tensor input, const Tensor gate, Tensor out) const override { + int block_size = GetOptimalBlockSize(); DispatchFunc( out_type_, [&](auto tag) { using T = typename decltype(tag)::type; auto cuda_stream = static_cast(stream_ ? stream_ : 0); - int block_size = GetOptimalBlockSize(); dim3 blockDims( std::min(static_cast(block_size), output_size_)); dim3 gridDims(utils::CeilDiv(output_size_, blockDims.x)); - size_t step = gridDims.x * blockDims.x; T* d_out = reinterpret_cast(out.data()); const T* d_input = reinterpret_cast(input.data()); const T* d_gate = reinterpret_cast(gate.data()); // Launch kernel with appropriate block size based on GPU architecture. -#define LAUNCH_SWIGLU_KERNEL(BLOCK_SIZE) \ - for (size_t i = 0; i < output_size_; i += step) { \ - SwigluKernel<<>>( \ - d_out, d_input, d_gate, d_out_shape_, d_input_shape_, d_gate_shape_, \ - d_out_strides_, d_input_strides_, d_gate_strides_, output_size_, \ - ndim_, i, is_out_contiguous_, is_input_contiguous_, \ - is_gate_contiguous_); \ - } - - if (block_size == CUDA_BLOCK_SIZE_1024) { +#define LAUNCH_SWIGLU_KERNEL(BLOCK_SIZE) \ + SwigluKernel<<>>( \ + d_out, d_input, d_gate, d_out_shape_, d_input_shape_, d_gate_shape_, \ + d_out_strides_, d_input_strides_, d_gate_strides_, output_size_, ndim_, \ + is_out_contiguous_, is_input_contiguous_, is_gate_contiguous_); + if (block_size == CUDA_BLOCK_SIZE_2048) { + LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_2048) + } else if (block_size == CUDA_BLOCK_SIZE_1024) { LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_1024) } else if (block_size == CUDA_BLOCK_SIZE_512) { LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_512) diff --git a/src/iluvatar/swiglu/kernel.h b/src/iluvatar/swiglu/kernel.h new file mode 100644 index 0000000..cf5310c --- /dev/null +++ b/src/iluvatar/swiglu/kernel.h @@ -0,0 +1,41 @@ +#ifndef INFINI_OPS_ILUVATAR_SWIGLU_KERNEL_H_ +#define INFINI_OPS_ILUVATAR_SWIGLU_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/swiglu/kernel.h" + +namespace infini::ops { + +namespace swiglu { + +struct IluvatarBackend { + using stream_t = cudaStream_t; + + static constexpr auto malloc = [](auto&&... args) { + return cudaMalloc(std::forward(args)...); + }; + + static constexpr auto memcpy = cudaMemcpy; + + static constexpr auto free = cudaFree; + + static constexpr auto memcpyH2D = cudaMemcpyHostToDevice; +}; + +} // namespace swiglu + +template <> +class Operator + : public CudaSwiglu { + public: + using CudaSwiglu::CudaSwiglu; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_add.py b/tests/test_add.py index afbce0d..1c98d91 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -46,10 +46,16 @@ (torch.uint64, 0, 0), ), ) -def test_add(shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol): +def test_add( + shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol +): if dtype in _INT_DTYPES: - input = randint_strided(0, 100, shape, input_strides, dtype=dtype, device=device) - other = randint_strided(0, 100, shape, other_strides, dtype=dtype, device=device) + input = randint_strided( + 0, 100, shape, input_strides, dtype=dtype, device=device + ) + other = randint_strided( + 0, 100, shape, other_strides, dtype=dtype, device=device + ) else: input = randn_strided(shape, input_strides, dtype=dtype, device=device) other = randn_strided(shape, other_strides, dtype=dtype, device=device) From 9de2fd28e933f438beaff5d25a6ecd0992ebdc78 Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Thu, 19 Mar 2026 16:07:08 +0800 Subject: [PATCH 82/93] fix: filter out unsupported integer data types in `tests/test_add.py` and add `_torch_rms_norm` fallback (#20) * fix: remove uint16 test from test_add.py - Removed `torch.uint16` from the list of integer data types in the `_INT_DTYPES` tuple to streamline the code and eliminate redundancy. * refactor: enhance dtype handling in test_add.py * refactor: streamline dtype parameterization in test_add.py and enhance rms_norm fallback handling in test_rms_norm.py * refactor: add unsigned integer data types to test_add.py for enhanced dtype handling * refactor: simplify integer dtype filtering * refactor: simplify `_torch_rms_norm` fallback logic --------- Co-authored-by: Jiacheng Huang --- tests/test_add.py | 26 +++++++++----------------- tests/test_rms_norm.py | 17 ++++++++++++++++- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/tests/test_add.py b/tests/test_add.py index 1c98d91..d1ea0f8 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -4,13 +4,10 @@ from tests.utils import Payload, empty_strided, randint_strided, randn_strided -_INT_DTYPES = ( - torch.int16, - torch.uint16, - torch.int32, - torch.uint32, - torch.int64, - torch.uint64, +_INT_DTYPES = (torch.int16, torch.int32, torch.int64) + +_UINT_DTYPES = tuple( + filter(None, (getattr(torch, f"uint{bits}", None) for bits in (16, 32, 64))) ) @@ -38,18 +35,13 @@ (torch.float32, 1e-7, 1e-7), (torch.float16, 1e-3, 1e-3), (torch.bfloat16, 1e-2, 5e-3), - (torch.int16, 0, 0), - (torch.uint16, 0, 0), - (torch.int32, 0, 0), - (torch.uint32, 0, 0), - (torch.int64, 0, 0), - (torch.uint64, 0, 0), - ), + ) + + tuple((dtype, 0, 0) for dtype in _INT_DTYPES + _UINT_DTYPES), ) def test_add( shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol ): - if dtype in _INT_DTYPES: + if dtype in _INT_DTYPES or dtype in _UINT_DTYPES: input = randint_strided( 0, 100, shape, input_strides, dtype=dtype, device=device ) @@ -72,10 +64,10 @@ def _add(input, other, out): def _torch_add(input, other, out): - if input.dtype in (torch.uint16, torch.uint32, torch.uint64): + if input.dtype in _UINT_DTYPES: input = input.to(torch.int64) - if other.dtype in (torch.uint16, torch.uint32, torch.uint64): + if other.dtype in _UINT_DTYPES: other = other.to(torch.int64) res = torch.add(input, other) diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index f447091..d6d4dff 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -59,4 +59,19 @@ def _rms_norm(input, weight, *, eps=1e-6, out=None): def _torch_rms_norm(input, weight, *, eps=1e-6, out=None): - return torch.nn.functional.rms_norm(input, input.shape[-1:], weight=weight, eps=eps) + # Fallback for `torch<2.3`: `rms_norm = (x / sqrt(mean(x^2) + eps)) * weight`. + def _fallback(input, _normalized_shape, weight, *, eps=1e-6): + rms = torch.sqrt(torch.mean(input * input, dim=-1, keepdim=True) + eps) + + return (input / rms) * weight + + rms_norm_fn = getattr(torch.nn.functional, "rms_norm", _fallback) + + result = rms_norm_fn(input, input.shape[-1:], weight=weight, eps=eps) + + if out is not None: + out.copy_(result) + else: + out = result + + return out From dc9f4403c13ece6cb6f90e9051df64d1e93853fc Mon Sep 17 00:00:00 2001 From: gongchensu Date: Thu, 19 Mar 2026 19:46:54 +0800 Subject: [PATCH 83/93] feat(moore): add Moore backend for `Add` (#26) * feat(moore): add add-op support with musa integration. * refactor: improve specialization logic * fix: fix `AttributeError` on Cambricon --------- Co-authored-by: Jiacheng Huang --- CMakeLists.txt | 69 +++++++++++++++++++++----------- scripts/mcc_wrapper.sh | 48 ++++++++++++++++++++++ src/CMakeLists.txt | 30 ++++++++++---- src/common/cast.h | 3 +- src/common/cuda/cast.h | 2 + src/common/cuda/kernel_commons.h | 14 +++++++ src/cuda/add/kernel.cuh | 24 ++++++++++- src/data_type.h | 6 +++ src/moore/add/kernel.h | 44 ++++++++++++++++++++ tests/test_add.py | 5 +++ 10 files changed, 212 insertions(+), 33 deletions(-) create mode 100755 scripts/mcc_wrapper.sh create mode 100644 src/moore/add/kernel.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 18e98c9..570b7d7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,30 +73,18 @@ if(AUTO_DETECT_DEVICES) endif() endif() -if(WITH_MOORE) - set(MUSA_ROOT $ENV{MUSA_ROOT} $ENV{MUSA_HOME} $ENV{MUSA_PATH}) - list(FILTER MUSA_ROOT EXCLUDE REGEX "^$") - list(GET MUSA_ROOT 0 MUSA_ROOT) - if(NOT MUSA_ROOT) - message(FATAL_ERROR "`WITH_MOORE` is `ON` but `MUSA_ROOT`/`MUSA_HOME`/`MUSA_PATH` is not set.") - endif() - message(STATUS "Using Moore from `${MUSA_ROOT}`.") - list(PREPEND CMAKE_MODULE_PATH "${MUSA_ROOT}/cmake") - set(MUSA_TOOLKIT_ROOT_DIR "${MUSA_ROOT}" CACHE PATH "Toolkit location." FORCE) - find_package(MUSA REQUIRED) - add_compile_definitions(WITH_MOORE=1) - include_directories("${MUSA_ROOT}/include") - link_directories("${MUSA_ROOT}/lib") - find_library(MUSA_LIB NAMES musa HINTS "${MUSA_ROOT}/lib" REQUIRED) - find_library(MUSART_LIB NAMES musart HINTS "${MUSA_ROOT}/lib" REQUIRED) - find_library(MUBLAS_LIB NAMES mublas HINTS "${MUSA_ROOT}/lib" REQUIRED) -endif() - include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) -# NVIDIA and Iluvatar are parallel backends; only one GPU backend at a time. -if(WITH_NVIDIA AND WITH_ILUVATAR) - message(FATAL_ERROR "`WITH_NVIDIA` and `WITH_ILUVATAR` cannot both be `ON`. Build one GPU backend at a time.") +# Only one CUDA-like GPU backend can be enabled at a time. +set(_gpu_backend_count 0) +foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_METAX WITH_MOORE) + if(${_gpu_backend}) + math(EXPR _gpu_backend_count "${_gpu_backend_count} + 1") + endif() +endforeach() + +if(_gpu_backend_count GREATER 1) + message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, and `WITH_MOORE` are mutually exclusive. Build one GPU backend at a time.") endif() if(WITH_NVIDIA) @@ -140,6 +128,41 @@ if(WITH_METAX) find_library(MACA_BLAS_LIB NAMES mcblas HINTS "${MACA_PATH}/lib" REQUIRED) endif() +if(WITH_MOORE) + add_compile_definitions(WITH_MOORE=1) + + set(MUSA_ROOT "") + foreach(_musa_env MUSA_ROOT MUSA_HOME MUSA_PATH) + if(NOT MUSA_ROOT AND DEFINED ENV{${_musa_env}} AND NOT "$ENV{${_musa_env}}" STREQUAL "") + set(MUSA_ROOT "$ENV{${_musa_env}}") + endif() + endforeach() + + if(NOT MUSA_ROOT AND EXISTS "/usr/local/musa") + set(MUSA_ROOT "/usr/local/musa") + endif() + + if(NOT MUSA_ROOT) + message(FATAL_ERROR "`WITH_MOORE` is `ON` but `MUSA_ROOT`/`MUSA_HOME`/`MUSA_PATH` is not set and `/usr/local/musa` was not found.") + endif() + + if(NOT EXISTS "${MUSA_ROOT}/bin/mcc") + message(FATAL_ERROR "Could not find `mcc` under `${MUSA_ROOT}/bin`.") + endif() + + message(STATUS "Using Moore from `${MUSA_ROOT}`.") + + set(CMAKE_C_COMPILER ${CMAKE_CURRENT_SOURCE_DIR}/scripts/mcc_wrapper.sh) + set(CMAKE_CXX_COMPILER ${CMAKE_CURRENT_SOURCE_DIR}/scripts/mcc_wrapper.sh) + + include_directories("${MUSA_ROOT}/include") + link_directories("${MUSA_ROOT}/lib") + + find_library(MUSA_LIB NAMES musa HINTS "${MUSA_ROOT}/lib" REQUIRED) + find_library(MUSART_LIB NAMES musart HINTS "${MUSA_ROOT}/lib" REQUIRED) + find_library(MUBLAS_LIB NAMES mublas HINTS "${MUSA_ROOT}/lib" REQUIRED) +endif() + if(WITH_CAMBRICON) add_compile_definitions(WITH_CAMBRICON=1) set(NEUWARE_HOME $ENV{NEUWARE_HOME}) @@ -160,7 +183,7 @@ if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE) add_compile_definitions(WITH_CPU=1) endif() -if(WITH_METAX) +if(WITH_METAX OR WITH_MOORE) set(PYBIND11_ENABLE_EXTRAS OFF) endif() diff --git a/scripts/mcc_wrapper.sh b/scripts/mcc_wrapper.sh new file mode 100755 index 0000000..29ce5cd --- /dev/null +++ b/scripts/mcc_wrapper.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# Filter out flags unsupported by `mcc`. +ARGS=() +skip_next=0 +linking=1 +for arg in "$@"; do + if [ $skip_next -eq 1 ]; then + skip_next=0 + continue + fi + case "$arg" in + -c|-E|-S) + linking=0 + ARGS+=("$arg") + ;; + -pthread) + ;; + -B) + skip_next=1 + ;; + -B*) + ;; + *) + ARGS+=("$arg") + ;; + esac +done + +MUSA_ROOT_DIR="${MUSA_ROOT:-${MUSA_HOME:-${MUSA_PATH:-/usr/local/musa}}}" + +if command -v g++ >/dev/null 2>&1; then + GXX_MAJOR="$(g++ -dumpversion | cut -d. -f1)" + if [ -d "/usr/include/c++/${GXX_MAJOR}" ]; then + ARGS=( + "-isystem" "/usr/include/c++/${GXX_MAJOR}" + "-isystem" "/usr/include/x86_64-linux-gnu/c++/${GXX_MAJOR}" + "-isystem" "/usr/include/c++/${GXX_MAJOR}/backward" + "${ARGS[@]}" + ) + fi + + STDCPP_LIB="$(g++ -print-file-name=libstdc++.so)" + if [ $linking -eq 1 ] && [ -f "${STDCPP_LIB}" ]; then + ARGS=("-L$(dirname "${STDCPP_LIB}")" "${ARGS[@]}") + fi +fi + +exec "${MUSA_ROOT_DIR}/bin/mcc" "${ARGS[@]}" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3a04144..3ca0715 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -103,6 +103,29 @@ if(WITH_METAX) list(APPEND DEVICE_LIST "metax") endif() +if(WITH_MOORE) + set(MOORE_PATTERNS + "cuda/*.cc" + "cuda/*.cpp" + "moore/*.cc" + "moore/*.cpp" + "moore/*.mu" + ) + + file(GLOB_RECURSE MOORE_SOURCES CONFIGURE_DEPENDS ${MOORE_PATTERNS}) + + set_source_files_properties(${MOORE_SOURCES} PROPERTIES LANGUAGE CXX) + + target_compile_definitions(infiniops PRIVATE WITH_MOORE=1) + target_compile_options(infiniops PUBLIC "-x" "musa") + target_sources(infiniops PRIVATE ${MOORE_SOURCES}) + + target_include_directories(infiniops PUBLIC "${MUSA_ROOT}/include") + target_link_libraries(infiniops PUBLIC ${MUSA_LIB} ${MUSART_LIB} ${MUBLAS_LIB}) + + list(APPEND DEVICE_LIST "moore") +endif() + if(WITH_CAMBRICON) target_compile_definitions(infiniops PUBLIC WITH_CAMBRICON=1) @@ -112,13 +135,6 @@ if(WITH_CAMBRICON) list(APPEND DEVICE_LIST "cambricon") endif() -if(WITH_MOORE) - target_compile_definitions(infiniops PUBLIC WITH_MOORE=1) - target_include_directories(infiniops PUBLIC "${MUSA_ROOT}/include") - target_link_libraries(infiniops PUBLIC ${MUSA_LIB} ${MUSART_LIB} ${MUBLAS_LIB}) - list(APPEND DEVICE_LIST "moore") -endif() - target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) if(GENERATE_PYTHON_BINDINGS) diff --git a/src/common/cast.h b/src/common/cast.h index 4973764..a37fb94 100644 --- a/src/common/cast.h +++ b/src/common/cast.h @@ -1,7 +1,8 @@ #ifndef INFINI_OPS_COMMON_CAST_H_ #define INFINI_OPS_COMMON_CAST_H_ -#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) || defined(WITH_METAX) +#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) || defined(WITH_METAX) || \ + defined(WITH_MOORE) #include "common/cuda/cast.h" #else #include "common/cpu/cast.h" diff --git a/src/common/cuda/cast.h b/src/common/cuda/cast.h index d3dcdb9..1f67a44 100644 --- a/src/common/cuda/cast.h +++ b/src/common/cuda/cast.h @@ -7,6 +7,8 @@ #include #elif defined(WITH_METAX) #include +#elif defined(WITH_MOORE) +#include #endif #include "data_type.h" diff --git a/src/common/cuda/kernel_commons.h b/src/common/cuda/kernel_commons.h index 3c85031..e2ef107 100644 --- a/src/common/cuda/kernel_commons.h +++ b/src/common/cuda/kernel_commons.h @@ -17,6 +17,12 @@ using cuda_bfloat162 = nv_bfloat162; #include using cuda_bfloat16 = maca_bfloat16; using cuda_bfloat162 = maca_bfloat162; +#elif defined(WITH_MOORE) +#include +#include +#include +using cuda_bfloat16 = __mt_bfloat16; +using cuda_bfloat162 = __mt_bfloat162; #endif #include @@ -72,6 +78,14 @@ inline int QueryMaxThreadsPerBlock() { // TODO: Add MCR device properties query for Metax. return CUDA_BLOCK_SIZE_256; } +#elif defined(WITH_MOORE) +inline int QueryMaxThreadsPerBlock() { + int device = 0; + musaGetDevice(&device); + musaDeviceProp prop; + musaGetDeviceProperties(&prop, device); + return prop.maxThreadsPerBlock; +} #endif // Get optimal block size based on GPU hardware architecture. diff --git a/src/cuda/add/kernel.cuh b/src/cuda/add/kernel.cuh index 2d58809..6903925 100644 --- a/src/cuda/add/kernel.cuh +++ b/src/cuda/add/kernel.cuh @@ -5,6 +5,24 @@ namespace infini::ops { +namespace detail { + +template +struct HasHAdd : std::false_type {}; + +template +struct HasHAdd< + T, std::void_t< + decltype(__hadd(std::declval(), std::declval())), + std::enable_if_t(), std::declval())), T>>>> + : std::true_type {}; + +template +inline constexpr bool HasHAddValue = HasHAdd::value; + +} // namespace detail + struct AddOp { static constexpr std::size_t num_inputs = 2; @@ -13,8 +31,10 @@ struct AddOp { const T& other) const { if constexpr (std::is_same_v) { return __hadd2(input, other); - } else if constexpr (std::is_same_v || - std::is_same_v>) { + } else if constexpr ((std::is_same_v || + std::is_same_v>) && + detail::HasHAddValue) { return __hadd(input, other); } else if constexpr (std::is_same_v) { return __fadd_rn(input, other); diff --git a/src/data_type.h b/src/data_type.h index 8a3e544..af2aec7 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -14,6 +14,9 @@ #elif defined(WITH_METAX) #include #include +#elif defined(WITH_MOORE) +#include +#include #endif #include "common/constexpr_map.h" @@ -201,6 +204,9 @@ DEFINE_DATA_TYPE_MAPPING(kBFloat16, __nv_bfloat16) #elif defined(WITH_METAX) DEFINE_DATA_TYPE_MAPPING(kFloat16, __half) DEFINE_DATA_TYPE_MAPPING(kBFloat16, __maca_bfloat16) +#elif defined(WITH_MOORE) +DEFINE_DATA_TYPE_MAPPING(kFloat16, half) +DEFINE_DATA_TYPE_MAPPING(kBFloat16, __mt_bfloat16) #else DEFINE_DATA_TYPE_MAPPING(kFloat16, Float16) DEFINE_DATA_TYPE_MAPPING(kBFloat16, BFloat16) diff --git a/src/moore/add/kernel.h b/src/moore/add/kernel.h new file mode 100644 index 0000000..5092f25 --- /dev/null +++ b/src/moore/add/kernel.h @@ -0,0 +1,44 @@ +#ifndef INFINI_OPS_MOORE_ADD_KERNEL_H_ +#define INFINI_OPS_MOORE_ADD_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/add/kernel.h" + +namespace infini::ops { + +namespace add { + +struct MooreBackend { + using stream_t = musaStream_t; + + static constexpr auto malloc = [](auto&&... args) { + return musaMalloc(std::forward(args)...); + }; + + static constexpr auto memcpy = [](auto&&... args) { + return musaMemcpy(std::forward(args)...); + }; + + static constexpr auto free = [](auto&&... args) { + return musaFree(std::forward(args)...); + }; + + static constexpr auto memcpyH2D = musaMemcpyHostToDevice; +}; + +} // namespace add + +template <> +class Operator : public CudaAdd { + public: + using CudaAdd::CudaAdd; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_add.py b/tests/test_add.py index d1ea0f8..8b8166c 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -41,6 +41,11 @@ def test_add( shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol ): + if device == "musa" and dtype in _UINT_DTYPES: + pytest.skip( + "The `torch.musa` test cloning path does not support `uint16`, `uint32`, or `uint64`." + ) + if dtype in _INT_DTYPES or dtype in _UINT_DTYPES: input = randint_strided( 0, 100, shape, input_strides, dtype=dtype, device=device From f0fccb124f91c66b78733d989c27af9544f13e57 Mon Sep 17 00:00:00 2001 From: gongchensu Date: Thu, 19 Mar 2026 21:50:11 +0800 Subject: [PATCH 84/93] feat(moore): add Moore SwiGLU (#24) * feat(moore): add swiglu-op support with musa integration - add a Moore swiglu backend on top of the shared CUDA-style path - extract shared swiglu compute into a reusable op for backend override - keep Moore-specific half and bfloat162 handling in the Moore backend only * refactor: introduce `src/moore/polyfills.cuh` * refactor: use polyfills for Moore SwiGLU --------- Co-authored-by: Jiacheng Huang --- src/cuda/add/kernel.cuh | 24 ++----------------- src/cuda/swiglu/kernel.h | 4 ---- src/moore/add/kernel.h | 4 ++++ src/moore/polyfills.cuh | 41 ++++++++++++++++++++++++++++++++ src/moore/swiglu/kernel.h | 49 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 96 insertions(+), 26 deletions(-) create mode 100644 src/moore/polyfills.cuh create mode 100644 src/moore/swiglu/kernel.h diff --git a/src/cuda/add/kernel.cuh b/src/cuda/add/kernel.cuh index 6903925..2d58809 100644 --- a/src/cuda/add/kernel.cuh +++ b/src/cuda/add/kernel.cuh @@ -5,24 +5,6 @@ namespace infini::ops { -namespace detail { - -template -struct HasHAdd : std::false_type {}; - -template -struct HasHAdd< - T, std::void_t< - decltype(__hadd(std::declval(), std::declval())), - std::enable_if_t(), std::declval())), T>>>> - : std::true_type {}; - -template -inline constexpr bool HasHAddValue = HasHAdd::value; - -} // namespace detail - struct AddOp { static constexpr std::size_t num_inputs = 2; @@ -31,10 +13,8 @@ struct AddOp { const T& other) const { if constexpr (std::is_same_v) { return __hadd2(input, other); - } else if constexpr ((std::is_same_v || - std::is_same_v>) && - detail::HasHAddValue) { + } else if constexpr (std::is_same_v || + std::is_same_v>) { return __hadd(input, other); } else if constexpr (std::is_same_v) { return __fadd_rn(input, other); diff --git a/src/cuda/swiglu/kernel.h b/src/cuda/swiglu/kernel.h index d05a1af..47849fe 100644 --- a/src/cuda/swiglu/kernel.h +++ b/src/cuda/swiglu/kernel.h @@ -3,10 +3,6 @@ #include -// clang-format off -#include // TODO: Remove this -// clang-format on - #include "base/swiglu.h" #include "common/generic_utils.h" #include "cuda/swiglu/kernel.cuh" diff --git a/src/moore/add/kernel.h b/src/moore/add/kernel.h index 5092f25..21a51f6 100644 --- a/src/moore/add/kernel.h +++ b/src/moore/add/kernel.h @@ -7,6 +7,10 @@ #include // clang-format on +// clang-format off +#include "moore/polyfills.cuh" +// clang-format on + #include "cuda/add/kernel.h" namespace infini::ops { diff --git a/src/moore/polyfills.cuh b/src/moore/polyfills.cuh new file mode 100644 index 0000000..b3c7e70 --- /dev/null +++ b/src/moore/polyfills.cuh @@ -0,0 +1,41 @@ +#ifndef INFINI_OPS_MOORE_POLYFILLS_CUH_ +#define INFINI_OPS_MOORE_POLYFILLS_CUH_ + +#include + +// clang-format off +#include +// clang-format on + +namespace infini::ops { + +template +__device__ __forceinline__ T __hadd(const T& a, const T& b) { + return a + b; +} + +template +__device__ __forceinline__ auto __high2bfloat16(const T& a) { + return __float2bfloat16_rn(::__high2float(a)); +} + +template +__device__ __forceinline__ T __hneg(const T& a) { + return -a; +} + +template +__device__ __forceinline__ auto __low2bfloat16(const T& a) { + return __float2bfloat16_rn(::__low2float(a)); +} + +template +__device__ __forceinline__ T hrcp(const T& a) { + return T(__frcp_rn(static_cast(a))); +} + +} // namespace infini::ops + +#define hrcp infini::ops::hrcp + +#endif diff --git a/src/moore/swiglu/kernel.h b/src/moore/swiglu/kernel.h new file mode 100644 index 0000000..a7759fb --- /dev/null +++ b/src/moore/swiglu/kernel.h @@ -0,0 +1,49 @@ +#ifndef INFINI_OPS_MOORE_SWIGLU_KERNEL_H_ +#define INFINI_OPS_MOORE_SWIGLU_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +// clang-format off +#include "moore/polyfills.cuh" +// clang-format on + +#include "cuda/swiglu/kernel.h" + +namespace infini::ops { + +namespace swiglu { + +struct MooreBackend { + using stream_t = musaStream_t; + + static constexpr auto malloc = [](auto&&... args) { + return musaMalloc(std::forward(args)...); + }; + + static constexpr auto memcpy = [](auto&&... args) { + return musaMemcpy(std::forward(args)...); + }; + + static constexpr auto free = [](auto&&... args) { + return musaFree(std::forward(args)...); + }; + + static constexpr auto memcpyH2D = musaMemcpyHostToDevice; +}; + +} // namespace swiglu + +template <> +class Operator + : public CudaSwiglu { + public: + using CudaSwiglu::CudaSwiglu; +}; + +} // namespace infini::ops + +#endif From 1b0b5acce5ab15136010b6876c5716a50081c11b Mon Sep 17 00:00:00 2001 From: gongchensu Date: Fri, 20 Mar 2026 15:49:29 +0800 Subject: [PATCH 85/93] feat(ops): add MetaX `causal_softmax` (#27) - add MetaX operator specialization - make the shared CUDA-style kernel compatible with MetaX - reuse common casting utilities for fp16 and bf16 conversions Co-authored-by: gongchensu --- src/cuda/causal_softmax/kernel.cuh | 31 +++++++++--------------------- src/cuda/causal_softmax/kernel.h | 4 ---- src/metax/causal_softmax/kernel.h | 31 ++++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 26 deletions(-) create mode 100644 src/metax/causal_softmax/kernel.h diff --git a/src/cuda/causal_softmax/kernel.cuh b/src/cuda/causal_softmax/kernel.cuh index d195237..d578998 100644 --- a/src/cuda/causal_softmax/kernel.cuh +++ b/src/cuda/causal_softmax/kernel.cuh @@ -1,27 +1,20 @@ #ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_CUH_ #define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_CUH_ -#include -#include - #include #include #include +#include "common/cuda/cast.h" +#include "common/cuda/kernel_commons.h" + namespace infini::ops { namespace { template __device__ __forceinline__ Data ExpAndCast(Compute x) { - Compute e = std::exp(x); - if constexpr (std::is_same_v) { - return __float2half(static_cast(e)); - } else if constexpr (std::is_same_v) { - return __float2bfloat16(static_cast(e)); - } else { - return static_cast(e); - } + return Cast(expf(Cast(x))); } struct BlockMaxOp { @@ -48,7 +41,7 @@ __device__ __forceinline__ Compute BlockSum(const Data* data_ptr, size_t count) { Compute thread_sum = 0; for (size_t i = threadIdx.x; i < count; i += block_size) { - thread_sum += Compute(data_ptr[i]); + thread_sum += Cast(data_ptr[i]); } using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -83,10 +76,10 @@ __global__ void CausalSoftmaxKernel( for (size_t col = threadIdx.x; col < total_seq_len; col += block_size) { if (col < valid_len) { Compute diff = - static_cast(input_row[col]) - static_cast(max_val); + Cast(input_row[col]) - Cast(max_val); out_row[col] = ExpAndCast(diff); } else { - out_row[col] = Data(0); + out_row[col] = Cast(0.0f); } } __syncthreads(); @@ -100,14 +93,8 @@ __global__ void CausalSoftmaxKernel( __syncthreads(); for (size_t col = threadIdx.x; col < total_seq_len; col += block_size) { - Compute quot = static_cast(out_row[col]) / sum_val; - if constexpr (std::is_same_v) { - out_row[col] = __float2half(static_cast(quot)); - } else if constexpr (std::is_same_v) { - out_row[col] = __float2bfloat16(static_cast(quot)); - } else { - out_row[col] = static_cast(quot); - } + Compute quot = Cast(out_row[col]) / sum_val; + out_row[col] = Cast(quot); } } diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h index 23a040a..924be40 100644 --- a/src/cuda/causal_softmax/kernel.h +++ b/src/cuda/causal_softmax/kernel.h @@ -3,10 +3,6 @@ #include -// clang-format off -#include // TODO: Remove this -// clang-format on - #include "base/causal_softmax.h" #include "common/cuda/kernel_commons.h" #include "cuda/causal_softmax/kernel.cuh" diff --git a/src/metax/causal_softmax/kernel.h b/src/metax/causal_softmax/kernel.h new file mode 100644 index 0000000..d44919f --- /dev/null +++ b/src/metax/causal_softmax/kernel.h @@ -0,0 +1,31 @@ +#ifndef INFINI_OPS_METAX_CAUSAL_SOFTMAX_KERNEL_H_ +#define INFINI_OPS_METAX_CAUSAL_SOFTMAX_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/causal_softmax/kernel.h" + +namespace infini::ops { + +namespace causal_softmax { + +struct MetaxBackend { + using stream_t = mcStream_t; +}; + +} // namespace causal_softmax + +template <> +class Operator + : public CudaCausalSoftmax { + public: + using CudaCausalSoftmax::CudaCausalSoftmax; +}; + +} // namespace infini::ops + +#endif From f44be6fedf547468bec581736e9d4e5c96dc933b Mon Sep 17 00:00:00 2001 From: gongchensu Date: Fri, 20 Mar 2026 16:04:57 +0800 Subject: [PATCH 86/93] feat(ops): add MetaX backend for `RmsNorm` (#25) - add MetaX `RmsNorm` operator specialization - make the shared CUDA-style rms_norm kernel compatible with MetaX - forward runtime `eps` when launching the kernel Co-authored-by: gongchensu --- src/base/rms_norm.h | 1 + src/cuda/rms_norm/kernel.cuh | 39 ++++++++++++++++++------------------ src/cuda/rms_norm/kernel.h | 18 +++++++---------- src/metax/rms_norm/kernel.h | 31 ++++++++++++++++++++++++++++ 4 files changed, 59 insertions(+), 30 deletions(-) create mode 100644 src/metax/rms_norm/kernel.h diff --git a/src/base/rms_norm.h b/src/base/rms_norm.h index 3b40a1c..65f44b3 100644 --- a/src/base/rms_norm.h +++ b/src/base/rms_norm.h @@ -25,6 +25,7 @@ class RmsNorm : public Operator { RmsNorm(const Tensor input, const Tensor weight, Tensor out) : RmsNorm{input, weight, 1e-6f, out} {} + // TODO: Type of `eps` should be `std::optional` instead of `float`. virtual void operator()(const Tensor input, const Tensor weight, float eps, Tensor out) const = 0; diff --git a/src/cuda/rms_norm/kernel.cuh b/src/cuda/rms_norm/kernel.cuh index 98383f3..10228a6 100644 --- a/src/cuda/rms_norm/kernel.cuh +++ b/src/cuda/rms_norm/kernel.cuh @@ -1,39 +1,39 @@ #ifndef INFINI_OPS_CUDA_RMS_NORM_KERNEL_CUH_ #define INFINI_OPS_CUDA_RMS_NORM_KERNEL_CUH_ -#include -#include - #include #include #include +#include "common/cuda/cast.h" +#include "common/cuda/kernel_commons.h" + namespace infini::ops { namespace { -template -__device__ __forceinline__ Compute SumSquared(const Data* data_ptr, - size_t count) { - Compute ss = 0; +template +__device__ __forceinline__ TCompute SumSquared(const TData* data_ptr, + size_t count) { + TCompute ss = 0; for (size_t i = threadIdx.x; i < count; i += block_size) { - Compute val = Compute(data_ptr[i]); - ss += val * val; + TCompute value = Cast(data_ptr[i]); + ss += value * value; } - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; return BlockReduce(temp_storage).Sum(ss); } } // namespace -template -__global__ void RmsNormKernel(Data* __restrict__ y, int64_t stride_y_batch, +template +__global__ void RmsNormKernel(TData* __restrict__ y, int64_t stride_y_batch, int64_t stride_y_nhead, - const Data* __restrict__ x, + const TData* __restrict__ x, int64_t stride_x_batch, int64_t stride_x_nhead, - const Weight* __restrict__ w, size_t nhead, + const TWeight* __restrict__ w, size_t nhead, size_t dim, float epsilon) { size_t batch_idx = blockIdx.x / nhead; size_t head_idx = blockIdx.x % nhead; @@ -42,16 +42,17 @@ __global__ void RmsNormKernel(Data* __restrict__ y, int64_t stride_y_batch, auto x_ptr = x + batch_idx * stride_x_batch + head_idx * stride_x_nhead; auto w_ptr = w; - Compute ss = SumSquared(x_ptr, dim); + TCompute ss = SumSquared(x_ptr, dim); - __shared__ Compute rms; + __shared__ TCompute rms; if (threadIdx.x == 0) { - rms = Compute(rsqrtf(ss / Compute(dim) + epsilon)); + rms = Cast(rsqrtf(ss / Cast(dim) + epsilon)); } __syncthreads(); for (size_t i = threadIdx.x; i < dim; i += block_size) { - y_ptr[i] = Data(Compute(x_ptr[i]) * Compute(w_ptr[i]) * rms); + y_ptr[i] = + Cast(Cast(x_ptr[i]) * Cast(w_ptr[i]) * rms); } } diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h index f450e91..dc28ee5 100644 --- a/src/cuda/rms_norm/kernel.h +++ b/src/cuda/rms_norm/kernel.h @@ -3,10 +3,6 @@ #include -// clang-format off -#include // TODO: Remove this -// clang-format on - #include "base/rms_norm.h" #include "common/cuda/kernel_commons.h" #include "cuda/rms_norm/kernel.cuh" @@ -45,13 +41,13 @@ class CudaRmsNorm : public RmsNorm { [&](auto tag) { using T = typename decltype(tag)::type; -#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \ - RmsNormKernel \ - <<>>( \ - reinterpret_cast(out.data()), stride_out_batch, \ - stride_out_nhead, reinterpret_cast(input.data()), \ - stride_input_batch, stride_input_nhead, \ - reinterpret_cast(weight.data()), nhead_, dim_, eps_); +#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \ + RmsNormKernel \ + <<>>( \ + reinterpret_cast(out.data()), stride_out_batch, \ + stride_out_nhead, reinterpret_cast(input.data()), \ + stride_input_batch, stride_input_nhead, \ + reinterpret_cast(weight.data()), nhead_, dim_, eps); if (block_size == CUDA_BLOCK_SIZE_2048) { LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_2048) diff --git a/src/metax/rms_norm/kernel.h b/src/metax/rms_norm/kernel.h new file mode 100644 index 0000000..b724552 --- /dev/null +++ b/src/metax/rms_norm/kernel.h @@ -0,0 +1,31 @@ +#ifndef INFINI_OPS_METAX_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_METAX_RMS_NORM_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/rms_norm/kernel.h" + +namespace infini::ops { + +namespace rms_norm { + +struct MetaxBackend { + using stream_t = mcStream_t; +}; + +} // namespace rms_norm + +template <> +class Operator + : public CudaRmsNorm { + public: + using CudaRmsNorm::CudaRmsNorm; +}; + +} // namespace infini::ops + +#endif From 61fcdf783281bf86e0c744aedd2fdd01828e415c Mon Sep 17 00:00:00 2001 From: gongchensu Date: Fri, 20 Mar 2026 16:14:28 +0800 Subject: [PATCH 87/93] feat(ops): add MetaX backend for `Swiglu` (#28) Co-authored-by: gongchensu --- src/metax/swiglu/kernel.h | 41 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 src/metax/swiglu/kernel.h diff --git a/src/metax/swiglu/kernel.h b/src/metax/swiglu/kernel.h new file mode 100644 index 0000000..77416aa --- /dev/null +++ b/src/metax/swiglu/kernel.h @@ -0,0 +1,41 @@ +#ifndef INFINI_OPS_METAX_SWIGLU_KERNEL_H_ +#define INFINI_OPS_METAX_SWIGLU_KERNEL_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/swiglu/kernel.h" + +namespace infini::ops { + +namespace swiglu { + +struct MetaxBackend { + using stream_t = mcStream_t; + + static constexpr auto malloc = [](auto&&... args) { + return mcMalloc(std::forward(args)...); + }; + + static constexpr auto memcpy = mcMemcpy; + + static constexpr auto free = mcFree; + + static constexpr auto memcpyH2D = mcMemcpyHostToDevice; +}; + +} // namespace swiglu + +template <> +class Operator + : public CudaSwiglu { + public: + using CudaSwiglu::CudaSwiglu; +}; + +} // namespace infini::ops + +#endif From 3557dda43137a7bd217282d27299f62477767d3e Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Wed, 25 Mar 2026 16:58:48 +0800 Subject: [PATCH 88/93] feat: develop CI infrastructure (#21) * feat/nv ci test * feat: ci sys for nv platform * fix(ci): fix results dir permissions and reduce parallel workers - Pass host UID/GID into container and `chown` results after tests, so mounted `ci-results/` is accessible by the host user. - Limit `pytest-xdist` workers from `-n auto` to `-n 8` to prevent OOM worker crashes on high-core-count machines. Co-Authored-By: Claude Opus 4.6 * refactor(ci): Refactor code structure for improved readability and maintainability * docs: add multi-machine deployment guide for NVIDIA and Iluvatar platform * feat(ci): enhance CI configuration and agent functionality with platform detection and job resolution * feat(ci): add MetaX platform CI support Add Dockerfile, config, and mx-smi GPU detection for MetaX (MACA) platform. Co-Authored-By: Claude Opus 4.6 * feat(ci): improve job dispatch logging and handle job results more effectively * feat(ci): add Moore Threads (MUSA) platform CI support Add GPU detection via mthreads-gmi, Dockerfile, config, and update docs with Moore and MetaX platform deployment instructions. Co-Authored-By: Claude Opus 4.6 * feat(ci): capture Docker error output for remote job diagnostics * feat(ci): capture error output and improve CLI result display - Capture last 50 lines of Docker output via ring buffer so failed jobs return diagnostic info to the CLI client. - Store raw bytes during execution; decode only on the failure path. - Align job name columns in `<==` result lines for readability. - Show summary only when jobs fail, removing redundant all-pass output. Co-Authored-By: Claude * feat(ci): add Cambricon MLU platform CI support - Add .ci/images/cambricon/Dockerfile for AnolisOS-based Cambricon image - Add cambricon platform to config.yaml with MLU-style GPU passthrough - Add GPU_STYLE_MLU constant and MLU_VISIBLE_DEVICES support in run.py - Add cnmon-based GPU detection (_detect_gpus_cambricon) in ci_resource.py - Add --test CLI flag to override pytest test path at runtime - Skip empty stage run commands instead of erroring (compilation-only mode) - Fix _torch_gemm fallback for CPU float16/bfloat16 (upcast to float32) - Skip bfloat16 on MLU (cnnlBatchMatMulEx does not support it) - Hoist _PYTEST_VALUE_FLAGS to module level; add ValueError guard in cambricon parser - Remove redundant yaml import guard in agent.py (utils.py already handles it) Co-Authored-By: Claude Sonnet 4.6 * docs(ci): translate README and comments to English, use ngpus for NVIDIA scheduler - Rewrite README.md entirely in English; add Cambricon to platform table and directory tree. - Translate all inline comments in config.yaml to English. - Replace `gpu_ids: "0"` with `ngpus: 1` for NVIDIA platform so the scheduler auto-picks a free GPU rather than pinning to device 0. - Add `ngpus` support to `parse_gpu_requirement` in ci_resource.py so scheduler correctly counts NVIDIA GPU demand. - Replace deprecated `gpu_count` fallback with `ngpus` in run.py `build_docker_args`. Co-Authored-By: Claude * feat(ci): add --local flag to run.py for testing uncommitted changes - Mount current directory read-only into container via `-v cwd:/workspace/repo:ro` - Copy to writable `/tmp/src` inside container before setup runs, so host files are never modified by pip install or build artifacts - Simplify README: fix ngpus example, add gpu_style column, add --local docs Co-Authored-By: Claude * style(ci): normalize comments to complete English sentences with markdown - Backtick-quote tool/package names (`torch`, `pip`, `git`, `cmake`, `coreutils-single`, `conda`) and paths in Dockerfile comments. - Add explanatory comment to the commented-out `agents:` block in `config.yaml` describing when to uncomment it. - Convert all section-header banners in `.ci/tests/` to "Tests for `FunctionName`." sentence form; fix three docstrings in `test_agent.py`. - Backtick-quote identifiers in `tests/test_gemm.py` inline comments. Co-Authored-By: Claude * style(tests): backtick-quote identifiers in test_gemm.py skip message Co-Authored-By: Claude --------- Co-authored-by: Claude Opus 4.6 --- .ci/README.md | 386 +++++++++++++ .ci/agent.py | 988 ++++++++++++++++++++++++++++++++ .ci/build.py | 260 +++++++++ .ci/ci_resource.py | 478 +++++++++++++++ .ci/config.yaml | 146 +++++ .ci/github_status.py | 100 ++++ .ci/images/ascend/Dockerfile | 39 ++ .ci/images/cambricon/Dockerfile | 33 ++ .ci/images/iluvatar/Dockerfile | 53 ++ .ci/images/metax/Dockerfile | 46 ++ .ci/images/moore/Dockerfile | 38 ++ .ci/images/nvidia/Dockerfile | 46 ++ .ci/run.py | 411 +++++++++++++ .ci/tests/__init__.py | 0 .ci/tests/conftest.py | 46 ++ .ci/tests/test_agent.py | 535 +++++++++++++++++ .ci/tests/test_build.py | 186 ++++++ .ci/tests/test_github_status.py | 145 +++++ .ci/tests/test_resource.py | 327 +++++++++++ .ci/tests/test_run.py | 298 ++++++++++ .ci/tests/test_utils.py | 90 +++ .ci/utils.py | 112 ++++ pyproject.toml | 2 +- tests/test_gemm.py | 12 +- 24 files changed, 4773 insertions(+), 4 deletions(-) create mode 100644 .ci/README.md create mode 100644 .ci/agent.py create mode 100644 .ci/build.py create mode 100644 .ci/ci_resource.py create mode 100644 .ci/config.yaml create mode 100644 .ci/github_status.py create mode 100644 .ci/images/ascend/Dockerfile create mode 100644 .ci/images/cambricon/Dockerfile create mode 100644 .ci/images/iluvatar/Dockerfile create mode 100644 .ci/images/metax/Dockerfile create mode 100644 .ci/images/moore/Dockerfile create mode 100644 .ci/images/nvidia/Dockerfile create mode 100644 .ci/run.py create mode 100644 .ci/tests/__init__.py create mode 100644 .ci/tests/conftest.py create mode 100644 .ci/tests/test_agent.py create mode 100644 .ci/tests/test_build.py create mode 100644 .ci/tests/test_github_status.py create mode 100644 .ci/tests/test_resource.py create mode 100644 .ci/tests/test_run.py create mode 100644 .ci/tests/test_utils.py create mode 100644 .ci/utils.py diff --git a/.ci/README.md b/.ci/README.md new file mode 100644 index 0000000..190d012 --- /dev/null +++ b/.ci/README.md @@ -0,0 +1,386 @@ +# .ci — CI Images and Pipeline + +``` +.ci/ +├── config.yaml # Unified config (images, jobs, agent definitions) +├── utils.py # Shared utilities (load_config, normalize_config, get_git_commit) +├── agent.py # Runner Agent (scheduler, webhooks, remote dispatch) +├── build.py # Image builder +├── run.py # CI pipeline runner (Docker layer) +├── ci_resource.py # GPU/memory detection and allocation +├── github_status.py # GitHub Commit Status reporting +├── images/ +│ ├── nvidia/Dockerfile +│ ├── iluvatar/Dockerfile +│ ├── metax/Dockerfile +│ ├── moore/Dockerfile +│ ├── cambricon/Dockerfile +│ └── ascend/Dockerfile +└── tests/ # Unit tests + ├── conftest.py + ├── test_agent.py + ├── test_build.py + ├── test_run.py + ├── test_resource.py + ├── test_github_status.py + └── test_utils.py +``` + +**Prerequisites**: Docker, Python 3.10+, `pip install pyyaml` + +--- + +## Configuration `config.yaml` + +Config uses a **platform-centric** top-level structure. Each platform defines its image, platform-level defaults, and job list. +At load time, jobs are flattened to `{platform}_{job}` format (e.g., `nvidia_gpu`). + +```yaml +repo: + url: https://github.com/InfiniTensor/InfiniOps.git + branch: master + +github: + status_context_prefix: "ci/infiniops" + +agents: # Remote agent URLs (used by CLI for cross-machine dispatch) + nvidia: + url: http://nvidia-host:8080 + iluvatar: + url: http://iluvatar-host:8080 + +platforms: + nvidia: + image: # Image definition + dockerfile: .ci/images/nvidia/ + build_args: + BASE_IMAGE: nvcr.io/nvidia/pytorch:24.10-py3 + setup: pip install .[dev] --no-build-isolation + jobs: + gpu: # Flattened as nvidia_gpu + resources: + ngpus: 1 # Scheduler auto-picks this many free GPUs + memory: 32GB + shm_size: 16g + timeout: 3600 + stages: + - name: test + run: pytest tests/ -n 8 -v --tb=short --junitxml=/workspace/results/test-results.xml + + iluvatar: + image: + dockerfile: .ci/images/iluvatar/ + build_args: + BASE_IMAGE: corex:qs_pj20250825 + APT_MIRROR: http://archive.ubuntu.com/ubuntu + PIP_INDEX_URL: https://pypi.org/simple + docker_args: # Platform-level docker args, inherited by all jobs + - "--privileged" + - "--cap-add=ALL" + - "--pid=host" + - "--ipc=host" + volumes: + - /dev:/dev + - /lib/firmware:/lib/firmware + - /usr/src:/usr/src + - /lib/modules:/lib/modules + setup: pip install .[dev] --no-build-isolation + jobs: + gpu: # Flattened as iluvatar_gpu + resources: + gpu_ids: "0" + gpu_style: none # CoreX: passthrough via --privileged + /dev mount + memory: 32GB + shm_size: 16g + timeout: 3600 + stages: + - name: test + run: pytest tests/ -n 8 -v --tb=short --junitxml=/workspace/results/test-results.xml +``` + +### Config hierarchy + +| Level | Field | Description | +|---|---|---| +| **Platform** | `image` | Image definition (dockerfile, build_args) | +| | `image_tag` | Default image tag (defaults to `latest`) | +| | `docker_args` | Extra `docker run` args (e.g., `--privileged`) | +| | `volumes` | Extra volume mounts | +| | `setup` | In-container setup command | +| | `env` | Injected container env vars | +| **Job** | `resources.ngpus` | Number of GPUs — scheduler auto-picks free ones (NVIDIA only) | +| | `resources.gpu_ids` | Static GPU device IDs (e.g., `"0"`, `"0,2"`) | +| | `resources.gpu_style` | GPU passthrough: `nvidia` (default), `none`, or `mlu` | +| | `resources.memory` | Container memory limit | +| | `resources.shm_size` | Shared memory size | +| | `resources.timeout` | Max run time in seconds | +| | `stages` | Execution stage list | +| | Any platform field | Jobs can override any platform-level default | + +--- + +## Image builder `build.py` + +| Flag | Description | +|---|---| +| `--platform nvidia\|iluvatar\|metax\|moore\|ascend\|all` | Target platform (default: `all`) | +| `--commit` | Use specific commit ref as image tag (default: HEAD) | +| `--force` | Skip Dockerfile change detection | +| `--dry-run` | Print commands without executing | + +```bash +# Build with change detection (skips if no Dockerfile changes) +python .ci/build.py --platform nvidia + +# Build Iluvatar image +python .ci/build.py --platform iluvatar --force + +# Force build all platforms +python .ci/build.py --force +``` + +Build artifacts are stored as local Docker image tags: `infiniops-ci/:` and `:latest`. +Proxy and `no_proxy` env vars are forwarded from the host to `docker build` automatically. + +> `--push` is reserved for future use; requires a `registry` section in `config.yaml`. + +--- + +## Pipeline runner `run.py` + +Platform is auto-detected (via `nvidia-smi`/`ixsmi`/`mx-smi`/`mthreads-gmi`/`cnmon` on PATH), no manual specification needed. + +| Flag | Description | +|---|---| +| `--config` | Config file path (default: `.ci/config.yaml`) | +| `--job` | Job name: short (`gpu`) or full (`nvidia_gpu`). Defaults to all jobs for the current platform | +| `--branch` | Override clone branch (default: config `repo.branch`) | +| `--stage` | Run only the specified stage | +| `--image-tag` | Override image tag | +| `--gpu-id` | Override GPU device IDs (nvidia via `--gpus`, others via `CUDA_VISIBLE_DEVICES`) | +| `--test` | Override pytest test path (e.g., `tests/test_gemm.py::test_gemm`) | +| `--results-dir` | Host directory mounted to `/workspace/results` inside the container | +| `--local` | Mount current directory (read-only) instead of cloning from git | +| `--dry-run` | Print docker command without executing | + +```bash +# Simplest usage: auto-detect platform, run all jobs, use config default branch +python .ci/run.py + +# Specify short job name +python .ci/run.py --job gpu + +# Full job name (backward compatible) +python .ci/run.py --job nvidia_gpu + +# Run only the test stage, preview mode +python .ci/run.py --job gpu --stage test --dry-run + +# Test local uncommitted changes without pushing +python .ci/run.py --local +``` + +Container execution flow: `git clone` → `checkout` → `setup` → stages. +With `--local`, the current directory is mounted read-only at `/workspace/repo` and copied to a writable temp directory inside the container before setup runs — host files are never modified. +Proxy vars are forwarded from the host. Test results are written to `--results-dir`. Each run uses a clean environment (no host pip cache mounted). + +--- + +## Platform differences + +| Platform | GPU passthrough | `gpu_style` | Base image | Detection tool | +|---|---|---|---|---| +| NVIDIA | `--gpus` (NVIDIA Container Toolkit) | `nvidia` (default) | `nvcr.io/nvidia/pytorch:24.10-py3` | `nvidia-smi` | +| Iluvatar | `--privileged` + `/dev` mount | `none` | `corex:qs_pj20250825` | `ixsmi` | +| MetaX | `--privileged` | `none` | `maca-pytorch:3.2.1.4-...` | `mx-smi` | +| Moore | `--privileged` | `none` | `vllm_musa:20251112_hygon` | `mthreads-gmi` | +| Cambricon | `--privileged` | `mlu` | `cambricon/pytorch:v1.25.3` | `cnmon` | +| Ascend | TODO | — | `ascend-pytorch:24.0.0` | — | + +`gpu_style` controls the Docker device injection mechanism: `nvidia` uses `--gpus`, `none` uses `CUDA_VISIBLE_DEVICES` (or skips injection for Moore), `mlu` uses `MLU_VISIBLE_DEVICES`. + +--- + +## Runner Agent `agent.py` + +The Runner Agent supports CLI manual dispatch, GitHub webhook triggers, resource-aware dynamic scheduling, and cross-machine remote dispatch. + +### CLI manual execution + +```bash +# Run all jobs (dispatched to remote agents, using config default branch) +python .ci/agent.py run + +# Specify branch +python .ci/agent.py run --branch feat/xxx + +# Run a specific job +python .ci/agent.py run --job nvidia_gpu + +# Filter by platform +python .ci/agent.py run --platform nvidia + +# Preview mode +python .ci/agent.py run --dry-run +``` + +| Flag | Description | +|---|---| +| `--branch` | Test branch (default: config `repo.branch`) | +| `--job` | Specific job name | +| `--platform` | Filter jobs by platform | +| `--commit` | Override commit SHA used for GitHub status reporting | +| `--image-tag` | Override image tag | +| `--dry-run` | Preview mode | + +### Webhook server + +Deploy one Agent instance per platform machine (platform is auto-detected). On each machine: + +```bash +python .ci/agent.py serve --port 8080 +``` + +Additional `serve` flags: + +| Flag | Description | +|---|---| +| `--port` | Listen port (default: 8080) | +| `--host` | Listen address (default: `0.0.0.0`) | +| `--webhook-secret` | GitHub webhook signing secret (or `WEBHOOK_SECRET` env var) | +| `--api-token` | `/api/run` Bearer auth token (or `AGENT_API_TOKEN` env var) | +| `--results-dir` | Results directory (default: `ci-results`) | +| `--utilization-threshold` | GPU idle threshold percentage (default: 10) | + +| Endpoint | Method | Description | +|---|---|---| +| `/webhook` | POST | GitHub webhook (push/pull_request) | +| `/api/run` | POST | Remote job trigger | +| `/api/job/{id}` | GET | Query job status | +| `/health` | GET | Health check | +| `/status` | GET | Queue + resource status | + +Webhook supports `X-Hub-Signature-256` signature verification via `--webhook-secret` or `WEBHOOK_SECRET` env var. + +### Remote agent configuration + +Configure agent URLs in `config.yaml`; the CLI automatically dispatches remote jobs to the corresponding agents: + +```yaml +agents: + nvidia: + url: http://:8080 + iluvatar: + url: http://:8080 + metax: + url: http://:8080 + moore: + url: http://:8080 +``` + +### Resource scheduling + +The Agent auto-detects GPU utilization and system memory to dynamically determine parallelism: +- GPU utilization < threshold (default 10%) and not allocated by Agent → available +- When resources are insufficient, jobs are queued automatically; completed jobs release resources and trigger scheduling of queued tasks + +### GitHub Status + +Set the `GITHUB_TOKEN` env var and the Agent will automatically report commit status: +- `pending` — job started +- `success` / `failure` — job completed + +Status context format: `ci/infiniops/{job_name}` + +--- + +## Multi-machine deployment guide + +### Per-platform setup + +Each machine needs Docker installed, the platform runtime, and the base CI image built. + +| Platform | Runtime check | Base image | Build command | +|---|---|---|---| +| NVIDIA | `nvidia-smi` (+ [Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)) | `nvcr.io/nvidia/pytorch:24.10-py3` (public) | `python .ci/build.py --platform nvidia` | +| Iluvatar | `ixsmi` | `corex:qs_pj20250825` (import in advance) | `python .ci/build.py --platform iluvatar` | +| MetaX | `mx-smi` | `maca-pytorch:3.2.1.4-...` (import in advance) | `python .ci/build.py --platform metax` | +| Moore | `mthreads-gmi` | `vllm_musa:20251112_hygon` (import in advance) | `python .ci/build.py --platform moore` | + +### Start Agent services + +On each machine (platform is auto-detected): + +```bash +python .ci/agent.py serve --port 8080 +``` + +### Configure remote agent URLs + +On the trigger machine, add the `agents` section to `config.yaml` (see [Remote agent configuration](#remote-agent-configuration) above for the format). + +### Trigger cross-platform tests + +```bash +# Run all platform jobs at once (using config default branch) +python .ci/agent.py run + +# Preview mode (no actual execution) +python .ci/agent.py run --dry-run + +# Run only a specific platform +python .ci/agent.py run --platform nvidia +``` + +### Optional configuration + +#### GitHub Status reporting + +Set the env var on all machines so each reports its own platform's test status: + +```bash +export GITHUB_TOKEN=ghp_xxxxxxxxxxxx +``` + +#### API Token authentication + +When agents are exposed on untrusted networks, enable token auth: + +```bash +python .ci/agent.py serve --port 8080 --api-token +# Or: export AGENT_API_TOKEN= +``` + +#### GitHub Webhook auto-trigger + +In GitHub repo → Settings → Webhooks, add a webhook for each machine: + +| Field | Value | +|---|---| +| Payload URL | `http://:8080/webhook` | +| Content type | `application/json` | +| Secret | Must match `--webhook-secret` | +| Events | `push` and `pull_request` | + +```bash +python .ci/agent.py serve --port 8080 --webhook-secret +# Or: export WEBHOOK_SECRET= +``` + +### Verification checklist + +```bash +# 1. Dry-run each machine individually +for platform in nvidia iluvatar metax moore; do + python .ci/agent.py run --platform $platform --dry-run +done + +# 2. Health and resource checks +for ip in ; do + curl http://$ip:8080/health + curl http://$ip:8080/status +done + +# 3. Cross-platform test +python .ci/agent.py run --branch master +``` diff --git a/.ci/agent.py b/.ci/agent.py new file mode 100644 index 0000000..3fb5d9e --- /dev/null +++ b/.ci/agent.py @@ -0,0 +1,988 @@ +#!/usr/bin/env python3 +"""CI Runner Agent: webhook server, resource-aware scheduler, GitHub status reporting. + +Usage: + # Run jobs locally (or dispatch to remote agents) + python .ci/agent.py run + python .ci/agent.py run --branch master --job nvidia_gpu --dry-run + + # Start webhook server (auto-detects platform) + python .ci/agent.py serve --port 8080 +""" + +import argparse +import collections +import hashlib +import hmac +import json +import os +import shlex +import subprocess +import sys +import threading +import time +import urllib.error +import urllib.request +import uuid +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path + +import ci_resource as res +import github_status as gh +import run + +# Maximum POST body size (1 MB) to prevent memory exhaustion +MAX_CONTENT_LENGTH = 1 * 1024 * 1024 + +# Job states +STATE_QUEUED = "queued" +STATE_RUNNING = "running" +STATE_PENDING = "pending" +STATE_SUCCESS = "success" +STATE_FAILURE = "failure" +STATE_ERROR = "error" + +TAIL_LINES = 50 + +# urllib helpers (module-level for easier mocking in tests) +urllib_request = urllib.request.Request +urllib_urlopen = urllib.request.urlopen + + +# --------------------------------------------------------------------------- +# Data classes +# --------------------------------------------------------------------------- + + +class JobRequest: + """Describes a CI job to be executed.""" + + def __init__( + self, job_name, branch, commit_sha, config, image_tag=None, results_dir=None + ): + self.job_id = str(uuid.uuid4())[:8] + self.job_name = job_name + self.branch = branch + self.commit_sha = commit_sha + self.config = config + self.image_tag = image_tag + self.results_dir = results_dir or Path("ci-results") + self.created_at = datetime.now().isoformat() + + job = config["jobs"][job_name] + self.platform = job.get("platform", "nvidia") + + def to_dict(self): + return { + "job_id": self.job_id, + "job_name": self.job_name, + "branch": self.branch, + "commit_sha": self.commit_sha, + "platform": self.platform, + "created_at": self.created_at, + } + + +class JobResult: + """Outcome of a completed job.""" + + def __init__( + self, + job_id, + job_name, + commit_sha, + returncode, + results_dir, + duration, + error_tail=None, + ): + self.job_id = job_id + self.job_name = job_name + self.commit_sha = commit_sha + self.returncode = returncode + self.results_dir = results_dir + self.duration = duration + self.error_tail = error_tail or [] + + self.state = STATE_SUCCESS if returncode == 0 else STATE_FAILURE + + def to_dict(self): + d = { + "job_id": self.job_id, + "job_name": self.job_name, + "commit_sha": self.commit_sha, + "state": self.state, + "returncode": self.returncode, + "results_dir": str(self.results_dir), + "duration_seconds": round(self.duration, 1), + } + + if self.error_tail: + d["error_tail"] = self.error_tail + + return d + + +# --------------------------------------------------------------------------- +# Job selection and routing +# --------------------------------------------------------------------------- + + +def select_jobs(config, platform=None, job_name=None): + """Return list of job names to run.""" + jobs = config.get("jobs", {}) + + if job_name: + if job_name not in jobs: + raise ValueError(f"job {job_name!r} not in config") + + return [job_name] + + if platform: + return [name for name, job in jobs.items() if job.get("platform") == platform] + + return list(jobs.keys()) + + +# --------------------------------------------------------------------------- +# Scheduler +# --------------------------------------------------------------------------- + + +class Scheduler: + """Resource-aware job scheduler with dynamic parallelism.""" + + def __init__( + self, + config, + platform, + resource_pool, + results_dir=None, + max_workers=4, + no_status=False, + dry_run=False, + ): + self._config = config + self._platform = platform + self._resource_pool = resource_pool + self._results_dir = results_dir or Path("ci-results") + self._no_status = no_status + self._dry_run = dry_run + self._queue = collections.deque() + self._jobs: dict[str, dict] = {} # job_id -> {request, result, state, gpu_ids} + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._lock = threading.Lock() + self._done_event = threading.Event() + + # GitHub config + github_cfg = config.get("github", {}) + self._status_prefix = github_cfg.get("status_context_prefix", "ci/infiniops") + repo = config.get("repo", {}) + repo_url = repo.get("url", "") + self._owner, self._repo = gh.parse_repo_url(repo_url) + + def submit(self, job_request): + """Add a job to the queue and attempt to schedule it. + + Returns the job_id. + """ + with self._lock: + self._jobs[job_request.job_id] = { + "request": job_request, + "result": None, + "state": STATE_QUEUED, + "gpu_ids": [], + } + self._queue.append(job_request) + + self._try_schedule() + return job_request.job_id + + def get_job(self, job_id): + """Get job info by ID.""" + with self._lock: + entry = self._jobs.get(job_id) + + if not entry: + return None + + info = entry["request"].to_dict() + info["state"] = entry["state"] + + if entry["result"]: + info.update(entry["result"].to_dict()) + + return info + + def get_status(self): + """Return scheduler status for the /status endpoint.""" + with self._lock: + queued = [self._jobs[r.job_id]["request"].to_dict() for r in self._queue] + running = [] + completed = [] + + for entry in self._jobs.values(): + state = entry["state"] + + if state == STATE_RUNNING: + running.append( + {**entry["request"].to_dict(), "gpu_ids": entry["gpu_ids"]} + ) + elif state in (STATE_SUCCESS, STATE_FAILURE): + completed.append(entry["result"].to_dict()) + + return { + "queued": queued, + "running": running, + "completed": completed[-20:], # Last 20 + "resources": self._resource_pool.get_status(), + } + + def wait_all(self): + """Block until all submitted jobs are done. Returns list of JobResult.""" + while True: + with self._lock: + pending = any( + e["state"] in (STATE_QUEUED, STATE_RUNNING) + for e in self._jobs.values() + ) + + if not pending: + break + + self._done_event.wait(timeout=2.0) + self._done_event.clear() + + with self._lock: + return [e["result"] for e in self._jobs.values() if e["result"] is not None] + + def _try_schedule(self): + """Try to run queued jobs that have enough resources. + + Resource allocation and job submission are split: allocation decisions + are made under the lock, but executor.submit() happens outside to + prevent deadlock when the thread pool is saturated. + """ + to_launch = [] # [(req, gpu_ids), ...] + + with self._lock: + remaining = collections.deque() + + while self._queue: + req = self._queue.popleft() + job_cfg = self._config["jobs"].get(req.job_name, {}) + gpu_count = res.parse_gpu_requirement(job_cfg) + memory_mb = res.parse_memory_requirement(job_cfg) + + if self._dry_run: + # In dry-run mode, skip resource checks + gpu_ids, ok = [], True + else: + gpu_ids, ok = self._resource_pool.allocate(gpu_count, memory_mb) + + if ok: + self._jobs[req.job_id]["state"] = STATE_RUNNING + self._jobs[req.job_id]["gpu_ids"] = gpu_ids + to_launch.append((req, gpu_ids)) + else: + remaining.append(req) + + self._queue = remaining + + # Submit outside the lock to avoid deadlock with ThreadPoolExecutor + for req, gpu_ids in to_launch: + self._executor.submit(self._run_job, req, gpu_ids) + + def _run_job(self, req, gpu_ids): + """Execute a single job in a worker thread. + + Wrapped in try/finally to guarantee GPU resources are always released + and job state is updated even on unexpected exceptions. + """ + context = gh.build_status_context(self._status_prefix, req.job_name) + result = None + + try: + # Post pending status + if not self._no_status: + gh.post_commit_status( + self._owner, + self._repo, + req.commit_sha, + STATE_PENDING, + context, + f"Running {req.job_name}...", + ) + + job_cfg = self._config["jobs"][req.job_name] + all_stages = job_cfg.get("stages", []) + repo_url = self._config.get("repo", {}).get("url", "") + commit_short = ( + req.commit_sha[:7] if len(req.commit_sha) > 7 else req.commit_sha + ) + results_dir = run.build_results_dir( + req.results_dir, req.platform, all_stages, commit_short + ) + + gpu_id_str = ",".join(str(g) for g in gpu_ids) if gpu_ids else None + docker_args = run.build_docker_args( + self._config, + req.job_name, + repo_url, + req.branch, + all_stages, + "/workspace", + req.image_tag, + gpu_id_override=gpu_id_str, + results_dir=results_dir, + ) + + start = time.monotonic() + + if self._dry_run: + print(f"[dry-run] {req.job_name}: {shlex.join(docker_args)}") + returncode = 0 + error_tail = [] + else: + results_dir.mkdir(parents=True, exist_ok=True) + proc = subprocess.Popen( + docker_args, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + tail_buf = collections.deque(maxlen=TAIL_LINES) + + for line in proc.stdout: + sys.stdout.buffer.write(line) + tail_buf.append(line) + + proc.stdout.close() + returncode = proc.wait() + + if returncode != 0: + error_tail = [ + raw.decode("utf-8", errors="replace").rstrip("\n") + for raw in tail_buf + ] + else: + error_tail = [] + + duration = time.monotonic() - start + + result = JobResult( + job_id=req.job_id, + job_name=req.job_name, + commit_sha=req.commit_sha, + returncode=returncode, + results_dir=results_dir, + duration=duration, + error_tail=error_tail, + ) + + # Post final status + if not self._no_status: + gh.post_commit_status( + self._owner, + self._repo, + req.commit_sha, + result.state, + context, + f"{req.job_name}: {result.state} in {duration:.0f}s", + ) + except Exception as e: + print( + f"error: job {req.job_name} failed with exception: {e}", file=sys.stderr + ) + + if result is None: + result = JobResult( + job_id=req.job_id, + job_name=req.job_name, + commit_sha=req.commit_sha, + returncode=-1, + results_dir=req.results_dir, + duration=0, + error_tail=[str(e)], + ) + + if not self._no_status: + gh.post_commit_status( + self._owner, + self._repo, + req.commit_sha, + STATE_ERROR, + context, + f"{req.job_name}: internal error", + ) + finally: + # Always release resources and update state + self._resource_pool.release(gpu_ids) + + with self._lock: + self._jobs[req.job_id]["result"] = result + self._jobs[req.job_id]["state"] = ( + result.state if result else STATE_FAILURE + ) + + self._done_event.set() + self._try_schedule() + + return result + + +# --------------------------------------------------------------------------- +# Webhook server +# --------------------------------------------------------------------------- + + +def verify_signature(secret, body, signature_header): + """Verify GitHub webhook HMAC-SHA256 signature.""" + if not signature_header: + return False + + expected = ( + "sha256=" + hmac.new(secret.encode("utf-8"), body, hashlib.sha256).hexdigest() + ) + return hmac.compare_digest(expected, signature_header) + + +def _verify_api_token(handler): + """Check Bearer token for /api/run authentication. + + Returns True if authenticated, False (and sends 401) if not. + When no api_token is configured on the server, all requests are allowed. + """ + api_token = getattr(handler.server, "api_token", None) + + if not api_token: + return True + + auth_header = handler.headers.get("Authorization", "") + + if auth_header == f"Bearer {api_token}": + return True + + handler._respond_json(401, {"error": "unauthorized"}) + return False + + +class WebhookHandler(BaseHTTPRequestHandler): + """HTTP handler for GitHub webhooks and API endpoints.""" + + def log_message(self, format, *args): + print(f"[agent] {args[0]}", file=sys.stderr) + + def do_GET(self): + if self.path == "/health": + self._respond_json(200, {"status": "ok", "platform": self.server.platform}) + elif self.path == "/status": + status = self.server.scheduler.get_status() + self._respond_json(200, status) + elif self.path.startswith("/api/job/"): + self._handle_api_job() + else: + self._respond_json(404, {"error": "not found"}) + + def do_POST(self): + content_length = int(self.headers.get("Content-Length", 0)) + + if content_length > MAX_CONTENT_LENGTH: + self._respond_json(413, {"error": "payload too large"}) + return + + body = self.rfile.read(content_length) + + if self.path == "/webhook": + self._handle_webhook(body) + elif self.path == "/api/run": + self._handle_api_run(body) + else: + self._respond_json(404, {"error": "not found"}) + + def _handle_webhook(self, body): + # Verify signature if secret is configured + if self.server.webhook_secret: + sig = self.headers.get("X-Hub-Signature-256", "") + + if not verify_signature(self.server.webhook_secret, body, sig): + self._respond_json(401, {"error": "invalid signature"}) + return + + event_type = self.headers.get("X-GitHub-Event", "") + + if event_type == "ping": + self._respond_json(200, {"msg": "pong"}) + return + + try: + payload = json.loads(body) + except json.JSONDecodeError: + self._respond_json(400, {"error": "invalid JSON"}) + return + + if event_type == "push": + branch, sha = self._parse_push(payload) + elif event_type == "pull_request": + action = payload.get("action", "") + + if action not in ("opened", "synchronize"): + self._respond_json(200, {"msg": f"ignored PR action: {action}"}) + return + + branch, sha = self._parse_pull_request(payload) + else: + self._respond_json(200, {"msg": f"ignored event: {event_type}"}) + return + + if not branch or not sha: + self._respond_json(400, {"error": "could not extract branch/sha"}) + return + + job_ids = self._submit_jobs(branch, sha) + self._respond_json(200, {"accepted": True, "job_ids": job_ids}) + + def _handle_api_run(self, body): + """Handle /api/run: remote job trigger (requires Bearer token auth).""" + if not _verify_api_token(self): + return + + try: + payload = json.loads(body) + except json.JSONDecodeError: + self._respond_json(400, {"error": "invalid JSON"}) + return + + branch = payload.get("branch", "") + sha = payload.get("commit_sha", "") + job_name = payload.get("job") + image_tag = payload.get("image_tag") + + if not branch: + self._respond_json(400, {"error": "branch is required"}) + return + + if not sha: + sha = run.get_git_commit() + + job_ids = self._submit_jobs(branch, sha, job_name=job_name, image_tag=image_tag) + self._respond_json(200, {"accepted": True, "job_ids": job_ids}) + + def _handle_api_job(self): + """Handle GET /api/job/{id}.""" + parts = self.path.split("/") + + if len(parts) < 4: + self._respond_json(400, {"error": "missing job_id"}) + return + + job_id = parts[3] + info = self.server.scheduler.get_job(job_id) + + if info is None: + self._respond_json(404, {"error": f"job {job_id} not found"}) + else: + self._respond_json(200, info) + + def _parse_push(self, payload): + branch = payload.get("ref", "").removeprefix("refs/heads/") + sha = payload.get("after", "") + return branch, sha + + def _parse_pull_request(self, payload): + pr = payload.get("pull_request", {}) + head = pr.get("head", {}) + branch = head.get("ref", "") + sha = head.get("sha", "") + return branch, sha + + def _submit_jobs(self, branch, sha, job_name=None, image_tag=None): + config = self.server.config + job_names = select_jobs( + config, platform=self.server.platform, job_name=job_name + ) + job_ids = [] + + for name in job_names: + req = JobRequest( + job_name=name, + branch=branch, + commit_sha=sha, + config=config, + image_tag=image_tag, + results_dir=self.server.results_dir, + ) + jid = self.server.scheduler.submit(req) + job_ids.append(jid) + + return job_ids + + def _respond_json(self, status_code, data): + body = json.dumps(data, indent=2).encode("utf-8") + self.send_response(status_code) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + +class AgentServer(HTTPServer): + """HTTP server with scheduler and config context.""" + + def __init__( + self, + host, + port, + config, + scheduler, + platform, + webhook_secret=None, + api_token=None, + results_dir=None, + ): + super().__init__((host, port), WebhookHandler) + self.config = config + self.scheduler = scheduler + self.platform = platform + self.webhook_secret = webhook_secret + self.api_token = api_token + self.results_dir = results_dir or Path("ci-results") + + +# --------------------------------------------------------------------------- +# Remote job dispatch (for CLI triggering remote agents) +# --------------------------------------------------------------------------- + + +def dispatch_remote_job( + agent_url, job_name, branch, commit_sha, image_tag=None, api_token=None +): + """Send a job to a remote agent via HTTP API. Returns job_id or None.""" + url = f"{agent_url.rstrip('/')}/api/run" + body = { + "branch": branch, + "commit_sha": commit_sha, + "job": job_name, + } + + if image_tag: + body["image_tag"] = image_tag + + data = json.dumps(body).encode("utf-8") + headers = {"Content-Type": "application/json"} + + if api_token: + headers["Authorization"] = f"Bearer {api_token}" + + req = urllib_request(url, data=data, headers=headers, method="POST") + + try: + with urllib_urlopen(req, timeout=30) as resp: + result = json.loads(resp.read()) + job_ids = result.get("job_ids", []) + return job_ids[0] if job_ids else None + except Exception as e: + print(f"error: failed to dispatch to {agent_url}: {e}", file=sys.stderr) + return None + + +def poll_remote_job(agent_url, job_id, interval=5.0, timeout=7200): + """Poll a remote agent for job completion. Returns final state dict or None.""" + url = f"{agent_url.rstrip('/')}/api/job/{job_id}" + deadline = time.monotonic() + timeout + + while time.monotonic() < deadline: + try: + req = urllib_request(url) + + with urllib_urlopen(req, timeout=10) as resp: + info = json.loads(resp.read()) + + state = info.get("state", "") + + if state in (STATE_SUCCESS, STATE_FAILURE): + return info + except Exception: + pass + + time.sleep(interval) + + return None + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def cmd_run(args): + """Handle 'run' subcommand: dispatch jobs to platform agents via HTTP.""" + config = run.load_config(args.config) + agents = config.get("agents", {}) + branch = args.branch or config.get("repo", {}).get("branch", "master") + commit_sha = args.commit or run.get_git_commit(short=False) + + # Determine which jobs to run + try: + job_names = select_jobs(config, platform=args.platform, job_name=args.job) + except ValueError as e: + print(f"error: {e}", file=sys.stderr) + sys.exit(1) + + if not job_names: + print("error: no matching jobs found", file=sys.stderr) + sys.exit(1) + + # Resolve agent URL for each job + jobs_to_dispatch = [] # [(name, agent_url)] + + for name in job_names: + job = config.get("jobs", {}).get(name, {}) + platform = job.get("platform", "") + agent_url = agents.get(platform, {}).get("url", "") + + if not agent_url: + print( + f"error: no agent URL configured for platform {platform!r} (job {name})", + file=sys.stderr, + ) + sys.exit(1) + + jobs_to_dispatch.append((name, agent_url)) + + api_token = os.environ.get("AGENT_API_TOKEN", "") + results = [] + + if args.dry_run: + for name, agent_url in jobs_to_dispatch: + platform, _, job = name.partition("_") + print(f"[dry-run] dispatch {platform} {job} job to {agent_url}") + else: + # Dispatch all jobs, then poll concurrently. + dispatched = [] # [(name, agent_url, job_id)] + + for name, agent_url in jobs_to_dispatch: + platform, _, job = name.partition("_") + print( + f"==> dispatching {platform} {job} job to {agent_url}", + file=sys.stderr, + ) + job_id = dispatch_remote_job( + agent_url, + name, + branch, + commit_sha, + args.image_tag, + api_token=api_token or None, + ) + + if job_id: + print(f" job_id: {job_id}", file=sys.stderr) + dispatched.append((name, agent_url, job_id)) + else: + print(f" failed to dispatch {name}", file=sys.stderr) + results.append({"job_name": name, "state": "error"}) + + if dispatched: + with ThreadPoolExecutor(max_workers=len(dispatched)) as executor: + futures = { + executor.submit(poll_remote_job, url, jid): (name, url, jid) + for name, url, jid in dispatched + } + + # Collect name lengths for column alignment. + name_width = max(len(n) for n, _, _ in dispatched) + + for future in as_completed(futures): + name, _, _ = futures[future] + result = future.result() + + if result: + state = result.get("state", "unknown") + duration = result.get("duration_seconds", 0) + tag = "PASS" if state == STATE_SUCCESS else "FAIL" + print( + f"<== {tag} {name:<{name_width}} ({duration:.0f}s)", + file=sys.stderr, + ) + + error_tail = result.get("error_tail", []) + + if error_tail: + print( + f"--- error output (last {len(error_tail)} lines) ---", + file=sys.stderr, + ) + + for line in error_tail: + print(f" {line}", file=sys.stderr) + + print("---", file=sys.stderr) + + results.append(result) + else: + print( + f"<== TIMEOUT {name:<{name_width}}", + file=sys.stderr, + ) + results.append({"job_name": name, "state": "timeout"}) + + # Summary: only print when there are failures. + failed = [r for r in results if r.get("state") != STATE_SUCCESS] + + if failed: + print("\n========== Failed ==========", file=sys.stderr) + name_width = max(len(r.get("job_name", "?")) for r in failed) + + for r in failed: + name = r.get("job_name", "?") + state = r.get("state", "unknown") + duration = r.get("duration_seconds", 0) + print( + f" FAIL {name:<{name_width}} {state} ({duration:.0f}s)", + file=sys.stderr, + ) + + sys.exit(1) + + +def cmd_serve(args): + """Handle 'serve' subcommand: start webhook server.""" + config = run.load_config(args.config) + + platform = res.detect_platform() + + if not platform: + print( + "error: could not detect platform (no nvidia-smi or ixsmi found)", + file=sys.stderr, + ) + sys.exit(1) + + platform_jobs = select_jobs(config, platform=platform) + + if not platform_jobs: + print( + f"error: platform {platform!r} detected but no jobs defined in config", + file=sys.stderr, + ) + sys.exit(1) + + pool = res.ResourcePool( + platform, + utilization_threshold=args.utilization_threshold, + ) + scheduler = Scheduler( + config, + platform, + pool, + results_dir=args.results_dir, + ) + + webhook_secret = args.webhook_secret or os.environ.get("WEBHOOK_SECRET", "") + api_token = args.api_token or os.environ.get("AGENT_API_TOKEN", "") + + if not webhook_secret: + print( + "WARNING: No webhook secret configured. Webhook endpoint accepts " + "unsigned requests. Set --webhook-secret or WEBHOOK_SECRET for production.", + file=sys.stderr, + ) + + if not api_token: + print( + "WARNING: No API token configured. /api/run endpoint is unauthenticated. " + "Set --api-token or AGENT_API_TOKEN for production.", + file=sys.stderr, + ) + + server = AgentServer( + args.host, + args.port, + config, + scheduler, + platform, + webhook_secret=webhook_secret or None, + api_token=api_token or None, + results_dir=args.results_dir, + ) + + print( + f"Agent serving on {args.host}:{args.port} (platform={platform})", + file=sys.stderr, + ) + print(" POST /webhook — GitHub webhook", file=sys.stderr) + print(" POST /api/run — remote job trigger", file=sys.stderr) + print(" GET /health — health check", file=sys.stderr) + print(" GET /status — queue & resource status", file=sys.stderr) + print(" GET /api/job/{id} — job status", file=sys.stderr) + + try: + server.serve_forever() + except KeyboardInterrupt: + print("\nShutting down...", file=sys.stderr) + server.shutdown() + + +def main(): + parser = argparse.ArgumentParser( + description="CI Runner Agent: run jobs locally, dispatch remotely, or serve webhooks", + ) + subparsers = parser.add_subparsers(dest="command") + + # --- run subcommand --- + run_parser = subparsers.add_parser("run", help="Run CI jobs") + run_parser.add_argument( + "--config", + type=Path, + default=Path(__file__).resolve().parent / "config.yaml", + ) + run_parser.add_argument( + "--branch", type=str, help="Branch to test (default: config repo.branch)" + ) + run_parser.add_argument("--job", type=str, help="Specific job name") + run_parser.add_argument("--platform", type=str, help="Filter jobs by platform") + run_parser.add_argument("--image-tag", type=str, help="Override image tag") + run_parser.add_argument("--commit", type=str, help="Override commit SHA") + run_parser.add_argument("--dry-run", action="store_true") + + # --- serve subcommand --- + serve_parser = subparsers.add_parser("serve", help="Start webhook server") + serve_parser.add_argument( + "--config", + type=Path, + default=Path(__file__).resolve().parent / "config.yaml", + ) + serve_parser.add_argument("--port", type=int, default=8080) + serve_parser.add_argument("--host", type=str, default="0.0.0.0") + serve_parser.add_argument("--webhook-secret", type=str) + serve_parser.add_argument( + "--api-token", + type=str, + help="Bearer token for /api/run authentication (or AGENT_API_TOKEN env var)", + ) + serve_parser.add_argument( + "--results-dir", + type=Path, + default=Path("ci-results"), + ) + serve_parser.add_argument( + "--utilization-threshold", + type=int, + default=10, + ) + + args = parser.parse_args() + + if args.command == "run": + cmd_run(args) + elif args.command == "serve": + cmd_serve(args) + else: + parser.print_help() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/.ci/build.py b/.ci/build.py new file mode 100644 index 0000000..7953209 --- /dev/null +++ b/.ci/build.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +"""CI image builder: detect changes, build, tag, and optionally push Docker images.""" + +import argparse +import json +import os +import shlex +import subprocess +import sys +from pathlib import Path + +from utils import get_git_commit, load_config + + +def has_dockerfile_changed(dockerfile_dir, base_ref="HEAD~1"): + """Check if any file under `dockerfile_dir` changed since `base_ref`.""" + result = subprocess.run( + ["git", "diff", "--name-only", base_ref, "--", dockerfile_dir], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + print( + "warning: git diff failed (shallow clone or initial commit?);" + " assuming Dockerfile changed", + file=sys.stderr, + ) + return True + + return bool(result.stdout.strip()) + + +def docker_login(registry_cfg, dry_run): + """Log in to the registry using `credentials_env` token. + + Returns True on success. + + NOTE: Registry support is currently unused (`config.yaml` has no registry + section). Retained for future integration with an external image management + system. + """ + credentials_env = registry_cfg.get("credentials_env") + registry_url = registry_cfg.get("url", "") + + if not credentials_env or not registry_url: + return True + + token = os.environ.get(credentials_env) + + if not token: + print( + f"error: {credentials_env} not set, cannot login", + file=sys.stderr, + ) + return False + + if dry_run: + print( + f"[dry-run] echo | docker login {registry_url}" + " --username token --password-stdin" + ) + return True + + result = subprocess.run( + ["docker", "login", registry_url, "--username", "token", "--password-stdin"], + input=token, + text=True, + ) + + if result.returncode != 0: + print("error: docker login failed", file=sys.stderr) + return False + + return True + + +def build_image_tag(registry_url, project, platform, tag): + if registry_url: + return f"{registry_url}/{project}/{platform}:{tag}" + + return f"{project}-ci/{platform}:{tag}" + + +def build_image(platform, platform_cfg, registry_cfg, commit, push, dry_run, logged_in): + """Build a single platform image. Returns True on success.""" + registry_url = registry_cfg.get("url", "") + project = registry_cfg.get("project", "infiniops") + dockerfile_dir = platform_cfg["dockerfile"] + commit_tag = build_image_tag(registry_url, project, platform, commit) + latest_tag = build_image_tag(registry_url, project, platform, "latest") + + build_args_cfg = platform_cfg.get("build_args", {}) + build_cmd = ["docker", "build", "--network", "host"] + + for key, value in build_args_cfg.items(): + build_cmd.extend(["--build-arg", f"{key}={value}"]) + + for proxy_var in ("HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"): + proxy_val = os.environ.get(proxy_var) or os.environ.get(proxy_var.lower()) + + if proxy_val: + build_cmd.extend(["--build-arg", f"{proxy_var}={proxy_val}"]) + build_cmd.extend(["--build-arg", f"{proxy_var.lower()}={proxy_val}"]) + + private_sdk = platform_cfg.get("private_sdk", {}) + + if private_sdk: + source_env = private_sdk.get("source_env", "") + sdk_url = os.environ.get(source_env, "") if source_env else "" + + if sdk_url: + build_cmd.extend(["--build-arg", f"PRIVATE_SDK_URL={sdk_url}"]) + + build_cmd.extend(["-t", commit_tag, "-t", latest_tag, dockerfile_dir]) + + if dry_run: + print(f"[dry-run] {shlex.join(build_cmd)}") + + if push: + if not logged_in: + print("[dry-run] (skipping push: docker login failed)") + else: + print(f"[dry-run] docker push {commit_tag}") + print(f"[dry-run] docker push {latest_tag}") + + return True + + print(f"==> building {platform}: {commit_tag}", file=sys.stderr) + result = subprocess.run(build_cmd) + + if result.returncode != 0: + error = { + "stage": "build", + "platform": platform, + "tag": commit_tag, + "exit_code": result.returncode, + } + print(json.dumps(error), file=sys.stderr) + + return False + + if push: + if not logged_in: + print("error: docker login failed, cannot push", file=sys.stderr) + return False + + for tag in (commit_tag, latest_tag): + print(f"==> pushing {tag}", file=sys.stderr) + push_result = subprocess.run(["docker", "push", tag]) + + if push_result.returncode != 0: + error = { + "stage": "push", + "platform": platform, + "tag": tag, + "exit_code": push_result.returncode, + } + print(json.dumps(error), file=sys.stderr) + + return False + + return True + + +def main(): + parser = argparse.ArgumentParser(description="Build CI Docker images") + parser.add_argument( + "--platform", + type=str, + default="all", + help="Platform to build: nvidia, ascend, or all (default: all)", + ) + parser.add_argument( + "--config", + type=Path, + default=Path(__file__).resolve().parent / "config.yaml", + help="Path to config.yaml", + ) + parser.add_argument( + "--commit", + type=str, + default="HEAD", + help="Git ref for tagging the image (default: HEAD)", + ) + parser.add_argument( + "--push", + action="store_true", + help="Push images to registry after building (requires registry in config)", + ) + parser.add_argument( + "--force", + action="store_true", + help="Skip change detection and force build", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print commands without executing", + ) + args = parser.parse_args() + + config = load_config(args.config) + registry_cfg = config.get("registry", {}) + images_cfg = config.get("images", {}) + + if not images_cfg: + print("error: no `images` section in config", file=sys.stderr) + sys.exit(1) + + if args.platform == "all": + platforms = list(images_cfg.keys()) + else: + if args.platform not in images_cfg: + print( + f"error: platform `{args.platform}` not found in config", + file=sys.stderr, + ) + sys.exit(1) + platforms = [args.platform] + + commit = get_git_commit(args.commit) + logged_in = docker_login(registry_cfg, args.dry_run) if args.push else True + failed = False + + for platform in platforms: + platform_cfg = images_cfg[platform] + dockerfile_dir = platform_cfg["dockerfile"] + + if not Path(dockerfile_dir).is_dir(): + print( + f"warning: dockerfile directory `{dockerfile_dir}` does not exist," + f" skipping {platform}", + file=sys.stderr, + ) + continue + + if not args.force and not has_dockerfile_changed(dockerfile_dir): + print(f"==> {platform}: no changes detected, skipping", file=sys.stderr) + continue + + ok = build_image( + platform, + platform_cfg, + registry_cfg, + commit, + args.push, + args.dry_run, + logged_in=logged_in, + ) + + if not ok: + failed = True + + if failed: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/.ci/ci_resource.py b/.ci/ci_resource.py new file mode 100644 index 0000000..51b181f --- /dev/null +++ b/.ci/ci_resource.py @@ -0,0 +1,478 @@ +#!/usr/bin/env python3 +"""Resource detection and allocation for CI Runner Agent.""" + +import json +import operator +import os +import re +import shutil +import subprocess +import threading +from dataclasses import dataclass + +# GPU passthrough styles +GPU_STYLE_NVIDIA = "nvidia" +GPU_STYLE_NONE = "none" +GPU_STYLE_MLU = "mlu" + + +@dataclass +class GpuInfo: + index: int + memory_used_mb: float + memory_total_mb: float + utilization_pct: float + + +@dataclass +class SystemResources: + total_memory_mb: float + available_memory_mb: float + cpu_count: int + + +class ResourcePool: + """Thread-safe GPU and system resource manager. + + Detects available GPUs via platform-specific tools (nvidia-smi, ixsmi, mx-smi, mthreads-gmi) + and tracks allocations to enable dynamic parallel scheduling. + """ + + GPU_QUERY_TOOLS = { + "nvidia": "nvidia-smi", + "iluvatar": "ixsmi", + "metax": "mx-smi", + "moore": "mthreads-gmi", + "cambricon": "cnmon", + } + + def __init__(self, platform, utilization_threshold=10): + self._platform = platform + self._utilization_threshold = utilization_threshold + self._allocated: set[int] = set() + self._lock = threading.Lock() + + @property + def platform(self): + return self._platform + + @property + def allocated(self): + with self._lock: + return set(self._allocated) + + def detect_gpus(self) -> list[GpuInfo]: + """Query GPU status via platform-specific CLI tool.""" + if self._platform == "metax": + return self._detect_gpus_metax() + + if self._platform == "moore": + return self._detect_gpus_moore() + + if self._platform == "cambricon": + return self._detect_gpus_cambricon() + + tool = self.GPU_QUERY_TOOLS.get(self._platform) + + if not tool: + return [] + + try: + result = subprocess.run( + [ + tool, + "--query-gpu=index,memory.used,memory.total,utilization.gpu", + "--format=csv,noheader,nounits", + ], + capture_output=True, + text=True, + timeout=10, + ) + except (FileNotFoundError, subprocess.TimeoutExpired): + return [] + + if result.returncode != 0: + return [] + + gpus = [] + + for line in result.stdout.strip().splitlines(): + parts = [p.strip() for p in line.split(",")] + + if len(parts) < 4: + continue + + try: + gpus.append( + GpuInfo( + index=int(parts[0]), + memory_used_mb=float(parts[1]), + memory_total_mb=float(parts[2]), + utilization_pct=float(parts[3]), + ) + ) + except (ValueError, IndexError): + continue + + return gpus + + def _detect_gpus_metax(self) -> list[GpuInfo]: + """Parse mx-smi output for MetaX GPUs. + + Runs --show-memory and --show-usage separately and merges results. + Output format example: + GPU#0 MXC550 0000:1a:00.0 + Memory + vis_vram total : 67108864 KB + vis_vram used : 879032 KB + Utilization + GPU : 0 % + """ + + def run_mxsmi(flag): + try: + r = subprocess.run( + ["mx-smi", flag], + capture_output=True, + text=True, + timeout=10, + ) + return r.stdout if r.returncode == 0 else "" + except (FileNotFoundError, subprocess.TimeoutExpired): + return "" + + mem_out = run_mxsmi("--show-memory") + util_out = run_mxsmi("--show-usage") + + # Parse memory: collect {index: (used_kb, total_kb)} + mem = {} + current = None + for line in mem_out.splitlines(): + m = re.match(r"GPU#(\d+)", line.strip()) + if m: + current = int(m.group(1)) + mem[current] = [0.0, 0.0] + continue + if current is None: + continue + m = re.search(r"vis_vram total\s*:\s*([\d.]+)\s*KB", line) + if m: + mem[current][1] = float(m.group(1)) / 1024 # KB -> MB + m = re.search(r"vis_vram used\s*:\s*([\d.]+)\s*KB", line) + if m: + mem[current][0] = float(m.group(1)) / 1024 # KB -> MB + + # Parse utilization: collect {index: utilization_pct} + util = {} + current = None + in_util = False + for line in util_out.splitlines(): + m = re.match(r"GPU#(\d+)", line.strip()) + if m: + current = int(m.group(1)) + in_util = False + continue + if current is None: + continue + if "Utilization" in line: + in_util = True + continue + if in_util: + m = re.match(r"\s*GPU\s*:\s*([\d.]+)\s*%", line) + if m: + util[current] = float(m.group(1)) + in_util = False + + gpus = [] + for idx in sorted(mem): + used_mb, total_mb = mem[idx] + gpus.append( + GpuInfo( + index=idx, + memory_used_mb=used_mb, + memory_total_mb=total_mb, + utilization_pct=util.get(idx, 0.0), + ) + ) + return gpus + + def _detect_gpus_moore(self) -> list[GpuInfo]: + """Parse mthreads-gmi JSON output for Moore Threads GPUs. + + Uses: mthreads-gmi -q --json + Expected JSON structure: + { + "Attached GPUs": { + "GPU 00000000:3B:00.0": { + "Minor Number": "0", + "Memory Usage": { + "Total": "24576 MiB", + "Used": "512 MiB" + }, + "Utilization": { + "Gpu": "5 %" + } + } + } + } + """ + + def extract_number(s): + m = re.search(r"([\d.]+)", str(s)) + return float(m.group(1)) if m else 0.0 + + try: + result = subprocess.run( + ["mthreads-gmi", "-q", "--json"], + capture_output=True, + text=True, + timeout=10, + ) + except (FileNotFoundError, subprocess.TimeoutExpired): + return [] + + if result.returncode != 0: + return [] + + try: + data = json.loads(result.stdout) + except json.JSONDecodeError: + return [] + + gpus = [] + attached = data.get("Attached GPUs", {}) + + for gpu_data in attached.values(): + try: + index = int(gpu_data.get("Minor Number", len(gpus))) + + mem = gpu_data.get("Memory Usage", {}) + total_mb = extract_number(mem.get("Total", "0 MiB")) + used_mb = extract_number(mem.get("Used", "0 MiB")) + util_pct = extract_number( + gpu_data.get("Utilization", {}).get("Gpu", "0 %") + ) + + gpus.append( + GpuInfo( + index=index, + memory_used_mb=used_mb, + memory_total_mb=total_mb, + utilization_pct=util_pct, + ) + ) + except (ValueError, AttributeError): + continue + + return sorted(gpus, key=operator.attrgetter("index")) + + def _detect_gpus_cambricon(self) -> list[GpuInfo]: + """Parse cnmon output for Cambricon MLU cards. + + Each card appears as two consecutive data rows: + Row 1: | {card} {vf} {name} {fw} | {bus_id} | {util}% {ecc} | + Row 2: | {fan}% {temp} {pwr} | {mem_used} MiB/ {mem_total} MiB | ... | + """ + try: + result = subprocess.run( + ["cnmon"], + capture_output=True, + text=True, + timeout=10, + ) + except (FileNotFoundError, subprocess.TimeoutExpired): + return [] + + if result.returncode != 0: + return [] + + gpus = [] + lines = result.stdout.splitlines() + i = 0 + + while i < len(lines): + line = lines[i] + # Row 1: "| {index} ... | {bus_id} | {util}% {ecc} |" + m1 = re.match(r"^\|\s+(\d+)\s+.*\|\s*([\d.]+)%", line) + + if m1 and i + 1 < len(lines): + try: + card_index = int(m1.group(1)) + util_pct = float(m1.group(2)) + row2 = lines[i + 1] + mem_m = re.search(r"([\d.]+)\s+MiB/\s*([\d.]+)\s+MiB", row2) + + if mem_m: + used_mb = float(mem_m.group(1)) + total_mb = float(mem_m.group(2)) + else: + used_mb, total_mb = 0.0, 0.0 + + gpus.append( + GpuInfo( + index=card_index, + memory_used_mb=used_mb, + memory_total_mb=total_mb, + utilization_pct=util_pct, + ) + ) + except (ValueError, AttributeError): + pass + i += 2 + continue + + i += 1 + + return sorted(gpus, key=operator.attrgetter("index")) + + def detect_system_resources(self) -> SystemResources: + """Read system memory from /proc/meminfo and CPU count.""" + total_mb = 0.0 + available_mb = 0.0 + + try: + with open("/proc/meminfo", encoding="utf-8") as f: + for line in f: + if line.startswith("MemTotal:"): + total_mb = float(line.split()[1]) / 1024 + elif line.startswith("MemAvailable:"): + available_mb = float(line.split()[1]) / 1024 + except OSError: + pass + + return SystemResources( + total_memory_mb=total_mb, + available_memory_mb=available_mb, + cpu_count=os.cpu_count() or 1, + ) + + def get_free_gpus(self) -> list[int]: + """Return GPU indices with utilization below threshold.""" + gpus = self.detect_gpus() + return [ + g.index for g in gpus if g.utilization_pct < self._utilization_threshold + ] + + def allocate(self, gpu_count, memory_mb=0) -> tuple[list[int], bool]: + """Try to allocate GPUs and check memory. + + Returns (allocated_gpu_ids, success). On failure returns ([], False). + GPU detection and memory checks run outside the lock to avoid blocking + other threads while subprocess.run (nvidia-smi) executes. + """ + if gpu_count <= 0: + if memory_mb > 0: + sys_res = self.detect_system_resources() + + if sys_res.available_memory_mb < memory_mb: + return ([], False) + + return ([], True) + + # Detect GPUs and memory outside the lock (subprocess.run can block) + free_gpus = set(self.get_free_gpus()) + sys_res = self.detect_system_resources() if memory_mb > 0 else None + + with self._lock: + available = free_gpus - self._allocated + + if len(available) < gpu_count: + return ([], False) + + if sys_res is not None and sys_res.available_memory_mb < memory_mb: + return ([], False) + + selected = sorted(available)[:gpu_count] + self._allocated.update(selected) + return (selected, True) + + def release(self, gpu_ids): + """Return GPUs to the free pool.""" + with self._lock: + self._allocated -= set(gpu_ids) + + def get_status(self) -> dict: + """Return current resource status for API endpoints.""" + gpus = self.detect_gpus() + sys_res = self.detect_system_resources() + + with self._lock: + allocated = sorted(self._allocated) + + return { + "platform": self._platform, + "gpus": [ + { + "index": g.index, + "memory_used_mb": g.memory_used_mb, + "memory_total_mb": g.memory_total_mb, + "utilization_pct": g.utilization_pct, + "allocated_by_agent": g.index in allocated, + } + for g in gpus + ], + "allocated_gpu_ids": allocated, + "system": { + "total_memory_mb": round(sys_res.total_memory_mb, 1), + "available_memory_mb": round(sys_res.available_memory_mb, 1), + "cpu_count": sys_res.cpu_count, + }, + "utilization_threshold": self._utilization_threshold, + } + + +def parse_gpu_requirement(job_config) -> int: + """Extract GPU count requirement from a job config.""" + resources = job_config.get("resources", {}) + gpu_style = resources.get("gpu_style", GPU_STYLE_NVIDIA) + + if gpu_style == GPU_STYLE_NONE: + return 0 + + ngpus = resources.get("ngpus") + if ngpus is not None: + return int(ngpus) + + gpu_ids = str(resources.get("gpu_ids", "")) + + if not gpu_ids: + return resources.get("gpu_count", 0) + + if gpu_ids == "all": + return 0 # "all" means use all available, don't reserve specific count + + return len(gpu_ids.split(",")) + + +def parse_memory_requirement(job_config) -> float: + """Extract memory requirement in MB from a job config.""" + resources = job_config.get("resources", {}) + memory = str(resources.get("memory", "")) + + if not memory: + return 0 + + memory = memory.lower().strip() + + if memory.endswith("gb"): + return float(memory[:-2]) * 1024 + elif memory.endswith("g"): + return float(memory[:-1]) * 1024 + elif memory.endswith("mb"): + return float(memory[:-2]) + elif memory.endswith("m"): + return float(memory[:-1]) + + try: + return float(memory) * 1024 # Default: GB + except ValueError: + return 0 + + +def detect_platform(): + """Auto-detect the current platform by probing GPU query tools on PATH.""" + for platform, tool in ResourcePool.GPU_QUERY_TOOLS.items(): + if shutil.which(tool): + return platform + + return None diff --git a/.ci/config.yaml b/.ci/config.yaml new file mode 100644 index 0000000..b70e7df --- /dev/null +++ b/.ci/config.yaml @@ -0,0 +1,146 @@ +repo: + url: https://github.com/InfiniTensor/InfiniOps.git + branch: master + +github: + status_context_prefix: "ci/infiniops" + +# Uncomment and replace the URLs below with actual host IPs to dispatch jobs to remote +# machines via `agent.py run`. Required on the trigger machine when each platform's +# agent runs on a separate host. See the README for multi-machine deployment details. +# agents: +# nvidia: +# url: http://nvidia-host:8080 +# iluvatar: +# url: http://iluvatar-host:8080 +# metax: +# url: http://metax-host:8080 +# moore: +# url: http://moore-host:8080 +# cambricon: +# url: http://cambricon-host:8080 + +platforms: + nvidia: + image: + dockerfile: .ci/images/nvidia/ + build_args: + BASE_IMAGE: nvcr.io/nvidia/pytorch:24.10-py3 + setup: pip install .[dev] --no-build-isolation + jobs: + gpu: + resources: + ngpus: 1 # Scheduler auto-picks this many free GPUs + memory: 32GB + shm_size: 16g # Prevent PyTorch default 64MB shared memory limit + timeout: 3600 + # env: # Uncomment to inject extra env vars into the container. + # MY_VAR: value + stages: + - name: test + run: pytest tests/ -n 8 -v --tb=short --junitxml=/workspace/results/test-results.xml + + iluvatar: + image: + dockerfile: .ci/images/iluvatar/ + build_args: + BASE_IMAGE: corex:qs_pj20250825 + APT_MIRROR: http://archive.ubuntu.com/ubuntu + PIP_INDEX_URL: https://pypi.org/simple + docker_args: + - "--privileged" + - "--cap-add=ALL" + - "--pid=host" + - "--ipc=host" + volumes: + - /dev:/dev + - /lib/firmware:/lib/firmware + - /usr/src:/usr/src + - /lib/modules:/lib/modules + setup: pip install .[dev] --no-build-isolation + jobs: + gpu: + resources: + gpu_ids: "0" # GPU visibility via CUDA_VISIBLE_DEVICES + gpu_style: none # CoreX: passthrough via --privileged + /dev mount + memory: 32GB + shm_size: 16g + timeout: 3600 + stages: + - name: test + run: pytest tests/ -n 8 -v --tb=short --junitxml=/workspace/results/test-results.xml + + metax: + image: + dockerfile: .ci/images/metax/ + build_args: + BASE_IMAGE: cr.metax-tech.com/public-library/maca-pytorch:3.2.1.4-torch2.4-py310-ubuntu22.04-amd64 + APT_MIRROR: http://archive.ubuntu.com/ubuntu + PIP_INDEX_URL: https://pypi.org/simple + docker_args: + - "--privileged" + - "--ulimit=memlock=-1" + - "--ulimit=stack=67108864" + setup: pip install .[dev] --no-build-isolation + jobs: + gpu: + resources: + gpu_ids: "0" + gpu_style: none # MetaX: passthrough via --privileged, no CUDA_VISIBLE_DEVICES + memory: 32GB + shm_size: 16g + timeout: 3600 + stages: + - name: test + run: pytest tests/ -n 4 -v --tb=short --junitxml=/workspace/results/test-results.xml + + moore: + image: + dockerfile: .ci/images/moore/ + build_args: + BASE_IMAGE: sh-harbor.mthreads.com/mcctest/vllm_musa:20251112_hygon + APT_MIRROR: http://archive.ubuntu.com/ubuntu + PIP_INDEX_URL: https://pypi.org/simple + docker_args: + - "--privileged" + setup: pip install .[dev] --no-build-isolation + jobs: + gpu: + resources: + gpu_ids: "0" + gpu_style: none # Moore: passthrough via --privileged, MTHREADS_VISIBLE_DEVICES set by base image + memory: 32GB + shm_size: 16g + timeout: 3600 + stages: + - name: test + run: pytest tests/test_add.py tests/test_gemm.py tests/test_swiglu.py -n 4 -v --tb=short --junitxml=/workspace/results/test-results.xml + + cambricon: + image: + dockerfile: .ci/images/cambricon/ + build_args: + BASE_IMAGE: cambricon/pytorch:v1.25.3-torch2.1-anolisos8.8-py310 + PIP_INDEX_URL: https://pypi.org/simple + docker_args: + - "--privileged" + setup: pip install .[dev] --no-build-isolation + jobs: + gpu: + resources: + gpu_ids: "0" + gpu_style: mlu # Cambricon: passthrough via --privileged, MLU_VISIBLE_DEVICES for device control + memory: 32GB + shm_size: 16g + timeout: 3600 + stages: + - name: test + run: pytest tests/test_gemm.py -n 4 -v --tb=short --junitxml=/workspace/results/test-results.xml + + ascend: # TODO: Ascend image is not ready yet + image: + dockerfile: .ci/images/ascend/ + build_args: + BASE_IMAGE: ascendhub.huawei.com/public-ascendhub/ascend-pytorch:24.0.0 + private_sdk: + source_env: PRIVATE_SDK_URL diff --git a/.ci/github_status.py b/.ci/github_status.py new file mode 100644 index 0000000..f8f017f --- /dev/null +++ b/.ci/github_status.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +"""GitHub Commit Status API wrapper using urllib (zero external dependencies).""" + +import json +import os +import re +import sys +import urllib.error +import urllib.request + + +def parse_repo_url(url): + """Extract (owner, repo) from a GitHub URL. + + Handles: + - https://github.com/Owner/Repo.git + - git@github.com:Owner/Repo.git + """ + # HTTPS format + m = re.match(r"https?://[^/]+/([^/]+)/([^/]+?)(?:\.git)?$", url) + + if m: + return m.group(1), m.group(2) + + # SSH format + m = re.match(r"git@[^:]+:([^/]+)/([^/]+?)(?:\.git)?$", url) + + if m: + return m.group(1), m.group(2) + + return "", "" + + +def build_status_context(prefix, job_name): + """Build status context string, e.g. 'ci/infiniops/nvidia_gpu'.""" + return f"{prefix}/{job_name}" + + +def post_commit_status( + owner, + repo, + sha, + state, + context, + description, + target_url=None, + token=None, +): + """Post a commit status to GitHub. + + Args: + state: One of 'pending', 'success', 'failure', 'error'. + Returns True on success, False on failure. + """ + token = token or os.environ.get("GITHUB_TOKEN", "") + + if not token: + print("warning: GITHUB_TOKEN not set, skipping status update", file=sys.stderr) + return False + + if not owner or not repo or not sha: + print( + "warning: missing owner/repo/sha, skipping status update", file=sys.stderr + ) + return False + + url = f"https://api.github.com/repos/{owner}/{repo}/statuses/{sha}" + body = { + "state": state, + "context": context, + "description": description[:140], + } + + if target_url: + body["target_url"] = target_url + + data = json.dumps(body).encode("utf-8") + req = urllib.request.Request( + url, + data=data, + headers={ + "Authorization": f"token {token}", + "Accept": "application/vnd.github.v3+json", + "Content-Type": "application/json", + }, + method="POST", + ) + + try: + with urllib.request.urlopen(req, timeout=30) as resp: + return 200 <= resp.status < 300 + except urllib.error.HTTPError as e: + print( + f"warning: GitHub status API returned {e.code}: {e.reason}", + file=sys.stderr, + ) + return False + except urllib.error.URLError as e: + print(f"warning: GitHub status API error: {e.reason}", file=sys.stderr) + return False diff --git a/.ci/images/ascend/Dockerfile b/.ci/images/ascend/Dockerfile new file mode 100644 index 0000000..66392eb --- /dev/null +++ b/.ci/images/ascend/Dockerfile @@ -0,0 +1,39 @@ +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +ENV DEBIAN_FRONTEND=noninteractive + +ARG HTTP_PROXY +ARG HTTPS_PROXY +ARG NO_PROXY +ARG http_proxy +ARG https_proxy +ARG no_proxy + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + git \ + cmake \ + ninja-build \ + coreutils \ + curl \ + libclang-dev \ + && rm -rf /var/lib/apt/lists/* + +ARG PRIVATE_SDK_URL +RUN if [ -n "$PRIVATE_SDK_URL" ]; then \ + curl -fSL "$PRIVATE_SDK_URL" -o /tmp/sdk.run && \ + chmod +x /tmp/sdk.run && /tmp/sdk.run --quiet && \ + rm /tmp/sdk.run; \ + fi + +RUN pip install --no-cache-dir \ + scikit-build-core \ + pybind11 \ + libclang \ + pytest \ + pytest-cov \ + pytest-xdist \ + pyyaml + +WORKDIR /workspace diff --git a/.ci/images/cambricon/Dockerfile b/.ci/images/cambricon/Dockerfile new file mode 100644 index 0000000..138f3cb --- /dev/null +++ b/.ci/images/cambricon/Dockerfile @@ -0,0 +1,33 @@ +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +# Python 3.10 executables (`pip`-installed tools) live under `/usr/local/python3.10/bin`. +ENV PATH=/usr/local/python3.10/bin:${PATH} + +ARG HTTP_PROXY +ARG HTTPS_PROXY +ARG NO_PROXY +ARG http_proxy +ARG https_proxy +ARG no_proxy + +# `git` and `cmake` are pre-installed; `coreutils-single` covers coreutils needs. +RUN dnf install -y ninja-build && dnf clean all + +ARG PIP_INDEX_URL +RUN pip install --no-cache-dir \ + ${PIP_INDEX_URL:+--index-url "$PIP_INDEX_URL"} \ + scikit-build-core \ + libclang \ + pytest \ + pytest-cov \ + pytest-xdist \ + ruff==0.15.7 + +# Pin pre-installed Cambricon `torch` to prevent `pip` from replacing it with upstream version. +RUN pip show torch >/dev/null 2>&1 && \ + echo "torch==$(pip show torch | grep '^Version:' | awk '{print $2}')" > /etc/pip-constraints.txt || \ + touch /etc/pip-constraints.txt +ENV PIP_CONSTRAINT=/etc/pip-constraints.txt + +WORKDIR /workspace diff --git a/.ci/images/iluvatar/Dockerfile b/.ci/images/iluvatar/Dockerfile new file mode 100644 index 0000000..79afc85 --- /dev/null +++ b/.ci/images/iluvatar/Dockerfile @@ -0,0 +1,53 @@ +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +ENV DEBIAN_FRONTEND=noninteractive + +# CoreX runtime environment (base image sets these in `/etc/bash.bashrc`, +# but `docker build` `RUN` uses `/bin/sh` which doesn't source it). +ENV PATH=/usr/local/corex/bin:/usr/local/corex-4.3.0/corex-toolbox-1.0.0/bin:/usr/local/corex/lib64/python3/dist-packages/bin:/usr/local/openmpi/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV PYTHONPATH=/usr/local/corex/lib64/python3/dist-packages +ENV LD_LIBRARY_PATH=/usr/local/corex/lib64:/usr/local/lib:/usr/local/openmpi/lib + +ARG HTTP_PROXY +ARG HTTPS_PROXY +ARG NO_PROXY +ARG http_proxy +ARG https_proxy +ARG no_proxy + +ARG APT_MIRROR +RUN if [ -n "$APT_MIRROR" ]; then \ + sed -i "s|http://[^/]*/ubuntu|${APT_MIRROR}|g" /etc/apt/sources.list; \ + fi && \ + apt-get update && \ + apt-get install -y --no-install-recommends \ + git \ + ninja-build \ + coreutils \ + && rm -rf /var/lib/apt/lists/* + +RUN ln -sf $(which python3) /usr/local/bin/python 2>/dev/null || true + +ARG PIP_INDEX_URL +RUN pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir \ + ${PIP_INDEX_URL:+--index-url "$PIP_INDEX_URL"} \ + scikit-build-core \ + pybind11 \ + libclang \ + pytest \ + pytest-cov \ + pytest-xdist \ + pyyaml \ + ruff==0.15.7 + +RUN pip config set global.index-url https://pypi.org/simple + +# Pin pre-installed CoreX `torch` to prevent `pip` from replacing it with upstream version. +RUN pip show torch >/dev/null 2>&1 && \ + echo "torch==$(pip show torch | grep '^Version:' | awk '{print $2}')" > /etc/pip-constraints.txt || \ + touch /etc/pip-constraints.txt +ENV PIP_CONSTRAINT=/etc/pip-constraints.txt + +WORKDIR /workspace diff --git a/.ci/images/metax/Dockerfile b/.ci/images/metax/Dockerfile new file mode 100644 index 0000000..540bc9d --- /dev/null +++ b/.ci/images/metax/Dockerfile @@ -0,0 +1,46 @@ +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +ENV DEBIAN_FRONTEND=noninteractive + +# `conda` Python is used in this image. +ENV PATH=/opt/conda/bin:${PATH} + +ARG HTTP_PROXY +ARG HTTPS_PROXY +ARG NO_PROXY +ARG http_proxy +ARG https_proxy +ARG no_proxy + +ARG APT_MIRROR +RUN if [ -n "$APT_MIRROR" ]; then \ + sed -i "s|http://[^/]*/ubuntu|${APT_MIRROR}|g" /etc/apt/sources.list; \ + fi && \ + apt-get update && \ + apt-get install -y --no-install-recommends \ + git \ + cmake \ + ninja-build \ + coreutils \ + libclang-dev \ + && rm -rf /var/lib/apt/lists/* + +ARG PIP_INDEX_URL +RUN pip install --no-cache-dir \ + ${PIP_INDEX_URL:+--index-url "$PIP_INDEX_URL"} \ + scikit-build-core \ + pybind11 \ + libclang \ + pytest-cov \ + pytest-xdist \ + pyyaml \ + ruff==0.15.7 + +# Pin pre-installed MetaX `torch` to prevent `pip` from replacing it with upstream version. +RUN pip show torch >/dev/null 2>&1 && \ + echo "torch==$(pip show torch | grep '^Version:' | awk '{print $2}')" > /etc/pip-constraints.txt || \ + touch /etc/pip-constraints.txt +ENV PIP_CONSTRAINT=/etc/pip-constraints.txt + +WORKDIR /workspace diff --git a/.ci/images/moore/Dockerfile b/.ci/images/moore/Dockerfile new file mode 100644 index 0000000..a95d9bd --- /dev/null +++ b/.ci/images/moore/Dockerfile @@ -0,0 +1,38 @@ +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +ENV DEBIAN_FRONTEND=noninteractive + +# `MUSA_HOME`, `PATH`, `LD_LIBRARY_PATH` already set by base image. + +ARG HTTP_PROXY +ARG HTTPS_PROXY +ARG NO_PROXY +ARG http_proxy +ARG https_proxy +ARG no_proxy + +ARG APT_MIRROR +RUN if [ -n "$APT_MIRROR" ]; then \ + sed -i "s|http://[^/]*/ubuntu|${APT_MIRROR}|g" /etc/apt/sources.list; \ + fi && \ + apt-get update && \ + apt-get install -y --no-install-recommends \ + ninja-build \ + libclang-dev \ + && rm -rf /var/lib/apt/lists/* + +ARG PIP_INDEX_URL +RUN pip install --no-cache-dir \ + ${PIP_INDEX_URL:+--index-url "$PIP_INDEX_URL"} \ + scikit-build-core \ + libclang \ + pytest-cov \ + pytest-xdist \ + ruff==0.15.7 + +# Pin pre-installed `torch` to prevent `pip` from replacing it with upstream version. +RUN echo "torch==$(pip show torch | grep '^Version:' | awk '{print $2}')" > /etc/pip-constraints.txt +ENV PIP_CONSTRAINT=/etc/pip-constraints.txt + +WORKDIR /workspace diff --git a/.ci/images/nvidia/Dockerfile b/.ci/images/nvidia/Dockerfile new file mode 100644 index 0000000..b4984da --- /dev/null +++ b/.ci/images/nvidia/Dockerfile @@ -0,0 +1,46 @@ +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +ENV DEBIAN_FRONTEND=noninteractive + +ARG HTTP_PROXY +ARG HTTPS_PROXY +ARG NO_PROXY +ARG http_proxy +ARG https_proxy +ARG no_proxy + +ARG APT_MIRROR +RUN if [ -n "$APT_MIRROR" ]; then \ + sed -i "s|http://[^/]*/ubuntu|${APT_MIRROR}|g" /etc/apt/sources.list; \ + fi && \ + apt-get update && \ + apt-get install -y --no-install-recommends \ + git \ + cmake \ + ninja-build \ + coreutils \ + libclang-dev \ + && rm -rf /var/lib/apt/lists/* + + +ARG PIP_INDEX_URL +RUN pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir \ + ${PIP_INDEX_URL:+--index-url "$PIP_INDEX_URL"} \ + scikit-build-core \ + pybind11 \ + libclang \ + pytest \ + pytest-cov \ + pytest-xdist \ + pyyaml \ + ruff==0.15.7 + +# Pin pre-installed `torch` to prevent `pip` from replacing it with a different version. +RUN pip show torch >/dev/null 2>&1 && \ + echo "torch==$(pip show torch | grep '^Version:' | awk '{print $2}')" > /etc/pip-constraints.txt || \ + touch /etc/pip-constraints.txt +ENV PIP_CONSTRAINT=/etc/pip-constraints.txt + +WORKDIR /workspace diff --git a/.ci/run.py b/.ci/run.py new file mode 100644 index 0000000..24a8867 --- /dev/null +++ b/.ci/run.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python3 +"""Standalone Docker CI runner: clone repo, setup, run stages. Output to stdout.""" + +import argparse +import os +import shlex +import subprocess +import sys +from datetime import datetime +from pathlib import Path + +from ci_resource import ( + GPU_STYLE_NVIDIA, + GPU_STYLE_NONE, + GPU_STYLE_MLU, + ResourcePool, + detect_platform, +) +from utils import get_git_commit, load_config + +# Flags that consume the next token as their value (e.g. -n 4, -k expr). +_PYTEST_VALUE_FLAGS = {"-n", "-k", "-m", "-p", "--tb", "--junitxml", "--rootdir"} + + +def apply_test_override(run_cmd, test_path): + """Replace positional test path(s) in a pytest stage command. + + For example: ``pytest tests/ -n 4 ...`` becomes + ``pytest tests/test_gemm.py -n 4 ...`` when ``test_path`` is + ``tests/test_gemm.py``. + """ + parts = shlex.split(run_cmd) + + if not parts or parts[0] != "pytest": + return run_cmd + + result = ["pytest", test_path] + skip_next = False + + for p in parts[1:]: + if skip_next: + result.append(p) + skip_next = False + continue + + if p.startswith("-"): + result.append(p) + if p in _PYTEST_VALUE_FLAGS: + skip_next = True + continue + + # Skip existing test paths; the override is already in result[1]. + if not ("/" in p or p.endswith(".py") or "::" in p): + result.append(p) + + return shlex.join(result) + + +def build_results_dir(base, platform, stages, commit): + """Build a results directory path: `{base}/{platform}_{stages}_{commit}_{timestamp}`.""" + stage_names = "+".join(s["name"] for s in stages) + timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + dirname = f"{platform}_{stage_names}_{commit}_{timestamp}" + + return Path(base) / dirname + + +def resolve_image(config, platform, image_tag): + """Resolve an image reference to a full image name. + + Accepts `stable`, `latest`, or a commit hash as `image_tag`. When config + contains a registry section, returns a registry-prefixed URL. Otherwise + returns a local tag (current default). + """ + registry = config.get("registry", {}) + registry_url = registry.get("url", "") + project = registry.get("project", "infiniops") + + if not registry_url: + return f"{project}-ci/{platform}:{image_tag}" + + return f"{registry_url}/{project}/{platform}:{image_tag}" + + +def build_runner_script(): + return r""" +set -e +cd /workspace +mkdir -p /workspace/results +if [ -n "$LOCAL_SRC" ]; then + cp -r "$LOCAL_SRC" /tmp/src + cd /tmp/src +else + git clone "$REPO_URL" repo + cd repo + git checkout "$BRANCH" +fi +echo "========== Setup ==========" +eval "$SETUP_CMD" +set +e +failed=0 +for i in $(seq 1 "$NUM_STAGES"); do + name_var="STAGE_${i}_NAME" + cmd_var="STAGE_${i}_CMD" + name="${!name_var}" + cmd="${!cmd_var}" + echo "========== Stage: $name ==========" + [ -n "$cmd" ] && { eval "$cmd" || failed=1; } +done +echo "========== Summary ==========" +if [ -n "$HOST_UID" ] && [ -n "$HOST_GID" ]; then + chown -R "$HOST_UID:$HOST_GID" /workspace/results 2>/dev/null || true +fi +exit $failed +""" + + +def build_docker_args( + config, + job_name, + repo_url, + branch, + stages, + workdir, + image_tag_override, + gpu_id_override=None, + results_dir=None, + local_path=None, +): + job = config["jobs"][job_name] + platform = job.get("platform", "nvidia") + image_tag = image_tag_override or job.get("image", "latest") + image = resolve_image(config, platform, image_tag) + resources = job.get("resources", {}) + setup_raw = job.get("setup", "pip install .[dev]") + + if isinstance(setup_raw, list): + setup_cmd = "\n".join(setup_raw) + else: + setup_cmd = setup_raw + + args = [ + "docker", + "run", + "--rm", + "--network", + "host", + "-i", + "-w", + workdir, + "-e", + f"REPO_URL={repo_url}", + "-e", + f"BRANCH={branch}", + "-e", + f"SETUP_CMD={setup_cmd}", + "-e", + f"NUM_STAGES={len(stages)}", + "-e", + f"HOST_UID={os.getuid()}", + "-e", + f"HOST_GID={os.getgid()}", + ] + + for proxy_var in ("HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"): + proxy_val = os.environ.get(proxy_var) or os.environ.get(proxy_var.lower()) + + if proxy_val: + args.extend(["-e", f"{proxy_var}={proxy_val}"]) + args.extend(["-e", f"{proxy_var.lower()}={proxy_val}"]) + + for key, value in job.get("env", {}).items(): + args.extend(["-e", f"{key}={value}"]) + + if results_dir: + args.extend(["-v", f"{results_dir.resolve()}:/workspace/results"]) + + if local_path: + args.extend(["-v", f"{local_path}:/workspace/repo:ro"]) + args.extend(["-e", "LOCAL_SRC=/workspace/repo"]) + + for i, s in enumerate(stages): + args.append("-e") + args.append(f"STAGE_{i + 1}_NAME={s['name']}") + args.append("-e") + args.append(f"STAGE_{i + 1}_CMD={s.get('run', '')}") + + # Platform-specific device access + for flag in job.get("docker_args", []): + args.append(flag) + + for vol in job.get("volumes", []): + args.extend(["-v", vol]) + + gpu_id = gpu_id_override or str(resources.get("gpu_ids", "")) + ngpus = resources.get("ngpus") + gpu_style = resources.get("gpu_style", GPU_STYLE_NVIDIA) + + if gpu_style == GPU_STYLE_NVIDIA: + if gpu_id: + if gpu_id == "all": + args.extend(["--gpus", "all"]) + else: + args.extend(["--gpus", f'"device={gpu_id}"']) + elif ngpus: + args.extend(["--gpus", f"count={ngpus}"]) + elif gpu_style == GPU_STYLE_NONE and gpu_id and gpu_id != "all": + # For platforms like Iluvatar/CoreX that use --privileged + /dev mount, + # control visible GPUs via CUDA_VISIBLE_DEVICES. + args.extend(["-e", f"CUDA_VISIBLE_DEVICES={gpu_id}"]) + elif gpu_style == GPU_STYLE_MLU and gpu_id and gpu_id != "all": + # For Cambricon MLU platforms that use --privileged, + # control visible devices via MLU_VISIBLE_DEVICES. + args.extend(["-e", f"MLU_VISIBLE_DEVICES={gpu_id}"]) + + memory = resources.get("memory") + + if memory: + mem = str(memory).lower().replace("gb", "g").replace("mb", "m") + + if not mem.endswith("g") and not mem.endswith("m"): + mem = f"{mem}g" + + args.extend(["--memory", mem]) + + shm_size = resources.get("shm_size") + + if shm_size: + args.extend(["--shm-size", str(shm_size)]) + + timeout_sec = resources.get("timeout") + args.append(image) + + if timeout_sec: + # Requires coreutils `timeout` inside the container image. + args.extend(["timeout", str(timeout_sec)]) + + args.extend(["bash", "-c", build_runner_script().strip()]) + + return args + + +def resolve_job_names(jobs, platform, job=None): + """Resolve job names for a platform. + + - ``job=None`` — all jobs for the platform. + - ``job="gpu"`` (short name) — matched via ``short_name`` field. + - ``job="nvidia_gpu"`` (full name) — direct lookup. + """ + if job and job in jobs: + return [job] + + if job: + matches = [ + name + for name, cfg in jobs.items() + if cfg.get("platform") == platform and cfg.get("short_name") == job + ] + + if not matches: + print( + f"error: job {job!r} not found for platform {platform!r}", + file=sys.stderr, + ) + sys.exit(1) + + return matches + + matches = [name for name, cfg in jobs.items() if cfg.get("platform") == platform] + + if not matches: + print(f"error: no jobs for platform {platform!r}", file=sys.stderr) + sys.exit(1) + + return matches + + +def main(): + parser = argparse.ArgumentParser(description="Run Docker CI pipeline") + parser.add_argument( + "--config", + type=Path, + default=Path(__file__).resolve().parent / "config.yaml", + help="Path to config.yaml", + ) + parser.add_argument( + "--branch", type=str, help="Override repo branch (default: config repo.branch)" + ) + parser.add_argument( + "--job", + type=str, + help="Job name: short name (gpu) or full name (nvidia_gpu). Default: all jobs", + ) + parser.add_argument( + "--stage", + type=str, + help="Run only this stage name (still runs setup first)", + ) + parser.add_argument( + "--image-tag", + type=str, + help="Override image tag (stable, latest, or commit hash)", + ) + parser.add_argument( + "--gpu-id", + type=str, + help='GPU device IDs to use, e.g. "0", "0,2", "all"', + ) + parser.add_argument( + "--results-dir", + type=Path, + default=Path("ci-results"), + help="Base directory for test results (default: ./ci-results)", + ) + parser.add_argument( + "--test", + type=str, + help='Override pytest test path, e.g. "tests/test_gemm.py" or "tests/test_gemm.py::test_gemm"', + ) + parser.add_argument( + "--local", + action="store_true", + help="Mount current directory (read-only) into the container instead of cloning from git", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print docker command and exit", + ) + args = parser.parse_args() + + config = load_config(args.config) + repo = config.get("repo", {}) + repo_url = repo.get("url", "https://github.com/InfiniTensor/InfiniOps.git") + branch = args.branch or repo.get("branch", "master") + + platform = detect_platform() + + if not platform: + tools = ", ".join(ResourcePool.GPU_QUERY_TOOLS.values()) + print(f"error: could not detect platform (no {tools} found)", file=sys.stderr) + sys.exit(1) + + print(f"platform: {platform}", file=sys.stderr) + + jobs = config.get("jobs", {}) + + if not jobs: + print("error: no jobs in config", file=sys.stderr) + sys.exit(1) + + job_names = resolve_job_names(jobs, platform, job=args.job) + failed = 0 + + for job_name in job_names: + job = jobs[job_name] + all_stages = job.get("stages", []) + + if args.stage: + stages = [s for s in all_stages if s["name"] == args.stage] + + if not stages: + print( + f"error: stage {args.stage!r} not found in {job_name}", + file=sys.stderr, + ) + sys.exit(1) + else: + stages = all_stages + + if args.test: + stages = [ + {**s, "run": apply_test_override(s.get("run", ""), args.test)} + for s in stages + ] + + job_platform = job.get("platform", platform) + commit = get_git_commit() + results_dir = build_results_dir(args.results_dir, job_platform, stages, commit) + + local_path = Path.cwd().resolve() if args.local else None + docker_args = build_docker_args( + config, + job_name, + repo_url, + branch, + stages, + "/workspace", + args.image_tag, + gpu_id_override=args.gpu_id, + results_dir=results_dir, + local_path=local_path, + ) + + if args.dry_run: + print(shlex.join(docker_args)) + continue + + print(f"==> running job: {job_name}", file=sys.stderr) + results_dir.mkdir(parents=True, exist_ok=True) + returncode = subprocess.run(docker_args).returncode + + if returncode != 0: + print(f"job {job_name} failed (exit code {returncode})", file=sys.stderr) + failed += 1 + + sys.exit(1 if failed else 0) + + +if __name__ == "__main__": + main() diff --git a/.ci/tests/__init__.py b/.ci/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/.ci/tests/conftest.py b/.ci/tests/conftest.py new file mode 100644 index 0000000..38ed716 --- /dev/null +++ b/.ci/tests/conftest.py @@ -0,0 +1,46 @@ +import sys +from pathlib import Path + +# Allow `import run` and `import build` directly. +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import pytest + +from utils import normalize_config + + +@pytest.fixture +def minimal_config(): + """Minimal platform-centric config, normalized to flat format.""" + raw = { + "repo": { + "url": "https://github.com/InfiniTensor/InfiniOps.git", + "branch": "master", + }, + "platforms": { + "nvidia": { + "image": { + "dockerfile": ".ci/images/nvidia/", + "build_args": {"BASE_IMAGE": "nvcr.io/nvidia/pytorch:24.10-py3"}, + }, + "setup": "pip install .[dev]", + "jobs": { + "gpu": { + "resources": { + "gpu_ids": "0", + "memory": "32GB", + "shm_size": "16g", + "timeout": 3600, + }, + "stages": [ + { + "name": "test", + "run": "pytest tests/ -v", + } + ], + } + }, + } + }, + } + return normalize_config(raw) diff --git a/.ci/tests/test_agent.py b/.ci/tests/test_agent.py new file mode 100644 index 0000000..73708db --- /dev/null +++ b/.ci/tests/test_agent.py @@ -0,0 +1,535 @@ +import hashlib +import hmac +import json +import threading +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +import agent +import ci_resource as res +from utils import normalize_config + + +# --------------------------------------------------------------------------- +# Test fixtures. +# --------------------------------------------------------------------------- + + +@pytest.fixture +def agent_config(): + raw = { + "repo": { + "url": "https://github.com/InfiniTensor/InfiniOps.git", + "branch": "master", + }, + "github": { + "status_context_prefix": "ci/infiniops", + }, + "agents": { + "nvidia": {"url": "http://nvidia-host:8080"}, + "iluvatar": {"url": "http://iluvatar-host:8080"}, + }, + "platforms": { + "nvidia": { + "image": { + "dockerfile": ".ci/images/nvidia/", + "build_args": {"BASE_IMAGE": "nvcr.io/nvidia/pytorch:24.10-py3"}, + }, + "setup": "pip install .[dev]", + "jobs": { + "gpu": { + "resources": { + "gpu_ids": "0", + "memory": "32GB", + "shm_size": "16g", + "timeout": 3600, + }, + "stages": [{"name": "test", "run": "pytest tests/ -v"}], + }, + }, + }, + "iluvatar": { + "image": { + "dockerfile": ".ci/images/iluvatar/", + "build_args": {"BASE_IMAGE": "corex:qs_pj20250825"}, + }, + "setup": "pip install .[dev]", + "jobs": { + "gpu": { + "resources": { + "gpu_ids": "0", + "gpu_style": "none", + "memory": "32GB", + "shm_size": "16g", + "timeout": 3600, + }, + "stages": [{"name": "test", "run": "pytest tests/ -v"}], + }, + }, + }, + }, + } + return normalize_config(raw) + + +@pytest.fixture +def mock_resource_pool(): + pool = MagicMock(spec=res.ResourcePool) + pool.platform = "nvidia" + pool.allocate.return_value = ([0], True) + pool.release.return_value = None + pool.get_status.return_value = { + "platform": "nvidia", + "gpus": [], + "allocated_gpu_ids": [], + "system": {}, + } + return pool + + +# --------------------------------------------------------------------------- +# Tests for `select_jobs`. +# --------------------------------------------------------------------------- + + +def test_select_jobs_by_name(agent_config): + jobs = agent.select_jobs(agent_config, job_name="nvidia_gpu") + assert jobs == ["nvidia_gpu"] + + +def test_select_jobs_by_platform(agent_config): + jobs = agent.select_jobs(agent_config, platform="nvidia") + assert jobs == ["nvidia_gpu"] + + +def test_select_jobs_by_platform_iluvatar(agent_config): + jobs = agent.select_jobs(agent_config, platform="iluvatar") + assert jobs == ["iluvatar_gpu"] + + +def test_select_jobs_all(agent_config): + jobs = agent.select_jobs(agent_config) + assert set(jobs) == {"nvidia_gpu", "iluvatar_gpu"} + + +def test_select_jobs_invalid_name(agent_config): + with pytest.raises(ValueError, match="not_exist"): + agent.select_jobs(agent_config, job_name="not_exist") + + +# --------------------------------------------------------------------------- +# Tests for `verify_signature`. +# --------------------------------------------------------------------------- + + +def test_verify_signature_valid(): + secret = "my-secret" + body = b'{"action": "push"}' + sig = "sha256=" + hmac.new(secret.encode(), body, hashlib.sha256).hexdigest() + assert agent.verify_signature(secret, body, sig) is True + + +def test_verify_signature_invalid(): + assert agent.verify_signature("secret", b"body", "sha256=wrong") is False + + +def test_verify_signature_empty(): + assert agent.verify_signature("secret", b"body", "") is False + + +# --------------------------------------------------------------------------- +# Tests for `JobRequest` and `JobResult`. +# --------------------------------------------------------------------------- + + +def test_job_request_fields(agent_config): + req = agent.JobRequest("nvidia_gpu", "master", "abc123", agent_config) + assert req.job_name == "nvidia_gpu" + assert req.platform == "nvidia" + assert req.commit_sha == "abc123" + assert len(req.job_id) == 8 + d = req.to_dict() + assert d["job_name"] == "nvidia_gpu" + + +def test_job_result_success(): + r = agent.JobResult("id1", "nvidia_gpu", "abc", 0, Path("/tmp/res"), 42.5) + assert r.state == "success" + + +def test_job_result_failure(): + r = agent.JobResult("id1", "nvidia_gpu", "abc", 1, Path("/tmp/res"), 10.0) + assert r.state == "failure" + + +# --------------------------------------------------------------------------- +# Tests for the `Scheduler` class. +# --------------------------------------------------------------------------- + + +def test_scheduler_submit_and_run(agent_config, mock_resource_pool, monkeypatch): + monkeypatch.setattr("subprocess.run", lambda cmd, **kw: MagicMock(returncode=0)) + monkeypatch.setattr("agent.gh.post_commit_status", lambda *a, **kw: True) + + scheduler = agent.Scheduler( + agent_config, + "nvidia", + mock_resource_pool, + results_dir=Path("/tmp/test-results"), + no_status=True, + dry_run=True, + ) + req = agent.JobRequest( + "nvidia_gpu", + "master", + "abc123", + agent_config, + results_dir=Path("/tmp/test-results"), + ) + scheduler.submit(req) + results = scheduler.wait_all() + assert len(results) == 1 + assert results[0].state == "success" + + +def test_scheduler_queues_when_no_resources(agent_config, monkeypatch): + pool = MagicMock(spec=res.ResourcePool) + pool.allocate.return_value = ([], False) + pool.get_status.return_value = { + "platform": "nvidia", + "gpus": [], + "allocated_gpu_ids": [], + "system": {}, + } + + scheduler = agent.Scheduler( + agent_config, + "nvidia", + pool, + no_status=True, + dry_run=False, + ) + + req = agent.JobRequest("nvidia_gpu", "master", "abc123", agent_config) + scheduler.submit(req) + + info = scheduler.get_job(req.job_id) + assert info["state"] == "queued" + + +def test_scheduler_get_status(agent_config, mock_resource_pool): + scheduler = agent.Scheduler( + agent_config, + "nvidia", + mock_resource_pool, + no_status=True, + dry_run=True, + ) + + status = scheduler.get_status() + assert "queued" in status + assert "running" in status + assert "completed" in status + assert "resources" in status + + +# --------------------------------------------------------------------------- +# Tests for `WebhookHandler` push event parsing. +# --------------------------------------------------------------------------- + + +def test_webhook_parse_push(): + handler = agent.WebhookHandler.__new__(agent.WebhookHandler) + payload = {"ref": "refs/heads/feat/test", "after": "abc123def456"} + branch, sha = handler._parse_push(payload) + assert branch == "feat/test" + assert sha == "abc123def456" + + +def test_webhook_parse_pr(): + handler = agent.WebhookHandler.__new__(agent.WebhookHandler) + payload = { + "pull_request": { + "head": { + "ref": "feat/pr-branch", + "sha": "def789", + } + } + } + branch, sha = handler._parse_pull_request(payload) + assert branch == "feat/pr-branch" + assert sha == "def789" + + +# --------------------------------------------------------------------------- +# Integration-style webhook HTTP tests. +# --------------------------------------------------------------------------- + + +def _urlopen_no_proxy(url_or_req, **kwargs): + """`urlopen` mock that bypasses any `HTTP_PROXY`.""" + import urllib.request + + opener = urllib.request.build_opener(urllib.request.ProxyHandler({})) + return opener.open(url_or_req, **kwargs) + + +def test_health_endpoint(agent_config, mock_resource_pool): + scheduler = agent.Scheduler( + agent_config, + "nvidia", + mock_resource_pool, + no_status=True, + ) + server = agent.AgentServer( + "127.0.0.1", + 0, + agent_config, + scheduler, + "nvidia", + ) + port = server.server_address[1] + + t = threading.Thread(target=server.handle_request, daemon=True) + t.start() + + try: + resp = _urlopen_no_proxy(f"http://127.0.0.1:{port}/health", timeout=5) + data = json.loads(resp.read()) + assert data["status"] == "ok" + assert data["platform"] == "nvidia" + finally: + server.server_close() + + +def test_api_run_endpoint(agent_config, mock_resource_pool, monkeypatch): + monkeypatch.setattr("agent.gh.post_commit_status", lambda *a, **kw: True) + + scheduler = agent.Scheduler( + agent_config, + "nvidia", + mock_resource_pool, + no_status=True, + dry_run=True, + ) + server = agent.AgentServer( + "127.0.0.1", + 0, + agent_config, + scheduler, + "nvidia", + results_dir=Path("/tmp/test-results"), + ) + port = server.server_address[1] + + t = threading.Thread(target=server.handle_request, daemon=True) + t.start() + + import urllib.request + + body = json.dumps({"branch": "master", "commit_sha": "abc123"}).encode() + req = urllib.request.Request( + f"http://127.0.0.1:{port}/api/run", + data=body, + headers={"Content-Type": "application/json"}, + ) + + try: + resp = _urlopen_no_proxy(req, timeout=5) + data = json.loads(resp.read()) + assert data["accepted"] is True + assert len(data["job_ids"]) >= 1 + finally: + server.server_close() + + +def test_webhook_with_signature(agent_config, mock_resource_pool, monkeypatch): + monkeypatch.setattr("agent.gh.post_commit_status", lambda *a, **kw: True) + + scheduler = agent.Scheduler( + agent_config, + "nvidia", + mock_resource_pool, + no_status=True, + dry_run=True, + ) + secret = "test-secret" + server = agent.AgentServer( + "127.0.0.1", + 0, + agent_config, + scheduler, + "nvidia", + webhook_secret=secret, + results_dir=Path("/tmp/test-results"), + ) + port = server.server_address[1] + + t = threading.Thread(target=server.handle_request, daemon=True) + t.start() + + import urllib.request + + payload = json.dumps( + { + "ref": "refs/heads/master", + "after": "abc123def456", + } + ).encode() + sig = "sha256=" + hmac.new(secret.encode(), payload, hashlib.sha256).hexdigest() + + req = urllib.request.Request( + f"http://127.0.0.1:{port}/webhook", + data=payload, + headers={ + "Content-Type": "application/json", + "X-GitHub-Event": "push", + "X-Hub-Signature-256": sig, + }, + ) + + try: + resp = _urlopen_no_proxy(req, timeout=5) + data = json.loads(resp.read()) + assert data["accepted"] is True + finally: + server.server_close() + + +def test_webhook_invalid_signature(agent_config, mock_resource_pool): + scheduler = agent.Scheduler( + agent_config, + "nvidia", + mock_resource_pool, + no_status=True, + ) + server = agent.AgentServer( + "127.0.0.1", + 0, + agent_config, + scheduler, + "nvidia", + webhook_secret="real-secret", + ) + port = server.server_address[1] + + t = threading.Thread(target=server.handle_request, daemon=True) + t.start() + + import urllib.error + import urllib.request + + payload = b'{"ref": "refs/heads/master", "after": "abc"}' + req = urllib.request.Request( + f"http://127.0.0.1:{port}/webhook", + data=payload, + headers={ + "Content-Type": "application/json", + "X-GitHub-Event": "push", + "X-Hub-Signature-256": "sha256=invalid", + }, + ) + + try: + with pytest.raises(urllib.error.HTTPError) as exc_info: + _urlopen_no_proxy(req, timeout=5) + + assert exc_info.value.code == 401 + finally: + server.server_close() + + +# --------------------------------------------------------------------------- +# Tests for API token authentication. +# --------------------------------------------------------------------------- + + +def test_api_run_requires_token(agent_config, mock_resource_pool, monkeypatch): + """When `api_token` is set, `/api/run` rejects requests without a valid token.""" + monkeypatch.setattr("agent.gh.post_commit_status", lambda *a, **kw: True) + + scheduler = agent.Scheduler( + agent_config, + "nvidia", + mock_resource_pool, + no_status=True, + dry_run=True, + ) + server = agent.AgentServer( + "127.0.0.1", + 0, + agent_config, + scheduler, + "nvidia", + api_token="my-secret-token", + results_dir=Path("/tmp/test-results"), + ) + port = server.server_address[1] + + t = threading.Thread(target=server.handle_request, daemon=True) + t.start() + + import urllib.error + import urllib.request + + body = json.dumps({"branch": "master", "commit_sha": "abc123"}).encode() + req = urllib.request.Request( + f"http://127.0.0.1:{port}/api/run", + data=body, + headers={"Content-Type": "application/json"}, + ) + + try: + with pytest.raises(urllib.error.HTTPError) as exc_info: + _urlopen_no_proxy(req, timeout=5) + + assert exc_info.value.code == 401 + finally: + server.server_close() + + +def test_api_run_accepts_valid_token(agent_config, mock_resource_pool, monkeypatch): + """When `api_token` is set, `/api/run` accepts requests with a correct Bearer token.""" + monkeypatch.setattr("agent.gh.post_commit_status", lambda *a, **kw: True) + + scheduler = agent.Scheduler( + agent_config, + "nvidia", + mock_resource_pool, + no_status=True, + dry_run=True, + ) + server = agent.AgentServer( + "127.0.0.1", + 0, + agent_config, + scheduler, + "nvidia", + api_token="my-secret-token", + results_dir=Path("/tmp/test-results"), + ) + port = server.server_address[1] + + t = threading.Thread(target=server.handle_request, daemon=True) + t.start() + + import urllib.request + + body = json.dumps({"branch": "master", "commit_sha": "abc123"}).encode() + req = urllib.request.Request( + f"http://127.0.0.1:{port}/api/run", + data=body, + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer my-secret-token", + }, + ) + + try: + resp = _urlopen_no_proxy(req, timeout=5) + data = json.loads(resp.read()) + assert data["accepted"] is True + finally: + server.server_close() diff --git a/.ci/tests/test_build.py b/.ci/tests/test_build.py new file mode 100644 index 0000000..4d28885 --- /dev/null +++ b/.ci/tests/test_build.py @@ -0,0 +1,186 @@ +import build + + +# --------------------------------------------------------------------------- +# Tests for `build_image_tag`. +# --------------------------------------------------------------------------- + + +def test_build_image_tag_with_registry(): + tag = build.build_image_tag("localhost:5000", "infiniops", "nvidia", "latest") + assert tag == "localhost:5000/infiniops/nvidia:latest" + + +def test_build_image_tag_without_registry(): + tag = build.build_image_tag("", "infiniops", "nvidia", "abc1234") + assert tag == "infiniops-ci/nvidia:abc1234" + + +def test_build_image_tag_commit_hash(): + tag = build.build_image_tag( + "registry.example.com:5000", "proj", "ascend", "deadbeef" + ) + assert tag == "registry.example.com:5000/proj/ascend:deadbeef" + + +# --------------------------------------------------------------------------- +# Tests for `has_dockerfile_changed`. +# --------------------------------------------------------------------------- + + +def test_has_dockerfile_changed_true_when_stdout_nonempty(mocker): + mocker.patch( + "subprocess.run", + return_value=mocker.Mock(returncode=0, stdout="Dockerfile\n"), + ) + assert build.has_dockerfile_changed(".ci/images/nvidia/") is True + + +def test_has_dockerfile_changed_false_when_stdout_empty(mocker): + mocker.patch( + "subprocess.run", + return_value=mocker.Mock(returncode=0, stdout=""), + ) + assert build.has_dockerfile_changed(".ci/images/nvidia/") is False + + +def test_has_dockerfile_changed_true_on_git_error(mocker): + # Shallow clone or initial commit: `git diff` returns non-zero. + mocker.patch( + "subprocess.run", + return_value=mocker.Mock(returncode=128, stdout=""), + ) + assert build.has_dockerfile_changed(".ci/images/nvidia/") is True + + +# --------------------------------------------------------------------------- +# Tests for `docker_login`. +# --------------------------------------------------------------------------- + + +def test_docker_login_no_credentials_env(mocker): + run_mock = mocker.patch("subprocess.run") + result = build.docker_login({"url": "localhost:5000"}, dry_run=False) + assert result is True + run_mock.assert_not_called() + + +def test_docker_login_token_not_set(mocker, monkeypatch, capsys): + monkeypatch.delenv("REGISTRY_TOKEN", raising=False) + run_mock = mocker.patch("subprocess.run") + cfg = {"url": "localhost:5000", "credentials_env": "REGISTRY_TOKEN"} + result = build.docker_login(cfg, dry_run=False) + assert result is False + run_mock.assert_not_called() + + +def test_docker_login_dry_run_does_not_call_subprocess(mocker, monkeypatch): + monkeypatch.setenv("REGISTRY_TOKEN", "mytoken") + run_mock = mocker.patch("subprocess.run") + cfg = {"url": "localhost:5000", "credentials_env": "REGISTRY_TOKEN"} + result = build.docker_login(cfg, dry_run=True) + assert result is True + run_mock.assert_not_called() + + +def test_docker_login_success(mocker, monkeypatch): + monkeypatch.setenv("REGISTRY_TOKEN", "mytoken") + run_mock = mocker.patch( + "subprocess.run", + return_value=mocker.Mock(returncode=0), + ) + cfg = {"url": "localhost:5000", "credentials_env": "REGISTRY_TOKEN"} + result = build.docker_login(cfg, dry_run=False) + assert result is True + run_mock.assert_called_once() + cmd = run_mock.call_args[0][0] + assert "docker" in cmd + assert "login" in cmd + + +# --------------------------------------------------------------------------- +# Tests for `build_image` dry-run mode and proxy forwarding. +# --------------------------------------------------------------------------- + + +def _platform_cfg(): + return { + "dockerfile": ".ci/images/nvidia/", + "build_args": {"BASE_IMAGE": "nvcr.io/nvidia/pytorch:24.10-py3"}, + } + + +def _registry_cfg(): + return {"url": "localhost:5000", "project": "infiniops"} + + +def test_build_image_dry_run_no_subprocess(mocker, monkeypatch, capsys): + monkeypatch.delenv("HTTP_PROXY", raising=False) + run_mock = mocker.patch("subprocess.run") + build.build_image( + "nvidia", + _platform_cfg(), + _registry_cfg(), + "abc1234", + push=False, + dry_run=True, + logged_in=True, + ) + run_mock.assert_not_called() + captured = capsys.readouterr() + assert "[dry-run]" in captured.out + + +def test_build_image_dry_run_output_contains_image_tag(mocker, monkeypatch, capsys): + monkeypatch.delenv("HTTP_PROXY", raising=False) + mocker.patch("subprocess.run") + build.build_image( + "nvidia", + _platform_cfg(), + _registry_cfg(), + "abc1234", + push=False, + dry_run=True, + logged_in=True, + ) + captured = capsys.readouterr() + assert "abc1234" in captured.out + + +def test_build_image_proxy_in_build_args(mocker, monkeypatch): + monkeypatch.setenv("HTTP_PROXY", "http://proxy.test:3128") + run_mock = mocker.patch( + "subprocess.run", + return_value=mocker.Mock(returncode=0), + ) + build.build_image( + "nvidia", + _platform_cfg(), + _registry_cfg(), + "abc1234", + push=False, + dry_run=False, + logged_in=True, + ) + called_cmd = run_mock.call_args[0][0] + joined = " ".join(called_cmd) + assert "HTTP_PROXY=http://proxy.test:3128" in joined + assert "http_proxy=http://proxy.test:3128" in joined + + +def test_build_image_returns_false_on_docker_error(mocker, monkeypatch): + monkeypatch.delenv("HTTP_PROXY", raising=False) + mocker.patch( + "subprocess.run", + return_value=mocker.Mock(returncode=1), + ) + result = build.build_image( + "nvidia", + _platform_cfg(), + _registry_cfg(), + "abc1234", + push=False, + dry_run=False, + logged_in=True, + ) + assert result is False diff --git a/.ci/tests/test_github_status.py b/.ci/tests/test_github_status.py new file mode 100644 index 0000000..9e29c79 --- /dev/null +++ b/.ci/tests/test_github_status.py @@ -0,0 +1,145 @@ +import json +from unittest.mock import MagicMock + + +import github_status as gh + + +# --------------------------------------------------------------------------- +# Tests for `parse_repo_url`. +# --------------------------------------------------------------------------- + + +def test_parse_repo_url_https(): + owner, repo = gh.parse_repo_url("https://github.com/InfiniTensor/InfiniOps.git") + assert owner == "InfiniTensor" + assert repo == "InfiniOps" + + +def test_parse_repo_url_https_no_git(): + owner, repo = gh.parse_repo_url("https://github.com/Owner/Repo") + assert owner == "Owner" + assert repo == "Repo" + + +def test_parse_repo_url_ssh(): + owner, repo = gh.parse_repo_url("git@github.com:Owner/Repo.git") + assert owner == "Owner" + assert repo == "Repo" + + +def test_parse_repo_url_invalid(): + owner, repo = gh.parse_repo_url("not-a-url") + assert owner == "" + assert repo == "" + + +# --------------------------------------------------------------------------- +# Tests for `build_status_context`. +# --------------------------------------------------------------------------- + + +def test_build_status_context(): + ctx = gh.build_status_context("ci/infiniops", "nvidia_gpu") + assert ctx == "ci/infiniops/nvidia_gpu" + + +# --------------------------------------------------------------------------- +# Tests for `post_commit_status`. +# --------------------------------------------------------------------------- + + +def test_post_status_no_token(monkeypatch): + monkeypatch.delenv("GITHUB_TOKEN", raising=False) + result = gh.post_commit_status("owner", "repo", "abc123", "success", "ctx", "desc") + assert result is False + + +def test_post_status_missing_owner(): + result = gh.post_commit_status( + "", "repo", "abc123", "success", "ctx", "desc", token="tok" + ) + assert result is False + + +def test_post_status_success(monkeypatch): + mock_response = MagicMock() + mock_response.status = 201 + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock(return_value=False) + + captured_req = {} + + def mock_urlopen(req, **kwargs): + captured_req["url"] = req.full_url + captured_req["data"] = json.loads(req.data) + captured_req["headers"] = dict(req.headers) + return mock_response + + monkeypatch.setattr("urllib.request.urlopen", mock_urlopen) + + result = gh.post_commit_status( + "InfiniTensor", + "InfiniOps", + "abc123def", + "success", + "ci/infiniops/nvidia_gpu", + "Tests passed", + token="ghp_test_token", + ) + + assert result is True + assert "abc123def" in captured_req["url"] + assert captured_req["data"]["state"] == "success" + assert captured_req["data"]["context"] == "ci/infiniops/nvidia_gpu" + assert "ghp_test_token" in captured_req["headers"]["Authorization"] + + +def test_post_status_http_error(monkeypatch): + import urllib.error + + def mock_urlopen(req, **kwargs): + raise urllib.error.HTTPError( + url="", code=422, msg="Unprocessable", hdrs=None, fp=None + ) + + monkeypatch.setattr("urllib.request.urlopen", mock_urlopen) + + result = gh.post_commit_status( + "owner", "repo", "sha", "success", "ctx", "desc", token="tok" + ) + assert result is False + + +def test_post_status_url_error(monkeypatch): + import urllib.error + + def mock_urlopen(req, **kwargs): + raise urllib.error.URLError("connection refused") + + monkeypatch.setattr("urllib.request.urlopen", mock_urlopen) + + result = gh.post_commit_status( + "owner", "repo", "sha", "success", "ctx", "desc", token="tok" + ) + assert result is False + + +def test_post_status_truncates_description(monkeypatch): + mock_response = MagicMock() + mock_response.status = 201 + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock(return_value=False) + + captured = {} + + def mock_urlopen(req, **kwargs): + captured["data"] = json.loads(req.data) + return mock_response + + monkeypatch.setattr("urllib.request.urlopen", mock_urlopen) + + long_desc = "x" * 200 + gh.post_commit_status("o", "r", "sha", "success", "ctx", long_desc, token="tok") + + assert len(captured["data"]["description"]) == 140 diff --git a/.ci/tests/test_resource.py b/.ci/tests/test_resource.py new file mode 100644 index 0000000..0db3fbb --- /dev/null +++ b/.ci/tests/test_resource.py @@ -0,0 +1,327 @@ +import threading + + +import ci_resource as res + + +# --------------------------------------------------------------------------- +# Tests for `GpuInfo` and `SystemResources`. +# --------------------------------------------------------------------------- + + +def test_gpu_info_fields(): + g = res.GpuInfo( + index=0, memory_used_mb=1000, memory_total_mb=8000, utilization_pct=50 + ) + assert g.index == 0 + assert g.memory_total_mb == 8000 + + +def test_system_resources_fields(): + s = res.SystemResources( + total_memory_mb=32000, available_memory_mb=16000, cpu_count=8 + ) + assert s.cpu_count == 8 + + +# --------------------------------------------------------------------------- +# Tests for `detect_gpus`. +# --------------------------------------------------------------------------- + + +def test_detect_gpus_nvidia_parses_csv(monkeypatch): + csv_output = "0, 512, 8192, 5\n1, 1024, 8192, 80\n" + + def mock_run(cmd, **kwargs): + class R: + returncode = 0 + stdout = csv_output + + return R() + + monkeypatch.setattr("subprocess.run", mock_run) + + pool = res.ResourcePool("nvidia") + gpus = pool.detect_gpus() + assert len(gpus) == 2 + assert gpus[0].index == 0 + assert gpus[0].memory_used_mb == 512 + assert gpus[0].utilization_pct == 5 + assert gpus[1].index == 1 + assert gpus[1].utilization_pct == 80 + + +def test_detect_gpus_empty_on_failure(monkeypatch): + def mock_run(cmd, **kwargs): + class R: + returncode = 1 + stdout = "" + + return R() + + monkeypatch.setattr("subprocess.run", mock_run) + + pool = res.ResourcePool("nvidia") + assert pool.detect_gpus() == [] + + +def test_detect_gpus_unknown_platform(): + pool = res.ResourcePool("unknown_platform") + assert pool.detect_gpus() == [] + + +def test_detect_gpus_file_not_found(monkeypatch): + def mock_run(cmd, **kwargs): + raise FileNotFoundError("nvidia-smi not found") + + monkeypatch.setattr("subprocess.run", mock_run) + + pool = res.ResourcePool("nvidia") + assert pool.detect_gpus() == [] + + +# --------------------------------------------------------------------------- +# Tests for `detect_system_resources`. +# --------------------------------------------------------------------------- + + +def test_detect_system_resources(monkeypatch, tmp_path): + meminfo = tmp_path / "meminfo" + meminfo.write_text( + "MemTotal: 32000000 kB\n" + "MemFree: 10000000 kB\n" + "MemAvailable: 20000000 kB\n" + ) + + + _real_open = open + + def fake_open(path, **kw): + if str(path) == "/proc/meminfo": + return _real_open(str(meminfo), **kw) + return _real_open(path, **kw) + + monkeypatch.setattr("builtins.open", fake_open) + + pool = res.ResourcePool("nvidia") + sys_res = pool.detect_system_resources() + assert abs(sys_res.total_memory_mb - 32000000 / 1024) < 1 + assert abs(sys_res.available_memory_mb - 20000000 / 1024) < 1 + assert sys_res.cpu_count > 0 + + +# --------------------------------------------------------------------------- +# Tests for `get_free_gpus`. +# --------------------------------------------------------------------------- + + +def test_get_free_gpus_filters_by_utilization(monkeypatch): + csv_output = "0, 100, 8192, 5\n1, 4000, 8192, 95\n2, 200, 8192, 8\n" + + def mock_run(cmd, **kwargs): + class R: + returncode = 0 + stdout = csv_output + + return R() + + monkeypatch.setattr("subprocess.run", mock_run) + + pool = res.ResourcePool("nvidia", utilization_threshold=10) + free = pool.get_free_gpus() + assert 0 in free + assert 2 in free + assert 1 not in free + + +# --------------------------------------------------------------------------- +# Tests for `allocate` and `release`. +# --------------------------------------------------------------------------- + + +def test_allocate_success(monkeypatch): + csv_output = "0, 100, 8192, 5\n1, 200, 8192, 3\n" + + def mock_run(cmd, **kwargs): + class R: + returncode = 0 + stdout = csv_output + + return R() + + monkeypatch.setattr("subprocess.run", mock_run) + + pool = res.ResourcePool("nvidia", utilization_threshold=10) + gpu_ids, ok = pool.allocate(1) + assert ok is True + assert len(gpu_ids) == 1 + assert gpu_ids[0] in (0, 1) + + +def test_allocate_insufficient_gpus(monkeypatch): + csv_output = "0, 100, 8192, 5\n" + + def mock_run(cmd, **kwargs): + class R: + returncode = 0 + stdout = csv_output + + return R() + + monkeypatch.setattr("subprocess.run", mock_run) + + pool = res.ResourcePool("nvidia", utilization_threshold=10) + gpu_ids, ok = pool.allocate(3) + assert ok is False + assert gpu_ids == [] + + +def test_allocate_zero_gpus(): + pool = res.ResourcePool("unknown") + gpu_ids, ok = pool.allocate(0) + assert ok is True + assert gpu_ids == [] + + +def test_release_frees_gpus(monkeypatch): + csv_output = "0, 100, 8192, 5\n1, 200, 8192, 3\n" + + def mock_run(cmd, **kwargs): + class R: + returncode = 0 + stdout = csv_output + + return R() + + monkeypatch.setattr("subprocess.run", mock_run) + + pool = res.ResourcePool("nvidia", utilization_threshold=10) + gpu_ids, ok = pool.allocate(2) + assert ok is True + assert len(gpu_ids) == 2 + + # All GPUs allocated; next allocation should fail. + _, ok2 = pool.allocate(1) + assert ok2 is False + + # Release one GPU. + pool.release([gpu_ids[0]]) + gpu_ids2, ok3 = pool.allocate(1) + assert ok3 is True + assert gpu_ids2 == [gpu_ids[0]] + + +def test_allocate_excludes_allocated(monkeypatch): + csv_output = "0, 100, 8192, 5\n1, 200, 8192, 3\n" + + def mock_run(cmd, **kwargs): + class R: + returncode = 0 + stdout = csv_output + + return R() + + monkeypatch.setattr("subprocess.run", mock_run) + + pool = res.ResourcePool("nvidia", utilization_threshold=10) + gpu_ids1, _ = pool.allocate(1) + gpu_ids2, _ = pool.allocate(1) + + assert gpu_ids1 != gpu_ids2 + assert set(gpu_ids1 + gpu_ids2) == {0, 1} + + +def test_thread_safety(monkeypatch): + csv_output = "0, 0, 8192, 0\n1, 0, 8192, 0\n2, 0, 8192, 0\n3, 0, 8192, 0\n" + + def mock_run(cmd, **kwargs): + class R: + returncode = 0 + stdout = csv_output + + return R() + + monkeypatch.setattr("subprocess.run", mock_run) + + pool = res.ResourcePool("nvidia", utilization_threshold=50) + allocated_all = [] + lock = threading.Lock() + + def allocate_one(): + ids, ok = pool.allocate(1) + + if ok: + with lock: + allocated_all.extend(ids) + + threads = [threading.Thread(target=allocate_one) for _ in range(4)] + + for t in threads: + t.start() + + for t in threads: + t.join() + + assert len(allocated_all) == 4 + assert len(set(allocated_all)) == 4 + + +# --------------------------------------------------------------------------- +# Tests for `get_status`. +# --------------------------------------------------------------------------- + + +def test_get_status(monkeypatch): + csv_output = "0, 512, 8192, 5\n" + + def mock_run(cmd, **kwargs): + class R: + returncode = 0 + stdout = csv_output + + return R() + + monkeypatch.setattr("subprocess.run", mock_run) + + pool = res.ResourcePool("nvidia") + status = pool.get_status() + assert status["platform"] == "nvidia" + assert len(status["gpus"]) == 1 + assert "system" in status + + +# --------------------------------------------------------------------------- +# Tests for `parse_gpu_requirement` and `parse_memory_requirement`. +# --------------------------------------------------------------------------- + + +def test_parse_gpu_requirement_nvidia(): + job = {"resources": {"gpu_ids": "0,1", "gpu_style": "nvidia"}} + assert res.parse_gpu_requirement(job) == 2 + + +def test_parse_gpu_requirement_none(): + job = {"resources": {"gpu_style": "none"}} + assert res.parse_gpu_requirement(job) == 0 + + +def test_parse_gpu_requirement_all(): + job = {"resources": {"gpu_ids": "all"}} + assert res.parse_gpu_requirement(job) == 0 + + +def test_parse_gpu_requirement_default(): + job = {"resources": {"gpu_ids": "0"}} + assert res.parse_gpu_requirement(job) == 1 + + +def test_parse_memory_requirement_gb(): + assert res.parse_memory_requirement({"resources": {"memory": "32GB"}}) == 32 * 1024 + + +def test_parse_memory_requirement_mb(): + assert res.parse_memory_requirement({"resources": {"memory": "512MB"}}) == 512 + + +def test_parse_memory_requirement_empty(): + assert res.parse_memory_requirement({"resources": {}}) == 0 diff --git a/.ci/tests/test_run.py b/.ci/tests/test_run.py new file mode 100644 index 0000000..93987e5 --- /dev/null +++ b/.ci/tests/test_run.py @@ -0,0 +1,298 @@ +from pathlib import Path + +import pytest + +import run + + +# --------------------------------------------------------------------------- +# Tests for `resolve_image`. +# --------------------------------------------------------------------------- + + +def test_resolve_image_with_registry(): + cfg = {"registry": {"url": "localhost:5000", "project": "infiniops"}} + img = run.resolve_image(cfg, "nvidia", "latest") + assert img == "localhost:5000/infiniops/nvidia:latest" + + +def test_resolve_image_without_registry(minimal_config): + img = run.resolve_image(minimal_config, "nvidia", "abc1234") + assert img == "infiniops-ci/nvidia:abc1234" + + +# --------------------------------------------------------------------------- +# Tests for `build_runner_script`. +# --------------------------------------------------------------------------- + + +def test_runner_script_contains_git_clone(): + script = run.build_runner_script() + assert "git clone" in script + + +def test_runner_script_contains_setup_cmd(): + script = run.build_runner_script() + assert "SETUP_CMD" in script + + +def test_runner_script_exits_on_failure(): + script = run.build_runner_script() + assert "exit $failed" in script + + +def test_runner_script_creates_results_dir(): + script = run.build_runner_script() + assert "mkdir -p /workspace/results" in script + + +# --------------------------------------------------------------------------- +# Tests for `build_docker_args` basic structure. +# --------------------------------------------------------------------------- + + +def test_docker_args_basic_structure(minimal_config): + args = run.build_docker_args( + minimal_config, + "nvidia_gpu", + "https://github.com/example/repo.git", + "master", + minimal_config["jobs"]["nvidia_gpu"]["stages"], + "/workspace", + None, + ) + assert args[0] == "docker" + assert "run" in args + assert "--rm" in args + + +def test_docker_args_correct_image(minimal_config): + args = run.build_docker_args( + minimal_config, + "nvidia_gpu", + "https://github.com/example/repo.git", + "master", + minimal_config["jobs"]["nvidia_gpu"]["stages"], + "/workspace", + None, + ) + assert "infiniops-ci/nvidia:latest" in args + + +def test_docker_args_image_tag_override(minimal_config): + args = run.build_docker_args( + minimal_config, + "nvidia_gpu", + "https://github.com/example/repo.git", + "master", + minimal_config["jobs"]["nvidia_gpu"]["stages"], + "/workspace", + "abc1234", + ) + assert "infiniops-ci/nvidia:abc1234" in args + + +# --------------------------------------------------------------------------- +# Tests for `build_docker_args` proxy passthrough. +# --------------------------------------------------------------------------- + + +def test_docker_args_proxy_present_when_set(minimal_config, monkeypatch): + monkeypatch.setenv("HTTP_PROXY", "http://proxy.example.com:8080") + args = run.build_docker_args( + minimal_config, + "nvidia_gpu", + "https://github.com/example/repo.git", + "master", + minimal_config["jobs"]["nvidia_gpu"]["stages"], + "/workspace", + None, + ) + assert "-e" in args + assert "HTTP_PROXY=http://proxy.example.com:8080" in args + assert "http_proxy=http://proxy.example.com:8080" in args + + +def test_docker_args_proxy_absent_when_not_set(minimal_config, monkeypatch): + monkeypatch.delenv("HTTP_PROXY", raising=False) + monkeypatch.delenv("http_proxy", raising=False) + monkeypatch.delenv("HTTPS_PROXY", raising=False) + monkeypatch.delenv("https_proxy", raising=False) + monkeypatch.delenv("NO_PROXY", raising=False) + monkeypatch.delenv("no_proxy", raising=False) + args = run.build_docker_args( + minimal_config, + "nvidia_gpu", + "https://github.com/example/repo.git", + "master", + minimal_config["jobs"]["nvidia_gpu"]["stages"], + "/workspace", + None, + ) + + for arg in args: + assert not arg.startswith("HTTP_PROXY=") + assert not arg.startswith("http_proxy=") + assert not arg.startswith("HTTPS_PROXY=") + assert not arg.startswith("https_proxy=") + assert not arg.startswith("NO_PROXY=") + assert not arg.startswith("no_proxy=") + + +def test_docker_args_proxy_lowercase_fallback(minimal_config, monkeypatch): + monkeypatch.delenv("HTTP_PROXY", raising=False) + monkeypatch.setenv("http_proxy", "http://lowercase.proxy:3128") + args = run.build_docker_args( + minimal_config, + "nvidia_gpu", + "https://github.com/example/repo.git", + "master", + minimal_config["jobs"]["nvidia_gpu"]["stages"], + "/workspace", + None, + ) + assert "HTTP_PROXY=http://lowercase.proxy:3128" in args + assert "http_proxy=http://lowercase.proxy:3128" in args + + +# --------------------------------------------------------------------------- +# Tests for `build_docker_args` GPU flags. +# --------------------------------------------------------------------------- + + +def _make_args(config, gpu_id_override=None): + return run.build_docker_args( + config, + "nvidia_gpu", + "https://github.com/example/repo.git", + "master", + config["jobs"]["nvidia_gpu"]["stages"], + "/workspace", + None, + gpu_id_override=gpu_id_override, + ) + + +def test_docker_args_gpu_device(minimal_config): + args = _make_args(minimal_config) + idx = args.index("--gpus") + assert "device=0" in args[idx + 1] + + +def test_docker_args_gpu_all(minimal_config): + minimal_config["jobs"]["nvidia_gpu"]["resources"]["gpu_ids"] = "all" + args = _make_args(minimal_config) + idx = args.index("--gpus") + assert args[idx + 1] == "all" + + +def test_docker_args_no_gpu(minimal_config): + minimal_config["jobs"]["nvidia_gpu"]["resources"]["gpu_ids"] = "" + minimal_config["jobs"]["nvidia_gpu"]["resources"].pop("gpu_count", None) + args = _make_args(minimal_config) + assert "--gpus" not in args + + +def test_docker_args_gpu_override(minimal_config): + args = _make_args(minimal_config, gpu_id_override="2,3") + idx = args.index("--gpus") + assert "2,3" in args[idx + 1] + + +# --------------------------------------------------------------------------- +# Tests for `build_docker_args` memory format. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "raw,expected", + [ + ("32GB", "32g"), + ("512MB", "512m"), + ("8", "8g"), + ("16gb", "16g"), + ("256mb", "256m"), + ], +) +def test_docker_args_memory_format(minimal_config, raw, expected): + minimal_config["jobs"]["nvidia_gpu"]["resources"]["memory"] = raw + args = _make_args(minimal_config) + idx = args.index("--memory") + assert args[idx + 1] == expected + + +# --------------------------------------------------------------------------- +# Tests for `build_docker_args` stages encoding. +# --------------------------------------------------------------------------- + + +def test_docker_args_num_stages(minimal_config): + args = _make_args(minimal_config) + assert "NUM_STAGES=1" in args + + +def test_docker_args_stage_name_cmd(minimal_config): + args = _make_args(minimal_config) + assert "STAGE_1_NAME=test" in args + assert any(a.startswith("STAGE_1_CMD=") for a in args) + + +def test_docker_args_multiple_stages(minimal_config): + minimal_config["jobs"]["nvidia_gpu"]["stages"] = [ + {"name": "lint", "run": "ruff check ."}, + {"name": "test", "run": "pytest tests/"}, + ] + args = _make_args(minimal_config) + assert "NUM_STAGES=2" in args + assert "STAGE_1_NAME=lint" in args + assert "STAGE_2_NAME=test" in args + + +# --------------------------------------------------------------------------- +# Tests for `build_docker_args` `results_dir` mount. +# --------------------------------------------------------------------------- + + +def test_docker_args_results_dir(minimal_config, tmp_path): + args = run.build_docker_args( + minimal_config, + "nvidia_gpu", + "https://github.com/example/repo.git", + "master", + minimal_config["jobs"]["nvidia_gpu"]["stages"], + "/workspace", + None, + results_dir=tmp_path, + ) + joined = " ".join(str(a) for a in args) + assert "-v" in args + assert "/workspace/results" in joined + + +# --------------------------------------------------------------------------- +# Tests for `build_results_dir`. +# --------------------------------------------------------------------------- + + +def test_build_results_dir_contains_platform(): + stages = [{"name": "test", "run": "pytest"}] + d = run.build_results_dir("ci-results", "nvidia", stages, "abc1234") + assert "nvidia" in d.name + + +def test_build_results_dir_contains_commit(): + stages = [{"name": "test", "run": "pytest"}] + d = run.build_results_dir("ci-results", "nvidia", stages, "abc1234") + assert "abc1234" in d.name + + +def test_build_results_dir_contains_stage_names(): + stages = [{"name": "lint", "run": "ruff"}, {"name": "test", "run": "pytest"}] + d = run.build_results_dir("ci-results", "nvidia", stages, "abc1234") + assert "lint+test" in d.name + + +def test_build_results_dir_under_base(): + stages = [{"name": "test", "run": "pytest"}] + d = run.build_results_dir("/tmp/my-results", "ascend", stages, "def5678") + assert d.parent == Path("/tmp/my-results") diff --git a/.ci/tests/test_utils.py b/.ci/tests/test_utils.py new file mode 100644 index 0000000..b07011c --- /dev/null +++ b/.ci/tests/test_utils.py @@ -0,0 +1,90 @@ +from utils import normalize_config + + +def test_normalize_creates_flat_jobs(): + raw = { + "repo": {"url": "https://github.com/org/repo.git"}, + "platforms": { + "nvidia": { + "image": {"dockerfile": ".ci/images/nvidia/"}, + "setup": "pip install .", + "docker_args": ["--gpus", "all"], + "jobs": { + "gpu": { + "resources": {"gpu_ids": "0"}, + "stages": [{"name": "test", "run": "pytest"}], + }, + "multi_gpu": { + "resources": {"gpu_ids": "0,1"}, + "stages": [{"name": "test", "run": "pytest"}], + }, + }, + }, + }, + } + config = normalize_config(raw) + + assert "nvidia_gpu" in config["jobs"] + assert "nvidia_multi_gpu" in config["jobs"] + assert config["jobs"]["nvidia_gpu"]["platform"] == "nvidia" + assert config["jobs"]["nvidia_gpu"]["setup"] == "pip install ." + assert config["jobs"]["nvidia_gpu"]["docker_args"] == ["--gpus", "all"] + assert config["jobs"]["nvidia_gpu"]["resources"]["gpu_ids"] == "0" + assert config["jobs"]["nvidia_multi_gpu"]["resources"]["gpu_ids"] == "0,1" + + +def test_normalize_extracts_images(): + raw = { + "platforms": { + "nvidia": { + "image": { + "dockerfile": ".ci/images/nvidia/", + "build_args": {"BASE_IMAGE": "pytorch:latest"}, + }, + "jobs": {}, + }, + }, + } + config = normalize_config(raw) + assert config["images"]["nvidia"]["dockerfile"] == ".ci/images/nvidia/" + assert config["images"]["nvidia"]["build_args"]["BASE_IMAGE"] == "pytorch:latest" + + +def test_normalize_job_overrides_platform_defaults(): + raw = { + "platforms": { + "nvidia": { + "setup": "default setup", + "jobs": { + "special": { + "setup": "custom setup", + "stages": [], + }, + }, + }, + }, + } + config = normalize_config(raw) + assert config["jobs"]["nvidia_special"]["setup"] == "custom setup" + + +def test_normalize_preserves_top_level_keys(): + raw = { + "repo": {"url": "https://github.com/org/repo.git"}, + "github": {"status_context_prefix": "ci/test"}, + "agents": {"nvidia": {"url": "http://host:8080"}}, + "platforms": {}, + } + config = normalize_config(raw) + assert config["repo"]["url"] == "https://github.com/org/repo.git" + assert config["github"]["status_context_prefix"] == "ci/test" + assert config["agents"]["nvidia"]["url"] == "http://host:8080" + + +def test_normalize_passthrough_flat_config(): + """Old flat format without `platforms` key is returned as-is.""" + flat = { + "images": {"nvidia": {}}, + "jobs": {"nvidia_gpu": {"platform": "nvidia"}}, + } + assert normalize_config(flat) is flat diff --git a/.ci/utils.py b/.ci/utils.py new file mode 100644 index 0000000..07dec87 --- /dev/null +++ b/.ci/utils.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +"""Shared utilities for the CI toolchain.""" + +import subprocess +import sys + +try: + import yaml +except ImportError: + print( + "error: pyyaml is required. Install with: pip install pyyaml", file=sys.stderr + ) + sys.exit(1) + + +def normalize_config(raw): + """Convert platform-centric config to flat images/jobs format. + + Input (new format): + platforms: + nvidia: + image: {dockerfile: ..., build_args: ...} + setup: pip install .[dev] + jobs: + gpu: {resources: ..., stages: ...} + + Output (flat format consumed by run.py / build.py / agent.py): + images: + nvidia: {dockerfile: ..., build_args: ...} + jobs: + nvidia_gpu: {platform: nvidia, setup: ..., resources: ..., stages: ...} + + If the config already uses the flat format (no 'platforms' key), returns as-is. + """ + if "platforms" not in raw: + return raw + + config = {} + + for key in ("repo", "github", "agents"): + if key in raw: + config[key] = raw[key] + + config["images"] = {} + config["jobs"] = {} + + for platform, pcfg in raw.get("platforms", {}).items(): + # Image config + if "image" in pcfg: + config["images"][platform] = pcfg["image"] + + # Platform-level defaults inherited by jobs + defaults = {} + + for key in ("image_tag", "docker_args", "volumes", "setup", "env"): + if key in pcfg: + defaults[key] = pcfg[key] + + # Flatten jobs: {platform}_{job_name} + for job_name, job_cfg in pcfg.get("jobs", {}).items(): + full_name = f"{platform}_{job_name}" + flat = { + "platform": platform, + "short_name": job_name, + "image": defaults.get("image_tag", "latest"), + } + + # Apply platform defaults + for key in ("docker_args", "volumes", "setup", "env"): + if key in defaults: + flat[key] = defaults[key] + + # Job-level overrides + flat.update(job_cfg) + + config["jobs"][full_name] = flat + + # Warn on mismatched agent/platform keys (catches typos like 'nvdia'). + agent_keys = set(config.get("agents", {}).keys()) + platform_keys = set(raw.get("platforms", {}).keys()) + + for key in agent_keys - platform_keys: + print( + f"warning: agents.{key} has no matching platform in platforms.*", + file=sys.stderr, + ) + + return config + + +def load_config(path): + """Load a YAML config file and normalize to flat format.""" + with open(path, encoding="utf-8") as f: + raw = yaml.safe_load(f) + + return normalize_config(raw) + + +def get_git_commit(ref="HEAD", short=True): + """Get git commit SHA. Returns 'unknown' on failure.""" + cmd = ["git", "rev-parse"] + + if short: + cmd.append("--short") + + cmd.append(ref) + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + return "unknown" + + return result.stdout.strip() diff --git a/pyproject.toml b/pyproject.toml index 765b90a..3dbc186 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "InfiniOps" version = "0.1.0" [project.optional-dependencies] -dev = ["pytest", "pytest-cov", "pytest-xdist", "ruff", "torch"] +dev = ["pytest", "pytest-cov", "pytest-xdist", "ruff", "torch", "pyyaml"] [tool.scikit-build.wheel] install-dir = "infini" diff --git a/tests/test_gemm.py b/tests/test_gemm.py index 43a47b6..136e991 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -48,6 +48,10 @@ def test_gemm( if device == "mlu" and (trans_a or trans_b): pytest.skip("transposing is not currently supported on MLU") + # `cnnlBatchMatMulEx` does not accept `bfloat16` inputs on MLU. + if device == "mlu" and dtype == torch.bfloat16: + pytest.skip("`bfloat16` is not supported by `cnnlBatchMatMulEx`") + a = randn_strided(a_shape, a_strides, dtype=dtype, device=device) b = randn_strided(b_shape, b_strides, dtype=dtype, device=device) @@ -97,8 +101,10 @@ def _torch_gemm(a, b, alpha=1.0, beta=1.0, trans_a=False, trans_b=False, c=None) return torch.baddbmm(c, a, b, beta=beta, alpha=alpha, out=c) except RuntimeError: - c_original = c.clone() - torch.matmul(a, b, out=c) - c.mul_(alpha).add_(c_original, alpha=beta) + # Fallback for backends that don't support `addmm`/`baddbmm` (e.g. CPU `float16`/`bfloat16`): + # compute in float32 and cast back. + c_original = c.float() + result = torch.matmul(a.float(), b.float()) + c.copy_((alpha * result + beta * c_original).to(c.dtype)) return c From d6b5fd5b941d9b98819d7c7c68b2e40b2a5ce5ce Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 25 Mar 2026 10:02:24 +0000 Subject: [PATCH 89/93] chore: ignore ci-results/ directory Co-Authored-By: Claude --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 2effaff..540edc5 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ build/ generated/ +# CI test results +ci-results/ + # Prerequisites *.d From 56f33301e1f0134c1e76df75fd345d5b82dee571 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Thu, 26 Mar 2026 01:55:18 +0000 Subject: [PATCH 90/93] Revert "chore: ignore ci-results/ directory" This reverts commit d6b5fd5b941d9b98819d7c7c68b2e40b2a5ce5ce. --- .gitignore | 3 --- 1 file changed, 3 deletions(-) diff --git a/.gitignore b/.gitignore index 540edc5..2effaff 100644 --- a/.gitignore +++ b/.gitignore @@ -2,9 +2,6 @@ build/ generated/ -# CI test results -ci-results/ - # Prerequisites *.d From 8c92b2efd0287422210c66ab88dfe728a2ecb179 Mon Sep 17 00:00:00 2001 From: zhangyunze <93699316+bitzyz@users.noreply.github.com> Date: Thu, 26 Mar 2026 10:53:42 +0800 Subject: [PATCH 91/93] feat: add Cambricon `RMSNorm` (#19) * feat: Add RMSNorm op in cambricon backend. * refactor: make `Cast` utility to use `Device::Type` template parameter * refactor: add `Caster` mixin * refactor: rename `cast**` to `caster**` * fix: fix the mlu naming to google c++ naming style * chore: format files with `clang-format` * refactor: update CUDA kernels to use `Caster` * fix: fix rmsnorm dispatch to use one dispatch --------- Co-authored-by: Jiacheng Huang --- CMakeLists.txt | 2 +- src/CMakeLists.txt | 39 ++- src/base/rms_norm.h | 10 +- src/cambricon/common.h | 50 ++++ src/cambricon/rms_norm/kernel.mlu | 332 ++++++++++++++++++++++++ src/cambricon/rms_norm/rms_norm.h | 64 +++++ src/caster.h | 14 + src/common/cast.h | 11 - src/common/cpu/cast.h | 57 ---- src/common/cuda/cast.h | 107 -------- src/cpu/add/add.h | 5 +- src/cpu/caster_.h | 71 +++++ src/cpu/causal_softmax/causal_softmax.h | 5 +- src/cpu/gemm/gemm.h | 5 +- src/cpu/rms_norm/rms_norm.h | 5 +- src/cpu/swiglu/swiglu.h | 5 +- src/cuda/add/kernel.cuh | 2 +- src/cuda/add/kernel.h | 1 + src/cuda/caster_.h | 111 ++++++++ src/cuda/causal_softmax/kernel.cuh | 7 +- src/cuda/causal_softmax/kernel.h | 2 +- src/{common => }/cuda/kernel_commons.h | 2 +- src/cuda/rms_norm/kernel.cuh | 4 +- src/cuda/rms_norm/kernel.h | 14 +- src/cuda/swiglu/kernel.cuh | 2 +- src/data_type.h | 6 + src/iluvatar/causal_softmax/kernel.h | 4 + src/iluvatar/device_.h | 12 + src/iluvatar/rms_norm/kernel.h | 4 + src/metax/causal_softmax/kernel.h | 4 + src/metax/device_.h | 12 + src/metax/rms_norm/kernel.h | 4 + src/moore/device_.h | 12 + src/nvidia/causal_softmax/kernel.h | 4 + src/nvidia/device_.h | 12 + src/nvidia/rms_norm/kernel.h | 4 + src/operator.h | 5 +- 37 files changed, 801 insertions(+), 209 deletions(-) create mode 100644 src/cambricon/rms_norm/kernel.mlu create mode 100644 src/cambricon/rms_norm/rms_norm.h create mode 100644 src/caster.h delete mode 100644 src/common/cast.h delete mode 100644 src/common/cpu/cast.h delete mode 100644 src/common/cuda/cast.h create mode 100644 src/cpu/caster_.h create mode 100644 src/cuda/caster_.h rename src/{common => }/cuda/kernel_commons.h (99%) create mode 100644 src/iluvatar/device_.h create mode 100644 src/metax/device_.h create mode 100644 src/moore/device_.h create mode 100644 src/nvidia/device_.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 570b7d7..b9e2deb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -179,7 +179,7 @@ if(WITH_CAMBRICON) endif() # If all other platforms are not enabled, CPU is enabled by default. -if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE) +if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON) add_compile_definitions(WITH_CPU=1) endif() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3ca0715..585e3ab 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -127,11 +127,48 @@ if(WITH_MOORE) endif() if(WITH_CAMBRICON) - target_compile_definitions(infiniops PUBLIC WITH_CAMBRICON=1) + file(GLOB_RECURSE CAMBRICON_MLU_SOURCES CONFIGURE_DEPENDS "cambricon/*/*.mlu") + find_program(CNCC_COMPILER cncc HINTS "${NEUWARE_HOME}/bin" "$ENV{NEUWARE_HOME}/bin" /usr/local/neuware/bin) + if(CNCC_COMPILER) + message(STATUS "Found cncc: ${CNCC_COMPILER}") + set(MLU_COMPILE_OPTS + -c --bang-mlu-arch=mtp_592 -O3 -fPIC -Wall -Werror -std=c++17 -pthread + -I${CMAKE_CURRENT_SOURCE_DIR} -I${NEUWARE_HOME}/include + -idirafter /usr/local/neuware/lib/clang/11.1.0/include + ) + function(compile_mlu_file src_file) + get_filename_component(name ${src_file} NAME_WE) + get_filename_component(path ${src_file} DIRECTORY) + set(out_file "${CMAKE_CURRENT_BINARY_DIR}/${path}/${name}.o") + file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${path}") + add_custom_command(OUTPUT ${out_file} + COMMAND ${CNCC_COMPILER} ${MLU_COMPILE_OPTS} -c ${src_file} -o ${out_file} + DEPENDS ${src_file} + COMMENT "Building MLU kernel: ${src_file}" + ) + set_property(DIRECTORY APPEND PROPERTY CAMBRICON_OBJECTS ${out_file}) + endfunction() + foreach(src ${CAMBRICON_MLU_SOURCES}) + compile_mlu_file(${src}) + endforeach() + get_directory_property(CAMBRICON_OBJECT_FILES CAMBRICON_OBJECTS) + if(CAMBRICON_OBJECT_FILES) + target_sources(infiniops PRIVATE ${CAMBRICON_OBJECT_FILES}) + endif() + else() + message(WARNING "cncc compiler not found. MLU kernels will not be compiled.") + endif() + target_compile_definitions(infiniops PRIVATE WITH_CAMBRICON=1) target_include_directories(infiniops PUBLIC "${NEUWARE_HOME}/include") target_link_libraries(infiniops PUBLIC ${CAMBRICON_RUNTIME_LIB} ${CAMBRICON_CNNL_LIB} ${CAMBRICON_CNNL_EXTRA_LIB} ${CAMBRICON_PAPI_LIB}) + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + target_compile_options(infiniops PUBLIC + "$<$:SHELL:-idirafter /usr/local/neuware/lib/clang/11.1.0/include>" + ) + endif() + list(APPEND DEVICE_LIST "cambricon") endif() diff --git a/src/base/rms_norm.h b/src/base/rms_norm.h index 65f44b3..dc28f0a 100644 --- a/src/base/rms_norm.h +++ b/src/base/rms_norm.h @@ -12,15 +12,17 @@ namespace infini::ops { class RmsNorm : public Operator { public: RmsNorm(const Tensor input, const Tensor weight, float eps, Tensor out) - : eps_{eps}, + : input_shape_{input.shape()}, out_shape_{out.shape()}, - input_shape_{input.shape()}, - out_strides_{out.strides()}, input_strides_{input.strides()}, + out_strides_{out.strides()}, + eps_{eps}, dim_{out.size(-1)}, ndim_{out.ndim()}, batch_size_{ndim_ == 2 ? out.size(-2) : out.size(-3)}, - nhead_{ndim_ == 2 ? 1 : out.size(-2)} {} + nhead_{ndim_ == 2 ? 1 : out.size(-2)} { + assert(input.dtype() == out.dtype()); + } RmsNorm(const Tensor input, const Tensor weight, Tensor out) : RmsNorm{input, weight, 1e-6f, out} {} diff --git a/src/cambricon/common.h b/src/cambricon/common.h index 50775c2..fc8ede0 100644 --- a/src/cambricon/common.h +++ b/src/cambricon/common.h @@ -2,19 +2,58 @@ #define INFINI_OPS_CAMBRICON_COMMON_H_ #include +#include #include "data_type.h" +#include "device.h" + +#define NRAM_MAX_SIZE (1024 * 240) + +#ifdef __BANG__ + +namespace infini::ops::reduce { + +constexpr int batch_size = 128 / sizeof(float); + +__mlu_func__ void SumInternal(float* dst, float* src, int max_batch) { + const int width = max_batch / batch_size; + + if (width >= 4) { + __bang_sumpool(dst, src, batch_size, 1, width, 1, width, 1, 1); + __bang_reduce_sum(dst, dst, batch_size); + } else { + float sum = 0.0f; + for (int i = 0; i < max_batch; ++i) { + sum += src[i]; + } + dst[0] = sum; + } +} + +} // namespace infini::ops::reduce + +#endif // __BANG__ namespace infini::ops::cnnl_utils { inline cnnlDataType_t GetDataType(DataType dtype) { switch (dtype) { + case DataType::kInt8: + return CNNL_DTYPE_INT8; + case DataType::kUInt8: + return CNNL_DTYPE_UINT8; case DataType::kInt32: return CNNL_DTYPE_INT32; + case DataType::kInt64: + return CNNL_DTYPE_INT64; case DataType::kFloat16: return CNNL_DTYPE_HALF; case DataType::kFloat32: return CNNL_DTYPE_FLOAT; + case DataType::kBFloat16: + return CNNL_DTYPE_BFLOAT16; + case DataType::kFloat64: + return CNNL_DTYPE_DOUBLE; default: return CNNL_DTYPE_INVALID; } @@ -22,4 +61,15 @@ inline cnnlDataType_t GetDataType(DataType dtype) { } // namespace infini::ops::cnnl_utils +namespace infini::ops::cnrt_utils { + +inline void GetLaunchConfig(const Device& device, int* core_per_cluster, + int* cluster_count) { + int device_id = device.index(); + cnrtDeviceGetAttribute(cluster_count, cnrtAttrClusterCount, device_id); + cnrtDeviceGetAttribute(core_per_cluster, cnrtAttrMcorePerCluster, device_id); +} + +} // namespace infini::ops::cnrt_utils + #endif diff --git a/src/cambricon/rms_norm/kernel.mlu b/src/cambricon/rms_norm/kernel.mlu new file mode 100644 index 0000000..6648d2d --- /dev/null +++ b/src/cambricon/rms_norm/kernel.mlu @@ -0,0 +1,332 @@ +#include "rms_norm.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +namespace infini::ops { + +template +__mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output, + size_t *shape, ptrdiff_t *output_strides, + ptrdiff_t *input_strides, float epsilon, + int num_dims, int norm_dim_size) { + // Calculate problem dimensions. + int batch_volume = 1; + for (int dim_idx = 0; dim_idx < num_dims - 1; ++dim_idx) { + batch_volume *= shape[dim_idx]; + } + int vector_size = shape[num_dims - 1]; + + // Task distribution across cores. + int remaining_tasks = batch_volume % taskDim; + int base_tasks_per_core = batch_volume / taskDim; + int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0); + int task_start_idx = + (taskId < remaining_tasks + ? taskId * (base_tasks_per_core + 1) + : remaining_tasks * (base_tasks_per_core + 1) + + (taskId - remaining_tasks) * base_tasks_per_core); + + // Determine optimal batch size based on vector size. + int max_batch_size; + if (vector_size <= 64) { + max_batch_size = vector_size; + } else { + max_batch_size = + (NRAM_MAX_SIZE - 256) / (2 * sizeof(T) + sizeof(TW) + sizeof(float)); + max_batch_size = std::min(max_batch_size, vector_size); + max_batch_size = (max_batch_size / 64) * 64; // Align to 64 elements + } + + constexpr int reduce_buffer_size = 128 / sizeof(float); + + // NRAM buffer allocation with dynamic sizing. + float *reduction_buffer = (float *)nram_buffer; + T *input_cache = (T *)(reduction_buffer + reduce_buffer_size); + TW *weight_cache = (TW *)(input_cache + max_batch_size); + float *float_buffer = (float *)(weight_cache + max_batch_size); + float *weight_float_buffer = (float *)(float_buffer + max_batch_size); + + // Process vectors assigned to current core. + for (int task_idx = 0; task_idx < actual_tasks; ++task_idx) { + int current_index = task_start_idx + task_idx; + + // Calculate memory offsets for current task. + int input_offset = 0; + int output_offset = 0; + int temp_index = current_index; + + for (int dim = 0; dim < num_dims - 1; ++dim) { + int dim_coord = temp_index % shape[dim]; + input_offset += dim_coord * input_strides[dim]; + output_offset += dim_coord * output_strides[dim]; + temp_index /= shape[dim]; + } + + // Compute sum of squares. + float sum_squared = 0.0f; + + if (vector_size <= 128) { + __memcpy(input_cache, input + input_offset, vector_size * sizeof(T), + GDRAM2NRAM); + if constexpr (std::is_same::value) { + __bang_half2float(float_buffer, reinterpret_cast(input_cache), + vector_size); + } else if constexpr (std::is_same::value) { + __bang_bfloat162float(float_buffer, input_cache, vector_size); + } else { + __memcpy(float_buffer, input_cache, vector_size * sizeof(float), + NRAM2NRAM); + } + + __bang_mul(float_buffer, float_buffer, float_buffer, vector_size); + + // Direct accumulation for small vectors. + for (int i = 0; i < vector_size; ++i) { + sum_squared += float_buffer[i]; + } + } else { + // Large vector processing with chunking. + __bang_write_value(reduction_buffer, reduce_buffer_size, 0); + size_t processed_elements = 0; + + while (processed_elements < vector_size) { + size_t current_batch = + std::min((size_t)max_batch_size, vector_size - processed_elements); + + __memcpy(input_cache, + input + input_offset + + processed_elements * input_strides[num_dims - 1], + current_batch * sizeof(T), GDRAM2NRAM); + + if constexpr (std::is_same::value) { + __bang_half2float(float_buffer, reinterpret_cast(input_cache), + current_batch); + } else if constexpr (std::is_same::value) { + __bang_bfloat162float(float_buffer, input_cache, current_batch); + } else { + __memcpy(float_buffer, input_cache, current_batch * sizeof(float), + NRAM2NRAM); + } + + __bang_mul(float_buffer, float_buffer, float_buffer, current_batch); + + float batch_sum = 0.0f; + if (current_batch >= 128) { + infini::ops::reduce::SumInternal(reduction_buffer, float_buffer, + current_batch); + batch_sum = reduction_buffer[0]; + } else { + for (size_t i = 0; i < current_batch; ++i) { + batch_sum += float_buffer[i]; + } + } + + sum_squared += batch_sum; + processed_elements += current_batch; + } + } + + // Compute normalization factor. + float rms_value = sqrtf(sum_squared / vector_size + epsilon); + float inv_rms = 1.0f / rms_value; + + // Process vector for normalization. + if (vector_size <= max_batch_size) { + __memcpy(input_cache, input + input_offset, vector_size * sizeof(T), + GDRAM2NRAM); + __memcpy(weight_cache, weight, vector_size * sizeof(TW), GDRAM2NRAM); + + if constexpr (std::is_same::value) { + __bang_half2float(float_buffer, reinterpret_cast(input_cache), + vector_size); + } else if constexpr (std::is_same::value) { + __bang_bfloat162float(float_buffer, input_cache, vector_size); + } else { + __memcpy(float_buffer, input_cache, vector_size * sizeof(float), + NRAM2NRAM); + } + + if constexpr (std::is_same::value) { + __bang_half2float(weight_float_buffer, + reinterpret_cast(weight_cache), vector_size); + } else if constexpr (std::is_same::value) { + __bang_bfloat162float(weight_float_buffer, weight_cache, vector_size); + } else { + __memcpy(weight_float_buffer, weight_cache, vector_size * sizeof(float), + NRAM2NRAM); + } + + // Multiply by weight and apply normalization. + __bang_mul(float_buffer, float_buffer, weight_float_buffer, vector_size); + __bang_mul_scalar(float_buffer, float_buffer, inv_rms, vector_size); + + if constexpr (std::is_same::value) { + __bang_float2half(reinterpret_cast(input_cache), float_buffer, + vector_size); + } else if constexpr (std::is_same::value) { + __bang_float2bfloat16(input_cache, float_buffer, vector_size); + } else { + __memcpy(input_cache, float_buffer, vector_size * sizeof(float), + NRAM2NRAM); + } + + __memcpy(output + output_offset, input_cache, vector_size * sizeof(T), + NRAM2GDRAM); + } else { + // Large vector processing with chunking. + size_t processed_elements = 0; + while (processed_elements < vector_size) { + size_t current_batch = + std::min((size_t)max_batch_size, vector_size - processed_elements); + + // Load input and weight data. + __memcpy(input_cache, + input + input_offset + + processed_elements * input_strides[num_dims - 1], + current_batch * sizeof(T), GDRAM2NRAM); + __memcpy(weight_cache, weight + processed_elements, + current_batch * sizeof(TW), GDRAM2NRAM); + + if constexpr (std::is_same::value) { + __bang_half2float(float_buffer, reinterpret_cast(input_cache), + current_batch); + } else if constexpr (std::is_same::value) { + __bang_bfloat162float(float_buffer, input_cache, current_batch); + } else { + __memcpy(float_buffer, input_cache, current_batch * sizeof(float), + NRAM2NRAM); + } + + if constexpr (std::is_same::value) { + __bang_half2float(weight_float_buffer, + reinterpret_cast(weight_cache), + current_batch); + } else if constexpr (std::is_same::value) { + __bang_bfloat162float(weight_float_buffer, weight_cache, + current_batch); + } else { + __memcpy(weight_float_buffer, weight_cache, + current_batch * sizeof(float), NRAM2NRAM); + } + + __bang_mul(float_buffer, float_buffer, weight_float_buffer, + current_batch); + __bang_mul_scalar(float_buffer, float_buffer, inv_rms, current_batch); + + if constexpr (std::is_same::value) { + __bang_float2half(reinterpret_cast(input_cache), float_buffer, + current_batch); + } else if constexpr (std::is_same::value) { + __bang_float2bfloat16(input_cache, float_buffer, current_batch); + } else { + __memcpy(input_cache, float_buffer, current_batch * sizeof(float), + NRAM2NRAM); + } + + __memcpy(output + output_offset + + processed_elements * output_strides[num_dims - 1], + input_cache, current_batch * sizeof(T), NRAM2GDRAM); + + processed_elements += current_batch; + } + } + } +} + +template +void RmsNormUnion(void *workspace, int core_per_cluster, int cluster_count, + cnrtQueue_t queue, void *y, const void *x, const void *w, + const size_t *shape, const ptrdiff_t *y_strides, + const ptrdiff_t *x_strides, float eps, int ndim) { + cnrtDim3_t kernel_dim; + cnrtFunctionType_t kernel_type; + + // Configure kernel dimensions. + kernel_dim.x = core_per_cluster; + kernel_dim.y = cluster_count; + kernel_dim.z = 1; + kernel_type = cnrtFuncTypeUnion1; // Can choose others, but must adapt + // kernel_type accordingly + int dimsize = shape[ndim - 1]; // Length of operation dimension + int dim_s; // dim_s is the next power of 2 greater than dimsize + float mi = log2(dimsize); + if (floor(mi) == mi) { + dim_s = dimsize; + } else { + dim_s = pow(2, floor(mi) + 1); + } + constexpr int reduce_num = + 128 / sizeof(float); // Cambricon __bang_reduce_sum can only reduce 128 + // bytes at a time + if (dim_s < reduce_num) { + dim_s = reduce_num; // Force dim_s >= reduce_num + } + + // Prepare device pointers. + auto y_ = reinterpret_cast(y); + auto x_ = reinterpret_cast(x); + auto w_ = reinterpret_cast(w); + char *tmp_device = reinterpret_cast(workspace); + char *tmp_stride = tmp_device + ndim * sizeof(size_t); + size_t *mlu_shape = (size_t *)tmp_device; + ptrdiff_t *mlu_x_strides = (ptrdiff_t *)tmp_stride; + ptrdiff_t *mlu_y_strides = mlu_x_strides + ndim; + + // Copy shape and stride information to device. + CNRT_CHECK(cnrtMemcpyAsync(mlu_shape, const_cast(shape), + ndim * sizeof(size_t), queue, + cnrtMemcpyHostToDev)); // const not supported + CNRT_CHECK(cnrtMemcpyAsync(mlu_x_strides, const_cast(x_strides), + ndim * sizeof(ptrdiff_t), queue, + cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_y_strides, const_cast(y_strides), + ndim * sizeof(ptrdiff_t), queue, + cnrtMemcpyHostToDev)); + + RmsNorm<<>>( + x_, w_, y_, mlu_shape, mlu_y_strides, mlu_x_strides, eps, ndim, dim_s); + + cnrtQueueSync(queue); +} + +template void RmsNormUnion<__half, __half>(void *, int, int, cnrtQueue_t, + void *, const void *, const void *, + const size_t *, const ptrdiff_t *, + const ptrdiff_t *, float, int); + +template void RmsNormUnion<__half, __bang_bfloat16>( + void *, int, int, cnrtQueue_t, void *, const void *, const void *, + const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + +template void RmsNormUnion<__half, float>(void *, int, int, cnrtQueue_t, void *, + const void *, const void *, + const size_t *, const ptrdiff_t *, + const ptrdiff_t *, float, int); + +template void RmsNormUnion<__bang_bfloat16, __half>( + void *, int, int, cnrtQueue_t, void *, const void *, const void *, + const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + +template void RmsNormUnion<__bang_bfloat16, __bang_bfloat16>( + void *, int, int, cnrtQueue_t, void *, const void *, const void *, + const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + +template void RmsNormUnion<__bang_bfloat16, float>( + void *, int, int, cnrtQueue_t, void *, const void *, const void *, + const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + +template void RmsNormUnion(void *, int, int, cnrtQueue_t, void *, + const void *, const void *, + const size_t *, const ptrdiff_t *, + const ptrdiff_t *, float, int); + +template void RmsNormUnion( + void *, int, int, cnrtQueue_t, void *, const void *, const void *, + const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + +template void RmsNormUnion(void *, int, int, cnrtQueue_t, void *, + const void *, const void *, + const size_t *, const ptrdiff_t *, + const ptrdiff_t *, float, int); + +} // namespace infini::ops diff --git a/src/cambricon/rms_norm/rms_norm.h b/src/cambricon/rms_norm/rms_norm.h new file mode 100644 index 0000000..0e331dd --- /dev/null +++ b/src/cambricon/rms_norm/rms_norm.h @@ -0,0 +1,64 @@ +#ifndef INFINI_OPS_CAMBRICON_RMS_NORM_H_ +#define INFINI_OPS_CAMBRICON_RMS_NORM_H_ + +#include +#include +#include + +#include "../common.h" +#include "base/rms_norm.h" + +namespace infini::ops { + +// TODO: Remove forward declaration. +template +void RmsNormUnion(void *workspace, int core_per_cluster, int cluster_count, + cnrtQueue_t queue, void *y, const void *x, const void *w, + const size_t *shape, const ptrdiff_t *y_strides, + const ptrdiff_t *x_strides, float eps, int ndim); + +template <> +class Operator : public RmsNorm { + public: + Operator(const Tensor input, const Tensor weight, float eps, Tensor out) + : RmsNorm{input, weight, eps, out} { + cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster, + &cluster_count); + cnrtMalloc(&default_workspace_, workspace_size_in_bytes()); + } + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { + auto queue = static_cast(stream_ ? stream_ : 0); + auto workspace{workspace_ ? workspace_ : default_workspace_}; + + DispatchFunc< + List, + List>( + {input.dtype(), weight.dtype()}, + [&](auto input_tag, auto weight_tag) { + using InputT = typename decltype(input_tag)::type; + using WeightT = typename decltype(weight_tag)::type; + + RmsNormUnion( + workspace, core_per_cluster, cluster_count, queue, + out.data(), input.data(), weight.data(), out_shape_.data(), + out_strides_.data(), input_strides_.data(), eps, ndim_); + }, + "CambriconRmsNorm::operator() - output dispatch"); + } + + ~Operator() { cnrtFree(default_workspace_); } + + std::size_t workspace_size_in_bytes() const override { + return ndim_ * (sizeof(size_t) + 2 * sizeof(ptrdiff_t)); + } + + void *default_workspace_{nullptr}; + int core_per_cluster = 0; + int cluster_count = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/caster.h b/src/caster.h new file mode 100644 index 0000000..cefd116 --- /dev/null +++ b/src/caster.h @@ -0,0 +1,14 @@ +#ifndef INFINI_OPS_CASTER_H_ +#define INFINI_OPS_CASTER_H_ + +#include "data_type.h" +#include "device.h" + +namespace infini::ops { + +template +struct Caster; + +} // namespace infini::ops + +#endif diff --git a/src/common/cast.h b/src/common/cast.h deleted file mode 100644 index a37fb94..0000000 --- a/src/common/cast.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef INFINI_OPS_COMMON_CAST_H_ -#define INFINI_OPS_COMMON_CAST_H_ - -#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) || defined(WITH_METAX) || \ - defined(WITH_MOORE) -#include "common/cuda/cast.h" -#else -#include "common/cpu/cast.h" -#endif - -#endif diff --git a/src/common/cpu/cast.h b/src/common/cpu/cast.h deleted file mode 100644 index 68b95fc..0000000 --- a/src/common/cpu/cast.h +++ /dev/null @@ -1,57 +0,0 @@ -#ifndef INFINI_OPS_COMMON_CPU_CAST_H_ -#define INFINI_OPS_COMMON_CPU_CAST_H_ - -#include "data_type.h" - -namespace infini::ops { - -namespace detail { - -template -constexpr float ToFloatHelper(T &&x) { - using PureSrc = std::remove_cv_t >; - if constexpr (IsBFloat16 || IsFP16) { - return std::forward(x).ToFloat(); - } else { - return static_cast(std::forward(x)); - } -} - -template -constexpr Dst FromFloatHelper(float f) { - using PureDst = std::remove_cv_t >; - if constexpr (IsBFloat16 || IsFP16) { - return PureDst::FromFloat(f); - } else { - return static_cast(f); - } -} - -} // namespace detail - -template -Dst Cast(Src &&x) { - static_assert(!std::is_reference_v, - "`Cast` cannot return reference types"); - - using PureDst = std::remove_cv_t >; - using PureSrc = std::remove_cv_t >; - - if constexpr (std::is_same_v) { - return std::forward(x); - } - - constexpr bool src_is_custom = IsBFloat16 || IsFP16; - constexpr bool dst_is_custom = IsBFloat16 || IsFP16; - - if constexpr (!src_is_custom && !dst_is_custom) { - return static_cast(std::forward(x)); - } else { - return detail::FromFloatHelper( - detail::ToFloatHelper(std::forward(x))); - } -} - -} // namespace infini::ops - -#endif diff --git a/src/common/cuda/cast.h b/src/common/cuda/cast.h deleted file mode 100644 index 1f67a44..0000000 --- a/src/common/cuda/cast.h +++ /dev/null @@ -1,107 +0,0 @@ -#ifndef INFINI_OPS_COMMON_CUDA_CAST_H_ -#define INFINI_OPS_COMMON_CUDA_CAST_H_ - -#ifdef WITH_NVIDIA -#include -#elif defined(WITH_ILUVATAR) -#include -#elif defined(WITH_METAX) -#include -#elif defined(WITH_MOORE) -#include -#endif - -#include "data_type.h" - -namespace infini::ops { - -namespace detail { - -template -using PureType = std::remove_cv_t>; - -template -__host__ __device__ constexpr float ToFloatHelper(T&& x) { - using PureSrc = PureType; - if constexpr (IsBFloat16) { - return __bfloat162float(x); - } else if constexpr (IsFP16) { - return __half2float(x); - } else { - return static_cast(std::forward(x)); - } -} - -template -__host__ __device__ constexpr Dst FromFloatHelper(float f) { - using PureDst = PureType; - if constexpr (IsBFloat16) { - return __float2bfloat16(f); - } else if constexpr (IsFP16) { - return __float2half(f); - } else { - return static_cast(f); - } -} - -// Priority tags for overload resolution. -struct PriorityLow {}; - -struct PriorityHigh : PriorityLow {}; - -// Fallback: lowest priority. This always matches if nothing else does. -template -__host__ __device__ constexpr Dst HardwareCast(Src&& x, PriorityLow) { - return FromFloatHelper(ToFloatHelper(std::forward(x))); -} - -// Usage: `DEFINE_DIRECT_CAST(INTRINSIC, CONDITION)`. -#define DEFINE_DIRECT_CAST(INTRINSIC, ...) \ - template \ - __host__ __device__ auto HardwareCast(Src x, PriorityHigh) \ - ->std::enable_if_t<(__VA_ARGS__), \ - decltype(INTRINSIC(std::declval()))> { \ - return INTRINSIC(x); \ - } - -DEFINE_DIRECT_CAST( - __bfloat162int_rn, - std::is_same_v, int>&& IsBFloat16>) -DEFINE_DIRECT_CAST( - __bfloat162short_rn, - std::is_same_v, short>&& IsBFloat16>) -DEFINE_DIRECT_CAST( - __int2bfloat16_rn, - IsBFloat16>&& std::is_same_v, int>) -DEFINE_DIRECT_CAST(__int2half_rn, - IsFP16>&& std::is_same_v, int>) -DEFINE_DIRECT_CAST( - __double2bfloat16, - IsBFloat16>&& std::is_same_v, double>) -DEFINE_DIRECT_CAST( - __double2half, - IsFP16>&& std::is_same_v, double>) -DEFINE_DIRECT_CAST(__half, IsFP16>&& IsBFloat16>) -#undef DEFINE_DIRECT_CAST - -} // namespace detail - -template -__host__ __device__ Dst Cast(Src&& x) { - static_assert(!std::is_reference_v, - "`Cast` cannot return reference types"); - - using PureSrc = std::remove_cv_t>; - using PureDst = std::remove_cv_t>; - - if constexpr (std::is_same_v) { - return std::forward(x); - } else { - return detail::HardwareCast(std::forward(x), - detail::PriorityHigh{}); - } -} - -} // namespace infini::ops - -#endif diff --git a/src/cpu/add/add.h b/src/cpu/add/add.h index ec605c3..48d2469 100644 --- a/src/cpu/add/add.h +++ b/src/cpu/add/add.h @@ -4,13 +4,14 @@ #include #include "base/add.h" -#include "common/cast.h" #include "common/generic_utils.h" +#include "cpu/caster_.h" namespace infini::ops { template <> -class Operator : public Add { +class Operator : public Add, + Caster { public: Operator(const Tensor input, const Tensor other, Tensor out) : Add{input, other, out} { diff --git a/src/cpu/caster_.h b/src/cpu/caster_.h new file mode 100644 index 0000000..7bd182f --- /dev/null +++ b/src/cpu/caster_.h @@ -0,0 +1,71 @@ +#ifndef INFINI_OPS_COMMON_CPU_CASTER_H_ +#define INFINI_OPS_COMMON_CPU_CASTER_H_ + +#include + +#include "caster.h" + +namespace infini::ops { + +template <> +struct Caster { + template + static Dst Cast(Src&& x) { + static_assert(!std::is_reference_v, + "`Cast` cannot return reference types"); + + using PureDst = std::remove_cv_t>; + using PureSrc = std::remove_cv_t>; + + if constexpr (std::is_same_v) { + return std::forward(x); + } + + constexpr bool src_is_custom = IsBFloat16 || IsFP16; + constexpr bool dst_is_custom = IsBFloat16 || IsFP16; + + if constexpr (!src_is_custom && !dst_is_custom) { + return static_cast(std::forward(x)); + } else { + return FromFloatHelper(ToFloatHelper(std::forward(x))); + } + } + + private: + template + struct HasToFloat : std::false_type {}; + + template + struct HasToFloat().ToFloat())>> + : std::true_type {}; + + template + struct HasFromFloat : std::false_type {}; + + template + struct HasFromFloat< + T, std::void_t()))>> + : std::true_type {}; + + template + static constexpr float ToFloatHelper(T&& x) { + if constexpr (HasToFloat::value) { + return std::forward(x).ToFloat(); + } else { + return static_cast(x); + } + } + + template + static constexpr PureDst FromFloatHelper(float f) { + if constexpr (HasFromFloat::value) { + return PureDst::FromFloat(f); + } else { + return static_cast(f); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/causal_softmax/causal_softmax.h b/src/cpu/causal_softmax/causal_softmax.h index ca207a2..e8cee7e 100644 --- a/src/cpu/causal_softmax/causal_softmax.h +++ b/src/cpu/causal_softmax/causal_softmax.h @@ -4,15 +4,16 @@ #include #include "base/causal_softmax.h" -#include "common/cast.h" #include "common/generic_utils.h" +#include "cpu/caster_.h" #include "data_type.h" #include "tensor.h" namespace infini::ops { template <> -class Operator : public CausalSoftmax { +class Operator : public CausalSoftmax, + Caster { public: Operator(const Tensor input, Tensor out) : CausalSoftmax{input, out} {} diff --git a/src/cpu/gemm/gemm.h b/src/cpu/gemm/gemm.h index 685a94a..c472085 100644 --- a/src/cpu/gemm/gemm.h +++ b/src/cpu/gemm/gemm.h @@ -4,13 +4,14 @@ #include #include "base/gemm.h" -#include "common/cast.h" #include "common/generic_utils.h" +#include "cpu/caster_.h" namespace infini::ops { template <> -class Operator : public Gemm { +class Operator : public Gemm, + Caster { public: Operator(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, std::optional trans_a, diff --git a/src/cpu/rms_norm/rms_norm.h b/src/cpu/rms_norm/rms_norm.h index b3caeb0..752a6a6 100644 --- a/src/cpu/rms_norm/rms_norm.h +++ b/src/cpu/rms_norm/rms_norm.h @@ -4,15 +4,16 @@ #include #include "base/rms_norm.h" -#include "common/cast.h" #include "common/generic_utils.h" +#include "cpu/caster_.h" #include "data_type.h" #include "tensor.h" namespace infini::ops { template <> -class Operator : public RmsNorm { +class Operator : public RmsNorm, + Caster { public: using RmsNorm::RmsNorm; diff --git a/src/cpu/swiglu/swiglu.h b/src/cpu/swiglu/swiglu.h index ac2b3b2..374c8d8 100644 --- a/src/cpu/swiglu/swiglu.h +++ b/src/cpu/swiglu/swiglu.h @@ -4,13 +4,14 @@ #include #include "base/swiglu.h" -#include "common/cast.h" #include "common/generic_utils.h" +#include "cpu/caster_.h" namespace infini::ops { template <> -class Operator : public Swiglu { +class Operator : public Swiglu, + Caster { public: using Swiglu::Swiglu; diff --git a/src/cuda/add/kernel.cuh b/src/cuda/add/kernel.cuh index 2d58809..cfd6496 100644 --- a/src/cuda/add/kernel.cuh +++ b/src/cuda/add/kernel.cuh @@ -1,7 +1,7 @@ #ifndef INFINI_OPS_CUDA_ADD_KERNEL_CUH_ #define INFINI_OPS_CUDA_ADD_KERNEL_CUH_ -#include "common/cuda/kernel_commons.h" +#include "cuda/kernel_commons.h" namespace infini::ops { diff --git a/src/cuda/add/kernel.h b/src/cuda/add/kernel.h index c174afb..e3a5bd0 100644 --- a/src/cuda/add/kernel.h +++ b/src/cuda/add/kernel.h @@ -6,6 +6,7 @@ #include "base/add.h" #include "common/generic_utils.h" #include "cuda/add/kernel.cuh" +#include "cuda/kernel_commons.h" namespace infini::ops { diff --git a/src/cuda/caster_.h b/src/cuda/caster_.h new file mode 100644 index 0000000..45bb602 --- /dev/null +++ b/src/cuda/caster_.h @@ -0,0 +1,111 @@ +#ifndef INFINI_OPS_COMMON_CUDA_CASTER_H_ +#define INFINI_OPS_COMMON_CUDA_CASTER_H_ + +#ifdef WITH_NVIDIA +#include +#elif defined(WITH_ILUVATAR) +#include +#elif defined(WITH_METAX) +#include +#elif defined(WITH_MOORE) +#include +#endif + +#include "caster.h" + +namespace infini::ops { + +template <> +struct Caster { + template + __host__ __device__ static Dst Cast(Src&& x) { + static_assert(!std::is_reference_v, + "`Cast` cannot return reference types"); + + using PureSrc = std::remove_cv_t>; + using PureDst = std::remove_cv_t>; + + if constexpr (std::is_same_v) { + return std::forward(x); + } else { + return HardwareCast(std::forward(x), PriorityHigh{}); + } + } + + private: + template + using PureType = std::remove_cv_t>; + + template + __host__ __device__ static constexpr float ToFloatHelper(T&& x) { + using PureSrc = PureType; + if constexpr (IsBFloat16) { + return __bfloat162float(x); + } else if constexpr (IsFP16) { + return __half2float(x); + } else { + return static_cast(std::forward(x)); + } + } + + template + __host__ __device__ static constexpr Dst FromFloatHelper(float f) { + using PureDst = PureType; + if constexpr (IsBFloat16) { + return __float2bfloat16(f); + } else if constexpr (IsFP16) { + return __float2half(f); + } else { + return static_cast(f); + } + } + + // Priority tags for overload resolution. + struct PriorityLow {}; + + struct PriorityHigh : PriorityLow {}; + + // Fallback: lowest priority. This always matches if nothing else does. + template + __host__ __device__ static constexpr Dst HardwareCast(Src&& x, PriorityLow) { + return FromFloatHelper(ToFloatHelper(std::forward(x))); + } + +// Usage: `DEFINE_DIRECT_CAST(INTRINSIC, CONDITION)`. +#define DEFINE_DIRECT_CAST(INTRINSIC, ...) \ + template \ + __host__ __device__ static auto HardwareCast(Src x, PriorityHigh) \ + -> std::enable_if_t<(__VA_ARGS__), \ + decltype(INTRINSIC(std::declval()))> { \ + return INTRINSIC(x); \ + } + + DEFINE_DIRECT_CAST( + __bfloat162int_rn, + std::is_same_v, int>&& IsBFloat16>) + DEFINE_DIRECT_CAST( + __bfloat162short_rn, + std::is_same_v, short>&& IsBFloat16>) + DEFINE_DIRECT_CAST( + __int2bfloat16_rn, + IsBFloat16>&& std::is_same_v, int>) + DEFINE_DIRECT_CAST(__int2half_rn, + IsFP16>&& std::is_same_v, int>) + DEFINE_DIRECT_CAST( + __double2bfloat16, + IsBFloat16>&& std::is_same_v, double>) + DEFINE_DIRECT_CAST( + __double2half, + IsFP16>&& std::is_same_v, double>) + DEFINE_DIRECT_CAST(__half, IsFP16>&& IsBFloat16>) +#undef DEFINE_DIRECT_CAST +}; + +template +__host__ __device__ __forceinline__ auto Cast(Args&&... args) { + return Caster::template Cast(std::forward(args)...); +} + +} // namespace infini::ops + +#endif diff --git a/src/cuda/causal_softmax/kernel.cuh b/src/cuda/causal_softmax/kernel.cuh index d578998..757c4cd 100644 --- a/src/cuda/causal_softmax/kernel.cuh +++ b/src/cuda/causal_softmax/kernel.cuh @@ -5,8 +5,8 @@ #include #include -#include "common/cuda/cast.h" -#include "common/cuda/kernel_commons.h" +#include "cuda/caster_.h" +#include "cuda/kernel_commons.h" namespace infini::ops { @@ -75,8 +75,7 @@ __global__ void CausalSoftmaxKernel( for (size_t col = threadIdx.x; col < total_seq_len; col += block_size) { if (col < valid_len) { - Compute diff = - Cast(input_row[col]) - Cast(max_val); + Compute diff = Cast(input_row[col]) - Cast(max_val); out_row[col] = ExpAndCast(diff); } else { out_row[col] = Cast(0.0f); diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h index 924be40..a320f63 100644 --- a/src/cuda/causal_softmax/kernel.h +++ b/src/cuda/causal_softmax/kernel.h @@ -4,8 +4,8 @@ #include #include "base/causal_softmax.h" -#include "common/cuda/kernel_commons.h" #include "cuda/causal_softmax/kernel.cuh" +#include "cuda/kernel_commons.h" #include "data_type.h" #include "dispatcher.h" diff --git a/src/common/cuda/kernel_commons.h b/src/cuda/kernel_commons.h similarity index 99% rename from src/common/cuda/kernel_commons.h rename to src/cuda/kernel_commons.h index e2ef107..8ccd7e9 100644 --- a/src/common/cuda/kernel_commons.h +++ b/src/cuda/kernel_commons.h @@ -29,7 +29,7 @@ using cuda_bfloat162 = __mt_bfloat162; #include #include -#include "cast.h" +#include "caster.h" namespace infini::ops { diff --git a/src/cuda/rms_norm/kernel.cuh b/src/cuda/rms_norm/kernel.cuh index 10228a6..1a35d22 100644 --- a/src/cuda/rms_norm/kernel.cuh +++ b/src/cuda/rms_norm/kernel.cuh @@ -5,8 +5,8 @@ #include #include -#include "common/cuda/cast.h" -#include "common/cuda/kernel_commons.h" +#include "cuda/caster_.h" +#include "cuda/kernel_commons.h" namespace infini::ops { diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h index dc28ee5..3f61c50 100644 --- a/src/cuda/rms_norm/kernel.h +++ b/src/cuda/rms_norm/kernel.h @@ -4,7 +4,7 @@ #include #include "base/rms_norm.h" -#include "common/cuda/kernel_commons.h" +#include "cuda/kernel_commons.h" #include "cuda/rms_norm/kernel.cuh" #include "data_type.h" #include "dispatcher.h" @@ -41,12 +41,12 @@ class CudaRmsNorm : public RmsNorm { [&](auto tag) { using T = typename decltype(tag)::type; -#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \ - RmsNormKernel \ - <<>>( \ - reinterpret_cast(out.data()), stride_out_batch, \ - stride_out_nhead, reinterpret_cast(input.data()), \ - stride_input_batch, stride_input_nhead, \ +#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \ + RmsNormKernel \ + <<>>( \ + reinterpret_cast(out.data()), stride_out_batch, \ + stride_out_nhead, reinterpret_cast(input.data()), \ + stride_input_batch, stride_input_nhead, \ reinterpret_cast(weight.data()), nhead_, dim_, eps); if (block_size == CUDA_BLOCK_SIZE_2048) { diff --git a/src/cuda/swiglu/kernel.cuh b/src/cuda/swiglu/kernel.cuh index 8004b76..f3997e6 100644 --- a/src/cuda/swiglu/kernel.cuh +++ b/src/cuda/swiglu/kernel.cuh @@ -3,7 +3,7 @@ #include -#include "common/cuda/kernel_commons.h" +#include "cuda/kernel_commons.h" namespace infini::ops { diff --git a/src/data_type.h b/src/data_type.h index af2aec7..ce2adfe 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -17,6 +17,9 @@ #elif defined(WITH_MOORE) #include #include +#elif defined(WITH_CAMBRICON) +#include "bang_bf16.h" +#include "bang_fp16.h" #endif #include "common/constexpr_map.h" @@ -207,6 +210,9 @@ DEFINE_DATA_TYPE_MAPPING(kBFloat16, __maca_bfloat16) #elif defined(WITH_MOORE) DEFINE_DATA_TYPE_MAPPING(kFloat16, half) DEFINE_DATA_TYPE_MAPPING(kBFloat16, __mt_bfloat16) +#elif defined(WITH_CAMBRICON) +DEFINE_DATA_TYPE_MAPPING(kFloat16, __half) +DEFINE_DATA_TYPE_MAPPING(kBFloat16, __bang_bfloat16) #else DEFINE_DATA_TYPE_MAPPING(kFloat16, Float16) DEFINE_DATA_TYPE_MAPPING(kBFloat16, BFloat16) diff --git a/src/iluvatar/causal_softmax/kernel.h b/src/iluvatar/causal_softmax/kernel.h index d216815..6187110 100644 --- a/src/iluvatar/causal_softmax/kernel.h +++ b/src/iluvatar/causal_softmax/kernel.h @@ -7,6 +7,10 @@ #include // clang-format on +// clang-format off +#include "iluvatar/device_.h" +// clang-format on + #include "cuda/causal_softmax/kernel.h" namespace infini::ops { diff --git a/src/iluvatar/device_.h b/src/iluvatar/device_.h new file mode 100644 index 0000000..9d46e77 --- /dev/null +++ b/src/iluvatar/device_.h @@ -0,0 +1,12 @@ +#ifndef INFINI_OPS_ILUVATAR_DEVICE_H_ +#define INFINI_OPS_ILUVATAR_DEVICE_H_ + +#include "device.h" + +namespace infini::ops { + +inline constexpr auto kDeviceType{Device::Type::kIluvatar}; + +} // namespace infini::ops + +#endif diff --git a/src/iluvatar/rms_norm/kernel.h b/src/iluvatar/rms_norm/kernel.h index 3971c3a..a07bff8 100644 --- a/src/iluvatar/rms_norm/kernel.h +++ b/src/iluvatar/rms_norm/kernel.h @@ -7,6 +7,10 @@ #include // clang-format on +// clang-format off +#include "iluvatar/device_.h" +// clang-format on + #include "cuda/rms_norm/kernel.h" namespace infini::ops { diff --git a/src/metax/causal_softmax/kernel.h b/src/metax/causal_softmax/kernel.h index d44919f..f8648e4 100644 --- a/src/metax/causal_softmax/kernel.h +++ b/src/metax/causal_softmax/kernel.h @@ -7,6 +7,10 @@ #include // clang-format on +// clang-format off +#include "metax/device_.h" +// clang-format on + #include "cuda/causal_softmax/kernel.h" namespace infini::ops { diff --git a/src/metax/device_.h b/src/metax/device_.h new file mode 100644 index 0000000..5e7c93c --- /dev/null +++ b/src/metax/device_.h @@ -0,0 +1,12 @@ +#ifndef INFINI_OPS_METAX_DEVICE_H_ +#define INFINI_OPS_METAX_DEVICE_H_ + +#include "device.h" + +namespace infini::ops { + +inline constexpr auto kDeviceType{Device::Type::kMetax}; + +} // namespace infini::ops + +#endif diff --git a/src/metax/rms_norm/kernel.h b/src/metax/rms_norm/kernel.h index b724552..bdf0bed 100644 --- a/src/metax/rms_norm/kernel.h +++ b/src/metax/rms_norm/kernel.h @@ -7,6 +7,10 @@ #include // clang-format on +// clang-format off +#include "metax/device_.h" +// clang-format on + #include "cuda/rms_norm/kernel.h" namespace infini::ops { diff --git a/src/moore/device_.h b/src/moore/device_.h new file mode 100644 index 0000000..fc9282f --- /dev/null +++ b/src/moore/device_.h @@ -0,0 +1,12 @@ +#ifndef INFINI_OPS_MOORE_DEVICE_H_ +#define INFINI_OPS_MOORE_DEVICE_H_ + +#include "device.h" + +namespace infini::ops { + +inline constexpr auto kDeviceType{Device::Type::kMoore}; + +} // namespace infini::ops + +#endif diff --git a/src/nvidia/causal_softmax/kernel.h b/src/nvidia/causal_softmax/kernel.h index 5be316a..0c13518 100644 --- a/src/nvidia/causal_softmax/kernel.h +++ b/src/nvidia/causal_softmax/kernel.h @@ -7,6 +7,10 @@ #include // clang-format on +// clang-format off +#include "nvidia/device_.h" +// clang-format on + #include "cuda/causal_softmax/kernel.h" namespace infini::ops { diff --git a/src/nvidia/device_.h b/src/nvidia/device_.h new file mode 100644 index 0000000..1d7fe05 --- /dev/null +++ b/src/nvidia/device_.h @@ -0,0 +1,12 @@ +#ifndef INFINI_OPS_NVIDIA_DEVICE_H_ +#define INFINI_OPS_NVIDIA_DEVICE_H_ + +#include "device.h" + +namespace infini::ops { + +inline constexpr auto kDeviceType{Device::Type::kNvidia}; + +} // namespace infini::ops + +#endif diff --git a/src/nvidia/rms_norm/kernel.h b/src/nvidia/rms_norm/kernel.h index 496bddd..bb766d4 100644 --- a/src/nvidia/rms_norm/kernel.h +++ b/src/nvidia/rms_norm/kernel.h @@ -7,6 +7,10 @@ #include // clang-format on +// clang-format off +#include "nvidia/device_.h" +// clang-format on + #include "cuda/rms_norm/kernel.h" namespace infini::ops { diff --git a/src/operator.h b/src/operator.h index be6fb51..e04e9af 100644 --- a/src/operator.h +++ b/src/operator.h @@ -93,7 +93,7 @@ class OperatorBase { std::size_t workspace_size_in_bytes_{0}; }; -template +template class Operator : public OperatorBase { public: template @@ -157,6 +157,9 @@ class Operator : public OperatorBase { auto operator()(Args&&... args) const { return (*static_cast(this))(std::forward(args)...); } + + protected: + static constexpr Device::Type device_type_{device_type}; }; } // namespace infini::ops From f17e37ccc8366383d025809b7d630fa4b5795072 Mon Sep 17 00:00:00 2001 From: Ziminli <70735843+Ziminli@users.noreply.github.com> Date: Thu, 26 Mar 2026 23:32:57 +0800 Subject: [PATCH 92/93] feat: add high-level `DispatchFunc()` interface for multi-type and mixed dispatch (#29) * feat: add a convenient interface for any `int64_t`-convertible types and use `DispatchFunc()` to dispatch `DataType` and block sizes with a single call. - add a convenient interface for any `int64_t`-convertible types, which is mostly used for multi-type dispatch and mixed dispatch - use `DispatchFunc()` to dispatch `DataType` and block sizes with a single function call in various kernels' implementation - remove the `CUDA_BLOCK_SIZE_XXX` macros and simply use numbers instead * style: fix the styling issue by adding a period to the TODO comment * fix: fix rebase error * style: fix the styling issues for comments in `dispatcher.h` and `cuda/causal_softmax/kernel.h` --- src/cuda/add/kernel.h | 34 +++++++++------------------ src/cuda/causal_softmax/kernel.h | 40 +++++++++++--------------------- src/cuda/kernel_commons.h | 26 +++++++++------------ src/cuda/rms_norm/kernel.h | 40 +++++++++++--------------------- src/cuda/swiglu/kernel.h | 34 +++++++++------------------ src/dispatcher.h | 10 ++++++++ src/operator.h | 6 ++--- 7 files changed, 74 insertions(+), 116 deletions(-) diff --git a/src/cuda/add/kernel.h b/src/cuda/add/kernel.h index e3a5bd0..928fa9c 100644 --- a/src/cuda/add/kernel.h +++ b/src/cuda/add/kernel.h @@ -51,10 +51,12 @@ class CudaAdd : public Add { void operator()(const Tensor input, const Tensor other, Tensor out) const override { int block_size = GetOptimalBlockSize(); - DispatchFunc( - out_type_, - [&](auto tag) { - using T = typename decltype(tag)::type; + DispatchFunc( + {static_cast(out_type_), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + auto cuda_stream = static_cast(stream_ ? stream_ : 0); dim3 blockDims( @@ -65,25 +67,11 @@ class CudaAdd : public Add { const T* d_input = reinterpret_cast(input.data()); const T* d_other = reinterpret_cast(other.data()); -#define LAUNCH_ADD_KERNEL(BLOCK_SIZE) \ - AddKernel<<>>( \ - d_out, d_input, d_other, d_out_shape_, d_input_shape_, d_other_shape_, \ - d_out_strides_, d_input_strides_, d_other_strides_, output_size_, ndim_, \ - is_out_contiguous_, is_input_contiguous_, is_other_contiguous_); - - if (block_size == CUDA_BLOCK_SIZE_2048) { - LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_2048) - } else if (block_size == CUDA_BLOCK_SIZE_1024) { - LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_1024) - } else if (block_size == CUDA_BLOCK_SIZE_512) { - LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_512) - } else if (block_size == CUDA_BLOCK_SIZE_256) { - LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_256) - } else { - LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_128) - } - -#undef LAUNCH_ADD_KERNEL + AddKernel<<>>( + d_out, d_input, d_other, d_out_shape_, d_input_shape_, + d_other_shape_, d_out_strides_, d_input_strides_, + d_other_strides_, output_size_, ndim_, is_out_contiguous_, + is_input_contiguous_, is_other_contiguous_); }, "CudaAdd::operator()"); } diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h index a320f63..3dce77d 100644 --- a/src/cuda/causal_softmax/kernel.h +++ b/src/cuda/causal_softmax/kernel.h @@ -34,32 +34,20 @@ class CudaCausalSoftmax : public CausalSoftmax { int block_size = GetOptimalBlockSize(); - DispatchFunc( - out.dtype(), - [&](auto tag) { - using T = typename decltype(tag)::type; - -#define LAUNCH_CAUSAL_SOFTMAX_KERNEL(BLOCK_SIZE) \ - CausalSoftmaxKernel \ - <<>>( \ - reinterpret_cast(out.data()), \ - reinterpret_cast(input.data()), batch_size_, seq_len_, \ - total_seq_len_, stride_out_batch, stride_out_row, \ - stride_input_batch, stride_input_row); - - if (block_size == CUDA_BLOCK_SIZE_2048) { - LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_2048) - } else if (block_size == CUDA_BLOCK_SIZE_1024) { - LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_1024) - } else if (block_size == CUDA_BLOCK_SIZE_512) { - LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_512) - } else if (block_size == CUDA_BLOCK_SIZE_256) { - LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_256) - } else { - LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_128) - } - -#undef LAUNCH_CAUSAL_SOFTMAX_KERNEL + DispatchFunc, ReducedFloatTypes>, + AllCudaBlockSizes>( + // TODO: Output dtype should use the one passed in during construction. + {static_cast(out.dtype()), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + CausalSoftmaxKernel + <<>>( + reinterpret_cast(out.data()), + reinterpret_cast(input.data()), batch_size_, + seq_len_, total_seq_len_, stride_out_batch, stride_out_row, + stride_input_batch, stride_input_row); }, "CudaCausalSoftmax::operator()"); } diff --git a/src/cuda/kernel_commons.h b/src/cuda/kernel_commons.h index 8ccd7e9..6c987c7 100644 --- a/src/cuda/kernel_commons.h +++ b/src/cuda/kernel_commons.h @@ -33,11 +33,7 @@ using cuda_bfloat162 = __mt_bfloat162; namespace infini::ops { -constexpr int CUDA_BLOCK_SIZE_128 = 128; -constexpr int CUDA_BLOCK_SIZE_256 = 256; -constexpr int CUDA_BLOCK_SIZE_512 = 512; -constexpr int CUDA_BLOCK_SIZE_1024 = 1024; -constexpr int CUDA_BLOCK_SIZE_2048 = 2048; +using AllCudaBlockSizes = List<128, 256, 512, 1024, 2048>; #if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) // Cache `cudaDeviceProp` per device, initialized once at first access. @@ -76,7 +72,7 @@ inline int QueryMaxThreadsPerBlock() { #elif defined(WITH_METAX) inline int QueryMaxThreadsPerBlock() { // TODO: Add MCR device properties query for Metax. - return CUDA_BLOCK_SIZE_256; + return 256; } #elif defined(WITH_MOORE) inline int QueryMaxThreadsPerBlock() { @@ -91,16 +87,16 @@ inline int QueryMaxThreadsPerBlock() { // Get optimal block size based on GPU hardware architecture. inline int GetOptimalBlockSize() { int max_threads = QueryMaxThreadsPerBlock(); - if (max_threads >= CUDA_BLOCK_SIZE_2048) { - return CUDA_BLOCK_SIZE_2048; - } else if (max_threads >= CUDA_BLOCK_SIZE_1024) { - return CUDA_BLOCK_SIZE_1024; - } else if (max_threads >= CUDA_BLOCK_SIZE_512) { - return CUDA_BLOCK_SIZE_512; - } else if (max_threads >= CUDA_BLOCK_SIZE_256) { - return CUDA_BLOCK_SIZE_256; + if (max_threads >= 2048) { + return 2048; + } else if (max_threads >= 1024) { + return 1024; + } else if (max_threads >= 512) { + return 512; + } else if (max_threads >= 256) { + return 256; } else { - return CUDA_BLOCK_SIZE_128; + return 128; } } diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h index 3f61c50..29ee51e 100644 --- a/src/cuda/rms_norm/kernel.h +++ b/src/cuda/rms_norm/kernel.h @@ -36,32 +36,20 @@ class CudaRmsNorm : public RmsNorm { int block_size = GetOptimalBlockSize(); - DispatchFunc( - out.dtype(), - [&](auto tag) { - using T = typename decltype(tag)::type; - -#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \ - RmsNormKernel \ - <<>>( \ - reinterpret_cast(out.data()), stride_out_batch, \ - stride_out_nhead, reinterpret_cast(input.data()), \ - stride_input_batch, stride_input_nhead, \ - reinterpret_cast(weight.data()), nhead_, dim_, eps); - - if (block_size == CUDA_BLOCK_SIZE_2048) { - LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_2048) - } else if (block_size == CUDA_BLOCK_SIZE_1024) { - LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_1024) - } else if (block_size == CUDA_BLOCK_SIZE_512) { - LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_512) - } else if (block_size == CUDA_BLOCK_SIZE_256) { - LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_256) - } else { - LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_128) - } - -#undef LAUNCH_RMS_NORM_KERNEL + DispatchFunc, ReducedFloatTypes>, + AllCudaBlockSizes>( + {static_cast(out.dtype()), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + RmsNormKernel + <<>>( + reinterpret_cast(out.data()), stride_out_batch, + stride_out_nhead, reinterpret_cast(input.data()), + stride_input_batch, stride_input_nhead, + reinterpret_cast(weight.data()), nhead_, dim_, + eps_); }, "CudaRmsNorm::operator()"); } diff --git a/src/cuda/swiglu/kernel.h b/src/cuda/swiglu/kernel.h index 47849fe..964e8b7 100644 --- a/src/cuda/swiglu/kernel.h +++ b/src/cuda/swiglu/kernel.h @@ -50,10 +50,12 @@ class CudaSwiglu : public Swiglu { void operator()(const Tensor input, const Tensor gate, Tensor out) const override { int block_size = GetOptimalBlockSize(); - DispatchFunc( - out_type_, - [&](auto tag) { - using T = typename decltype(tag)::type; + DispatchFunc( + {static_cast(out_type_), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + auto cuda_stream = static_cast(stream_ ? stream_ : 0); dim3 blockDims( @@ -64,25 +66,11 @@ class CudaSwiglu : public Swiglu { const T* d_input = reinterpret_cast(input.data()); const T* d_gate = reinterpret_cast(gate.data()); -// Launch kernel with appropriate block size based on GPU architecture. -#define LAUNCH_SWIGLU_KERNEL(BLOCK_SIZE) \ - SwigluKernel<<>>( \ - d_out, d_input, d_gate, d_out_shape_, d_input_shape_, d_gate_shape_, \ - d_out_strides_, d_input_strides_, d_gate_strides_, output_size_, ndim_, \ - is_out_contiguous_, is_input_contiguous_, is_gate_contiguous_); - if (block_size == CUDA_BLOCK_SIZE_2048) { - LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_2048) - } else if (block_size == CUDA_BLOCK_SIZE_1024) { - LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_1024) - } else if (block_size == CUDA_BLOCK_SIZE_512) { - LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_512) - } else if (block_size == CUDA_BLOCK_SIZE_256) { - LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_256) - } else { - LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_128) - } - -#undef LAUNCH_SWIGLU_KERNEL + SwigluKernel<<>>( + d_out, d_input, d_gate, d_out_shape_, d_input_shape_, + d_gate_shape_, d_out_strides_, d_input_strides_, d_gate_strides_, + output_size_, ndim_, is_out_contiguous_, is_input_contiguous_, + is_gate_contiguous_); }, "CudaSwiglu::operator()"); } diff --git a/src/dispatcher.h b/src/dispatcher.h index 83b282c..22aeb82 100644 --- a/src/dispatcher.h +++ b/src/dispatcher.h @@ -302,6 +302,16 @@ auto DispatchFunc(ValueType value, Functor &&func, std::forward(args)...); } +// Interface for Any `int64_t`-Convertible Types +template +auto DispatchFunc(std::initializer_list keys, Functor &&func, + std::string_view context_str = "", Args &&...args) { + std::vector v_keys(keys); + return DispatchFunc(v_keys, 0, std::forward(func), + context_str, List<>{}, + std::forward(args)...); +} + } // namespace infini::ops #endif diff --git a/src/operator.h b/src/operator.h index e04e9af..f482db2 100644 --- a/src/operator.h +++ b/src/operator.h @@ -103,10 +103,10 @@ class Operator : public OperatorBase { DispatchFunc( tensor.device().type(), [&](auto tag) { - constexpr Device::Type dev = decltype(tag)::value; - if constexpr (std::is_constructible_v, + constexpr Device::Type kDev = decltype(tag)::value; + if constexpr (std::is_constructible_v, const Tensor&, Args...>) { - op_ptr = std::make_unique>( + op_ptr = std::make_unique>( tensor, std::forward(args)...); } else { assert(false && "operator is not implemented for this device"); From 2816b5852cebdde96af9512518e65c3ae97049ee Mon Sep 17 00:00:00 2001 From: Jiacheng Huang <45955067+voltjia@users.noreply.github.com> Date: Wed, 1 Apr 2026 16:32:50 +0800 Subject: [PATCH 93/93] refactor: make data type mappings and shared CUDA headers device-aware (#38) * refactor: make `TypeMap`, `IsFP16`, `IsBFloat16`, and `DispatchFunc` device-aware * refactor: make `cuda/` shared headers self-contained and include-order-independent * fix: update call sites to device-aware `TypeMap`, `IsFP16`/`IsBFloat16`, and `DispatchFunc` * chore: format files with `clang-format` * fix: update `cuda/swiglu` kernels to use device-aware type predicates * fix: replace per-instance `blasHandle_t` with a static singleton in `Blas` * fix: restore kernel headers for `moore/add` and `moore/swiglu` to use `clang-format off` and `clang-format on` * fix: use absolute includes, consistent include guards, and formatted comments * refactor: extract `GetOptimalBlockSize` logic into shared `ComputeOptimalBlockSize` * fix: include `` in `polyfills.cuh` before `hrcp` macro to prevent collision * chore: add blank lines between `using` type alias declarations in `device_.h` * chore: add TODO comments for potential performance and concurrency issues * fix: move `clang-format` guards to wrap only CUDA headers in `iluvatar/device_.h` --- src/cambricon/device_.h | 23 +++++ src/cambricon/rms_norm/kernel.mlu | 114 ++++++++++---------- src/cambricon/rms_norm/rms_norm.h | 18 ++-- src/common/constexpr_map.h | 2 +- src/cpu/add/add.h | 7 +- src/cpu/caster_.h | 7 +- src/cpu/causal_softmax/causal_softmax.h | 2 +- src/cpu/device_.h | 21 ++++ src/cpu/gemm/gemm.h | 2 +- src/cpu/rms_norm/rms_norm.h | 2 +- src/cpu/swiglu/swiglu.h | 7 +- src/cuda/add/kernel.cuh | 10 +- src/cuda/add/kernel.h | 15 +-- src/cuda/caster_.h | 45 +++----- src/cuda/causal_softmax/kernel.cuh | 26 +++-- src/cuda/causal_softmax/kernel.h | 11 +- src/cuda/gemm/blas.h | 23 +++-- src/cuda/kernel_commons.h | 103 ++---------------- src/cuda/rms_norm/kernel.cuh | 19 ++-- src/cuda/rms_norm/kernel.h | 11 +- src/cuda/swiglu/kernel.cuh | 57 +++------- src/cuda/swiglu/kernel.h | 15 +-- src/data_type.h | 74 +++---------- src/dispatcher.h | 132 ++++++++++++++---------- src/iluvatar/add/kernel.h | 11 +- src/iluvatar/causal_softmax/kernel.h | 15 ++- src/iluvatar/device_.h | 63 ++++++++++- src/iluvatar/rms_norm/kernel.h | 15 ++- src/iluvatar/swiglu/kernel.h | 11 +- src/metax/add/kernel.h | 11 +- src/metax/causal_softmax/kernel.h | 15 ++- src/metax/device_.h | 30 +++++- src/metax/rms_norm/kernel.h | 15 ++- src/metax/swiglu/kernel.h | 11 +- src/moore/add/kernel.h | 11 +- src/moore/device_.h | 35 ++++++- src/moore/polyfills.cuh | 1 + src/moore/swiglu/kernel.h | 11 +- src/nvidia/add/kernel.h | 11 +- src/nvidia/causal_softmax/kernel.h | 15 ++- src/nvidia/device_.h | 62 ++++++++++- src/nvidia/rms_norm/kernel.h | 15 ++- src/nvidia/swiglu/kernel.h | 11 +- src/pybind11_utils.h | 89 ++++++++++++---- src/tensor.cc | 3 +- 45 files changed, 683 insertions(+), 524 deletions(-) create mode 100644 src/cambricon/device_.h create mode 100644 src/cpu/device_.h diff --git a/src/cambricon/device_.h b/src/cambricon/device_.h new file mode 100644 index 0000000..224a8d8 --- /dev/null +++ b/src/cambricon/device_.h @@ -0,0 +1,23 @@ +#ifndef INFINI_OPS_CAMBRICON_DEVICE__H_ +#define INFINI_OPS_CAMBRICON_DEVICE__H_ + +#include "bang_bf16.h" +#include "bang_fp16.h" +#include "data_type.h" +#include "device.h" + +namespace infini::ops { + +template <> +struct TypeMap { + using type = __half; +}; + +template <> +struct TypeMap { + using type = __bang_bfloat16; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cambricon/rms_norm/kernel.mlu b/src/cambricon/rms_norm/kernel.mlu index 6648d2d..b4d7e8d 100644 --- a/src/cambricon/rms_norm/kernel.mlu +++ b/src/cambricon/rms_norm/kernel.mlu @@ -5,9 +5,9 @@ __nram__ char nram_buffer[NRAM_MAX_SIZE]; namespace infini::ops { template -__mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output, - size_t *shape, ptrdiff_t *output_strides, - ptrdiff_t *input_strides, float epsilon, +__mlu_global__ void RmsNorm(const T* input, const TW* weight, T* output, + size_t* shape, ptrdiff_t* output_strides, + ptrdiff_t* input_strides, float epsilon, int num_dims, int norm_dim_size) { // Calculate problem dimensions. int batch_volume = 1; @@ -40,11 +40,11 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output, constexpr int reduce_buffer_size = 128 / sizeof(float); // NRAM buffer allocation with dynamic sizing. - float *reduction_buffer = (float *)nram_buffer; - T *input_cache = (T *)(reduction_buffer + reduce_buffer_size); - TW *weight_cache = (TW *)(input_cache + max_batch_size); - float *float_buffer = (float *)(weight_cache + max_batch_size); - float *weight_float_buffer = (float *)(float_buffer + max_batch_size); + float* reduction_buffer = (float*)nram_buffer; + T* input_cache = (T*)(reduction_buffer + reduce_buffer_size); + TW* weight_cache = (TW*)(input_cache + max_batch_size); + float* float_buffer = (float*)(weight_cache + max_batch_size); + float* weight_float_buffer = (float*)(float_buffer + max_batch_size); // Process vectors assigned to current core. for (int task_idx = 0; task_idx < actual_tasks; ++task_idx) { @@ -69,7 +69,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output, __memcpy(input_cache, input + input_offset, vector_size * sizeof(T), GDRAM2NRAM); if constexpr (std::is_same::value) { - __bang_half2float(float_buffer, reinterpret_cast(input_cache), + __bang_half2float(float_buffer, reinterpret_cast(input_cache), vector_size); } else if constexpr (std::is_same::value) { __bang_bfloat162float(float_buffer, input_cache, vector_size); @@ -99,7 +99,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output, current_batch * sizeof(T), GDRAM2NRAM); if constexpr (std::is_same::value) { - __bang_half2float(float_buffer, reinterpret_cast(input_cache), + __bang_half2float(float_buffer, reinterpret_cast(input_cache), current_batch); } else if constexpr (std::is_same::value) { __bang_bfloat162float(float_buffer, input_cache, current_batch); @@ -137,7 +137,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output, __memcpy(weight_cache, weight, vector_size * sizeof(TW), GDRAM2NRAM); if constexpr (std::is_same::value) { - __bang_half2float(float_buffer, reinterpret_cast(input_cache), + __bang_half2float(float_buffer, reinterpret_cast(input_cache), vector_size); } else if constexpr (std::is_same::value) { __bang_bfloat162float(float_buffer, input_cache, vector_size); @@ -148,7 +148,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output, if constexpr (std::is_same::value) { __bang_half2float(weight_float_buffer, - reinterpret_cast(weight_cache), vector_size); + reinterpret_cast(weight_cache), vector_size); } else if constexpr (std::is_same::value) { __bang_bfloat162float(weight_float_buffer, weight_cache, vector_size); } else { @@ -161,7 +161,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output, __bang_mul_scalar(float_buffer, float_buffer, inv_rms, vector_size); if constexpr (std::is_same::value) { - __bang_float2half(reinterpret_cast(input_cache), float_buffer, + __bang_float2half(reinterpret_cast(input_cache), float_buffer, vector_size); } else if constexpr (std::is_same::value) { __bang_float2bfloat16(input_cache, float_buffer, vector_size); @@ -188,7 +188,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output, current_batch * sizeof(TW), GDRAM2NRAM); if constexpr (std::is_same::value) { - __bang_half2float(float_buffer, reinterpret_cast(input_cache), + __bang_half2float(float_buffer, reinterpret_cast(input_cache), current_batch); } else if constexpr (std::is_same::value) { __bang_bfloat162float(float_buffer, input_cache, current_batch); @@ -199,7 +199,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output, if constexpr (std::is_same::value) { __bang_half2float(weight_float_buffer, - reinterpret_cast(weight_cache), + reinterpret_cast(weight_cache), current_batch); } else if constexpr (std::is_same::value) { __bang_bfloat162float(weight_float_buffer, weight_cache, @@ -214,7 +214,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output, __bang_mul_scalar(float_buffer, float_buffer, inv_rms, current_batch); if constexpr (std::is_same::value) { - __bang_float2half(reinterpret_cast(input_cache), float_buffer, + __bang_float2half(reinterpret_cast(input_cache), float_buffer, current_batch); } else if constexpr (std::is_same::value) { __bang_float2bfloat16(input_cache, float_buffer, current_batch); @@ -234,10 +234,10 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output, } template -void RmsNormUnion(void *workspace, int core_per_cluster, int cluster_count, - cnrtQueue_t queue, void *y, const void *x, const void *w, - const size_t *shape, const ptrdiff_t *y_strides, - const ptrdiff_t *x_strides, float eps, int ndim) { +void RmsNormUnion(void* workspace, int core_per_cluster, int cluster_count, + cnrtQueue_t queue, void* y, const void* x, const void* w, + const size_t* shape, const ptrdiff_t* y_strides, + const ptrdiff_t* x_strides, float eps, int ndim) { cnrtDim3_t kernel_dim; cnrtFunctionType_t kernel_type; @@ -263,23 +263,23 @@ void RmsNormUnion(void *workspace, int core_per_cluster, int cluster_count, } // Prepare device pointers. - auto y_ = reinterpret_cast(y); - auto x_ = reinterpret_cast(x); - auto w_ = reinterpret_cast(w); - char *tmp_device = reinterpret_cast(workspace); - char *tmp_stride = tmp_device + ndim * sizeof(size_t); - size_t *mlu_shape = (size_t *)tmp_device; - ptrdiff_t *mlu_x_strides = (ptrdiff_t *)tmp_stride; - ptrdiff_t *mlu_y_strides = mlu_x_strides + ndim; + auto y_ = reinterpret_cast(y); + auto x_ = reinterpret_cast(x); + auto w_ = reinterpret_cast(w); + char* tmp_device = reinterpret_cast(workspace); + char* tmp_stride = tmp_device + ndim * sizeof(size_t); + size_t* mlu_shape = (size_t*)tmp_device; + ptrdiff_t* mlu_x_strides = (ptrdiff_t*)tmp_stride; + ptrdiff_t* mlu_y_strides = mlu_x_strides + ndim; // Copy shape and stride information to device. - CNRT_CHECK(cnrtMemcpyAsync(mlu_shape, const_cast(shape), + CNRT_CHECK(cnrtMemcpyAsync(mlu_shape, const_cast(shape), ndim * sizeof(size_t), queue, cnrtMemcpyHostToDev)); // const not supported - CNRT_CHECK(cnrtMemcpyAsync(mlu_x_strides, const_cast(x_strides), + CNRT_CHECK(cnrtMemcpyAsync(mlu_x_strides, const_cast(x_strides), ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev)); - CNRT_CHECK(cnrtMemcpyAsync(mlu_y_strides, const_cast(y_strides), + CNRT_CHECK(cnrtMemcpyAsync(mlu_y_strides, const_cast(y_strides), ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev)); @@ -289,44 +289,44 @@ void RmsNormUnion(void *workspace, int core_per_cluster, int cluster_count, cnrtQueueSync(queue); } -template void RmsNormUnion<__half, __half>(void *, int, int, cnrtQueue_t, - void *, const void *, const void *, - const size_t *, const ptrdiff_t *, - const ptrdiff_t *, float, int); +template void RmsNormUnion<__half, __half>(void*, int, int, cnrtQueue_t, void*, + const void*, const void*, + const size_t*, const ptrdiff_t*, + const ptrdiff_t*, float, int); template void RmsNormUnion<__half, __bang_bfloat16>( - void *, int, int, cnrtQueue_t, void *, const void *, const void *, - const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + void*, int, int, cnrtQueue_t, void*, const void*, const void*, + const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int); -template void RmsNormUnion<__half, float>(void *, int, int, cnrtQueue_t, void *, - const void *, const void *, - const size_t *, const ptrdiff_t *, - const ptrdiff_t *, float, int); +template void RmsNormUnion<__half, float>(void*, int, int, cnrtQueue_t, void*, + const void*, const void*, + const size_t*, const ptrdiff_t*, + const ptrdiff_t*, float, int); template void RmsNormUnion<__bang_bfloat16, __half>( - void *, int, int, cnrtQueue_t, void *, const void *, const void *, - const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + void*, int, int, cnrtQueue_t, void*, const void*, const void*, + const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int); template void RmsNormUnion<__bang_bfloat16, __bang_bfloat16>( - void *, int, int, cnrtQueue_t, void *, const void *, const void *, - const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + void*, int, int, cnrtQueue_t, void*, const void*, const void*, + const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int); template void RmsNormUnion<__bang_bfloat16, float>( - void *, int, int, cnrtQueue_t, void *, const void *, const void *, - const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + void*, int, int, cnrtQueue_t, void*, const void*, const void*, + const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int); -template void RmsNormUnion(void *, int, int, cnrtQueue_t, void *, - const void *, const void *, - const size_t *, const ptrdiff_t *, - const ptrdiff_t *, float, int); +template void RmsNormUnion(void*, int, int, cnrtQueue_t, void*, + const void*, const void*, + const size_t*, const ptrdiff_t*, + const ptrdiff_t*, float, int); template void RmsNormUnion( - void *, int, int, cnrtQueue_t, void *, const void *, const void *, - const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + void*, int, int, cnrtQueue_t, void*, const void*, const void*, + const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int); -template void RmsNormUnion(void *, int, int, cnrtQueue_t, void *, - const void *, const void *, - const size_t *, const ptrdiff_t *, - const ptrdiff_t *, float, int); +template void RmsNormUnion(void*, int, int, cnrtQueue_t, void*, + const void*, const void*, + const size_t*, const ptrdiff_t*, + const ptrdiff_t*, float, int); } // namespace infini::ops diff --git a/src/cambricon/rms_norm/rms_norm.h b/src/cambricon/rms_norm/rms_norm.h index 0e331dd..852fe66 100644 --- a/src/cambricon/rms_norm/rms_norm.h +++ b/src/cambricon/rms_norm/rms_norm.h @@ -5,17 +5,18 @@ #include #include -#include "../common.h" +#include "cambricon/common.h" +#include "cambricon/device_.h" #include "base/rms_norm.h" namespace infini::ops { // TODO: Remove forward declaration. template -void RmsNormUnion(void *workspace, int core_per_cluster, int cluster_count, - cnrtQueue_t queue, void *y, const void *x, const void *w, - const size_t *shape, const ptrdiff_t *y_strides, - const ptrdiff_t *x_strides, float eps, int ndim); +void RmsNormUnion(void* workspace, int core_per_cluster, int cluster_count, + cnrtQueue_t queue, void* y, const void* x, const void* w, + const size_t* shape, const ptrdiff_t* y_strides, + const ptrdiff_t* x_strides, float eps, int ndim); template <> class Operator : public RmsNorm { @@ -33,6 +34,7 @@ class Operator : public RmsNorm { auto workspace{workspace_ ? workspace_ : default_workspace_}; DispatchFunc< + Device::Type::kCambricon, List, List>( {input.dtype(), weight.dtype()}, @@ -41,8 +43,8 @@ class Operator : public RmsNorm { using WeightT = typename decltype(weight_tag)::type; RmsNormUnion( - workspace, core_per_cluster, cluster_count, queue, - out.data(), input.data(), weight.data(), out_shape_.data(), + workspace, core_per_cluster, cluster_count, queue, out.data(), + input.data(), weight.data(), out_shape_.data(), out_strides_.data(), input_strides_.data(), eps, ndim_); }, "CambriconRmsNorm::operator() - output dispatch"); @@ -54,7 +56,7 @@ class Operator : public RmsNorm { return ndim_ * (sizeof(size_t) + 2 * sizeof(ptrdiff_t)); } - void *default_workspace_{nullptr}; + void* default_workspace_{nullptr}; int core_per_cluster = 0; int cluster_count = 0; }; diff --git a/src/common/constexpr_map.h b/src/common/constexpr_map.h index 27fdc67..7454f54 100644 --- a/src/common/constexpr_map.h +++ b/src/common/constexpr_map.h @@ -14,7 +14,7 @@ struct ConstexprMap { : data_(data) {} constexpr Value at(Key key) const { - for (const auto &pr : data_) { + for (const auto& pr : data_) { if (pr.first == key) return pr.second; } // TODO(lzm): change to logging. diff --git a/src/cpu/add/add.h b/src/cpu/add/add.h index 48d2469..c56d31f 100644 --- a/src/cpu/add/add.h +++ b/src/cpu/add/add.h @@ -20,7 +20,7 @@ class Operator : public Add, void operator()(const Tensor input, const Tensor other, Tensor out) const override { - DispatchFunc( + DispatchFunc( out_type_, [&](auto tag) { using T = typename decltype(tag)::type; @@ -32,8 +32,9 @@ class Operator : public Add, private: template void Compute(const Tensor input, const Tensor other, Tensor out) const { - using ComputeType = - std::conditional_t || IsFP16, float, T>; + using ComputeType = std::conditional_t || + IsFP16, + float, T>; const auto* input_ptr = static_cast(input.data()); const auto* other_ptr = static_cast(other.data()); diff --git a/src/cpu/caster_.h b/src/cpu/caster_.h index 7bd182f..4d2cca6 100644 --- a/src/cpu/caster_.h +++ b/src/cpu/caster_.h @@ -4,6 +4,7 @@ #include #include "caster.h" +#include "cpu/device_.h" namespace infini::ops { @@ -21,8 +22,10 @@ struct Caster { return std::forward(x); } - constexpr bool src_is_custom = IsBFloat16 || IsFP16; - constexpr bool dst_is_custom = IsBFloat16 || IsFP16; + constexpr bool src_is_custom = IsBFloat16 || + IsFP16; + constexpr bool dst_is_custom = IsBFloat16 || + IsFP16; if constexpr (!src_is_custom && !dst_is_custom) { return static_cast(std::forward(x)); diff --git a/src/cpu/causal_softmax/causal_softmax.h b/src/cpu/causal_softmax/causal_softmax.h index e8cee7e..14848ee 100644 --- a/src/cpu/causal_softmax/causal_softmax.h +++ b/src/cpu/causal_softmax/causal_softmax.h @@ -18,7 +18,7 @@ class Operator : public CausalSoftmax, Operator(const Tensor input, Tensor out) : CausalSoftmax{input, out} {} void operator()(const Tensor input, Tensor out) const override { - DispatchFunc( + DispatchFunc( out.dtype(), [&](auto tag) { using T = typename decltype(tag)::type; diff --git a/src/cpu/device_.h b/src/cpu/device_.h new file mode 100644 index 0000000..0d74232 --- /dev/null +++ b/src/cpu/device_.h @@ -0,0 +1,21 @@ +#ifndef INFINI_OPS_CPU_DEVICE__H_ +#define INFINI_OPS_CPU_DEVICE__H_ + +#include "data_type.h" +#include "device.h" + +namespace infini::ops { + +template <> +struct TypeMap { + using type = Float16; +}; + +template <> +struct TypeMap { + using type = BFloat16; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/gemm/gemm.h b/src/cpu/gemm/gemm.h index c472085..a4dfb98 100644 --- a/src/cpu/gemm/gemm.h +++ b/src/cpu/gemm/gemm.h @@ -31,7 +31,7 @@ class Operator : public Gemm, void operator()(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) const override { - DispatchFunc( + DispatchFunc( c.dtype(), [&](auto tag) { using T = typename decltype(tag)::type; diff --git a/src/cpu/rms_norm/rms_norm.h b/src/cpu/rms_norm/rms_norm.h index 752a6a6..9cae419 100644 --- a/src/cpu/rms_norm/rms_norm.h +++ b/src/cpu/rms_norm/rms_norm.h @@ -19,7 +19,7 @@ class Operator : public RmsNorm, void operator()(const Tensor input, const Tensor weight, float eps, Tensor out) const override { - DispatchFunc( + DispatchFunc( out.dtype(), [&](auto tag) { using T = typename decltype(tag)::type; diff --git a/src/cpu/swiglu/swiglu.h b/src/cpu/swiglu/swiglu.h index 374c8d8..57dccf1 100644 --- a/src/cpu/swiglu/swiglu.h +++ b/src/cpu/swiglu/swiglu.h @@ -17,7 +17,7 @@ class Operator : public Swiglu, void operator()(const Tensor input, const Tensor gate, Tensor out) const override { - DispatchFunc( + DispatchFunc( out_type_, [&](auto tag) { using T = typename decltype(tag)::type; @@ -29,8 +29,9 @@ class Operator : public Swiglu, private: template void Compute(const Tensor input, const Tensor gate, Tensor out) const { - using ComputeType = - std::conditional_t || IsFP16, float, T>; + using ComputeType = std::conditional_t || + IsFP16, + float, T>; const auto* input_ptr = static_cast(input.data()); const auto* gate_ptr = static_cast(gate.data()); diff --git a/src/cuda/add/kernel.cuh b/src/cuda/add/kernel.cuh index cfd6496..4928d6b 100644 --- a/src/cuda/add/kernel.cuh +++ b/src/cuda/add/kernel.cuh @@ -5,16 +5,14 @@ namespace infini::ops { +template struct AddOp { static constexpr std::size_t num_inputs = 2; template __device__ __forceinline__ T operator()(const T& input, const T& other) const { - if constexpr (std::is_same_v) { - return __hadd2(input, other); - } else if constexpr (std::is_same_v || - std::is_same_v>) { + if constexpr (IsFP16 || IsBFloat16) { return __hadd(input, other); } else if constexpr (std::is_same_v) { return __fadd_rn(input, other); @@ -24,7 +22,7 @@ struct AddOp { } }; -template +template __global__ void AddKernel(T* __restrict__ out, const T* __restrict__ input, const T* __restrict__ other, const size_t* __restrict__ out_shape, @@ -47,7 +45,7 @@ __global__ void AddKernel(T* __restrict__ out, const T* __restrict__ input, other_contiguous ? idx : IndexToOffset(idx, ndim, other_shape, other_strides); - out[out_idx] = AddOp{}(input[input_idx], other[other_idx]); + out[out_idx] = AddOp{}(input[input_idx], other[other_idx]); } } diff --git a/src/cuda/add/kernel.h b/src/cuda/add/kernel.h index 928fa9c..2e0ddb9 100644 --- a/src/cuda/add/kernel.h +++ b/src/cuda/add/kernel.h @@ -50,11 +50,11 @@ class CudaAdd : public Add { void operator()(const Tensor input, const Tensor other, Tensor out) const override { - int block_size = GetOptimalBlockSize(); + int block_size = Backend::GetOptimalBlockSize(); DispatchFunc( {static_cast(out_type_), block_size}, [&](auto list_tag) { - using T = TypeMapType(list_tag)>; + using T = TypeMapType(list_tag)>; constexpr int kBlockSize = ListGet<1>(list_tag); auto cuda_stream = @@ -67,11 +67,12 @@ class CudaAdd : public Add { const T* d_input = reinterpret_cast(input.data()); const T* d_other = reinterpret_cast(other.data()); - AddKernel<<>>( - d_out, d_input, d_other, d_out_shape_, d_input_shape_, - d_other_shape_, d_out_strides_, d_input_strides_, - d_other_strides_, output_size_, ndim_, is_out_contiguous_, - is_input_contiguous_, is_other_contiguous_); + AddKernel + <<>>( + d_out, d_input, d_other, d_out_shape_, d_input_shape_, + d_other_shape_, d_out_strides_, d_input_strides_, + d_other_strides_, output_size_, ndim_, is_out_contiguous_, + is_input_contiguous_, is_other_contiguous_); }, "CudaAdd::operator()"); } diff --git a/src/cuda/caster_.h b/src/cuda/caster_.h index 45bb602..4083f28 100644 --- a/src/cuda/caster_.h +++ b/src/cuda/caster_.h @@ -1,22 +1,12 @@ #ifndef INFINI_OPS_COMMON_CUDA_CASTER_H_ #define INFINI_OPS_COMMON_CUDA_CASTER_H_ -#ifdef WITH_NVIDIA -#include -#elif defined(WITH_ILUVATAR) -#include -#elif defined(WITH_METAX) -#include -#elif defined(WITH_MOORE) -#include -#endif - #include "caster.h" namespace infini::ops { -template <> -struct Caster { +template +struct CudaCasterImpl { template __host__ __device__ static Dst Cast(Src&& x) { static_assert(!std::is_reference_v, @@ -39,9 +29,9 @@ struct Caster { template __host__ __device__ static constexpr float ToFloatHelper(T&& x) { using PureSrc = PureType; - if constexpr (IsBFloat16) { + if constexpr (IsBFloat16) { return __bfloat162float(x); - } else if constexpr (IsFP16) { + } else if constexpr (IsFP16) { return __half2float(x); } else { return static_cast(std::forward(x)); @@ -51,9 +41,9 @@ struct Caster { template __host__ __device__ static constexpr Dst FromFloatHelper(float f) { using PureDst = PureType; - if constexpr (IsBFloat16) { + if constexpr (IsBFloat16) { return __float2bfloat16(f); - } else if constexpr (IsFP16) { + } else if constexpr (IsFP16) { return __float2half(f); } else { return static_cast(f); @@ -82,30 +72,27 @@ struct Caster { DEFINE_DIRECT_CAST( __bfloat162int_rn, - std::is_same_v, int>&& IsBFloat16>) + std::is_same_v, int>&& IsBFloat16>) DEFINE_DIRECT_CAST( __bfloat162short_rn, - std::is_same_v, short>&& IsBFloat16>) + std::is_same_v, short>&& IsBFloat16>) DEFINE_DIRECT_CAST( __int2bfloat16_rn, - IsBFloat16>&& std::is_same_v, int>) - DEFINE_DIRECT_CAST(__int2half_rn, - IsFP16>&& std::is_same_v, int>) + IsBFloat16>&& std::is_same_v, int>) + DEFINE_DIRECT_CAST( + __int2half_rn, + IsFP16>&& std::is_same_v, int>) DEFINE_DIRECT_CAST( __double2bfloat16, - IsBFloat16>&& std::is_same_v, double>) + IsBFloat16>&& std::is_same_v, double>) DEFINE_DIRECT_CAST( __double2half, - IsFP16>&& std::is_same_v, double>) - DEFINE_DIRECT_CAST(__half, IsFP16>&& IsBFloat16>) + IsFP16>&& std::is_same_v, double>) + DEFINE_DIRECT_CAST( + __half, IsFP16>&& IsBFloat16>) #undef DEFINE_DIRECT_CAST }; -template -__host__ __device__ __forceinline__ auto Cast(Args&&... args) { - return Caster::template Cast(std::forward(args)...); -} - } // namespace infini::ops #endif diff --git a/src/cuda/causal_softmax/kernel.cuh b/src/cuda/causal_softmax/kernel.cuh index 757c4cd..83acbc6 100644 --- a/src/cuda/causal_softmax/kernel.cuh +++ b/src/cuda/causal_softmax/kernel.cuh @@ -12,9 +12,10 @@ namespace infini::ops { namespace { -template +template __device__ __forceinline__ Data ExpAndCast(Compute x) { - return Cast(expf(Cast(x))); + return Caster::template Cast( + expf(Caster::template Cast(x))); } struct BlockMaxOp { @@ -36,12 +37,13 @@ __device__ __forceinline__ Data BlockMax(const Data* data_ptr, size_t count) { return BlockReduce(temp_storage).Reduce(thread_max, BlockMaxOp()); } -template +template __device__ __forceinline__ Compute BlockSum(const Data* data_ptr, size_t count) { Compute thread_sum = 0; for (size_t i = threadIdx.x; i < count; i += block_size) { - thread_sum += Cast(data_ptr[i]); + thread_sum += Caster::template Cast(data_ptr[i]); } using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -50,7 +52,8 @@ __device__ __forceinline__ Compute BlockSum(const Data* data_ptr, } // namespace -template +template __global__ void CausalSoftmaxKernel( Data* __restrict__ out_ptr, const Data* __restrict__ input_ptr, size_t batch_size, size_t seq_len, size_t total_seq_len, @@ -75,25 +78,26 @@ __global__ void CausalSoftmaxKernel( for (size_t col = threadIdx.x; col < total_seq_len; col += block_size) { if (col < valid_len) { - Compute diff = Cast(input_row[col]) - Cast(max_val); - out_row[col] = ExpAndCast(diff); + Compute diff = Caster::template Cast(input_row[col]) - + Caster::template Cast(max_val); + out_row[col] = ExpAndCast(diff); } else { - out_row[col] = Cast(0.0f); + out_row[col] = Caster::template Cast(0.0f); } } __syncthreads(); __shared__ Compute sum_val; Compute block_sum = - BlockSum(out_row, total_seq_len); + BlockSum(out_row, total_seq_len); if (threadIdx.x == 0) { sum_val = block_sum; } __syncthreads(); for (size_t col = threadIdx.x; col < total_seq_len; col += block_size) { - Compute quot = Cast(out_row[col]) / sum_val; - out_row[col] = Cast(quot); + Compute quot = Caster::template Cast(out_row[col]) / sum_val; + out_row[col] = Caster::template Cast(quot); } } diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h index 3dce77d..7ca0135 100644 --- a/src/cuda/causal_softmax/kernel.h +++ b/src/cuda/causal_softmax/kernel.h @@ -1,6 +1,7 @@ #ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_ #define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_ +#include #include #include "base/causal_softmax.h" @@ -28,21 +29,19 @@ class CudaCausalSoftmax : public CausalSoftmax { dim3 grid(static_cast(seq_len_), static_cast(batch_size_)); - if (out.dtype() != input.dtype()) { - std::abort(); - } + assert(out.dtype() == input.dtype()); - int block_size = GetOptimalBlockSize(); + int block_size = Backend::GetOptimalBlockSize(); DispatchFunc, ReducedFloatTypes>, AllCudaBlockSizes>( // TODO: Output dtype should use the one passed in during construction. {static_cast(out.dtype()), block_size}, [&](auto list_tag) { - using T = TypeMapType(list_tag)>; + using T = TypeMapType(list_tag)>; constexpr int kBlockSize = ListGet<1>(list_tag); - CausalSoftmaxKernel + CausalSoftmaxKernel <<>>( reinterpret_cast(out.data()), reinterpret_cast(input.data()), batch_size_, diff --git a/src/cuda/gemm/blas.h b/src/cuda/gemm/blas.h index 5a7cf2f..fe88a7b 100644 --- a/src/cuda/gemm/blas.h +++ b/src/cuda/gemm/blas.h @@ -17,12 +17,9 @@ class Blas : public Gemm { a_is_col_major_{a.stride(-1) == 1}, b_is_col_major_{b.stride(-1) == 1}, swap_a_and_b_{c.stride(-1) == 1} { - Backend::blasCreate(&handle_); // TODO: Check constraints. } - ~Blas() { Backend::blasDestroy(handle_); } - Blas(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, Tensor c) : Blas{a, b, alpha, beta, std::nullopt, std::nullopt, c} {} @@ -33,7 +30,7 @@ class Blas : public Gemm { void operator()(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) const override { - Backend::blasSetStream(handle_, + Backend::blasSetStream(GetHandle(), static_cast(stream_)); const auto& alpha_value{alpha.value_or(alpha_)}; @@ -47,8 +44,9 @@ class Blas : public Gemm { const void* beta_ptr{GetBetaPtr(beta_value, c.dtype())}; Backend::blasGemmStridedBatchedEx( - handle_, op_a, op_b, swap_a_and_b_ ? n_ : m_, swap_a_and_b_ ? m_ : n_, - k_, alpha_ptr, swap_a_and_b_ ? b.data() : a.data(), + GetHandle(), op_a, op_b, swap_a_and_b_ ? n_ : m_, + swap_a_and_b_ ? m_ : n_, k_, alpha_ptr, + swap_a_and_b_ ? b.data() : a.data(), Backend::GetDataType(swap_a_and_b_ ? b.dtype() : a.dtype()), swap_a_and_b_ ? ldb_ : lda_, swap_a_and_b_ ? batch_stride_b_ : batch_stride_a_, @@ -88,13 +86,22 @@ class Blas : public Gemm { : Backend::BLAS_OP_N; } + // TODO: This static singleton is not thread-safe under concurrent access + // from multiple host threads. Add proper synchronization in the future. + static typename Backend::blasHandle_t& GetHandle() { + static typename Backend::blasHandle_t handle = []() { + typename Backend::blasHandle_t h; + Backend::blasCreate(&h); + return h; + }(); + return handle; + } + bool a_is_col_major_{false}; bool b_is_col_major_{false}; bool swap_a_and_b_{false}; - - typename Backend::blasHandle_t handle_; }; } // namespace infini::ops diff --git a/src/cuda/kernel_commons.h b/src/cuda/kernel_commons.h index 6c987c7..bb25fad 100644 --- a/src/cuda/kernel_commons.h +++ b/src/cuda/kernel_commons.h @@ -1,105 +1,12 @@ #ifndef INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_ #define INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_ -#ifdef WITH_NVIDIA -#include -#include -#include -using cuda_bfloat16 = nv_bfloat16; -using cuda_bfloat162 = nv_bfloat162; -#elif defined(WITH_ILUVATAR) -#include -#include -#include -using cuda_bfloat16 = nv_bfloat16; -using cuda_bfloat162 = nv_bfloat162; -#elif defined(WITH_METAX) -#include -using cuda_bfloat16 = maca_bfloat16; -using cuda_bfloat162 = maca_bfloat162; -#elif defined(WITH_MOORE) -#include -#include -#include -using cuda_bfloat16 = __mt_bfloat16; -using cuda_bfloat162 = __mt_bfloat162; -#endif - -#include -#include -#include - #include "caster.h" namespace infini::ops { using AllCudaBlockSizes = List<128, 256, 512, 1024, 2048>; -#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) -// Cache `cudaDeviceProp` per device, initialized once at first access. -class DevicePropertyCache { - public: - static const cudaDeviceProp& GetCurrentDeviceProps() { - int device_id = 0; - cudaGetDevice(&device_id); - return GetDeviceProps(device_id); - } - - static const cudaDeviceProp& GetDeviceProps(int device_id) { - static std::vector cache = []() { - int count = 0; - cudaGetDeviceCount(&count); - if (count == 0) return std::vector{}; - std::vector props(count); - for (int i = 0; i < count; ++i) { - cudaGetDeviceProperties(&props[i], i); - } - return props; - }(); - - if (device_id < 0 || device_id >= static_cast(cache.size())) { - std::cerr << "error: `device_id` " << device_id << " is out of range [0, " - << cache.size() << ") in `GetDeviceProps`\n"; - std::abort(); - } - return cache[device_id]; - } -}; - -inline int QueryMaxThreadsPerBlock() { - return DevicePropertyCache::GetCurrentDeviceProps().maxThreadsPerBlock; -} -#elif defined(WITH_METAX) -inline int QueryMaxThreadsPerBlock() { - // TODO: Add MCR device properties query for Metax. - return 256; -} -#elif defined(WITH_MOORE) -inline int QueryMaxThreadsPerBlock() { - int device = 0; - musaGetDevice(&device); - musaDeviceProp prop; - musaGetDeviceProperties(&prop, device); - return prop.maxThreadsPerBlock; -} -#endif - -// Get optimal block size based on GPU hardware architecture. -inline int GetOptimalBlockSize() { - int max_threads = QueryMaxThreadsPerBlock(); - if (max_threads >= 2048) { - return 2048; - } else if (max_threads >= 1024) { - return 1024; - } else if (max_threads >= 512) { - return 512; - } else if (max_threads >= 256) { - return 256; - } else { - return 128; - } -} - __forceinline__ __device__ __host__ size_t IndexToOffset(size_t flat_index, size_t ndim, const size_t* shape, const ptrdiff_t* strides) { @@ -111,6 +18,16 @@ IndexToOffset(size_t flat_index, size_t ndim, const size_t* shape, return res; } +// Selects the largest block size from `AllCudaBlockSizes` that does not exceed +// `max_threads_per_block`. +inline int ComputeOptimalBlockSize(int max_threads_per_block) { + if (max_threads_per_block >= 2048) return 2048; + if (max_threads_per_block >= 1024) return 1024; + if (max_threads_per_block >= 512) return 512; + if (max_threads_per_block >= 256) return 256; + return 128; +} + } // namespace infini::ops #endif diff --git a/src/cuda/rms_norm/kernel.cuh b/src/cuda/rms_norm/kernel.cuh index 1a35d22..ccb091b 100644 --- a/src/cuda/rms_norm/kernel.cuh +++ b/src/cuda/rms_norm/kernel.cuh @@ -12,12 +12,13 @@ namespace infini::ops { namespace { -template +template __device__ __forceinline__ TCompute SumSquared(const TData* data_ptr, size_t count) { TCompute ss = 0; for (size_t i = threadIdx.x; i < count; i += block_size) { - TCompute value = Cast(data_ptr[i]); + TCompute value = Caster::template Cast(data_ptr[i]); ss += value * value; } using BlockReduce = cub::BlockReduce; @@ -27,8 +28,8 @@ __device__ __forceinline__ TCompute SumSquared(const TData* data_ptr, } // namespace -template +template __global__ void RmsNormKernel(TData* __restrict__ y, int64_t stride_y_batch, int64_t stride_y_nhead, const TData* __restrict__ x, @@ -42,17 +43,19 @@ __global__ void RmsNormKernel(TData* __restrict__ y, int64_t stride_y_batch, auto x_ptr = x + batch_idx * stride_x_batch + head_idx * stride_x_nhead; auto w_ptr = w; - TCompute ss = SumSquared(x_ptr, dim); + TCompute ss = SumSquared(x_ptr, dim); __shared__ TCompute rms; if (threadIdx.x == 0) { - rms = Cast(rsqrtf(ss / Cast(dim) + epsilon)); + rms = Caster::template Cast( + rsqrtf(ss / Caster::template Cast(dim) + epsilon)); } __syncthreads(); for (size_t i = threadIdx.x; i < dim; i += block_size) { - y_ptr[i] = - Cast(Cast(x_ptr[i]) * Cast(w_ptr[i]) * rms); + y_ptr[i] = Caster::template Cast( + Caster::template Cast(x_ptr[i]) * + Caster::template Cast(w_ptr[i]) * rms); } } diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h index 29ee51e..848f8fa 100644 --- a/src/cuda/rms_norm/kernel.h +++ b/src/cuda/rms_norm/kernel.h @@ -1,6 +1,7 @@ #ifndef INFINI_OPS_CUDA_RMS_NORM_KERNEL_H_ #define INFINI_OPS_CUDA_RMS_NORM_KERNEL_H_ +#include #include #include "base/rms_norm.h" @@ -30,20 +31,18 @@ class CudaRmsNorm : public RmsNorm { uint32_t num_blocks = static_cast(batch_size_ * nhead_); - if (out.dtype() != input.dtype() || out.dtype() != weight.dtype()) { - std::abort(); - } + assert(out.dtype() == input.dtype() && out.dtype() == weight.dtype()); - int block_size = GetOptimalBlockSize(); + int block_size = Backend::GetOptimalBlockSize(); DispatchFunc, ReducedFloatTypes>, AllCudaBlockSizes>( {static_cast(out.dtype()), block_size}, [&](auto list_tag) { - using T = TypeMapType(list_tag)>; + using T = TypeMapType(list_tag)>; constexpr int kBlockSize = ListGet<1>(list_tag); - RmsNormKernel + RmsNormKernel <<>>( reinterpret_cast(out.data()), stride_out_batch, stride_out_nhead, reinterpret_cast(input.data()), diff --git a/src/cuda/swiglu/kernel.cuh b/src/cuda/swiglu/kernel.cuh index f3997e6..5c3add3 100644 --- a/src/cuda/swiglu/kernel.cuh +++ b/src/cuda/swiglu/kernel.cuh @@ -7,23 +7,16 @@ namespace infini::ops { -// Optimized sigmoid function with support for vectorized types. -template +// Optimized sigmoid function with support for FP16 and BF16 types. +// TODO: The unified FP16/BF16 branch uses `Caster` and scalar float +// arithmetic instead of native vectorized intrinsics (e.g. `h2rcp`, +// `__hmul2`). Profile and restore specialized paths if needed. +template __device__ __forceinline__ T Sigmoid(const T& x) { - if constexpr (std::is_same_v) { - return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x)))); - } else if constexpr (std::is_same_v) { - return hrcp( - __hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x)))))); - } else if constexpr (std::is_same_v) { - float x0 = __bfloat162float(__low2bfloat16(x)); - float x1 = __bfloat162float(__high2bfloat16(x)); - float sig0 = __frcp_rn(__fadd_rn(1.0f, __expf(-x0))); - float sig1 = __frcp_rn(__fadd_rn(1.0f, __expf(-x1))); - return __floats2bfloat162_rn(sig0, sig1); - } else if constexpr (std::is_same_v) { - float xf = __bfloat162float(x); - return __float2bfloat16_rn(__frcp_rn(__fadd_rn(1.0f, __expf(-xf)))); + if constexpr (IsFP16 || IsBFloat16) { + float xf = Caster::template Cast(x); + return Caster::template Cast( + __frcp_rn(__fadd_rn(1.0f, __expf(-xf)))); } else if constexpr (std::is_same_v) { return __frcp_rn(__fadd_rn(1.0f, __expf(-x))); } else { @@ -32,7 +25,7 @@ __device__ __forceinline__ T Sigmoid(const T& x) { } // SwiGLU(x, gate) = Swish(x) * gate = (x * sigmoid(x)) * gate. -template +template __global__ void SwigluKernel(T* __restrict__ out, const T* __restrict__ a, const T* __restrict__ b, const size_t* __restrict__ out_shape, @@ -70,32 +63,16 @@ __global__ void SwigluKernel(T* __restrict__ out, const T* __restrict__ a, T up = a[input_idx]; T gate = b[gate_idx]; - if constexpr (std::is_same_v) { - // Vectorized `half2` computation for better performance. - out[out_idx] = __hmul2(__hmul2(gate, Sigmoid(gate)), up); - } else if constexpr (std::is_same_v) { - // Optimized `half` precision computation. - out[out_idx] = __hmul(__hmul(gate, Sigmoid(gate)), up); - } else if constexpr (std::is_same_v) { - float gate0 = __bfloat162float(__low2bfloat16(gate)); - float gate1 = __bfloat162float(__high2bfloat16(gate)); - float up0 = __bfloat162float(__low2bfloat16(up)); - float up1 = __bfloat162float(__high2bfloat16(up)); - float sig0 = __frcp_rn(__fadd_rn(1.0f, __expf(-gate0))); - float sig1 = __frcp_rn(__fadd_rn(1.0f, __expf(-gate1))); - out[out_idx] = - __floats2bfloat162_rn(__fmul_rn(__fmul_rn(gate0, sig0), up0), - __fmul_rn(__fmul_rn(gate1, sig1), up1)); - } else if constexpr (std::is_same_v) { - float gatef = __bfloat162float(gate); - float upf = __bfloat162float(up); + if constexpr (IsFP16 || IsBFloat16) { + float gatef = Caster::template Cast(gate); + float upf = Caster::template Cast(up); float sigf = __frcp_rn(__fadd_rn(1.0f, __expf(-gatef))); - out[out_idx] = - __float2bfloat16_rn(__fmul_rn(__fmul_rn(gatef, sigf), upf)); + out[out_idx] = Caster::template Cast( + __fmul_rn(__fmul_rn(gatef, sigf), upf)); } else if constexpr (std::is_same_v) { - out[out_idx] = __fmul_rn(__fmul_rn(gate, Sigmoid(gate)), up); + out[out_idx] = __fmul_rn(__fmul_rn(gate, Sigmoid(gate)), up); } else { - out[out_idx] = gate * Sigmoid(gate) * up; + out[out_idx] = gate * Sigmoid(gate) * up; } } } diff --git a/src/cuda/swiglu/kernel.h b/src/cuda/swiglu/kernel.h index 964e8b7..72ff3cc 100644 --- a/src/cuda/swiglu/kernel.h +++ b/src/cuda/swiglu/kernel.h @@ -49,11 +49,11 @@ class CudaSwiglu : public Swiglu { void operator()(const Tensor input, const Tensor gate, Tensor out) const override { - int block_size = GetOptimalBlockSize(); + int block_size = Backend::GetOptimalBlockSize(); DispatchFunc( {static_cast(out_type_), block_size}, [&](auto list_tag) { - using T = TypeMapType(list_tag)>; + using T = TypeMapType(list_tag)>; constexpr int kBlockSize = ListGet<1>(list_tag); auto cuda_stream = @@ -66,11 +66,12 @@ class CudaSwiglu : public Swiglu { const T* d_input = reinterpret_cast(input.data()); const T* d_gate = reinterpret_cast(gate.data()); - SwigluKernel<<>>( - d_out, d_input, d_gate, d_out_shape_, d_input_shape_, - d_gate_shape_, d_out_strides_, d_input_strides_, d_gate_strides_, - output_size_, ndim_, is_out_contiguous_, is_input_contiguous_, - is_gate_contiguous_); + SwigluKernel + <<>>( + d_out, d_input, d_gate, d_out_shape_, d_input_shape_, + d_gate_shape_, d_out_strides_, d_input_strides_, + d_gate_strides_, output_size_, ndim_, is_out_contiguous_, + is_input_contiguous_, is_gate_contiguous_); }, "CudaSwiglu::operator()"); } diff --git a/src/data_type.h b/src/data_type.h index ce2adfe..05ea3c3 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -5,25 +5,9 @@ #include #include -#ifdef WITH_NVIDIA -#include -#include -#elif defined(WITH_ILUVATAR) -#include -#include -#elif defined(WITH_METAX) -#include -#include -#elif defined(WITH_MOORE) -#include -#include -#elif defined(WITH_CAMBRICON) -#include "bang_bf16.h" -#include "bang_fp16.h" -#endif - #include "common/constexpr_map.h" #include "common/traits.h" +#include "device.h" namespace infini::ops { @@ -167,27 +151,16 @@ struct BFloat16 { } }; -template +template struct TypeMap; -template -using TypeMapType = typename TypeMap::type; - -template -struct DataTypeMap; +template +using TypeMapType = typename TypeMap::type; -template -inline constexpr DataType DataTypeMapValue = DataTypeMap::value; - -#define DEFINE_DATA_TYPE_MAPPING(ENUM_VALUE, CPP_TYPE) \ - template <> \ - struct TypeMap { \ - using type = CPP_TYPE; \ - }; \ - \ - template <> \ - struct DataTypeMap { \ - static constexpr DataType value = DataType::ENUM_VALUE; \ +#define DEFINE_DATA_TYPE_MAPPING(ENUM_VALUE, CPP_TYPE) \ + template \ + struct TypeMap { \ + using type = CPP_TYPE; \ }; DEFINE_DATA_TYPE_MAPPING(kUInt8, std::uint8_t) @@ -200,31 +173,18 @@ DEFINE_DATA_TYPE_MAPPING(kUInt64, std::uint64_t) DEFINE_DATA_TYPE_MAPPING(kInt64, std::int64_t) DEFINE_DATA_TYPE_MAPPING(kFloat32, float) DEFINE_DATA_TYPE_MAPPING(kFloat64, double) - -#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) -DEFINE_DATA_TYPE_MAPPING(kFloat16, half) -DEFINE_DATA_TYPE_MAPPING(kBFloat16, __nv_bfloat16) -#elif defined(WITH_METAX) -DEFINE_DATA_TYPE_MAPPING(kFloat16, __half) -DEFINE_DATA_TYPE_MAPPING(kBFloat16, __maca_bfloat16) -#elif defined(WITH_MOORE) -DEFINE_DATA_TYPE_MAPPING(kFloat16, half) -DEFINE_DATA_TYPE_MAPPING(kBFloat16, __mt_bfloat16) -#elif defined(WITH_CAMBRICON) -DEFINE_DATA_TYPE_MAPPING(kFloat16, __half) -DEFINE_DATA_TYPE_MAPPING(kBFloat16, __bang_bfloat16) -#else -DEFINE_DATA_TYPE_MAPPING(kFloat16, Float16) -DEFINE_DATA_TYPE_MAPPING(kBFloat16, BFloat16) -#endif #undef DEFINE_DATA_TYPE_MAPPING -// Define the traits to check whether a type is bfloat16 or float16. -template -inline constexpr bool IsBFloat16 = (DataTypeMapValue == DataType::kBFloat16); +// Checks whether a C++ type is the bfloat16 or float16 type for the given +// device. Full specializations for each device's float16/bfloat16 types are +// provided in the corresponding platform `device_.h` headers. +template +inline constexpr bool IsBFloat16 = + std::is_same_v>; -template -inline constexpr bool IsFP16 = (DataTypeMapValue == DataType::kFloat16); +template +inline constexpr bool IsFP16 = + std::is_same_v>; // Defines the common categories of data types using List. using FloatTypes = List; diff --git a/src/dispatcher.h b/src/dispatcher.h index 22aeb82..c971d0d 100644 --- a/src/dispatcher.h +++ b/src/dispatcher.h @@ -21,9 +21,9 @@ namespace detail { // Implements the dispatch body over a resolved `List`. template -auto DispatchFuncImpl(ValueType value, Functor &&func, +auto DispatchFuncImpl(ValueType value, Functor&& func, std::string_view context_str, List, - Args &&...args) { + Args&&... args) { using ReturnType = decltype(std::forward(func)( ValueTag(head)>{}, std::forward(args)...)); @@ -84,8 +84,8 @@ template struct DispatchFuncUnwrap, std::tuple> { - static auto call(ValueType value, Functor &&func, - std::string_view context_str, Args &&...args) { + static auto call(ValueType value, Functor&& func, + std::string_view context_str, Args&&... args) { return DispatchFuncImpl(value, std::forward(func), context_str, List{}, std::forward(args)...); } @@ -94,8 +94,8 @@ struct DispatchFuncUnwrap, // Empty-list specialization template struct DispatchFuncUnwrap, std::tuple> { - static auto call(ValueType value, Functor &&, std::string_view context_str, - Args &&...) { + static auto call(ValueType value, Functor&&, std::string_view context_str, + Args&&...) { // TODO(lzm): change to logging. std::cerr << "dispatch error: no allowed values registered for value " << static_cast(value) @@ -109,8 +109,8 @@ struct DispatchFuncUnwrap, std::tuple> { // (Single Dispatch) Dispatches a runtime value to a compile-time functor. template -auto DispatchFunc(ValueType value, Functor &&func, - std::string_view context_str = "", Args &&...args) { +auto DispatchFunc(ValueType value, Functor&& func, + std::string_view context_str = "", Args&&... args) { using FilteredPack = typename Filter, List<>, all_values...>::type; @@ -124,9 +124,9 @@ auto DispatchFunc(ValueType value, Functor &&func, // functor. // Base Case: All Dimensions Resolved template -auto DispatchFunc(const std::vector &values, size_t /*index*/, - Functor &&func, std::string_view /*context_str*/, - List, Args &&...args) { +auto DispatchFunc(const std::vector& values, size_t /*index*/, + Functor&& func, std::string_view /*context_str*/, + List, Args&&... args) { return std::forward(func)(List{}, std::forward(args)...); } @@ -134,9 +134,9 @@ auto DispatchFunc(const std::vector &values, size_t /*index*/, // Forward declaration of the recursive multi-dispatch overload. template -auto DispatchFunc(const std::vector &values, size_t index, - Functor &&func, std::string_view context_str, List, - Args &&...args); +auto DispatchFunc(const std::vector& values, size_t index, + Functor&& func, std::string_view context_str, List, + Args&&... args); // Adapter used in the recursive multi-dispatch case: given a resolved value // `val` recurse into the next dimension. @@ -145,13 +145,13 @@ struct MultiDispatchRecurseAdapter; template struct MultiDispatchRecurseAdapter, Functor, items...> { - const std::vector &values; + const std::vector& values; size_t next_index; - Functor &func; + Functor& func; std::string_view context_str; template - auto operator()(ValueTag, Args &&...args) const { + auto operator()(ValueTag, Args&&... args) const { return DispatchFunc(values, next_index, func, context_str, List{}, std::forward(args)...); @@ -160,9 +160,9 @@ struct MultiDispatchRecurseAdapter, Functor, items...> { template -auto MultiDispatchFirstDim(const std::vector &values, size_t index, - Functor &func, std::string_view context_str, - List, List, Args &&...args) { +auto MultiDispatchFirstDim(const std::vector& values, size_t index, + Functor& func, std::string_view context_str, + List, List, Args&&... args) { static_assert(sizeof...(allowed) > 0, "`DispatchFunc` dimension list is empty"); using EnumType = std::common_type_t; @@ -178,9 +178,9 @@ auto MultiDispatchFirstDim(const std::vector &values, size_t index, // (Multi-Dispatch) Recursive Case template -auto DispatchFunc(const std::vector &values, size_t index, - Functor &&func, std::string_view context_str, List, - Args &&...args) { +auto DispatchFunc(const std::vector& values, size_t index, + Functor&& func, std::string_view context_str, List, + Args&&... args) { return MultiDispatchFirstDim>( values, index, func, context_str, List{}, FirstList{}, std::forward(args)...); @@ -195,44 +195,44 @@ namespace detail { // Bridges the generic value dispatch layer to the `DataType`-specific type // dispatch layer. -template +template struct DataTypeAdapter { - Functor &func; + Functor& func; template - auto operator()(ValueTag, Args &&...args) const { - using T = TypeMapType(dtype)>; + auto operator()(ValueTag, Args&&... args) const { + using T = TypeMapType(dtype)>; return func(TypeTag{}, std::forward(args)...); } }; -template +template struct DataTypeMultiAdapter { - Functor &func; + Functor& func; template - auto operator()(List, Args &&...args) const { - return func(TypeTag(dtypes)>>{}..., + auto operator()(List, Args&&... args) const { + return func(TypeTag(dtypes)>>{}..., std::forward(args)...); } }; template struct DeviceAdapter { - Functor &func; + Functor& func; template - auto operator()(ValueTag, Args &&...args) const { + auto operator()(ValueTag, Args&&... args) const { return func(ValueTag{}, std::forward(args)...); } }; template struct DeviceMultiAdapter { - Functor &func; + Functor& func; template - auto operator()(List, Args &&...args) const { + auto operator()(List, Args&&... args) const { return func(ValueTag{}..., std::forward(args)...); } }; @@ -240,30 +240,33 @@ struct DeviceMultiAdapter { } // namespace detail // `DataType` Dispatch -template -auto DispatchFunc(DataType dtype, Functor &&func, - std::string_view context_str = "", Args &&...args) { - detail::DataTypeAdapter> adapter{func}; +template +auto DispatchFunc(DataType dtype, Functor&& func, + std::string_view context_str = "", Args&&... args) { + detail::DataTypeAdapter> adapter{func}; return DispatchFunc(dtype, adapter, context_str, std::forward(args)...); } // `DataType` Multi-Dispatch -template -auto DispatchFunc(std::initializer_list dtypes, Functor &&func, - std::string_view context_str = "", Args &&...args) { +template +auto DispatchFunc(std::initializer_list dtypes, Functor&& func, + std::string_view context_str = "", Args&&... args) { std::vector v; for (auto d : dtypes) v.push_back(static_cast(d)); - detail::DataTypeMultiAdapter> adapter{func}; + detail::DataTypeMultiAdapter> adapter{ + func}; return DispatchFunc(v, 0, adapter, context_str, List<>{}, std::forward(args)...); } // `Device` Dispatch template -auto DispatchFunc(Device::Type device, Functor &&func, - std::string_view context_str = "", Args &&...args) { +auto DispatchFunc(Device::Type device, Functor&& func, + std::string_view context_str = "", Args&&... args) { detail::DeviceAdapter> adapter{func}; return DispatchFunc(allowed_devices)...>( @@ -272,8 +275,8 @@ auto DispatchFunc(Device::Type device, Functor &&func, // `Device` Multi-Dispatch template -auto DispatchFunc(std::initializer_list devices, Functor &&func, - std::string_view context_str = "", Args &&...args) { +auto DispatchFunc(std::initializer_list devices, Functor&& func, + std::string_view context_str = "", Args&&... args) { std::vector v; for (auto d : devices) v.push_back(static_cast(d)); @@ -283,29 +286,50 @@ auto DispatchFunc(std::initializer_list devices, Functor &&func, } template -auto DispatchFuncListAliasImpl(ValueType value, Functor &&func, +auto DispatchFuncListAliasImpl(ValueType value, Functor&& func, std::string_view context_str, List, - Args &&...args) { + Args&&... args) { return DispatchFunc>(items)...>( value, std::forward(func), context_str, std::forward(args)...); } -// Interface for Generic `List` Aliases +template +auto DispatchFuncListAliasImpl(ValueType value, Functor&& func, + std::string_view context_str, List, + Args&&... args) { + return DispatchFunc>(items)...>( + value, std::forward(func), context_str, + std::forward(args)...); +} + +// Interface for Generic `List` Aliases (for non-DataType dispatch, e.g. Device) template ::value>> -auto DispatchFunc(ValueType value, Functor &&func, - std::string_view context_str = "", Args &&...args) { +auto DispatchFunc(ValueType value, Functor&& func, + std::string_view context_str = "", Args&&... args) { return DispatchFuncListAliasImpl(value, std::forward(func), context_str, ListType{}, std::forward(args)...); } +// Interface for Generic `List` Aliases (for DataType dispatch with device type) +template ::value>> +auto DispatchFunc(ValueType value, Functor&& func, + std::string_view context_str = "", Args&&... args) { + return DispatchFuncListAliasImpl(value, std::forward(func), + context_str, ListType{}, + std::forward(args)...); +} + // Interface for Any `int64_t`-Convertible Types template -auto DispatchFunc(std::initializer_list keys, Functor &&func, - std::string_view context_str = "", Args &&...args) { +auto DispatchFunc(std::initializer_list keys, Functor&& func, + std::string_view context_str = "", Args&&... args) { std::vector v_keys(keys); return DispatchFunc(v_keys, 0, std::forward(func), context_str, List<>{}, diff --git a/src/iluvatar/add/kernel.h b/src/iluvatar/add/kernel.h index 551544f..78ccff0 100644 --- a/src/iluvatar/add/kernel.h +++ b/src/iluvatar/add/kernel.h @@ -3,11 +3,8 @@ #include -// clang-format off -#include -// clang-format on - #include "cuda/add/kernel.h" +#include "iluvatar/device_.h" namespace infini::ops { @@ -16,6 +13,8 @@ namespace add { struct IluvatarBackend { using stream_t = cudaStream_t; + static constexpr Device::Type kDeviceType = Device::Type::kIluvatar; + static constexpr auto malloc = [](auto&&... args) { return cudaMalloc(std::forward(args)...); }; @@ -25,6 +24,10 @@ struct IluvatarBackend { static constexpr auto free = cudaFree; static constexpr auto memcpyH2D = cudaMemcpyHostToDevice; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } }; } // namespace add diff --git a/src/iluvatar/causal_softmax/kernel.h b/src/iluvatar/causal_softmax/kernel.h index 6187110..0f45118 100644 --- a/src/iluvatar/causal_softmax/kernel.h +++ b/src/iluvatar/causal_softmax/kernel.h @@ -3,15 +3,8 @@ #include -// clang-format off -#include -// clang-format on - -// clang-format off -#include "iluvatar/device_.h" -// clang-format on - #include "cuda/causal_softmax/kernel.h" +#include "iluvatar/device_.h" namespace infini::ops { @@ -19,6 +12,12 @@ namespace causal_softmax { struct IluvatarBackend { using stream_t = cudaStream_t; + + static constexpr Device::Type kDeviceType = Device::Type::kIluvatar; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } }; } // namespace causal_softmax diff --git a/src/iluvatar/device_.h b/src/iluvatar/device_.h index 9d46e77..387c453 100644 --- a/src/iluvatar/device_.h +++ b/src/iluvatar/device_.h @@ -1,11 +1,68 @@ -#ifndef INFINI_OPS_ILUVATAR_DEVICE_H_ -#define INFINI_OPS_ILUVATAR_DEVICE_H_ +#ifndef INFINI_OPS_ILUVATAR_DEVICE__H_ +#define INFINI_OPS_ILUVATAR_DEVICE__H_ +#include +#include + +// clang-format off +#include +#include +#include +// clang-format on + +#include "cuda/caster_.h" +#include "data_type.h" #include "device.h" namespace infini::ops { -inline constexpr auto kDeviceType{Device::Type::kIluvatar}; +using cuda_bfloat16 = nv_bfloat16; + +using cuda_bfloat162 = nv_bfloat162; + +template <> +struct TypeMap { + using type = half; +}; + +template <> +struct TypeMap { + using type = __nv_bfloat16; +}; + +// Caches `cudaDeviceProp` per device, initialized once at first access. +class DevicePropertyCache { + public: + static const cudaDeviceProp& GetCurrentDeviceProps() { + int device_id = 0; + cudaGetDevice(&device_id); + return GetDeviceProps(device_id); + } + + static const cudaDeviceProp& GetDeviceProps(int device_id) { + static std::vector cache = []() { + int count = 0; + cudaGetDeviceCount(&count); + if (count == 0) return std::vector{}; + std::vector props(count); + for (int i = 0; i < count; ++i) { + cudaGetDeviceProperties(&props[i], i); + } + return props; + }(); + + assert(device_id >= 0 && device_id < static_cast(cache.size())); + return cache[device_id]; + } +}; + +inline int QueryMaxThreadsPerBlock() { + return DevicePropertyCache::GetCurrentDeviceProps().maxThreadsPerBlock; +} + +template <> +struct Caster + : CudaCasterImpl {}; } // namespace infini::ops diff --git a/src/iluvatar/rms_norm/kernel.h b/src/iluvatar/rms_norm/kernel.h index a07bff8..470e764 100644 --- a/src/iluvatar/rms_norm/kernel.h +++ b/src/iluvatar/rms_norm/kernel.h @@ -3,15 +3,8 @@ #include -// clang-format off -#include -// clang-format on - -// clang-format off -#include "iluvatar/device_.h" -// clang-format on - #include "cuda/rms_norm/kernel.h" +#include "iluvatar/device_.h" namespace infini::ops { @@ -19,6 +12,12 @@ namespace rms_norm { struct IluvatarBackend { using stream_t = cudaStream_t; + + static constexpr Device::Type kDeviceType = Device::Type::kIluvatar; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } }; } // namespace rms_norm diff --git a/src/iluvatar/swiglu/kernel.h b/src/iluvatar/swiglu/kernel.h index cf5310c..7fc2e16 100644 --- a/src/iluvatar/swiglu/kernel.h +++ b/src/iluvatar/swiglu/kernel.h @@ -3,11 +3,8 @@ #include -// clang-format off -#include -// clang-format on - #include "cuda/swiglu/kernel.h" +#include "iluvatar/device_.h" namespace infini::ops { @@ -16,6 +13,8 @@ namespace swiglu { struct IluvatarBackend { using stream_t = cudaStream_t; + static constexpr Device::Type kDeviceType = Device::Type::kIluvatar; + static constexpr auto malloc = [](auto&&... args) { return cudaMalloc(std::forward(args)...); }; @@ -25,6 +24,10 @@ struct IluvatarBackend { static constexpr auto free = cudaFree; static constexpr auto memcpyH2D = cudaMemcpyHostToDevice; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } }; } // namespace swiglu diff --git a/src/metax/add/kernel.h b/src/metax/add/kernel.h index ce9ec01..6ef2a09 100644 --- a/src/metax/add/kernel.h +++ b/src/metax/add/kernel.h @@ -3,11 +3,8 @@ #include -// clang-format off -#include -// clang-format on - #include "cuda/add/kernel.h" +#include "metax/device_.h" namespace infini::ops { @@ -16,6 +13,8 @@ namespace add { struct MetaxBackend { using stream_t = mcStream_t; + static constexpr Device::Type kDeviceType = Device::Type::kMetax; + static constexpr auto malloc = mcMalloc; static constexpr auto memcpy = mcMemcpy; @@ -23,6 +22,10 @@ struct MetaxBackend { static constexpr auto free = mcFree; static constexpr auto memcpyH2D = mcMemcpyHostToDevice; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } }; } // namespace add diff --git a/src/metax/causal_softmax/kernel.h b/src/metax/causal_softmax/kernel.h index f8648e4..5ec32b7 100644 --- a/src/metax/causal_softmax/kernel.h +++ b/src/metax/causal_softmax/kernel.h @@ -3,15 +3,8 @@ #include -// clang-format off -#include -// clang-format on - -// clang-format off -#include "metax/device_.h" -// clang-format on - #include "cuda/causal_softmax/kernel.h" +#include "metax/device_.h" namespace infini::ops { @@ -19,6 +12,12 @@ namespace causal_softmax { struct MetaxBackend { using stream_t = mcStream_t; + + static constexpr Device::Type kDeviceType = Device::Type::kMetax; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } }; } // namespace causal_softmax diff --git a/src/metax/device_.h b/src/metax/device_.h index 5e7c93c..6d59c76 100644 --- a/src/metax/device_.h +++ b/src/metax/device_.h @@ -1,11 +1,35 @@ -#ifndef INFINI_OPS_METAX_DEVICE_H_ -#define INFINI_OPS_METAX_DEVICE_H_ +#ifndef INFINI_OPS_METAX_DEVICE__H_ +#define INFINI_OPS_METAX_DEVICE__H_ +#include +#include +#include + +#include "cuda/caster_.h" +#include "data_type.h" #include "device.h" namespace infini::ops { -inline constexpr auto kDeviceType{Device::Type::kMetax}; +using cuda_bfloat16 = maca_bfloat16; + +using cuda_bfloat162 = maca_bfloat162; + +template <> +struct TypeMap { + using type = __half; +}; + +template <> +struct TypeMap { + using type = __maca_bfloat16; +}; + +// TODO: Add MCR device properties query for Metax. +inline int QueryMaxThreadsPerBlock() { return 256; } + +template <> +struct Caster : CudaCasterImpl {}; } // namespace infini::ops diff --git a/src/metax/rms_norm/kernel.h b/src/metax/rms_norm/kernel.h index bdf0bed..5806435 100644 --- a/src/metax/rms_norm/kernel.h +++ b/src/metax/rms_norm/kernel.h @@ -3,15 +3,8 @@ #include -// clang-format off -#include -// clang-format on - -// clang-format off -#include "metax/device_.h" -// clang-format on - #include "cuda/rms_norm/kernel.h" +#include "metax/device_.h" namespace infini::ops { @@ -19,6 +12,12 @@ namespace rms_norm { struct MetaxBackend { using stream_t = mcStream_t; + + static constexpr Device::Type kDeviceType = Device::Type::kMetax; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } }; } // namespace rms_norm diff --git a/src/metax/swiglu/kernel.h b/src/metax/swiglu/kernel.h index 77416aa..75b9c46 100644 --- a/src/metax/swiglu/kernel.h +++ b/src/metax/swiglu/kernel.h @@ -3,11 +3,8 @@ #include -// clang-format off -#include -// clang-format on - #include "cuda/swiglu/kernel.h" +#include "metax/device_.h" namespace infini::ops { @@ -16,6 +13,8 @@ namespace swiglu { struct MetaxBackend { using stream_t = mcStream_t; + static constexpr Device::Type kDeviceType = Device::Type::kMetax; + static constexpr auto malloc = [](auto&&... args) { return mcMalloc(std::forward(args)...); }; @@ -25,6 +24,10 @@ struct MetaxBackend { static constexpr auto free = mcFree; static constexpr auto memcpyH2D = mcMemcpyHostToDevice; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } }; } // namespace swiglu diff --git a/src/moore/add/kernel.h b/src/moore/add/kernel.h index 21a51f6..1f393dc 100644 --- a/src/moore/add/kernel.h +++ b/src/moore/add/kernel.h @@ -3,15 +3,12 @@ #include -// clang-format off -#include -// clang-format on - // clang-format off #include "moore/polyfills.cuh" // clang-format on #include "cuda/add/kernel.h" +#include "moore/device_.h" namespace infini::ops { @@ -20,6 +17,8 @@ namespace add { struct MooreBackend { using stream_t = musaStream_t; + static constexpr Device::Type kDeviceType = Device::Type::kMoore; + static constexpr auto malloc = [](auto&&... args) { return musaMalloc(std::forward(args)...); }; @@ -33,6 +32,10 @@ struct MooreBackend { }; static constexpr auto memcpyH2D = musaMemcpyHostToDevice; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } }; } // namespace add diff --git a/src/moore/device_.h b/src/moore/device_.h index fc9282f..d7c7599 100644 --- a/src/moore/device_.h +++ b/src/moore/device_.h @@ -1,11 +1,40 @@ -#ifndef INFINI_OPS_MOORE_DEVICE_H_ -#define INFINI_OPS_MOORE_DEVICE_H_ +#ifndef INFINI_OPS_MOORE_DEVICE__H_ +#define INFINI_OPS_MOORE_DEVICE__H_ +#include +#include +#include + +#include "cuda/caster_.h" +#include "data_type.h" #include "device.h" namespace infini::ops { -inline constexpr auto kDeviceType{Device::Type::kMoore}; +using cuda_bfloat16 = __mt_bfloat16; + +using cuda_bfloat162 = __mt_bfloat162; + +template <> +struct TypeMap { + using type = half; +}; + +template <> +struct TypeMap { + using type = __mt_bfloat16; +}; + +inline int QueryMaxThreadsPerBlock() { + int device = 0; + musaGetDevice(&device); + musaDeviceProp prop; + musaGetDeviceProperties(&prop, device); + return prop.maxThreadsPerBlock; +} + +template <> +struct Caster : CudaCasterImpl {}; } // namespace infini::ops diff --git a/src/moore/polyfills.cuh b/src/moore/polyfills.cuh index b3c7e70..88645a4 100644 --- a/src/moore/polyfills.cuh +++ b/src/moore/polyfills.cuh @@ -5,6 +5,7 @@ // clang-format off #include +#include // clang-format on namespace infini::ops { diff --git a/src/moore/swiglu/kernel.h b/src/moore/swiglu/kernel.h index a7759fb..0c6b058 100644 --- a/src/moore/swiglu/kernel.h +++ b/src/moore/swiglu/kernel.h @@ -3,15 +3,12 @@ #include -// clang-format off -#include -// clang-format on - // clang-format off #include "moore/polyfills.cuh" // clang-format on #include "cuda/swiglu/kernel.h" +#include "moore/device_.h" namespace infini::ops { @@ -20,6 +17,8 @@ namespace swiglu { struct MooreBackend { using stream_t = musaStream_t; + static constexpr Device::Type kDeviceType = Device::Type::kMoore; + static constexpr auto malloc = [](auto&&... args) { return musaMalloc(std::forward(args)...); }; @@ -33,6 +32,10 @@ struct MooreBackend { }; static constexpr auto memcpyH2D = musaMemcpyHostToDevice; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } }; } // namespace swiglu diff --git a/src/nvidia/add/kernel.h b/src/nvidia/add/kernel.h index 7e6c3e5..6e6c2c3 100644 --- a/src/nvidia/add/kernel.h +++ b/src/nvidia/add/kernel.h @@ -3,11 +3,8 @@ #include -// clang-format off -#include -// clang-format on - #include "cuda/add/kernel.h" +#include "nvidia/device_.h" namespace infini::ops { @@ -16,6 +13,8 @@ namespace add { struct NvidiaBackend { using stream_t = cudaStream_t; + static constexpr Device::Type kDeviceType = Device::Type::kNvidia; + static constexpr auto malloc = [](auto&&... args) { return cudaMalloc(std::forward(args)...); }; @@ -25,6 +24,10 @@ struct NvidiaBackend { static constexpr auto free = cudaFree; static constexpr auto memcpyH2D = cudaMemcpyHostToDevice; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } }; } // namespace add diff --git a/src/nvidia/causal_softmax/kernel.h b/src/nvidia/causal_softmax/kernel.h index 0c13518..62fdf8b 100644 --- a/src/nvidia/causal_softmax/kernel.h +++ b/src/nvidia/causal_softmax/kernel.h @@ -3,15 +3,8 @@ #include -// clang-format off -#include -// clang-format on - -// clang-format off -#include "nvidia/device_.h" -// clang-format on - #include "cuda/causal_softmax/kernel.h" +#include "nvidia/device_.h" namespace infini::ops { @@ -19,6 +12,12 @@ namespace causal_softmax { struct NvidiaBackend { using stream_t = cudaStream_t; + + static constexpr Device::Type kDeviceType = Device::Type::kNvidia; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } }; } // namespace causal_softmax diff --git a/src/nvidia/device_.h b/src/nvidia/device_.h index 1d7fe05..90de446 100644 --- a/src/nvidia/device_.h +++ b/src/nvidia/device_.h @@ -1,11 +1,67 @@ -#ifndef INFINI_OPS_NVIDIA_DEVICE_H_ -#define INFINI_OPS_NVIDIA_DEVICE_H_ +#ifndef INFINI_OPS_NVIDIA_DEVICE__H_ +#define INFINI_OPS_NVIDIA_DEVICE__H_ +#include +#include + +// clang-format off +#include +#include +#include +// clang-format on + +#include "cuda/caster_.h" +#include "data_type.h" #include "device.h" namespace infini::ops { -inline constexpr auto kDeviceType{Device::Type::kNvidia}; +using cuda_bfloat16 = nv_bfloat16; + +using cuda_bfloat162 = nv_bfloat162; + +template <> +struct TypeMap { + using type = half; +}; + +template <> +struct TypeMap { + using type = __nv_bfloat16; +}; + +// Caches `cudaDeviceProp` per device, initialized once at first access. +class DevicePropertyCache { + public: + static const cudaDeviceProp& GetCurrentDeviceProps() { + int device_id = 0; + cudaGetDevice(&device_id); + return GetDeviceProps(device_id); + } + + static const cudaDeviceProp& GetDeviceProps(int device_id) { + static std::vector cache = []() { + int count = 0; + cudaGetDeviceCount(&count); + if (count == 0) return std::vector{}; + std::vector props(count); + for (int i = 0; i < count; ++i) { + cudaGetDeviceProperties(&props[i], i); + } + return props; + }(); + + assert(device_id >= 0 && device_id < static_cast(cache.size())); + return cache[device_id]; + } +}; + +inline int QueryMaxThreadsPerBlock() { + return DevicePropertyCache::GetCurrentDeviceProps().maxThreadsPerBlock; +} + +template <> +struct Caster : CudaCasterImpl {}; } // namespace infini::ops diff --git a/src/nvidia/rms_norm/kernel.h b/src/nvidia/rms_norm/kernel.h index bb766d4..e346a31 100644 --- a/src/nvidia/rms_norm/kernel.h +++ b/src/nvidia/rms_norm/kernel.h @@ -3,15 +3,8 @@ #include -// clang-format off -#include -// clang-format on - -// clang-format off -#include "nvidia/device_.h" -// clang-format on - #include "cuda/rms_norm/kernel.h" +#include "nvidia/device_.h" namespace infini::ops { @@ -19,6 +12,12 @@ namespace rms_norm { struct NvidiaBackend { using stream_t = cudaStream_t; + + static constexpr Device::Type kDeviceType = Device::Type::kNvidia; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } }; } // namespace rms_norm diff --git a/src/nvidia/swiglu/kernel.h b/src/nvidia/swiglu/kernel.h index 54644e5..8ea30f8 100644 --- a/src/nvidia/swiglu/kernel.h +++ b/src/nvidia/swiglu/kernel.h @@ -3,11 +3,8 @@ #include -// clang-format off -#include -// clang-format on - #include "cuda/swiglu/kernel.h" +#include "nvidia/device_.h" namespace infini::ops { @@ -16,6 +13,8 @@ namespace swiglu { struct NvidiaBackend { using stream_t = cudaStream_t; + static constexpr Device::Type kDeviceType = Device::Type::kNvidia; + static constexpr auto malloc = [](auto&&... args) { return cudaMalloc(std::forward(args)...); }; @@ -25,6 +24,10 @@ struct NvidiaBackend { static constexpr auto free = cudaFree; static constexpr auto memcpyH2D = cudaMemcpyHostToDevice; + + static int GetOptimalBlockSize() { + return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock()); + } }; } // namespace swiglu diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h index 8f48bf2..de4fa62 100644 --- a/src/pybind11_utils.h +++ b/src/pybind11_utils.h @@ -10,33 +10,78 @@ namespace py = pybind11; namespace infini::ops { +namespace detail { + +template +struct TorchDeviceName; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"cpu"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"cuda"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"cuda"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"cuda"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"cuda"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"cuda"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"cuda"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"mlu"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"npu"}; +}; + +template <> +struct TorchDeviceName { + static constexpr std::string_view kValue{"musa"}; +}; + +template +std::unordered_map BuildTorchNameMap( + List) { + std::unordered_map map; + (map.emplace(std::string{TorchDeviceName::kValue}, kDevs), ...); + return map; +} + +} // namespace detail + inline DataType DataTypeFromString(const std::string& name) { return kStringToDataType.at(name); } inline Device::Type DeviceTypeFromString(const std::string& name) { - static const std::unordered_map kTorchNameToTypes{ - {"cpu", Device::Type::kCpu}, -#ifdef WITH_NVIDIA - {"cuda", Device::Type::kNvidia}, -#endif -#ifdef WITH_METAX - {"cuda", Device::Type::kMetax}, -#endif -#ifdef WITH_ILUVATAR - {"cuda", Device::Type::kIluvatar}, -#endif -#ifdef WITH_KUNLUN - {"cuda", Device::Type::kKunlun}, -#endif -#ifdef WITH_HYGON - {"cuda", Device::Type::kHygon}, -#endif -#ifdef WITH_QY - {"cuda", Device::Type::kQy}, -#endif - {"mlu", Device::Type::kCambricon}, {"npu", Device::Type::kAscend}, - {"musa", Device::Type::kMoore}}; + static const auto kTorchNameToTypes{ + detail::BuildTorchNameMap(ActiveDevices{})}; auto it{kTorchNameToTypes.find(name)}; diff --git a/src/tensor.cc b/src/tensor.cc index b4806a2..cd11087 100644 --- a/src/tensor.cc +++ b/src/tensor.cc @@ -112,7 +112,8 @@ Tensor::Strides Tensor::DefaultStrides(const Shape& shape) { std::string Tensor::ToStringHelper() const { if (ndim() == 0) { - return DispatchFunc>( + return DispatchFunc>( dtype_, [&](auto tag) { using T = typename decltype(tag)::type;