diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 585e3ab..76a317d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -133,6 +133,7 @@ if(WITH_CAMBRICON) message(STATUS "Found cncc: ${CNCC_COMPILER}") set(MLU_COMPILE_OPTS -c --bang-mlu-arch=mtp_592 -O3 -fPIC -Wall -Werror -std=c++17 -pthread + -DWITH_CAMBRICON=1 -I${CMAKE_CURRENT_SOURCE_DIR} -I${NEUWARE_HOME}/include -idirafter /usr/local/neuware/lib/clang/11.1.0/include ) diff --git a/src/cambricon/swiglu/kernel.mlu b/src/cambricon/swiglu/kernel.mlu new file mode 100644 index 0000000..de21c4a --- /dev/null +++ b/src/cambricon/swiglu/kernel.mlu @@ -0,0 +1,211 @@ +#include "swiglu.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +namespace infini::ops { + +template +__mlu_device__ void ComputeSwiglu(const T *input, const T *gate, T *output, + size_t n) { + if constexpr (std::is_same_v) { + for (size_t i = 0; i < n; ++i) { + float g = gate[i]; + output[i] = input[i] * g / (1.0f + expf(-g)); + } + } else if constexpr (std::is_same_v) { + auto *out_h = reinterpret_cast(output); + auto *in_h = reinterpret_cast(input); + auto *gate_h = reinterpret_cast(gate); + __bang_active_sigmoid(out_h, gate_h, n); + __bang_mul(out_h, out_h, gate_h, n); + __bang_mul(out_h, out_h, in_h, n); + } else { + __bang_active_sigmoid(output, gate, n); + __bang_mul(output, output, gate, n); + __bang_mul(output, output, input, n); + } +} + +template +__mlu_global__ void SwigluKernel(const T *input, const T *gate, T *output, + const size_t *out_shape, + const ptrdiff_t *out_strides, + const size_t *input_shape, + const ptrdiff_t *input_strides, + const size_t *gate_shape, + const ptrdiff_t *gate_strides, + size_t output_size, int ndim, bool fast_path, + bool out_contiguous) { + size_t elements_per_task = (output_size + taskDim - 1) / taskDim; + size_t start = taskId * elements_per_task; + size_t end = start + elements_per_task; + if (end > output_size) end = output_size; + size_t num_elements = end > start ? end - start : 0; + if (num_elements == 0) return; + + size_t nram_usable = NRAM_MAX_SIZE - 256; + size_t block_size = nram_usable / (3 * sizeof(T)); + block_size = (block_size / 64) * 64; + if (block_size == 0) block_size = 64; + + T *input_buf = reinterpret_cast(nram_buffer); + T *gate_buf = input_buf + block_size; + T *output_buf = gate_buf + block_size; + + size_t processed = 0; + + if (fast_path) { + while (processed < num_elements) { + size_t curr = block_size; + if (curr > num_elements - processed) curr = num_elements - processed; + + __memcpy(input_buf, input + start + processed, + curr * sizeof(T), GDRAM2NRAM); + __memcpy(gate_buf, gate + start + processed, + curr * sizeof(T), GDRAM2NRAM); + ComputeSwiglu(input_buf, gate_buf, output_buf, curr); + __memcpy(output + start + processed, output_buf, + curr * sizeof(T), NRAM2GDRAM); + + processed += curr; + } + return; + } + + // General path: handle non-contiguous tensors and broadcasting. + while (processed < num_elements) { + size_t curr = block_size; + if (curr > num_elements - processed) curr = num_elements - processed; + + for (size_t i = 0; i < curr; ++i) { + size_t flat_idx = start + processed + i; + + // Compute `input` offset. + { + size_t tmp = flat_idx; + ptrdiff_t offset = 0; + for (int d = ndim - 1; d >= 0; --d) { + size_t coord = tmp % out_shape[d]; + tmp /= out_shape[d]; + size_t c = coord < input_shape[d] ? coord : 0; + offset += static_cast(c) * input_strides[d]; + } + input_buf[i] = input[offset]; + } + + // Compute `gate` offset. + { + size_t tmp = flat_idx; + ptrdiff_t offset = 0; + for (int d = ndim - 1; d >= 0; --d) { + size_t coord = tmp % out_shape[d]; + tmp /= out_shape[d]; + size_t c = coord < gate_shape[d] ? coord : 0; + offset += static_cast(c) * gate_strides[d]; + } + gate_buf[i] = gate[offset]; + } + } + + ComputeSwiglu(input_buf, gate_buf, output_buf, curr); + + if (out_contiguous) { + __memcpy(output + start + processed, output_buf, + curr * sizeof(T), NRAM2GDRAM); + } else { + for (size_t i = 0; i < curr; ++i) { + size_t flat_idx = start + processed + i; + size_t tmp = flat_idx; + ptrdiff_t offset = 0; + for (int d = ndim - 1; d >= 0; --d) { + size_t coord = tmp % out_shape[d]; + offset += static_cast(coord) * out_strides[d]; + tmp /= out_shape[d]; + } + output[offset] = output_buf[i]; + } + } + + processed += curr; + } +} + +template +void SwigluUnion(void *workspace, int core_per_cluster, int cluster_count, + cnrtQueue_t queue, void *out, const void *input, + const void *gate, const size_t *out_shape, + const ptrdiff_t *out_strides, const size_t *input_shape, + const ptrdiff_t *input_strides, const size_t *gate_shape, + const ptrdiff_t *gate_strides, size_t output_size, int ndim, + bool fast_path, bool out_contiguous) { + cnrtDim3_t kernel_dim; + cnrtFunctionType_t kernel_type; + + kernel_dim.x = core_per_cluster; + kernel_dim.y = cluster_count; + kernel_dim.z = 1; + kernel_type = cnrtFuncTypeUnion1; + + auto out_ = reinterpret_cast(out); + auto input_ = reinterpret_cast(input); + auto gate_ = reinterpret_cast(gate); + + char *tmp = reinterpret_cast(workspace); + size_t *mlu_out_shape = reinterpret_cast(tmp); + size_t *mlu_input_shape = mlu_out_shape + ndim; + size_t *mlu_gate_shape = mlu_input_shape + ndim; + ptrdiff_t *mlu_out_strides = + reinterpret_cast(mlu_gate_shape + ndim); + ptrdiff_t *mlu_input_strides = mlu_out_strides + ndim; + ptrdiff_t *mlu_gate_strides = mlu_input_strides + ndim; + + CNRT_CHECK(cnrtMemcpyAsync(mlu_out_shape, const_cast(out_shape), + ndim * sizeof(size_t), queue, + cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_input_shape, const_cast(input_shape), + ndim * sizeof(size_t), queue, + cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_gate_shape, const_cast(gate_shape), + ndim * sizeof(size_t), queue, + cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_out_strides, + const_cast(out_strides), + ndim * sizeof(ptrdiff_t), queue, + cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_input_strides, + const_cast(input_strides), + ndim * sizeof(ptrdiff_t), queue, + cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_gate_strides, + const_cast(gate_strides), + ndim * sizeof(ptrdiff_t), queue, + cnrtMemcpyHostToDev)); + + SwigluKernel<<>>( + input_, gate_, out_, mlu_out_shape, mlu_out_strides, mlu_input_shape, + mlu_input_strides, mlu_gate_shape, mlu_gate_strides, output_size, ndim, + fast_path, out_contiguous); + + cnrtQueueSync(queue); +} + +template void SwigluUnion<__half>(void *, int, int, cnrtQueue_t, void *, + const void *, const void *, const size_t *, + const ptrdiff_t *, const size_t *, + const ptrdiff_t *, const size_t *, + const ptrdiff_t *, size_t, int, bool, bool); + +template void SwigluUnion<__bang_bfloat16>(void *, int, int, cnrtQueue_t, void *, + const void *, const void *, + const size_t *, const ptrdiff_t *, + const size_t *, const ptrdiff_t *, + const size_t *, const ptrdiff_t *, + size_t, int, bool, bool); + +template void SwigluUnion(void *, int, int, cnrtQueue_t, void *, + const void *, const void *, const size_t *, + const ptrdiff_t *, const size_t *, + const ptrdiff_t *, const size_t *, + const ptrdiff_t *, size_t, int, bool, bool); + +} // namespace infini::ops diff --git a/src/cambricon/swiglu/swiglu.h b/src/cambricon/swiglu/swiglu.h new file mode 100644 index 0000000..01bb89d --- /dev/null +++ b/src/cambricon/swiglu/swiglu.h @@ -0,0 +1,66 @@ +#ifndef INFINI_OPS_CAMBRICON_SWIGLU_SWIGLU_H_ +#define INFINI_OPS_CAMBRICON_SWIGLU_SWIGLU_H_ + +#include "cambricon/common.h" +#include "base/swiglu.h" + +namespace infini::ops { + +template +void SwigluUnion(void *workspace, int core_per_cluster, int cluster_count, + cnrtQueue_t queue, void *out, const void *input, + const void *gate, const size_t *out_shape, + const ptrdiff_t *out_strides, const size_t *input_shape, + const ptrdiff_t *input_strides, const size_t *gate_shape, + const ptrdiff_t *gate_strides, size_t output_size, int ndim, + bool fast_path, bool out_contiguous); + +template <> +class Operator : public Swiglu { + public: + Operator(const Tensor input, const Tensor gate, Tensor out) + : Swiglu{input, gate, out} { + cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster, + &cluster_count); + cnrtMalloc(&default_workspace_, workspace_size_in_bytes()); + } + + void operator()(const Tensor input, const Tensor gate, + Tensor out) const override { + auto queue = static_cast(stream_ ? stream_ : 0); + auto workspace{workspace_ ? workspace_ : default_workspace_}; + + bool fast_path = is_input_contiguous_ && is_gate_contiguous_ && + is_out_contiguous_ && input_shape_ == out_shape_ && + gate_shape_ == out_shape_; + + DispatchFunc>( + out_type_, + [&](auto tag) { + using T = typename decltype(tag)::type; + SwigluUnion(workspace, core_per_cluster, cluster_count, queue, + out.data(), input.data(), gate.data(), + out_shape_.data(), out_strides_.data(), + input_shape_.data(), input_strides_.data(), + gate_shape_.data(), gate_strides_.data(), + output_size_, ndim_, fast_path, + is_out_contiguous_); + }, + "CambriconSwiglu::operator() - output dispatch"); + } + + ~Operator() { cnrtFree(default_workspace_); } + + std::size_t workspace_size_in_bytes() const override { + return ndim_ * (3 * sizeof(size_t) + 3 * sizeof(ptrdiff_t)); + } + + void *default_workspace_{nullptr}; + int core_per_cluster = 0; + int cluster_count = 0; +}; + +} // namespace infini::ops + +#endif