Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions commonir/src/target/codegen_commonir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,36 @@ bool AllZero(Array<PrimExpr> a) {
return true;
}

bool IsStaticUnitExtent(const PrimExpr &expr) {
if (const int64_t *expr_int = as_const_int(expr)) {
return *expr_int == 1;
}
return false;
}

Array<PrimExpr> GetSubviewResultShape(Array<PrimExpr> shape) {
Array<PrimExpr> result_shape;
for (const PrimExpr &extent : shape) {
if (!IsStaticUnitExtent(extent)) {
result_shape.push_back(extent);
}
}
return result_shape;
}

Array<PrimExpr> GetSubviewResultStride(Array<PrimExpr> shape,
Array<PrimExpr> stride) {
ICHECK(stride.empty() || stride.size() == shape.size())
<< "Subview shape and stride rank mismatch";
Array<PrimExpr> result_stride;
for (int i = 0; i < shape.size(); ++i) {
if (!IsStaticUnitExtent(shape[i]) && !stride.empty()) {
result_stride.push_back(stride[i]);
}
}
return result_stride;
}

std::vector<unsigned long> GetStrideFromShape(Array<tvm::PrimExpr> shape) {
std::vector<unsigned long> strides;
unsigned long total_size = 1;
Expand Down Expand Up @@ -720,10 +750,13 @@ String CodeGenTileLangCOMMONIR::GenSubviewFromRegion(Buffer buffer_data,
Array<String> cast_offset_array = GenConvertIndex(region_indeces);
Array<String> cast_shape_array = GenConvertIndex(region_shape);
unsigned long offset = ComputeOffset(src_memref, region_indeces);
Array<PrimExpr> result_shape = GetSubviewResultShape(region_shape);
Array<PrimExpr> result_stride =
GetSubviewResultStride(region_shape, src_memref->stride);
new_buffer_name = buffer_name_val + "_subview";
auto tempMemref = new Memref(new_buffer_name, region_shape, buffer_type,
auto tempMemref = new Memref(new_buffer_name, result_shape, buffer_type,
src_memref->address_space, offset == -1,
src_memref->stride, offset);
result_stride, offset);
String dst_data_info = GetMemrefInfo(tempMemref);
temp << "memref.subview \%" + buffer_name_val;
temp << "[";
Expand Down Expand Up @@ -757,8 +790,8 @@ String CodeGenTileLangCOMMONIR::GenSubviewFromRegion(Buffer buffer_data,
delete tempMemref;
new_buffer_name = SSAGetID(temp.str(), buffer_type);
this->type_info[new_buffer_name] = new Memref(
new_buffer_name, region_shape, buffer_type, src_memref->address_space,
offset == -1, src_memref->stride, offset);
new_buffer_name, result_shape, buffer_type, src_memref->address_space,
offset == -1, result_stride, offset);
}
return new_buffer_name;
}
Expand Down
Loading