From 1fa9917335480051e6d93cc138985c9ac4b8f2bd Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Mon, 30 Mar 2026 02:16:00 +0000 Subject: [PATCH] feat: add cambricon add op --- src/CMakeLists.txt | 1 + src/cambricon/add/add.h | 65 +++++++++++ src/cambricon/add/kernel.mlu | 214 +++++++++++++++++++++++++++++++++++ tests/test_add.py | 5 + 4 files changed, 285 insertions(+) create mode 100644 src/cambricon/add/add.h create mode 100644 src/cambricon/add/kernel.mlu diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 585e3ab..76a317d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -133,6 +133,7 @@ if(WITH_CAMBRICON) message(STATUS "Found cncc: ${CNCC_COMPILER}") set(MLU_COMPILE_OPTS -c --bang-mlu-arch=mtp_592 -O3 -fPIC -Wall -Werror -std=c++17 -pthread + -DWITH_CAMBRICON=1 -I${CMAKE_CURRENT_SOURCE_DIR} -I${NEUWARE_HOME}/include -idirafter /usr/local/neuware/lib/clang/11.1.0/include ) diff --git a/src/cambricon/add/add.h b/src/cambricon/add/add.h new file mode 100644 index 0000000..84092ac --- /dev/null +++ b/src/cambricon/add/add.h @@ -0,0 +1,65 @@ +#ifndef INFINI_OPS_CAMBRICON_ADD_ADD_H_ +#define INFINI_OPS_CAMBRICON_ADD_ADD_H_ + +#include "../common.h" +#include "base/add.h" + +namespace infini::ops { + +template +void AddUnion(void *workspace, int core_per_cluster, int cluster_count, + cnrtQueue_t queue, void *out, const void *input, + const void *other, const size_t *out_shape, + const ptrdiff_t *out_strides, const size_t *input_shape, + const ptrdiff_t *input_strides, const size_t *other_shape, + const ptrdiff_t *other_strides, size_t output_size, int ndim, + bool fast_path, bool out_contiguous); + +template <> +class Operator : public Add { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Add{input, other, out} { + cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster, + &cluster_count); + cnrtMalloc(&default_workspace_, workspace_size_in_bytes()); + } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + auto queue = static_cast(stream_ ? stream_ : 0); + auto workspace{workspace_ ? workspace_ : default_workspace_}; + + bool fast_path = is_input_contiguous_ && is_other_contiguous_ && + is_out_contiguous_ && input_shape_ == out_shape_ && + other_shape_ == out_shape_; + + DispatchFunc>( + out_type_, + [&](auto tag) { + using T = typename decltype(tag)::type; + AddUnion(workspace, core_per_cluster, cluster_count, queue, + out.data(), input.data(), other.data(), out_shape_.data(), + out_strides_.data(), input_shape_.data(), + input_strides_.data(), other_shape_.data(), + other_strides_.data(), output_size_, ndim_, fast_path, + is_out_contiguous_); + }, + "CambriconAdd::operator() - output dispatch"); + } + + ~Operator() { cnrtFree(default_workspace_); } + + std::size_t workspace_size_in_bytes() const override { + return ndim_ * (3 * sizeof(size_t) + 3 * sizeof(ptrdiff_t)); + } + + void *default_workspace_{nullptr}; + int core_per_cluster = 0; + int cluster_count = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cambricon/add/kernel.mlu b/src/cambricon/add/kernel.mlu new file mode 100644 index 0000000..5e482e0 --- /dev/null +++ b/src/cambricon/add/kernel.mlu @@ -0,0 +1,214 @@ +#include "add.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +namespace infini::ops { + +template +__mlu_device__ void BangAdd(const T *src1, const T *src2, T *dst, + size_t n) { + if constexpr (std::is_same_v) { + __bang_add(reinterpret_cast(dst), + reinterpret_cast(src1), + reinterpret_cast(src2), n); + } else { + __bang_add(dst, src1, src2, n); + } +} + +template +__mlu_global__ void AddKernel(const T *input, const T *other, T *output, + const size_t *out_shape, + const ptrdiff_t *out_strides, + const size_t *input_shape, + const ptrdiff_t *input_strides, + const size_t *other_shape, + const ptrdiff_t *other_strides, + size_t output_size, int ndim, bool fast_path, + bool out_contiguous) { + size_t elements_per_task = (output_size + taskDim - 1) / taskDim; + size_t start = taskId * elements_per_task; + size_t end = start + elements_per_task; + if (end > output_size) end = output_size; + size_t num_elements = end > start ? end - start : 0; + if (num_elements == 0) return; + + size_t nram_usable = NRAM_MAX_SIZE - 256; + size_t block_size = nram_usable / (3 * sizeof(T)); + block_size = (block_size / 64) * 64; // Align to 64 elements + if (block_size == 0) block_size = 64; + + T *input_buf = reinterpret_cast(nram_buffer); + T *other_buf = input_buf + block_size; + T *output_buf = other_buf + block_size; + + size_t processed = 0; + + if (fast_path) { + // Fast path: all tensors contiguous with matching shapes (no broadcast). + while (processed < num_elements) { + size_t curr = block_size; + if (curr > num_elements - processed) curr = num_elements - processed; + + __memcpy(input_buf, input + start + processed, + curr * sizeof(T), GDRAM2NRAM); + __memcpy(other_buf, other + start + processed, + curr * sizeof(T), GDRAM2NRAM); + BangAdd(input_buf, other_buf, output_buf, curr); + __memcpy(output + start + processed, output_buf, + curr * sizeof(T), NRAM2GDRAM); + + processed += curr; + } + return; + } + + // General path: handle non-contiguous tensors and broadcasting. + while (processed < num_elements) { + size_t curr = block_size; + if (curr > num_elements - processed) curr = num_elements - processed; + + for (size_t i = 0; i < curr; ++i) { + size_t flat_idx = start + processed + i; + + // Compute `input` offset. + { + size_t tmp = flat_idx; + ptrdiff_t offset = 0; + for (int d = ndim - 1; d >= 0; --d) { + size_t coord = tmp % out_shape[d]; + tmp /= out_shape[d]; + size_t c = coord < input_shape[d] ? coord : 0; + offset += static_cast(c) * input_strides[d]; + } + input_buf[i] = input[offset]; + } + + // Compute `other` offset. + { + size_t tmp = flat_idx; + ptrdiff_t offset = 0; + for (int d = ndim - 1; d >= 0; --d) { + size_t coord = tmp % out_shape[d]; + tmp /= out_shape[d]; + size_t c = coord < other_shape[d] ? coord : 0; + offset += static_cast(c) * other_strides[d]; + } + other_buf[i] = other[offset]; + } + } + + BangAdd(input_buf, other_buf, output_buf, curr); + + if (out_contiguous) { + __memcpy(output + start + processed, output_buf, + curr * sizeof(T), NRAM2GDRAM); + } else { + for (size_t i = 0; i < curr; ++i) { + size_t flat_idx = start + processed + i; + size_t tmp = flat_idx; + ptrdiff_t offset = 0; + for (int d = ndim - 1; d >= 0; --d) { + size_t coord = tmp % out_shape[d]; + offset += static_cast(coord) * out_strides[d]; + tmp /= out_shape[d]; + } + output[offset] = output_buf[i]; + } + } + + processed += curr; + } +} + +template +void AddUnion(void *workspace, int core_per_cluster, int cluster_count, + cnrtQueue_t queue, void *out, const void *input, const void *other, + const size_t *out_shape, const ptrdiff_t *out_strides, + const size_t *input_shape, const ptrdiff_t *input_strides, + const size_t *other_shape, const ptrdiff_t *other_strides, + size_t output_size, int ndim, bool fast_path, + bool out_contiguous) { + cnrtDim3_t kernel_dim; + cnrtFunctionType_t kernel_type; + + kernel_dim.x = core_per_cluster; + kernel_dim.y = cluster_count; + kernel_dim.z = 1; + kernel_type = cnrtFuncTypeUnion1; + + auto out_ = reinterpret_cast(out); + auto input_ = reinterpret_cast(input); + auto other_ = reinterpret_cast(other); + + char *tmp = reinterpret_cast(workspace); + size_t *mlu_out_shape = reinterpret_cast(tmp); + size_t *mlu_input_shape = mlu_out_shape + ndim; + size_t *mlu_other_shape = mlu_input_shape + ndim; + ptrdiff_t *mlu_out_strides = + reinterpret_cast(mlu_other_shape + ndim); + ptrdiff_t *mlu_input_strides = mlu_out_strides + ndim; + ptrdiff_t *mlu_other_strides = mlu_input_strides + ndim; + + CNRT_CHECK(cnrtMemcpyAsync(mlu_out_shape, const_cast(out_shape), + ndim * sizeof(size_t), queue, + cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_input_shape, const_cast(input_shape), + ndim * sizeof(size_t), queue, + cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_other_shape, const_cast(other_shape), + ndim * sizeof(size_t), queue, + cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_out_strides, + const_cast(out_strides), + ndim * sizeof(ptrdiff_t), queue, + cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_input_strides, + const_cast(input_strides), + ndim * sizeof(ptrdiff_t), queue, + cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_other_strides, + const_cast(other_strides), + ndim * sizeof(ptrdiff_t), queue, + cnrtMemcpyHostToDev)); + + AddKernel<<>>( + input_, other_, out_, mlu_out_shape, mlu_out_strides, mlu_input_shape, + mlu_input_strides, mlu_other_shape, mlu_other_strides, output_size, ndim, + fast_path, out_contiguous); + + cnrtQueueSync(queue); +} + +template void AddUnion<__half>(void *, int, int, cnrtQueue_t, void *, + const void *, const void *, const size_t *, + const ptrdiff_t *, const size_t *, + const ptrdiff_t *, const size_t *, + const ptrdiff_t *, size_t, int, bool, bool); + +template void AddUnion<__bang_bfloat16>(void *, int, int, cnrtQueue_t, void *, + const void *, const void *, + const size_t *, const ptrdiff_t *, + const size_t *, const ptrdiff_t *, + const size_t *, const ptrdiff_t *, + size_t, int, bool, bool); + +template void AddUnion(void *, int, int, cnrtQueue_t, void *, + const void *, const void *, const size_t *, + const ptrdiff_t *, const size_t *, + const ptrdiff_t *, const size_t *, + const ptrdiff_t *, size_t, int, bool, bool); + +template void AddUnion(void *, int, int, cnrtQueue_t, void *, + const void *, const void *, const size_t *, + const ptrdiff_t *, const size_t *, + const ptrdiff_t *, const size_t *, + const ptrdiff_t *, size_t, int, bool, bool); + +template void AddUnion(void *, int, int, cnrtQueue_t, void *, + const void *, const void *, const size_t *, + const ptrdiff_t *, const size_t *, + const ptrdiff_t *, const size_t *, + const ptrdiff_t *, size_t, int, bool, bool); + +} // namespace infini::ops diff --git a/tests/test_add.py b/tests/test_add.py index 8b8166c..0384ba5 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -45,6 +45,11 @@ def test_add( pytest.skip( "The `torch.musa` test cloning path does not support `uint16`, `uint32`, or `uint64`." ) + + if device == "mlu" and ( dtype in _UINT_DTYPES or dtype == torch.int16): + pytest.skip( + "The `torch.mlu` test cloning path does not support `int16`, `uint16`, `uint32`, or `uint64`." + ) if dtype in _INT_DTYPES or dtype in _UINT_DTYPES: input = randint_strided(