diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 585e3ab..76a317d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -133,6 +133,7 @@ if(WITH_CAMBRICON) message(STATUS "Found cncc: ${CNCC_COMPILER}") set(MLU_COMPILE_OPTS -c --bang-mlu-arch=mtp_592 -O3 -fPIC -Wall -Werror -std=c++17 -pthread + -DWITH_CAMBRICON=1 -I${CMAKE_CURRENT_SOURCE_DIR} -I${NEUWARE_HOME}/include -idirafter /usr/local/neuware/lib/clang/11.1.0/include ) diff --git a/src/cambricon/causal_softmax/causal_softmax.h b/src/cambricon/causal_softmax/causal_softmax.h new file mode 100644 index 0000000..33cde37 --- /dev/null +++ b/src/cambricon/causal_softmax/causal_softmax.h @@ -0,0 +1,60 @@ +#ifndef INFINI_OPS_CAMBRICON_CAUSAL_SOFTMAX_H +#define INFINI_OPS_CAMBRICON_CAUSAL_SOFTMAX_H + +#include "base/causal_softmax.h" +#include "cambricon/common.h" + +namespace infini::ops { + +// TODO: Remove forward declaration. +template +void CausalSoftmaxUnion(void *workspace, int core_per_cluster, + int cluster_count, cnrtQueue_t queue, void *y, + const void *x, size_t batch_size_, size_t seq_len_, + size_t total_seq_len_, ptrdiff_t y_stride_b, + ptrdiff_t y_stride_i, ptrdiff_t y_stride_j, + ptrdiff_t x_stride_b, ptrdiff_t x_stride_i, + ptrdiff_t x_stride_j); + +template <> +class Operator : public CausalSoftmax { + public: + Operator(const Tensor input, Tensor out) : CausalSoftmax{input, out} { + cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster, + &cluster_count); + } + void operator()(const Tensor input, Tensor out) const override { + auto queue = static_cast(stream_ ? stream_ : 0); + auto workspace{workspace_ ? workspace_ : default_workspace_}; + ptrdiff_t y_stride_b = ndim_ == 3 ? out_strides_[0] : 1; + ptrdiff_t y_stride_i = ndim_ == 3 ? out_strides_[1] : out_strides_[0]; + ptrdiff_t y_stride_j = ndim_ == 3 ? out_strides_[2] : out_strides_[1]; + ptrdiff_t x_stride_b = ndim_ == 3 ? input_strides_[0] : 1; + ptrdiff_t x_stride_i = ndim_ == 3 ? input_strides_[1] : input_strides_[0]; + ptrdiff_t x_stride_j = ndim_ == 3 ? input_strides_[2] : input_strides_[1]; + + DispatchFunc< + List>( + {input.dtype()}, + [&](auto input_tag) { + using InputT = typename decltype(input_tag)::type; + CausalSoftmaxUnion( + workspace, core_per_cluster, cluster_count, queue, out.data(), + input.data(), batch_size_, seq_len_, total_seq_len_, y_stride_b, + y_stride_i, y_stride_j, x_stride_b, x_stride_i, x_stride_j); + }, + "CambriconCausalSoftmax::operator() - output dispatch"); + } + + std::size_t workspace_size_in_bytes() const override { return 0; } + + ~Operator() {} + + void *default_workspace_{nullptr}; + int core_per_cluster = 0; + int cluster_count = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cambricon/causal_softmax/kernel.mlu b/src/cambricon/causal_softmax/kernel.mlu new file mode 100644 index 0000000..3bced48 --- /dev/null +++ b/src/cambricon/causal_softmax/kernel.mlu @@ -0,0 +1,164 @@ +#include "causal_softmax.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; +const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 4; + +namespace infini::ops { + +template +__mlu_func__ void ProcessSoftmaxStep(const T *input, T *output, float scalar, + int num_elements, int stride, + bool is_exp_phase) { + constexpr bool is_half = std::is_same_v; + constexpr bool is_bfloat16 = std::is_same_v; + constexpr bool is_float = !is_half && !is_bfloat16; + + const int chunk_size = + SRC_MAX_SIZE / + ((is_half || is_bfloat16) ? (2 * sizeof(float)) : sizeof(float)); + float *float_buffer = (float *)nram_buffer; + T *temp_buffer = + is_float ? nullptr : (T *)(nram_buffer + chunk_size * sizeof(float)); + + // Common stride configurations. + const int src_stride = stride * sizeof(T); + const int dst_stride = stride * sizeof(T); + + int processed = 0; + while (processed < num_elements) { + int curr_batch = std::min(chunk_size, num_elements - processed); + + if constexpr (is_float) { + __memcpy( + float_buffer, (is_exp_phase ? input : output) + processed * stride, + sizeof(float), GDRAM2NRAM, sizeof(float), src_stride, curr_batch - 1); + } else { + __memcpy(temp_buffer, + (is_exp_phase ? input : output) + processed * stride, sizeof(T), + GDRAM2NRAM, sizeof(T), src_stride, curr_batch - 1); + + if constexpr (is_half) { + __bang_half2float(float_buffer, reinterpret_cast(temp_buffer), + curr_batch); + } else if constexpr (is_bfloat16) { + __bang_bfloat162float(float_buffer, temp_buffer, curr_batch); + } + } + + // Common processing for all types. + if (is_exp_phase) { + __bang_sub_scalar(float_buffer, float_buffer, scalar, + curr_batch); // scalar is max_val + __bang_active_exphp(float_buffer, float_buffer, curr_batch); + } else { + __bang_mul_scalar(float_buffer, float_buffer, scalar, + curr_batch); // scalar is 1.0f/sum_val + } + + if constexpr (is_float) { + __memcpy(output + processed * stride, float_buffer, sizeof(float), + NRAM2GDRAM, dst_stride, sizeof(float), curr_batch - 1); + } else { + if constexpr (is_half) { + __bang_float2half(reinterpret_cast(temp_buffer), float_buffer, + curr_batch); + } else if constexpr (is_bfloat16) { + __bang_float2bfloat16(temp_buffer, float_buffer, curr_batch); + } + + __memcpy(output + processed * stride, temp_buffer, sizeof(T), NRAM2GDRAM, + dst_stride, sizeof(T), curr_batch - 1); + } + + processed += curr_batch; + } +} + +template +__mlu_global__ void CausalSoftmax(T *y, const T *x, size_t batch_size, + size_t seq_len, size_t total_seq_len, + ptrdiff_t y_stride_b, ptrdiff_t y_stride_i, + ptrdiff_t y_stride_j, ptrdiff_t x_stride_b, + ptrdiff_t x_stride_i, ptrdiff_t x_stride_j) { + size_t task_id = taskId; + size_t task_num = taskDimX * taskDimY; + + size_t total_tasks = batch_size * seq_len; + size_t tasks_per_core = (total_tasks + task_num - 1) / task_num; + size_t start = task_id * tasks_per_core; + size_t end = std::min(start + tasks_per_core, total_tasks); + + const int max_batch = SRC_MAX_SIZE / sizeof(T); + T *src = (T *)nram_buffer; + float *dst = (float *)(nram_buffer + max_batch * sizeof(T)); + + for (size_t index = start; index < end; index++) { + size_t batch = index / seq_len; + size_t i = (index % seq_len); + ptrdiff_t y_offset = batch * y_stride_b + i * y_stride_i; + ptrdiff_t x_offset = batch * x_stride_b + i * x_stride_i; + T *y_ = y + y_offset; + const T *x_ = x + x_offset; + + // Calculate the valid sequence length for this position. + size_t valid_len = total_seq_len - seq_len + i + 1; + + // Zero out future positions. + for (size_t j = valid_len; j < total_seq_len; j++) { + y_[j * y_stride_j] = (T)0.0f; + } + + // Calculate max value using optimized reduction. + float max_val = + infini::ops::reduce::MaxBatched(x_, src, dst, valid_len, max_batch); + + // Compute `exp(x - max)`. + ProcessSoftmaxStep(x_, y_, max_val, valid_len, x_stride_j, true); + + // Calculate sum of exponentials. + float sum_val = + infini::ops::reduce::SumBatched(y_, src, dst, valid_len, max_batch); + + // Normalize by sum. + ProcessSoftmaxStep(y_, y_, 1.0f / sum_val, valid_len, y_stride_j, false); + } +} + +template +void CausalSoftmaxUnion(void *workspace, int core_per_cluster, + int cluster_count, cnrtQueue_t queue, void *y, + const void *x, size_t batch_size_, size_t seq_len_, + size_t total_seq_len_, ptrdiff_t y_stride_b, + ptrdiff_t y_stride_i, ptrdiff_t y_stride_j, + ptrdiff_t x_stride_b, ptrdiff_t x_stride_i, + ptrdiff_t x_stride_j) { + cnrtDim3_t kernel_dim; + cnrtFunctionType_t kernel_type; + + kernel_dim.x = core_per_cluster; + kernel_dim.y = cluster_count; + kernel_dim.z = 1; + kernel_type = cnrtFuncTypeUnion1; + + CausalSoftmax<<>>( + (T *)y, (const T *)x, batch_size_, seq_len_, total_seq_len_, y_stride_b, + y_stride_i, y_stride_j, x_stride_b, x_stride_i, x_stride_j); + + cnrtQueueSync(queue); +} + +template void CausalSoftmaxUnion<__half>(void *, int, int, cnrtQueue_t, void *, + const void *, size_t, size_t, size_t, + ptrdiff_t, ptrdiff_t, ptrdiff_t, + ptrdiff_t, ptrdiff_t, ptrdiff_t); + +template void CausalSoftmaxUnion<__bang_bfloat16>( + void *, int, int, cnrtQueue_t, void *, const void *, size_t, size_t, size_t, + ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t); + +template void CausalSoftmaxUnion(void *, int, int, cnrtQueue_t, void *, + const void *, size_t, size_t, size_t, + ptrdiff_t, ptrdiff_t, ptrdiff_t, + ptrdiff_t, ptrdiff_t, ptrdiff_t); + +} // namespace infini::ops diff --git a/src/cambricon/common.h b/src/cambricon/common.h index fc8ede0..8e94333 100644 --- a/src/cambricon/common.h +++ b/src/cambricon/common.h @@ -15,7 +15,7 @@ namespace infini::ops::reduce { constexpr int batch_size = 128 / sizeof(float); -__mlu_func__ void SumInternal(float* dst, float* src, int max_batch) { +__mlu_func__ void SumInternal(float *dst, float *src, int max_batch) { const int width = max_batch / batch_size; if (width >= 4) { @@ -30,6 +30,164 @@ __mlu_func__ void SumInternal(float* dst, float* src, int max_batch) { } } +template +__mlu_func__ void SumTyped(float *result, T *data, size_t len) { + if constexpr (std::is_same_v) { + __bang_half2float((float *)data, reinterpret_cast(data) + len, len); + SumInternal(result, (float *)data, len); + } else if constexpr (std::is_same_v) { + __bang_bfloat162float((float *)data, data + len, len); + SumInternal(result, (float *)data, len); + } else { + SumInternal(result, data, len); + } +} + +template +__mlu_func__ float Sum(const T *source, T *src, float *dst, int num_elements, + int max_batch) { + float res = 0.0f; + int offset = (sizeof(T) == 2 ? max_batch : 0); + + size_t processed = 0; + while (processed < num_elements) { + size_t curr_batch = std::min(max_batch, num_elements - processed); + + if (curr_batch < max_batch) { + __bang_write_value(src, max_batch + offset, 0); + } + + __memcpy(src + offset, source + processed, curr_batch * sizeof(T), + GDRAM2NRAM); + SumTyped(dst, src, max_batch); + res += dst[0]; + processed += curr_batch; + } + + return res; +} + +template +__mlu_func__ float SumBatched(const T *source, T *src, float *dst, + int num_elements, int max_batch) { + constexpr int min_vector_size = 32; + + if (num_elements < min_vector_size) { + return Sum(source, src, dst, num_elements, max_batch); + } + + float res = 0.0f; + int offset = (sizeof(T) == 2 ? max_batch : 0); + + size_t processed = 0; + while (processed < num_elements) { + size_t curr_batch = std::min(max_batch, num_elements - processed); + size_t aligned_batch = (curr_batch / batch_size) * batch_size; + size_t remainder = curr_batch % batch_size; + + // Ensure NRAM buffer is zeroed. + __bang_write_value(src, max_batch + offset, 0); + + // Copy data to NRAM. + __memcpy(src + offset, source + processed, curr_batch * sizeof(T), + GDRAM2NRAM); + + if constexpr (std::is_same_v) { + __bang_half2float((float *)(src + offset), + reinterpret_cast(src) + offset, curr_batch); + } else if constexpr (std::is_same_v) { + __bang_bfloat162float((float *)(src + offset), src + offset, curr_batch); + } + + if (aligned_batch > 0) { + SumInternal(dst, (float *)(src + offset), aligned_batch); + res += dst[0]; + } + if (remainder > 0) { + for (size_t i = aligned_batch; i < curr_batch; ++i) { + res += ((float *)(src + offset))[i]; + } + } + + processed += curr_batch; + } + + return res; +} + +__mlu_func__ void MaxInternal(float *dst, float *src, int max_batch) { + __bang_maxpool(dst, src, batch_size, 1, max_batch / batch_size, 1, + max_batch / batch_size, 1, 1); + __bang_argmax(dst, dst, batch_size); +} + +template +__mlu_func__ void MaxTyped(float *result, T *data, size_t len) { + if constexpr (std::is_same_v) { + __bang_half2float((float *)data, reinterpret_cast(data) + len, len); + MaxInternal(result, (float *)data, len); + } else if constexpr (std::is_same_v) { + __bang_bfloat162float((float *)data, data + len, len); + MaxInternal(result, (float *)data, len); + } else { + MaxInternal(result, data, len); + } +} + +template +__mlu_func__ float Max(const T *source, T *src, float *dst, int num_elements, + int max_batch) { + float max_val = -INFINITY; + int offset = (sizeof(T) == 2 ? max_batch : 0); + + size_t processed = 0; + while (processed < num_elements) { + size_t curr_batch = std::min(max_batch, num_elements - processed); + + if (curr_batch < max_batch) { + __bang_write_value(src, max_batch + offset, 0); + } + + __memcpy(src + offset, source + processed, curr_batch * sizeof(T), + GDRAM2NRAM); + MaxTyped(dst, src, max_batch); + max_val = std::max(max_val, dst[0]); + processed += curr_batch; + } + + return max_val; +} + +template +__mlu_func__ float MaxBatched(const T *source, T *src, float *dst, + int num_elements, int max_batch) { + constexpr int min_vector_size = 32; + + if (num_elements < min_vector_size) { + return Max(source, src, dst, num_elements, max_batch); + } + + float max_val = -INFINITY; + int offset = (sizeof(T) == 2 ? max_batch : 0); + + size_t processed = 0; + while (processed < num_elements) { + size_t curr_batch = std::min(max_batch, num_elements - processed); + + if (curr_batch < max_batch) { + __bang_write_value(src, max_batch + offset, 0); + } + + __memcpy(src + offset, source + processed, curr_batch * sizeof(T), + GDRAM2NRAM); + MaxTyped(dst, src, max_batch); + max_val = std::max(max_val, dst[0]); + processed += curr_batch; + } + + return max_val; +} + } // namespace infini::ops::reduce #endif // __BANG__ @@ -63,8 +221,8 @@ inline cnnlDataType_t GetDataType(DataType dtype) { namespace infini::ops::cnrt_utils { -inline void GetLaunchConfig(const Device& device, int* core_per_cluster, - int* cluster_count) { +inline void GetLaunchConfig(const Device &device, int *core_per_cluster, + int *cluster_count) { int device_id = device.index(); cnrtDeviceGetAttribute(cluster_count, cnrtAttrClusterCount, device_id); cnrtDeviceGetAttribute(core_per_cluster, cnrtAttrMcorePerCluster, device_id); diff --git a/src/cambricon/rms_norm/rms_norm.h b/src/cambricon/rms_norm/rms_norm.h index 0e331dd..35a7730 100644 --- a/src/cambricon/rms_norm/rms_norm.h +++ b/src/cambricon/rms_norm/rms_norm.h @@ -5,7 +5,7 @@ #include #include -#include "../common.h" +#include "cambricon/common.h" #include "base/rms_norm.h" namespace infini::ops { @@ -41,8 +41,8 @@ class Operator : public RmsNorm { using WeightT = typename decltype(weight_tag)::type; RmsNormUnion( - workspace, core_per_cluster, cluster_count, queue, - out.data(), input.data(), weight.data(), out_shape_.data(), + workspace, core_per_cluster, cluster_count, queue, out.data(), + input.data(), weight.data(), out_shape_.data(), out_strides_.data(), input_strides_.data(), eps, ndim_); }, "CambriconRmsNorm::operator() - output dispatch"); diff --git a/tests/test_causal_softmax.py b/tests/test_causal_softmax.py index 8b35457..d7840b1 100644 --- a/tests/test_causal_softmax.py +++ b/tests/test_causal_softmax.py @@ -20,9 +20,9 @@ @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( - (torch.float32, 1e-5, 1e-5), - (torch.float16, 1e-2, 1e-2), - (torch.bfloat16, 1e-2, 1e-2), + (torch.float32, 1e-5, 3e-5), + (torch.float16, 1e-2, 1e-3), + (torch.bfloat16, 5e-2, 5e-3), ), ) def test_causal_softmax(shape, input_strides, out_strides, dtype, device, rtol, atol):