From 1e48844c3255fc7a35afe557c3e56903cdb21fb7 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Thu, 19 Oct 2023 22:21:37 +0530 Subject: [PATCH] Revert "[TOSA] Add StatefulOps to TOSA Dialect (#66843)" This reverts commit af972f01c01843a9ffe41ff496154267fa387a51. --- .../Conversion/TosaToLinalg/TosaToLinalg.h | 4 +- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 5 - .../mlir/Dialect/Tosa/IR/TosaUtilOps.td | 67 -------------- .../mlir/Dialect/Tosa/Transforms/Passes.h | 3 + .../mlir/Dialect/Tosa/Transforms/Passes.td | 3 +- .../TosaToLinalg/TosaToLinalgPass.cpp | 5 +- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 43 --------- .../Tosa/Transforms/TosaValidation.cpp | 92 ++----------------- mlir/test/Dialect/Tosa/invalid.mlir | 45 --------- mlir/test/Dialect/Tosa/variables.mlir | 33 ------- 10 files changed, 19 insertions(+), 281 deletions(-) delete mode 100644 mlir/test/Dialect/Tosa/variables.mlir diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h index c411010603ac61..d8d4027500f99c 100644 --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -35,8 +35,8 @@ std::unique_ptr createTosaToLinalgNamed(); void addTosaToLinalgPasses( OpPassManager &pm, const TosaToLinalgOptions &options, // Note: Default to 'none' level unless otherwise specified. - tosa::TosaValidationOptions const &validationOptions = { - tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None}); + tosa::ValidationOptions const &validationOptions = + tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None)); /// Populates conversion passes from TOSA dialect to Linalg dialect. void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns); diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h index a9bc3351f4cff0..555d9bea18ba4d 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -34,11 +34,6 @@ class PatternRewriter; namespace tosa { -ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, - Attribute &attr); -void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, - Attribute attr); - #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc" } // namespace tosa diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td index f9f25da1b649de..d75f5dffa8716c 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td @@ -79,71 +79,4 @@ def Tosa_YieldOp : Tosa_Op<"yield", [ let assemblyFormat = "$inputs attr-dict `:` type($inputs)"; } -//===----------------------------------------------------------------------===// -// Operator: variable -//===----------------------------------------------------------------------===// -def Tosa_VariableOp : Tosa_Op<"variable", []> { - let summary = "Defines a variable"; - - let description = [{ - Defines a new TOSA variable. This is a mutable value. - Modifications are expressed using read/write semantics. - }]; - - let arguments = (ins - SymbolNameAttr:$name, - TypeAttr:$type, - OptionalAttr:$initial_value - ); - - let assemblyFormat = [{ - $name - attr-dict - custom($type, $initial_value) - }]; -} - -//===----------------------------------------------------------------------===// -// Operator: variable.write -//===----------------------------------------------------------------------===// -def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> { - let summary = "write_buffer operator"; - - let description = [{ - Assigns a value to pseudo-buffer resource holding a mutable tensor. - }]; - - let arguments = (ins - SymbolNameAttr:$name, - AnyType:$value - ); - - let assemblyFormat = [{ - $name attr-dict `,` $value `:` type($value) - }]; -} - -//===----------------------------------------------------------------------===// -// Operator: variable.read -//===----------------------------------------------------------------------===// -def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> { - let summary = "read_buffer operator"; - - let description = [{ - Reads the value from a pseudo-buffer resource holding a mutable tensor. - }]; - - let arguments = (ins - SymbolNameAttr:$name - ); - - let results = (outs - AnyType:$value - ); - - let assemblyFormat = [{ - $name attr-dict `:` type($value) - }]; -} - #endif // TOSA_UTIL_OPS diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index fbfc56dfe2cf4f..940aed107e2f91 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -68,6 +68,9 @@ struct ValidationOptions { } }; +std::unique_ptr createTosaValidationPass( + ValidationOptions const &options = ValidationOptions()); + #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index a0f670de20150f..ac100a6d75c7c0 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -89,12 +89,13 @@ def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level", let cppNamespace = "mlir::tosa"; } -def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { +def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> { let summary = "Validates TOSA dialect"; let description = [{ This pass validates if input TOSA operations match the specification for given criteria, e.g. TOSA profile. }]; + let constructor = "createTosaValidationPass()"; let options = [ Option<"profile", "profile", "mlir::tosa::TosaProfileEnum", diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index 3c54f85b033b0b..718e34ced8d7e7 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -76,7 +76,7 @@ std::unique_ptr mlir::tosa::createTosaToLinalg() { void mlir::tosa::addTosaToLinalgPasses( OpPassManager &pm, const TosaToLinalgOptions &options, - tosa::TosaValidationOptions const &validationOptions) { + tosa::ValidationOptions const &validationOptions) { // Optional decompositions are designed to benefit linalg. if (!options.disableTosaDecompositions) pm.addNestedPass(tosa::createTosaOptionalDecompositions()); @@ -90,6 +90,7 @@ void mlir::tosa::addTosaToLinalgPasses( pm.addNestedPass(tosa::createTosaLayerwiseConstantFoldPass( {options.aggressiveReduceConstant})); pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); - pm.addNestedPass(tosa::createTosaValidation(validationOptions)); + pm.addNestedPass( + tosa::createTosaValidationPass(validationOptions)); pm.addNestedPass(tosa::createTosaToLinalg()); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index ff34183f9a030a..6db04fe38bcd35 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -146,49 +146,6 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, return nullptr; } -//===----------------------------------------------------------------------===// -// Parsers and printers -//===----------------------------------------------------------------------===// - -ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, - Attribute &attr) { - if (succeeded(parser.parseOptionalEqual())) { - if (failed(parser.parseAttribute(attr))) { - return parser.emitError(parser.getCurrentLocation()) - << "expected attribute"; - } - if (auto typedAttr = attr.dyn_cast()) { - typeAttr = TypeAttr::get(typedAttr.getType()); - } - return success(); - } - - Type type; - if (failed(parser.parseColonType(type))) { - return parser.emitError(parser.getCurrentLocation()) << "expected type"; - } - typeAttr = TypeAttr::get(type); - - return success(); -} - -void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, - Attribute attr) { - bool needsSpace = false; - auto typedAttr = attr.dyn_cast_or_null(); - if (!typedAttr || typedAttr.getType() != type.getValue()) { - p << ": "; - p.printAttribute(type); - needsSpace = true; // subsequent attr value needs a space separator - } - if (attr) { - if (needsSpace) - p << ' '; - p << "= "; - p.printAttribute(attr); - } -} - //===----------------------------------------------------------------------===// // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index d686ce125c1351..52885e69c3924f 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -14,9 +14,6 @@ #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" -#include -#include - #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/Builders.h" @@ -99,13 +96,12 @@ static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0}; struct TosaValidation : public tosa::impl::TosaValidationBase { public: explicit TosaValidation() { populateConstantOperandChecks(); } - explicit TosaValidation(const TosaValidationOptions &options) - : TosaValidation() { + explicit TosaValidation(const ValidationOptions &options) : TosaValidation() { this->profile = options.profile; - this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment; + this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment; this->level = options.level; } - void runOnOperation() final; + void runOnOperation() override; LogicalResult applyConstantOperandCheck(Operation *op) { for (auto &checker : const_checkers) { @@ -117,9 +113,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { LogicalResult applyLevelCheck(Operation *op); - // check variable read/write data types against variable declarations - LogicalResult applyVariableCheck(Operation *op); - private: void populateConstantOperandChecks() { const_checkers.emplace_back(checkConstantOperandPad); @@ -405,12 +398,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { } } - bool CheckVariable(Operation *op); - bool CheckVariableReadOrWrite(Operation *op); - SmallVector> const_checkers; tosa_level_t tosa_level; - DenseMap variables_map; }; LogicalResult TosaValidation::applyLevelCheck(Operation *op) { @@ -438,69 +427,6 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) { return success(); } -inline bool CompatibleTypes(const mlir::Type &type, - const mlir::Type &declared_type) { - // for now, simply use type equality comparison - return type == declared_type; -} - -bool TosaValidation::CheckVariable(Operation *op) { - if (isa(op)) { - auto name_attr = cast(op->getAttr("name")); - - if (variables_map.count(&name_attr)) { - op->emitOpError() << "name has already been declared"; - return false; - } - - auto type_attr = cast(op->getAttr("type")); - mlir::Type type = type_attr.getValue(); - - variables_map[&name_attr] = type; - } - - return true; -} - -bool TosaValidation::CheckVariableReadOrWrite(Operation *op) { - if (isa(op) || - isa(op)) { - auto name_attr = cast(op->getAttr("name")); - - if (!variables_map.count(&name_attr)) { - op->emitOpError() << "name has not been declared"; - return false; - } - - auto var_type = variables_map[&name_attr]; - - for (auto v : op->getOperands()) { - auto type = v.getType(); - if (!CompatibleTypes(type, var_type)) { - op->emitOpError() << "operand type does not equal variable type"; - return false; - } - } - - for (auto v : op->getResults()) { - auto type = v.getType(); - if (!CompatibleTypes(type, var_type)) { - op->emitOpError() << "result type does not equal variable type"; - return false; - } - } - } - - return true; -} - -LogicalResult TosaValidation::applyVariableCheck(Operation *op) { - if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) { - return failure(); - } - return success(); -} - void TosaValidation::runOnOperation() { configLevelAndProfile(); getOperation().walk([&](Operation *op) { @@ -514,18 +440,18 @@ void TosaValidation::runOnOperation() { } } - // Some uses of TOSA rely on the constant operands of particular - // operations. + // Some uses of TOSA rely on the constant operands of particular operations. if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op))) signalPassFailure(); // do level checks if (failed(applyLevelCheck(op))) signalPassFailure(); - - // do variable type checks - if (failed(applyVariableCheck(op))) - signalPassFailure(); }); } } // namespace + +std::unique_ptr +mlir::tosa::createTosaValidationPass(ValidationOptions const &options) { + return std::make_unique(options); +} diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 9233662e88db90..7c58bb10b9c5ed 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -203,48 +203,3 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor< : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> } - -// ----- - -func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () { - tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> - // expected-error@+1 {{'tosa.variable' op name has already been declared}} - tosa.variable @stored_var : tensor<1x4x8xi32> - return -} - -// ----- - -func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () { - tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> - // expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}} - %0 = tosa.variable.read @stored_var : tensor<2x4x8xi16> - return -} - -// ----- - -func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () { - tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> - // expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}} - %0 = tosa.variable.read @stored_var : tensor<1x4x8xi32> - return -} - -// ----- - -func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () { - tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> - // expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}} - tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16> - return -} - -// ----- - -func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () { - tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> - // expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}} - tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32> - return -} diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir deleted file mode 100644 index 9a26aa0bc8bf4d..00000000000000 --- a/mlir/test/Dialect/Tosa/variables.mlir +++ /dev/null @@ -1,33 +0,0 @@ -// RUN: mlir-opt %s | mlir-opt | FileCheck %s -// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s - - -// ----- -// CHECK-LABEL: @test_variable_scalar( -// CHECK-SAME: %[[ADD_VAL:.*]]: tensor) { -func.func @test_variable_scalar(%arg0: tensor) -> () { - // CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor - tosa.variable @stored_var = dense<3.14> : tensor - // CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor - %0 = tosa.variable.read @stored_var : tensor - // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor, tensor) -> tensor - %1 = "tosa.add"(%arg0, %0) : (tensor, tensor) -> tensor - // CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor - tosa.variable.write @stored_var, %1 : tensor - return -} - -// ----- -// CHECK-LABEL: @test_variable_tensor( -// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) { -func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () { - // CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> - tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> - // CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<2x4x8xi32> - %0 = tosa.variable.read @stored_var : tensor<2x4x8xi32> - // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> - %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> - // CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32> - tosa.variable.write @stored_var, %1 : tensor<2x4x8xi32> - return -}