Skip to content

Commit

Permalink
Revert "[TOSA] Add StatefulOps to TOSA Dialect (llvm#66843)"
Browse files Browse the repository at this point in the history
This reverts commit af972f0.
  • Loading branch information
Groverkss committed Oct 19, 2023
1 parent ad80af9 commit 0636b72
Show file tree
Hide file tree
Showing 10 changed files with 19 additions and 281 deletions.
4 changes: 2 additions & 2 deletions mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
void addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
// Note: Default to 'none' level unless otherwise specified.
tosa::TosaValidationOptions const &validationOptions = {
tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
tosa::ValidationOptions const &validationOptions =
tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None));

/// Populates conversion passes from TOSA dialect to Linalg dialect.
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
Expand Down
5 changes: 0 additions & 5 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ class PatternRewriter;

namespace tosa {

ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &attr);
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
Attribute attr);

#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"

} // namespace tosa
Expand Down
67 changes: 0 additions & 67 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,71 +79,4 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
}

//===----------------------------------------------------------------------===//
// Operator: variable
//===----------------------------------------------------------------------===//
def Tosa_VariableOp : Tosa_Op<"variable", []> {
let summary = "Defines a variable";

let description = [{
Defines a new TOSA variable. This is a mutable value.
Modifications are expressed using read/write semantics.
}];

let arguments = (ins
SymbolNameAttr:$name,
TypeAttr:$type,
OptionalAttr<AnyAttr>:$initial_value
);

let assemblyFormat = [{
$name
attr-dict
custom<TypeOrAttr>($type, $initial_value)
}];
}

//===----------------------------------------------------------------------===//
// Operator: variable.write
//===----------------------------------------------------------------------===//
def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
let summary = "write_buffer operator";

let description = [{
Assigns a value to pseudo-buffer resource holding a mutable tensor.
}];

let arguments = (ins
SymbolNameAttr:$name,
AnyType:$value
);

let assemblyFormat = [{
$name attr-dict `,` $value `:` type($value)
}];
}

//===----------------------------------------------------------------------===//
// Operator: variable.read
//===----------------------------------------------------------------------===//
def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
let summary = "read_buffer operator";

let description = [{
Reads the value from a pseudo-buffer resource holding a mutable tensor.
}];

let arguments = (ins
SymbolNameAttr:$name
);

let results = (outs
AnyType:$value
);

let assemblyFormat = [{
$name attr-dict `:` type($value)
}];
}

#endif // TOSA_UTIL_OPS
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ struct ValidationOptions {
}
};

std::unique_ptr<Pass> createTosaValidationPass(
ValidationOptions const &options = ValidationOptions());

#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"

Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,13 @@ def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
let cppNamespace = "mlir::tosa";
}

def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
let summary = "Validates TOSA dialect";
let description = [{
This pass validates if input TOSA operations match the specification for given
criteria, e.g. TOSA profile.
}];
let constructor = "createTosaValidationPass()";

let options = [
Option<"profile", "profile", "mlir::tosa::TosaProfileEnum",
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {

void mlir::tosa::addTosaToLinalgPasses(
OpPassManager &pm, const TosaToLinalgOptions &options,
tosa::TosaValidationOptions const &validationOptions) {
tosa::ValidationOptions const &validationOptions) {
// Optional decompositions are designed to benefit linalg.
if (!options.disableTosaDecompositions)
pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
Expand All @@ -90,6 +90,7 @@ void mlir::tosa::addTosaToLinalgPasses(
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
{options.aggressiveReduceConstant}));
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
pm.addNestedPass<func::FuncOp>(tosa::createTosaValidation(validationOptions));
pm.addNestedPass<func::FuncOp>(
tosa::createTosaValidationPass(validationOptions));
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
}
43 changes: 0 additions & 43 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,49 +146,6 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
return nullptr;
}

//===----------------------------------------------------------------------===//
// Parsers and printers
//===----------------------------------------------------------------------===//

ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &attr) {
if (succeeded(parser.parseOptionalEqual())) {
if (failed(parser.parseAttribute(attr))) {
return parser.emitError(parser.getCurrentLocation())
<< "expected attribute";
}
if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
typeAttr = TypeAttr::get(typedAttr.getType());
}
return success();
}

Type type;
if (failed(parser.parseColonType(type))) {
return parser.emitError(parser.getCurrentLocation()) << "expected type";
}
typeAttr = TypeAttr::get(type);

return success();
}

void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
Attribute attr) {
bool needsSpace = false;
auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
if (!typedAttr || typedAttr.getType() != type.getValue()) {
p << ": ";
p.printAttribute(type);
needsSpace = true; // subsequent attr value needs a space separator
}
if (attr) {
if (needsSpace)
p << ' ';
p << "= ";
p.printAttribute(attr);
}
}

//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
Expand Down
92 changes: 9 additions & 83 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"

#include <string>
#include <unordered_map>

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -99,13 +96,12 @@ static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0};
struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
public:
explicit TosaValidation() { populateConstantOperandChecks(); }
explicit TosaValidation(const TosaValidationOptions &options)
: TosaValidation() {
explicit TosaValidation(const ValidationOptions &options) : TosaValidation() {
this->profile = options.profile;
this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment;
this->level = options.level;
}
void runOnOperation() final;
void runOnOperation() override;

LogicalResult applyConstantOperandCheck(Operation *op) {
for (auto &checker : const_checkers) {
Expand All @@ -117,9 +113,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {

LogicalResult applyLevelCheck(Operation *op);

// check variable read/write data types against variable declarations
LogicalResult applyVariableCheck(Operation *op);

private:
void populateConstantOperandChecks() {
const_checkers.emplace_back(checkConstantOperandPad);
Expand Down Expand Up @@ -405,12 +398,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
}

bool CheckVariable(Operation *op);
bool CheckVariableReadOrWrite(Operation *op);

SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
tosa_level_t tosa_level;
DenseMap<const mlir::StringAttr *, mlir::Type> variables_map;
};

LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
Expand Down Expand Up @@ -438,69 +427,6 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
return success();
}

inline bool CompatibleTypes(const mlir::Type &type,
const mlir::Type &declared_type) {
// for now, simply use type equality comparison
return type == declared_type;
}

bool TosaValidation::CheckVariable(Operation *op) {
if (isa<mlir::tosa::VariableOp>(op)) {
auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));

if (variables_map.count(&name_attr)) {
op->emitOpError() << "name has already been declared";
return false;
}

auto type_attr = cast<mlir::TypeAttr>(op->getAttr("type"));
mlir::Type type = type_attr.getValue();

variables_map[&name_attr] = type;
}

return true;
}

bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
if (isa<mlir::tosa::VariableReadOp>(op) ||
isa<mlir::tosa::VariableWriteOp>(op)) {
auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));

if (!variables_map.count(&name_attr)) {
op->emitOpError() << "name has not been declared";
return false;
}

auto var_type = variables_map[&name_attr];

for (auto v : op->getOperands()) {
auto type = v.getType();
if (!CompatibleTypes(type, var_type)) {
op->emitOpError() << "operand type does not equal variable type";
return false;
}
}

for (auto v : op->getResults()) {
auto type = v.getType();
if (!CompatibleTypes(type, var_type)) {
op->emitOpError() << "result type does not equal variable type";
return false;
}
}
}

return true;
}

LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
return failure();
}
return success();
}

void TosaValidation::runOnOperation() {
configLevelAndProfile();
getOperation().walk([&](Operation *op) {
Expand All @@ -514,18 +440,18 @@ void TosaValidation::runOnOperation() {
}
}

// Some uses of TOSA rely on the constant operands of particular
// operations.
// Some uses of TOSA rely on the constant operands of particular operations.
if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
signalPassFailure();

// do level checks
if (failed(applyLevelCheck(op)))
signalPassFailure();

// do variable type checks
if (failed(applyVariableCheck(op)))
signalPassFailure();
});
}
} // namespace

std::unique_ptr<Pass>
mlir::tosa::createTosaValidationPass(ValidationOptions const &options) {
return std::make_unique<TosaValidation>(options);
}
45 changes: 0 additions & 45 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -203,48 +203,3 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}

// -----

func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
// expected-error@+1 {{'tosa.variable' op name has already been declared}}
tosa.variable @stored_var : tensor<1x4x8xi32>
return
}

// -----

func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
return
}

// -----

func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
%0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
return
}

// -----

func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
return
}

// -----

func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
return
}
Loading

0 comments on commit 0636b72

Please sign in to comment.