Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/Kernel/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ cc_test(
":KernelImplementation",
":RotationCountVisitor",
"@googletest//:gtest_main",
"@heir//lib/Utils/Layout:Convolution",
"@heir//lib/Utils/Layout:Evaluate",
"@heir//lib/Utils/Layout:Utils",
"@llvm-project//mlir:IR",
Expand Down
1 change: 1 addition & 0 deletions lib/Kernel/KernelImplementationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "lib/Kernel/KernelImplementation.h"
#include "lib/Kernel/KernelName.h"
#include "lib/Kernel/RotationCountVisitor.h"
#include "lib/Utils/Layout/Convolution.h"
#include "lib/Utils/Layout/Evaluate.h"
#include "lib/Utils/Layout/Utils.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
Expand Down
1 change: 1 addition & 0 deletions lib/Transforms/ConvertToCiphertextSemantics/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ cc_library(
"@heir//lib/Utils:ContextAwareTypeConversion",
"@heir//lib/Utils:MathUtils",
"@heir//lib/Utils/Layout:Codegen",
"@heir//lib/Utils/Layout:Convolution",
"@heir//lib/Utils/Layout:Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "lib/Utils/ContextAwareDialectConversion.h"
#include "lib/Utils/ContextAwareTypeConversion.h"
#include "lib/Utils/Layout/Codegen.h"
#include "lib/Utils/Layout/Convolution.h"
#include "lib/Utils/Layout/Utils.h"
#include "lib/Utils/MathUtils.h"
#include "lib/Utils/Utils.h"
Expand Down
1 change: 1 addition & 0 deletions lib/Transforms/LayoutPropagation/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ cc_library(
"@heir//lib/Dialect/TensorExt/Transforms:Patterns",
"@heir//lib/Kernel",
"@heir//lib/Utils:AttributeUtils",
"@heir//lib/Utils/Layout:Convolution",
"@heir//lib/Utils/Layout:Hoisting",
"@heir//lib/Utils/Layout:Utils",
"@llvm-project//llvm:Support",
Expand Down
1 change: 1 addition & 0 deletions lib/Transforms/LayoutPropagation/LayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "lib/Kernel/KernelName.h"
#include "lib/Transforms/LayoutPropagation/Utils.h"
#include "lib/Utils/AttributeUtils.h"
#include "lib/Utils/Layout/Convolution.h"
#include "lib/Utils/Layout/Hoisting.h"
#include "lib/Utils/Layout/Utils.h"
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
Expand Down
32 changes: 32 additions & 0 deletions lib/Utils/Layout/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ cc_test(
srcs = ["CodegenTest.cpp"],
deps = [
":Codegen",
":Convolution",
"@googletest//:gtest_main",
"@heir//lib/Utils/Layout:IslConversion",
"@heir//lib/Utils/Layout:Utils",
Expand Down Expand Up @@ -58,6 +59,18 @@ cc_library(
],
)

cc_library(
name = "Convolution",
srcs = ["Convolution.cpp"],
hdrs = ["Convolution.h"],
deps = [
":Utils",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)

cc_library(
name = "Evaluate",
srcs = ["Evaluate.cpp"],
Expand All @@ -77,6 +90,7 @@ cc_test(
name = "EvaluateTest",
srcs = ["EvaluateTest.cpp"],
deps = [
":Convolution",
":Evaluate",
":IslConversion",
":Utils",
Expand All @@ -91,6 +105,7 @@ cc_test(
name = "UtilsTest",
srcs = ["UtilsTest.cpp"],
deps = [
":Convolution",
":Evaluate",
":IslConversion",
":Utils",
Expand All @@ -104,6 +119,23 @@ cc_test(
],
)

cc_test(
name = "ConvolutionTest",
srcs = ["ConvolutionTest.cpp"],
deps = [
":Convolution",
":Evaluate",
":IslConversion",
":Utils",
"@googletest//:gtest_main",
"@heir//lib/Utils:TensorUtils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)

cc_library(
name = "Hoisting",
srcs = ["Hoisting.cpp"],
Expand Down
1 change: 1 addition & 0 deletions lib/Utils/Layout/CodegenTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "gmock/gmock.h" // from @googletest
#include "gtest/gtest.h" // from @googletest
#include "lib/Utils/Layout/Codegen.h"
#include "lib/Utils/Layout/Convolution.h"
#include "lib/Utils/Layout/IslConversion.h"
#include "lib/Utils/Layout/Utils.h"
#include "mlir/include/mlir/Analysis/Presburger/IntegerRelation.h" // from @llvm-project
Expand Down
260 changes: 260 additions & 0 deletions lib/Utils/Layout/Convolution.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
#include "lib/Utils/Layout/Convolution.h"

#include <cassert>
#include <cstdint>

#include "lib/Utils/Layout/Utils.h"
#include "mlir/include/mlir/Analysis/Presburger/IntegerRelation.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {

using presburger::BoundType;
using presburger::IntegerRelation;
using presburger::PresburgerSpace;
using presburger::VarKind;

presburger::IntegerRelation get2dConvFilterRelation(RankedTensorType filterType,
RankedTensorType dataType,
ArrayRef<int64_t> strides,
int64_t padding) {
auto domainSize = filterType.getRank();
assert(domainSize == 2 && "expected 2-D filter matrix");

IntegerRelation result(PresburgerSpace::getRelationSpace(
domainSize, /*numRange=*/2, /*numSymbol=*/0, /*numLocals=*/2));

// Filter row and column indices
auto filterRow = result.getVarKindOffset(VarKind::Domain);
auto filterCol = result.getVarKindOffset(VarKind::Domain) + 1;

// Matrix row and column indices
auto matRow = result.getVarKindOffset(VarKind::Range);
auto matCol = result.getVarKindOffset(VarKind::Range) + 1;

// Constant coefficient
auto constCoeff = result.getNumCols() - 1;

// Filter, datasize, and strides.
auto filterRowSize = filterType.getDimSize(0);
auto filterColSize = filterType.getDimSize(1);
auto dataRowSize = dataType.getDimSize(0);
auto dataColSize = dataType.getDimSize(1);
auto strideRow = strides[0];
auto strideCol = strides[1];

// These are the indices that represent the valid positions that the filter
// can move over the data. (0, 0) is the first position of (slidingRow,
// slidingCol).
auto slidingRow = result.getVarKindOffset(VarKind::Local);
auto slidingCol = result.getVarKindOffset(VarKind::Local) + 1;

// The maximum values for the sliding window indices.
auto slidingRowSize =
(dataRowSize + 2 * padding - filterRowSize) / strideRow + 1;
auto slidingColSize =
(dataColSize + 2 * padding - filterColSize) / strideCol + 1;

// Add bounds for the filter matrix dimensions.
addBounds(result, filterRow, 0, filterRowSize - 1);
addBounds(result, filterCol, 0, filterColSize - 1);

// Add bounds for the sliding window indices.
addBounds(result, slidingRow, 0, slidingRowSize - 1);
addBounds(result, slidingCol, 0, slidingColSize - 1);

// Define (dataRow, dataCol) to be the position on the data tensor for a given
// filter position (slidingRow, slidingCol) and a given filter index
// (filterRow, filterCol). E.g. the top left corner of the filter is at
// (filterRow, filterCol) = (0, 0) and the first position of the filter is at
// (slidingRow, slidingCol) = (0, 0). This corresponds to (-padding, -padding)
// on the data indices (dataRow, dataCol).
// dataRow = (slidingRow * strideRow - padding) + filterRow
// dataCol = (slidingCol * strideCol - padding) + filterCol

// Add constraints for when the filter sliding window index is at a valid
// data position. Require:
// 0 <= dataRow < dataRowSize and 0 <= dataCol < dataColSize.
// Substituting the expressions gives:
// 0 <= slidingRow * strideRow - padding + filterRow < dataRowSize
addConstraint(
result, {{slidingRow, strideRow}, {filterRow, 1}, {constCoeff, -padding}},
/*equality=*/false);
addConstraint(result,
{{constCoeff, dataRowSize + padding - 1},
{slidingRow, -strideRow},
{filterRow, -1}},
/*equality=*/false);

// 0 <= slidingCol * strideCol - padding + filterCol < dataColSize
addConstraint(
result, {{slidingCol, strideCol}, {filterCol, 1}, {constCoeff, -padding}},
/*equality=*/false);
addConstraint(result,
{{constCoeff, dataColSize + padding - 1},
{slidingCol, -strideCol},
{filterCol, -1}},
/*equality=*/false);

// Add equalities for the resulting matrix row and column. Each matrix row
// corresponds to one sliding window of the filter over the data. So flatten
// the filter sliding window indices (slidingRow, slidingCol):
// matRow = slidingRow * slidingColSize + slidingCol
addConstraint(result,
{{matRow, -1}, {slidingRow, slidingColSize}, {slidingCol, 1}},
/*equality=*/true);

// The matrix column is the flattened data indices:
// matCol = dataRow * dataColSize + dataCol
// matCol = (slidingRow * strideRow - padding + filterRow) * dataColSize +
// (slidingCol * strideCol - padding + filterCol)
// matCol = slidingRow * strideRow * dataColSize - padding * dataColSize +
// filterRow * dataColSize + slidingCol * strideCol - padding +
// filterCol
addConstraint(result,
{{matCol, -1},
{slidingRow, strideRow * dataColSize},
{slidingCol, strideCol},
{filterRow, dataColSize},
{filterCol, 1},
{constCoeff, -padding * dataColSize - padding}},
/*equality=*/true);
return result;
}

RankedTensorType get2dConvFilterExpandedType(RankedTensorType filterType,
RankedTensorType dataType,
int64_t padding,
ArrayRef<int64_t> strides) {
auto filterRowSize = filterType.getDimSize(0);
auto filterColSize = filterType.getDimSize(1);
auto dataRowSize = dataType.getDimSize(0);
auto dataColSize = dataType.getDimSize(1);
auto strideRow = strides[0];
auto strideCol = strides[1];

// Number of rows will be the filter sliding rows * filter sliding columns.
int64_t filterSlidingRows =
(dataRowSize + 2 * padding - filterRowSize) / strideRow + 1;
int64_t filterSlidingCols =
(dataColSize + 2 * padding - filterColSize) / strideCol + 1;
int64_t rows = filterSlidingRows * filterSlidingCols;

// Number of columns will be the number of data elements.
int64_t cols = dataType.getNumElements();

return RankedTensorType::get({rows, cols}, filterType.getElementType());
}

presburger::IntegerRelation get2dConvChwFchwFilterRelation(
RankedTensorType filterType, RankedTensorType dataType,
ArrayRef<int64_t> strides, int64_t padding) {
assert(filterType.getRank() == 4 && "expected 4-D filter matrix");
assert(dataType.getRank() == 3 && "expected 3-D data matrix");

// Get the filter relation for a single input and output channel.
RankedTensorType singleFilterType = RankedTensorType::get(
{filterType.getDimSize(2), filterType.getDimSize(3)},
filterType.getElementType());
RankedTensorType singleDataType =
RankedTensorType::get({dataType.getDimSize(1), dataType.getDimSize(2)},
dataType.getElementType());
auto singleFilterRelation = get2dConvFilterRelation(
singleFilterType, singleDataType, strides, padding);

// Map the single filter relation into the multi-channel matrix. Each single
// filter is offset into the result by adding (c * totalRowSize, f *
// totalColSize) to the range dimensions.

// First, add (f, c) to the domain vars and set bounds
singleFilterRelation.insertVar(VarKind::Domain, 0, 2);
auto fDim = singleFilterRelation.getVarKindOffset(VarKind::Domain);
auto cDim = singleFilterRelation.getVarKindOffset(VarKind::Domain) + 1;

auto inputChannels = dataType.getDimSize(0);
auto outputChannels = filterType.getDimSize(0);
assert(inputChannels == filterType.getDimSize(1) &&
"input channels must match filter input channels");
addBounds(singleFilterRelation, fDim, 0, outputChannels - 1);
addBounds(singleFilterRelation, cDim, 0, inputChannels - 1);

// Expand the range vars so that we can compose with the embedding relation.
singleFilterRelation.insertVar(VarKind::Range, 0, 2);
// (embedRow, embedCol) = position in the embedded matrix.
// (singleRow, singleCol) = position in the single filter matrix.
auto embedRow = singleFilterRelation.getVarKindOffset(VarKind::Range);
auto embedCol = singleFilterRelation.getVarKindOffset(VarKind::Range) + 1;
auto singleRow = singleFilterRelation.getVarKindOffset(VarKind::Range) + 2;
auto singleCol = singleFilterRelation.getVarKindOffset(VarKind::Range) + 3;

// embedRow = fDim * totalRowSize + singleRow
// embedCol = cDim * totalColSize + singleCol
auto singleResultType = get2dConvFilterExpandedType(
singleFilterType, singleDataType, padding, strides);
auto totalRowSize = singleResultType.getDimSize(0);
auto totalColSize = singleResultType.getDimSize(1);

addConstraint(singleFilterRelation,
{{embedRow, 1}, {fDim, -totalRowSize}, {singleRow, -1}}, true);
addConstraint(singleFilterRelation,
{{embedCol, 1}, {cDim, -totalColSize}, {singleCol, -1}}, true);
// Project out the single filter relation range vars.
singleFilterRelation.projectOut(singleRow, 2);

return singleFilterRelation;
}

FailureOr<presburger::IntegerRelation> get2dConvFilterDiagonalizedRelation(
RankedTensorType filterType, RankedTensorType dataType, int64_t padding,
int64_t ciphertextSize) {
SmallVector<int64_t> strides = {1, 1};
auto expandedFilterRelation =
get2dConvFilterRelation(filterType, dataType, strides, padding);
// Get size of the expanded filter matrix.
auto rowBound = expandedFilterRelation.getConstantBound64(
BoundType::UB, expandedFilterRelation.getVarKindOffset(VarKind::Range));
if (!rowBound.has_value()) {
return failure();
}
auto colBound = expandedFilterRelation.getConstantBound64(
BoundType::UB,
expandedFilterRelation.getVarKindOffset(VarKind::Range) + 1);
if (!colBound.has_value()) {
return failure();
}
RankedTensorType expandedFilterType =
RankedTensorType::get({rowBound.value() + 1, colBound.value() + 1},
filterType.getElementType());

auto diagonalizedFilterRelation =
getDiagonalLayoutRelation(expandedFilterType, ciphertextSize);

// Compose these relations.
expandedFilterRelation.compose(diagonalizedFilterRelation);
return expandedFilterRelation;
}

bool isRelation2dConvFilterDiagonalized(
RankedTensorType filterType, RankedTensorType dataType, int64_t padding,
int64_t ciphertextSize, const presburger::IntegerRelation& relation) {
auto diagonalizedRelation = get2dConvFilterDiagonalizedRelation(
filterType, dataType, padding, ciphertextSize);
if (failed(diagonalizedRelation)) {
return false;
}
bool fastCheck = relation.isObviouslyEqual(diagonalizedRelation.value());
if (fastCheck) return true;

LogicalResult inequalityTest =
tryProveUnequal(diagonalizedRelation.value(), relation);
if (succeeded(inequalityTest)) return false;

bool slowCheck = relation.isEqual(diagonalizedRelation.value());
return slowCheck;
}

} // namespace heir
} // namespace mlir
Loading
Loading