diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h index c8888f294f6ca1d..b155b110677d6c7 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h @@ -11,39 +11,71 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/STLExtras.h" +#include +#include namespace mlir { namespace transform { class MatchOpInterface; +namespace detail { +/// Dispatch `matchOperation` based on Operation* or std::optional +/// first operand. template -class SingleOpMatcherOpTrait - : public OpTrait::TraitBase { +DiagnosedSilenceableFailure matchOptionalOperation(OpTy op, + TransformResults &results, + TransformState &state) { + if constexpr (std::is_same_v< + typename llvm::function_traits< + decltype(&OpTy::matchOperation)>::template arg_t<0>, + Operation *>) { + return op.matchOperation(nullptr, results, state); + } else { + return op.matchOperation(std::nullopt, results, state); + } +} +} // namespace detail + +template +class AtMostOneOpMatcherOpTrait + : public OpTrait::TraitBase { template using has_get_operand_handle = decltype(std::declval().getOperandHandle()); template - using has_match_operation = decltype(std::declval().matchOperation( + using has_match_operation_ptr = decltype(std::declval().matchOperation( std::declval(), std::declval(), std::declval())); + template + using has_match_operation_optional = + decltype(std::declval().matchOperation( + std::declval>(), + std::declval(), + std::declval())); public: static LogicalResult verifyTrait(Operation *op) { static_assert(llvm::is_detected::value, - "SingleOpMatcherOpTrait expects operation type to have the " - "getOperandHandle() method"); - static_assert(llvm::is_detected::value, - "SingleOpMatcherOpTrait expected operation type to have the " - "matchOperation(Operation *, TransformResults &, " - "TransformState &) method"); + "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects " + "operation type to have the getOperandHandle() method"); + static_assert( + llvm::is_detected::value || + llvm::is_detected::value, + "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected operation " + "type to have either the matchOperation(Operation *, TransformResults " + "&, TransformState &) or the matchOperation(std::optional, " + "TransformResults &, TransformState &) method"); // This must be a dynamic assert because interface registration is dynamic. - assert(isa(op) && - "SingleOpMatchOpTrait is only available on operations with " - "MatchOpInterface"); + assert( + isa(op) && + "AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on " + "operations with MatchOpInterface"); Value operandHandle = cast(op).getOperandHandle(); if (!isa(operandHandle.getType())) { - return op->emitError() << "SingleOpMatchOpTrait requires the op handle " + return op->emitError() << "AtMostOneOpMatcherOpTrait/" + "SingleOpMatchOpTrait requires the op handle " "to be of TransformHandleTypeInterface"; } @@ -55,12 +87,15 @@ class SingleOpMatcherOpTrait TransformState &state) { Value operandHandle = cast(this->getOperation()).getOperandHandle(); auto payload = state.getPayloadOps(operandHandle); - if (!llvm::hasSingleElement(payload)) { + if (!llvm::hasNItemsOrLess(payload, 1)) { return emitDefiniteFailure(this->getOperation()->getLoc()) - << "SingleOpMatchOpTrait requires the operand handle to point to " - "a single payload op"; + << "AtMostOneOpMatcherOpTrait requires the operand handle to " + "point to at most one payload op"; + } + if (payload.empty()) { + return detail::matchOptionalOperation(cast(this->getOperation()), + results, state); } - return cast(this->getOperation()) .matchOperation(*payload.begin(), results, state); } @@ -72,12 +107,32 @@ class SingleOpMatcherOpTrait } }; +template +class SingleOpMatcherOpTrait : public AtMostOneOpMatcherOpTrait { + +public: + DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, + TransformResults &results, + TransformState &state) { + Value operandHandle = cast(this->getOperation()).getOperandHandle(); + auto payload = state.getPayloadOps(operandHandle); + if (!llvm::hasSingleElement(payload)) { + return emitDefiniteFailure(this->getOperation()->getLoc()) + << "SingleOpMatchOpTrait requires the operand handle to point to " + "a single payload op"; + } + return static_cast *>(this)->apply( + rewriter, results, state); + } +}; + template class SingleValueMatcherOpTrait : public OpTrait::TraitBase { public: static LogicalResult verifyTrait(Operation *op) { - // This must be a dynamic assert because interface registration is dynamic. + // This must be a dynamic assert because interface registration is + // dynamic. assert(isa(op) && "SingleValueMatchOpTrait is only available on operations with " "MatchOpInterface"); @@ -98,8 +153,8 @@ class SingleValueMatcherOpTrait auto payload = state.getPayloadValues(operandHandle); if (!llvm::hasSingleElement(payload)) { return emitDefiniteFailure(this->getOperation()->getLoc()) - << "SingleValueMatchOpTrait requires the value handle to point to " - "a single payload value"; + << "SingleValueMatchOpTrait requires the value handle to point " + "to a single payload value"; } return cast(this->getOperation()) diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td index 1f81fd5252eb45b..be92e4d91b42b32 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td @@ -14,11 +14,28 @@ def MatchOpInterface let cppNamespace = "::mlir::transform"; } +// Trait for "matcher" transform operations that apply to an operation handle +// associated with at most one payload operation. Checks that it is indeed +// the case and produces a definite failure when it is not. The matching logic +// is implemented in the `matchOperation` function instead of `apply`. The op +// with this trait must provide a `Value getOperandHandle()` function that +// returns the handle to be used for matching. +def AtMostOneOpMatcher : NativeOpTrait<"AtMostOneOpMatcherOpTrait"> { + let cppNamespace = "::mlir::transform"; + + string extraDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure matchOperation( + ::std::optional<::mlir::Operation *> maybeCurrent, + ::mlir::transform::TransformResults &results, + ::mlir::transform::TransformState &state); + }]; +} + // Trait for "matcher" transform operations that apply to an operation handle // associated with exactly one payload operation. Checks that it is indeed // the case and produces a definite failure when it is not. The matching logic // is implemented in the `matchOperation` function instead of `apply`. The op -// with this trait must provide a `Value getOperandHandle()` function that +// with this trait must provide a `Value getOperandHandle()` function that // returns the handle to be used for matching. def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> { let cppNamespace = "::mlir::transform"; @@ -35,7 +52,7 @@ def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> { // associated with exactly one payload value. Checks that it is indeed // the case and produces a definite failure when it is not. The matching logic // is implemented in the `matchValue` function instead of `apply`. The op -// with this trait must provide a `Value getOperandHandle()` function that +// with this trait must provide a `Value getOperandHandle()` function that // returns the handle to be used for matching. def SingleValueMatcher : NativeOpTrait<"SingleValueMatcherOpTrait"> { let cppNamespace = "::mlir::transform"; diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td index 3448e27a41a6804..70a76ab9670f907 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -20,107 +20,109 @@ def Transform_Dialect : Dialect { let hasOperationAttrVerify = 1; let extraClassDeclaration = [{ - /// Name of the attribute attachable to the symbol table operation - /// containing named sequences. This is used to trigger verification. - constexpr const static ::llvm::StringLiteral - kWithNamedSequenceAttrName = "transform.with_named_sequence"; - - /// Name of the attribute attachable to an operation so it can be - /// identified as root by the default interpreter pass. - constexpr const static ::llvm::StringLiteral - kTargetTagAttrName = "transform.target_tag"; - - /// Name of the attribute attachable to an operation, indicating that - /// TrackingListener failures should be silenced. - constexpr const static ::llvm::StringLiteral - kSilenceTrackingFailuresAttrName = "transform.silence_tracking_failures"; - - /// Names of the attributes indicating whether an argument of an external - /// transform dialect symbol is consumed or only read. - constexpr const static ::llvm::StringLiteral - kArgConsumedAttrName = "transform.consumed"; - constexpr const static ::llvm::StringLiteral - kArgReadOnlyAttrName = "transform.readonly"; - - template - const DataTy &getExtraData() const { - return *static_cast(extraData.at(::mlir::TypeID::get()).get()); - } - - /// Parses a type registered by this dialect or one of its extensions. - ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; - - /// Prints a type registered by this dialect or one of its extensions. - void printType(::mlir::Type type, - ::mlir::DialectAsmPrinter &printer) const override; - - /// Parser callback for an individual type registered by this dialect or - /// its extensions. - using ExtensionTypeParsingHook = ::mlir::Type (*)(::mlir::AsmParser &); - - /// Printer callback for an individual type registered by this dialect or - /// its extensions. - using ExtensionTypePrintingHook = - std::function; - - private: - /// Registers operations specified as template parameters with this - /// dialect. Checks that they implement the required interfaces. - template - void addOperationsChecked() { - (addOperationIfNotRegistered(), ...); - } - template - void addOperationIfNotRegistered(); - - /// Reports a repeated registration error of an op with the given name. - [[noreturn]] void reportDuplicateOpRegistration(StringRef opName); - - /// Registers the types specified as template parameters with the - /// Transform dialect. Checks that they meet the requirements for - /// Transform IR types. - template - void addTypesChecked() { - (addTypeIfNotRegistered(), ...); - } - template - void addTypeIfNotRegistered(); - - /// Reports a repeated registration error of a type with the given - /// mnemonic. - [[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic); - - /// Registers dialect types with the context. - void initializeTypes(); - - // Give extensions access to injection functions. - template - friend class TransformDialectExtension; - - /// Gets a mutable reference to extra data of the kind specified as - /// template argument. Allocates the data on the first call. - template - DataTy &getOrCreateExtraData(); - - //===----------------------------------------------------------------===// - // Data fields - //===----------------------------------------------------------------===// - - /// Additional data associated with and owned by the dialect. Accessible - /// to extensions. - ::llvm::DenseMap<::mlir::TypeID, std::unique_ptr< - ::mlir::transform::detail::TransformDialectDataBase>> - extraData; - - /// A map from type mnemonic to its parsing function for the remainder of - /// the syntax. The parser has access to the mnemonic, so it is used for - /// further dispatch. - ::llvm::StringMap typeParsingHooks; - - /// A map from type TypeID to its printing function. No need to do string - /// lookups when the type is fully constructed. - ::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook> - typePrintingHooks; + /// Name of the attribute attachable to the symbol table operation + /// containing named sequences. This is used to trigger verification. + constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName = + "transform.with_named_sequence"; + + /// Name of the attribute attachable to an operation so it can be + /// identified as root by the default interpreter pass. + constexpr const static ::llvm::StringLiteral kTargetTagAttrName = + "transform.target_tag"; + + /// Name of the attribute attachable to an operation, indicating that + /// TrackingListener failures should be silenced. + constexpr const static ::llvm::StringLiteral + kSilenceTrackingFailuresAttrName = + "transform.silence_tracking_failures"; + + /// Names of the attributes indicating whether an argument of an external + /// transform dialect symbol is consumed or only read. + constexpr const static ::llvm::StringLiteral kArgConsumedAttrName = + "transform.consumed"; + constexpr const static ::llvm::StringLiteral kArgReadOnlyAttrName = + "transform.readonly"; + + template + const DataTy &getExtraData() const { + return *static_cast( + extraData.at(::mlir::TypeID::get()).get()); + } + + /// Parses a type registered by this dialect or one of its extensions. + ::mlir::Type parseType(::mlir::DialectAsmParser & parser) const override; + + /// Prints a type registered by this dialect or one of its extensions. + void printType(::mlir::Type type, ::mlir::DialectAsmPrinter & printer) + const override; + + /// Parser callback for an individual type registered by this dialect or + /// its extensions. + using ExtensionTypeParsingHook = ::mlir::Type (*)(::mlir::AsmParser &); + + /// Printer callback for an individual type registered by this dialect or + /// its extensions. + using ExtensionTypePrintingHook = + std::function; + + private: + /// Registers operations specified as template parameters with this + /// dialect. Checks that they implement the required interfaces. + template + void addOperationsChecked() { + (addOperationIfNotRegistered(), ...); + } + template + void addOperationIfNotRegistered(); + + /// Reports a repeated registration error of an op with the given name. + [[noreturn]] void reportDuplicateOpRegistration(StringRef opName); + + /// Registers types specified as template parameters with the Transform + /// dialect. Checks that they meet the requirements for Transform IR types. + template + void addTypesChecked() { + (addTypeIfNotRegistered(), ...); + } + template + void addTypeIfNotRegistered(); + + /// Reports a repeated registration error of a type with the given + /// mnemonic. + [[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic); + + /// Registers dialect types with the context. + void initializeTypes(); + + // Give extensions access to injection functions. + template + friend class TransformDialectExtension; + + /// Gets a mutable reference to extra data of the kind specified as + /// template argument. Allocates the data on the first call. + template + DataTy &getOrCreateExtraData(); + + //===----------------------------------------------------------------===// + // Data fields + //===----------------------------------------------------------------===// + + /// Additional data associated with and owned by the dialect. Accessible + /// to extensions. + ::llvm::DenseMap< + ::mlir::TypeID, + std::unique_ptr<::mlir::transform::detail::TransformDialectDataBase>> + extraData; + + /// A map from type mnemonic to its parsing function for the remainder of + /// the syntax. The parser has access to the mnemonic, so it is used for + /// further dispatch. + ::llvm::StringMap typeParsingHooks; + + /// A map from type TypeID to its printing function. No need to do string + /// lookups when the type is fully constructed. + ::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook> + typePrintingHooks; }]; } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index ca5c915ef8c2caa..5bc92e8e954eae7 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -595,8 +595,9 @@ def GetDefiningOp : TransformDialectOp<"get_defining_op", def GetParentOp : TransformDialectOp<"get_parent_op", [DeclareOpInterfaceMethods, + MatchOpInterface, NavigationTransformOpTrait, MemoryEffectsOpInterface]> { - let summary = "Gets handles to the closest isolated-from-above parents"; + let summary = "Gets handles to the closest parent ops"; let description = [{ The handle defined by this Transform op corresponds to the parents of the targeted payload ops (in the same order). @@ -605,6 +606,9 @@ def GetParentOp : TransformDialectOp<"get_parent_op", that case for each target op, the closest parent op that fulfills all requirements, is returned. - `isolated_from_above`: the parent op must be isolated from above + - `allow_empty_results`: get_parent_op is allowed to return an empty list and + still succeeds. In such a case, if get_parent_op fails for any operation + in the list, the entire transform returns an empty handle. - `op_name`: the parent op must have the specified name If `deduplicate` is set, the result handle does not contain any duplicate @@ -614,12 +618,14 @@ def GetParentOp : TransformDialectOp<"get_parent_op", is applied, e.g., "B" may itself be a parent of "A". This may have an impact on the further transformation applied to the handle produced here. - If any of the given Payload IR ops has no such suitable parent, the - transformation fails silently. + If any of the given Payload IR ops has no such suitable parent, then: + - if `allow_empty_results` is set, the result handle is empty + - otherwise, the transformation produces a silenceable failure. }]; let arguments = (ins TransformHandleTypeInterface:$target, UnitAttr:$isolated_from_above, + UnitAttr:$allow_empty_results, OptionalAttr:$op_name, UnitAttr:$deduplicate); let results = (outs TransformHandleTypeInterface:$parent); @@ -739,6 +745,21 @@ def IncludeOp : TransformDialectOp<"include", }]; } +def MatchOperationEmptyOp : Op { + let summary = + "Matches if the handle is not associated to any op"; + let description = [{ + Succeeds if the handle is not associated to any op. + }]; + let arguments = (ins TransformHandleTypeInterface:$operand_handle); + let assemblyFormat = + "$operand_handle attr-dict `:` type($operand_handle)"; + let extraClassDeclaration = AtMostOneOpMatcher.extraDeclaration; +} + def MatchOperationNameOp : TransformDialectOp<"match.operation_name", [SingleOpMatcher, MatchOpInterface, diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 44626260e2f9ef3..0e20b379cc2a3e7 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1161,7 +1161,6 @@ void transform::ForeachOp::getEffects( SmallVectorImpl &effects) { BlockArgument iterVar = getIterationVariable(); if (any_of(getBody().front().without_terminator(), [&](Operation &op) { - return isHandleConsumed(iterVar, cast(&op)); })) { consumesHandle(getTarget(), effects); @@ -1244,6 +1243,10 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter, parent = parent->getParentOp(); } if (!parent) { + if (getAllowEmptyResults()) { + results.set(llvm::cast(getResult()), parents); + return DiagnosedSilenceableFailure::success(); + } DiagnosedSilenceableFailure diag = emitSilenceableError() << "could not find a parent op that matches all requirements"; @@ -1545,6 +1548,21 @@ transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { .checkAndReport(); } +//===----------------------------------------------------------------------===// +// MatchOperationEmptyOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation( + ::std::optional<::mlir::Operation *> maybeCurrent, + transform::TransformResults &results, transform::TransformState &state) { + if (!maybeCurrent.has_value()) { + DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; + return DiagnosedSilenceableFailure::success(); + } + DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; + return emitSilenceableError() << "operation is not empty"; +} + //===----------------------------------------------------------------------===// // MatchOperationNameOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index daa179cb15408b4..3891c16b4115595 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -2037,3 +2037,75 @@ transform.sequence failures(propagate) { // expected-remark @below{{0}} test_print_number_of_associated_payload_ir_ops %empty_op : !transform.any_op } + + +// ----- + +func.func @no_constant_under_loop(%lb: index, %ub: index, %step: index) { + scf.for %i= %lb to %ub step %step { + arith.constant 0 : index + } + return +} + +module @named_inclusion attributes { transform.with_named_sequence } { +// Match `arith.constant`s that are not nested under a `scf.for` and ensure +// there are none in the program + +transform.named_sequence @print(%root: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %root, "matched func" : !transform.any_op + transform.yield +} + +transform.named_sequence @match_constant_not_under_scf_for(%root: !transform.any_op {transform.readonly}) + -> !transform.any_op { + transform.match.operation_name %root ["arith.constant"] : !transform.any_op + %for = transform.get_parent_op %root { op_name = "scf.for", allow_empty_results } + : (!transform.any_op) -> (!transform.any_op) + transform.match.operation_empty %for : !transform.any_op + transform.yield %root : !transform.any_op +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + transform.foreach_match in %arg0 + @match_constant_not_under_scf_for -> @print + : (!transform.any_op) -> (!transform.any_op) + transform.yield +} +} + +// ----- + +func.func @no_constant_under_loop(%lb: index, %ub: index, %step: index) { + // expected-remark @below {{no parent scf.for}} + arith.constant 0 : index + return +} + +module @named_inclusion attributes { transform.with_named_sequence } { +// Match `arith.constant`s that are not nested under a `scf.for` and ensure +// there are none in the program + +transform.named_sequence @print(%root: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %root, "no parent scf.for" : !transform.any_op + transform.yield +} + +transform.named_sequence @match_constant_not_under_scf_for(%root: !transform.any_op {transform.readonly}) + -> !transform.any_op { + transform.match.operation_name %root ["arith.constant"] : !transform.any_op + %for = transform.get_parent_op %root { op_name = "scf.for", allow_empty_results } + : (!transform.any_op) -> (!transform.any_op) + transform.match.operation_empty %for : !transform.any_op + transform.yield %root : !transform.any_op +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + transform.foreach_match in %arg0 + @match_constant_not_under_scf_for -> @print + : (!transform.any_op) -> (!transform.any_op) + transform.yield +} +}