From a1047a3e8a7f052ae76c38db4f7f6066bb092bd4 Mon Sep 17 00:00:00 2001 From: zhaochaoxing Date: Wed, 25 Mar 2026 08:05:11 +0000 Subject: [PATCH] fix commonir --- commonir/src/target/codegen_commonir.cc | 41 ++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/commonir/src/target/codegen_commonir.cc b/commonir/src/target/codegen_commonir.cc index a8d2f2d7..1687898f 100644 --- a/commonir/src/target/codegen_commonir.cc +++ b/commonir/src/target/codegen_commonir.cc @@ -96,6 +96,36 @@ bool AllZero(Array a) { return true; } +bool IsStaticUnitExtent(const PrimExpr &expr) { + if (const int64_t *expr_int = as_const_int(expr)) { + return *expr_int == 1; + } + return false; +} + +Array GetSubviewResultShape(Array shape) { + Array result_shape; + for (const PrimExpr &extent : shape) { + if (!IsStaticUnitExtent(extent)) { + result_shape.push_back(extent); + } + } + return result_shape; +} + +Array GetSubviewResultStride(Array shape, + Array stride) { + ICHECK(stride.empty() || stride.size() == shape.size()) + << "Subview shape and stride rank mismatch"; + Array result_stride; + for (int i = 0; i < shape.size(); ++i) { + if (!IsStaticUnitExtent(shape[i]) && !stride.empty()) { + result_stride.push_back(stride[i]); + } + } + return result_stride; +} + std::vector GetStrideFromShape(Array shape) { std::vector strides; unsigned long total_size = 1; @@ -720,10 +750,13 @@ String CodeGenTileLangCOMMONIR::GenSubviewFromRegion(Buffer buffer_data, Array cast_offset_array = GenConvertIndex(region_indeces); Array cast_shape_array = GenConvertIndex(region_shape); unsigned long offset = ComputeOffset(src_memref, region_indeces); + Array result_shape = GetSubviewResultShape(region_shape); + Array result_stride = + GetSubviewResultStride(region_shape, src_memref->stride); new_buffer_name = buffer_name_val + "_subview"; - auto tempMemref = new Memref(new_buffer_name, region_shape, buffer_type, + auto tempMemref = new Memref(new_buffer_name, result_shape, buffer_type, src_memref->address_space, offset == -1, - src_memref->stride, offset); + result_stride, offset); String dst_data_info = GetMemrefInfo(tempMemref); temp << "memref.subview \%" + buffer_name_val; temp << "["; @@ -757,8 +790,8 @@ String CodeGenTileLangCOMMONIR::GenSubviewFromRegion(Buffer buffer_data, delete tempMemref; new_buffer_name = SSAGetID(temp.str(), buffer_type); this->type_info[new_buffer_name] = new Memref( - new_buffer_name, region_shape, buffer_type, src_memref->address_space, - offset == -1, src_memref->stride, offset); + new_buffer_name, result_shape, buffer_type, src_memref->address_space, + offset == -1, result_stride, offset); } return new_buffer_name; }