diff --git a/water/include/water/Dialect/Wave/IR/WaveAttrs.td b/water/include/water/Dialect/Wave/IR/WaveAttrs.td index a459c94d2d..bcf18e0188 100644 --- a/water/include/water/Dialect/Wave/IR/WaveAttrs.td +++ b/water/include/water/Dialect/Wave/IR/WaveAttrs.td @@ -506,6 +506,7 @@ def WaveHyperparameterAttr : AttrDef { let parameters = (ins "::mlir::DictionaryAttr":$mapping); let assemblyFormat = "`<` $mapping `>`"; + let genVerifyDecl = 1; let extraClassDeclaration = [{ /// Get the concrete value for a symbol name, returns std::nullopt if not found diff --git a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp index 6f08d6d78f..be11353aa4 100644 --- a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp +++ b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp @@ -483,6 +483,12 @@ WaveIndexMappingAttr WaveIndexMappingAttr::removeInput(Attribute input) const { // WaveHyperparameterAttr //===----------------------------------------------------------------------===// +/// Returns true if `attr` is a signless i64 IntegerAttr. +static bool isI64IntegerAttr(Attribute attr) { + auto intAttr = dyn_cast(attr); + return intAttr && intAttr.getType().isSignlessInteger(64); +} + std::optional WaveHyperparameterAttr::getSymbolValue(StringRef symbolName) const { DictionaryAttr mapping = getMapping(); @@ -500,6 +506,18 @@ bool WaveHyperparameterAttr::hasSymbol(StringRef symbolName) const { return getMapping().get(symbolName) != nullptr; } +LogicalResult +WaveHyperparameterAttr::verify(function_ref emitError, + DictionaryAttr mapping) { + for (NamedAttribute attr : mapping) { + if (!isI64IntegerAttr(attr.getValue())) + return emitError() << "hyperparameter '" << attr.getName().getValue() + << "' must be an i64 integer value, got " + << attr.getValue(); + } + return success(); +} + //===----------------------------------------------------------------------===// // WaveSymbolAttr //===----------------------------------------------------------------------===// @@ -718,11 +736,10 @@ LogicalResult HardwareConstraintAttr::verify( if (vectorShapes) { for (NamedAttribute attr : vectorShapes) { // TODO: verify that attr.getName() is a valid WaveSymbol - Attribute value = attr.getValue(); - - if (!isa(value)) - return emitError() << attr.getName() - << " is not an IntegerAttr: " << attr.getValue(); + if (!isI64IntegerAttr(attr.getValue())) + return emitError() << "vector_shapes entry '" << attr.getName().getValue() + << "' must be an i64 integer value, got " + << attr.getValue(); } } diff --git a/water/test/Dialect/Wave/attr-constraint-invalid.mlir b/water/test/Dialect/Wave/attr-constraint-invalid.mlir index 127b0d9f4e..90fd5889d9 100644 --- a/water/test/Dialect/Wave/attr-constraint-invalid.mlir +++ b/water/test/Dialect/Wave/attr-constraint-invalid.mlir @@ -9,7 +9,7 @@ func.func private @test_num_dimensions_mismatch1() attributes { wave.constraints // ----- -// expected-error @below {{"M" is not an IntegerAttr: "BLOCK_M"}} +// expected-error @below {{vector_shapes entry 'M' must be an i64 integer value, got "BLOCK_M"}} #hw_constraint = #wave.hardware_constraint, @@ -18,6 +18,14 @@ func.func private @test_num_dimensions_mismatch2() attributes { wave.constraints // ----- +// expected-error @below {{vector_shapes entry 'M' must be an i64 integer value, got 1 : i32}} +#hw_constraint = #wave.hardware_constraint +func.func private @test_vector_shapes_non_i64() attributes { wave.constraints = [#hw_constraint] } + +// ----- + #hw_constraint = #wave.hardware_constraint #hw_constraint2 = #wave.hardware_constraint // expected-error @below {{only one hardware constraint is allowed}} diff --git a/water/test/Dialect/Wave/ops-invalid.mlir b/water/test/Dialect/Wave/ops-invalid.mlir index f661d4ffe4..20e3d8be7f 100644 --- a/water/test/Dialect/Wave/ops-invalid.mlir +++ b/water/test/Dialect/Wave/ops-invalid.mlir @@ -481,6 +481,11 @@ module attributes { wave.hyperparameters = #wave.hyperparameters<{}> } { // ----- +// expected-error @below {{hyperparameter 'A' must be an i64 integer value, got 42 : i32}} +module attributes { wave.hyperparameters = #wave.hyperparameters<{A = 42 : i32}> } {} + +// ----- + module attributes { wave.hyperparameters = #wave.hyperparameters<{A = 42, C = 43}> } { // expected-error @below {{region #0 block #0 argument type #0 uses symbolic value #wave.symbol<"B"> not provided as a hyperparameter}} // expected-note @below {{available symbols: A, C}}