Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions water/include/water/Dialect/Wave/IR/WaveAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ def WaveHyperparameterAttr : AttrDef<WaveDialect, "WaveHyperparameter"> {

let parameters = (ins "::mlir::DictionaryAttr":$mapping);
let assemblyFormat = "`<` $mapping `>`";
let genVerifyDecl = 1;

let extraClassDeclaration = [{
/// Get the concrete value for a symbol name, returns std::nullopt if not found
Expand Down
27 changes: 22 additions & 5 deletions water/lib/Dialect/Wave/IR/WaveAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,12 @@ WaveIndexMappingAttr WaveIndexMappingAttr::removeInput(Attribute input) const {
// WaveHyperparameterAttr
//===----------------------------------------------------------------------===//

/// Returns true if `attr` is a signless i64 IntegerAttr.
static bool isI64IntegerAttr(Attribute attr) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
return intAttr && intAttr.getType().isSignlessInteger(64);
}

std::optional<int64_t>
WaveHyperparameterAttr::getSymbolValue(StringRef symbolName) const {
DictionaryAttr mapping = getMapping();
Expand All @@ -500,6 +506,18 @@ bool WaveHyperparameterAttr::hasSymbol(StringRef symbolName) const {
return getMapping().get(symbolName) != nullptr;
}

LogicalResult
WaveHyperparameterAttr::verify(function_ref<InFlightDiagnostic()> emitError,
DictionaryAttr mapping) {
for (NamedAttribute attr : mapping) {
if (!isI64IntegerAttr(attr.getValue()))
return emitError() << "hyperparameter '" << attr.getName().getValue()
<< "' must be an i64 integer value, got "
<< attr.getValue();
}
return success();
}

//===----------------------------------------------------------------------===//
// WaveSymbolAttr
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -718,11 +736,10 @@ LogicalResult HardwareConstraintAttr::verify(
if (vectorShapes) {
for (NamedAttribute attr : vectorShapes) {
// TODO: verify that attr.getName() is a valid WaveSymbol
Attribute value = attr.getValue();

if (!isa<IntegerAttr>(value))
return emitError() << attr.getName()
<< " is not an IntegerAttr: " << attr.getValue();
if (!isI64IntegerAttr(attr.getValue()))
return emitError() << "vector_shapes entry '" << attr.getName().getValue()
<< "' must be an i64 integer value, got "
<< attr.getValue();
}
}

Expand Down
10 changes: 9 additions & 1 deletion water/test/Dialect/Wave/attr-constraint-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ func.func private @test_num_dimensions_mismatch1() attributes { wave.constraints

// -----

// expected-error @below {{"M" is not an IntegerAttr: "BLOCK_M"}}
// expected-error @below {{vector_shapes entry 'M' must be an i64 integer value, got "BLOCK_M"}}
#hw_constraint = #wave.hardware_constraint<threads_per_wave = 64,
waves_per_block = [1, 1, 1],
mma_type = #wave.mma_kind<f32_16x16x16_f16>,
Expand All @@ -18,6 +18,14 @@ func.func private @test_num_dimensions_mismatch2() attributes { wave.constraints

// -----

// expected-error @below {{vector_shapes entry 'M' must be an i64 integer value, got 1 : i32}}
#hw_constraint = #wave.hardware_constraint<threads_per_wave = 64,
waves_per_block = [1, 1, 1],
vector_shapes = {M = 1 : i32, N = 64}>
func.func private @test_vector_shapes_non_i64() attributes { wave.constraints = [#hw_constraint] }

// -----

#hw_constraint = #wave.hardware_constraint<threads_per_wave = 64>
#hw_constraint2 = #wave.hardware_constraint<threads_per_wave = 32>
// expected-error @below {{only one hardware constraint is allowed}}
Expand Down
5 changes: 5 additions & 0 deletions water/test/Dialect/Wave/ops-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,11 @@ module attributes { wave.hyperparameters = #wave.hyperparameters<{}> } {

// -----

// expected-error @below {{hyperparameter 'A' must be an i64 integer value, got 42 : i32}}
module attributes { wave.hyperparameters = #wave.hyperparameters<{A = 42 : i32}> } {}

// -----

module attributes { wave.hyperparameters = #wave.hyperparameters<{A = 42, C = 43}> } {
// expected-error @below {{region #0 block #0 argument type #0 uses symbolic value #wave.symbol<"B"> not provided as a hyperparameter}}
// expected-note @below {{available symbols: A, C}}
Expand Down
Loading