Skip to content

Commit

Permalink
[mlir][Transform] Add a transform.match.operation_empty op to allow s… (
Browse files Browse the repository at this point in the history
#68319)

…pecifying negative conditions

In the process, get_parent_op gains an attribute to allow it to return
empty handles explicitly and still succeed.
  • Loading branch information
nicolasvasilache authored Oct 6, 2023
1 parent e91a4be commit 98341df
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 127 deletions.
95 changes: 75 additions & 20 deletions mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,71 @@

#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/STLExtras.h"
#include <optional>
#include <type_traits>

namespace mlir {
namespace transform {
class MatchOpInterface;

namespace detail {
/// Dispatch `matchOperation` based on Operation* or std::optional<Operation*>
/// first operand.
template <typename OpTy>
class SingleOpMatcherOpTrait
: public OpTrait::TraitBase<OpTy, SingleOpMatcherOpTrait> {
DiagnosedSilenceableFailure matchOptionalOperation(OpTy op,
TransformResults &results,
TransformState &state) {
if constexpr (std::is_same_v<
typename llvm::function_traits<
decltype(&OpTy::matchOperation)>::template arg_t<0>,
Operation *>) {
return op.matchOperation(nullptr, results, state);
} else {
return op.matchOperation(std::nullopt, results, state);
}
}
} // namespace detail

template <typename OpTy>
class AtMostOneOpMatcherOpTrait
: public OpTrait::TraitBase<OpTy, AtMostOneOpMatcherOpTrait> {
template <typename T>
using has_get_operand_handle =
decltype(std::declval<T &>().getOperandHandle());
template <typename T>
using has_match_operation = decltype(std::declval<T &>().matchOperation(
using has_match_operation_ptr = decltype(std::declval<T &>().matchOperation(
std::declval<Operation *>(), std::declval<TransformResults &>(),
std::declval<TransformState &>()));
template <typename T>
using has_match_operation_optional =
decltype(std::declval<T &>().matchOperation(
std::declval<std::optional<Operation *>>(),
std::declval<TransformResults &>(),
std::declval<TransformState &>()));

public:
static LogicalResult verifyTrait(Operation *op) {
static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,
"SingleOpMatcherOpTrait expects operation type to have the "
"getOperandHandle() method");
static_assert(llvm::is_detected<has_match_operation, OpTy>::value,
"SingleOpMatcherOpTrait expected operation type to have the "
"matchOperation(Operation *, TransformResults &, "
"TransformState &) method");
"AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects "
"operation type to have the getOperandHandle() method");
static_assert(
llvm::is_detected<has_match_operation_ptr, OpTy>::value ||
llvm::is_detected<has_match_operation_optional, OpTy>::value,
"AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected operation "
"type to have either the matchOperation(Operation *, TransformResults "
"&, TransformState &) or the matchOperation(std::optional<Operation*>, "
"TransformResults &, TransformState &) method");

// This must be a dynamic assert because interface registration is dynamic.
assert(isa<MatchOpInterface>(op) &&
"SingleOpMatchOpTrait is only available on operations with "
"MatchOpInterface");
assert(
isa<MatchOpInterface>(op) &&
"AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on "
"operations with MatchOpInterface");
Value operandHandle = cast<OpTy>(op).getOperandHandle();
if (!isa<TransformHandleTypeInterface>(operandHandle.getType())) {
return op->emitError() << "SingleOpMatchOpTrait requires the op handle "
return op->emitError() << "AtMostOneOpMatcherOpTrait/"
"SingleOpMatchOpTrait requires the op handle "
"to be of TransformHandleTypeInterface";
}

Expand All @@ -55,12 +87,15 @@ class SingleOpMatcherOpTrait
TransformState &state) {
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
auto payload = state.getPayloadOps(operandHandle);
if (!llvm::hasSingleElement(payload)) {
if (!llvm::hasNItemsOrLess(payload, 1)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
<< "SingleOpMatchOpTrait requires the operand handle to point to "
"a single payload op";
<< "AtMostOneOpMatcherOpTrait requires the operand handle to "
"point to at most one payload op";
}
if (payload.empty()) {
return detail::matchOptionalOperation(cast<OpTy>(this->getOperation()),
results, state);
}

return cast<OpTy>(this->getOperation())
.matchOperation(*payload.begin(), results, state);
}
Expand All @@ -72,12 +107,32 @@ class SingleOpMatcherOpTrait
}
};

template <typename OpTy>
class SingleOpMatcherOpTrait : public AtMostOneOpMatcherOpTrait<OpTy> {

public:
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
TransformResults &results,
TransformState &state) {
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
auto payload = state.getPayloadOps(operandHandle);
if (!llvm::hasSingleElement(payload)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
<< "SingleOpMatchOpTrait requires the operand handle to point to "
"a single payload op";
}
return static_cast<AtMostOneOpMatcherOpTrait<OpTy> *>(this)->apply(
rewriter, results, state);
}
};

template <typename OpTy>
class SingleValueMatcherOpTrait
: public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
public:
static LogicalResult verifyTrait(Operation *op) {
// This must be a dynamic assert because interface registration is dynamic.
// This must be a dynamic assert because interface registration is
// dynamic.
assert(isa<MatchOpInterface>(op) &&
"SingleValueMatchOpTrait is only available on operations with "
"MatchOpInterface");
Expand All @@ -98,8 +153,8 @@ class SingleValueMatcherOpTrait
auto payload = state.getPayloadValues(operandHandle);
if (!llvm::hasSingleElement(payload)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
<< "SingleValueMatchOpTrait requires the value handle to point to "
"a single payload value";
<< "SingleValueMatchOpTrait requires the value handle to point "
"to a single payload value";
}

return cast<OpTy>(this->getOperation())
Expand Down
21 changes: 19 additions & 2 deletions mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,28 @@ def MatchOpInterface
let cppNamespace = "::mlir::transform";
}

// Trait for "matcher" transform operations that apply to an operation handle
// associated with at most one payload operation. Checks that it is indeed
// the case and produces a definite failure when it is not. The matching logic
// is implemented in the `matchOperation` function instead of `apply`. The op
// with this trait must provide a `Value getOperandHandle()` function that
// returns the handle to be used for matching.
def AtMostOneOpMatcher : NativeOpTrait<"AtMostOneOpMatcherOpTrait"> {
let cppNamespace = "::mlir::transform";

string extraDeclaration = [{
::mlir::DiagnosedSilenceableFailure matchOperation(
::std::optional<::mlir::Operation *> maybeCurrent,
::mlir::transform::TransformResults &results,
::mlir::transform::TransformState &state);
}];
}

// Trait for "matcher" transform operations that apply to an operation handle
// associated with exactly one payload operation. Checks that it is indeed
// the case and produces a definite failure when it is not. The matching logic
// is implemented in the `matchOperation` function instead of `apply`. The op
// with this trait must provide a `Value getOperandHandle()` function that
// with this trait must provide a `Value getOperandHandle()` function that
// returns the handle to be used for matching.
def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
let cppNamespace = "::mlir::transform";
Expand All @@ -35,7 +52,7 @@ def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
// associated with exactly one payload value. Checks that it is indeed
// the case and produces a definite failure when it is not. The matching logic
// is implemented in the `matchValue` function instead of `apply`. The op
// with this trait must provide a `Value getOperandHandle()` function that
// with this trait must provide a `Value getOperandHandle()` function that
// returns the handle to be used for matching.
def SingleValueMatcher : NativeOpTrait<"SingleValueMatcherOpTrait"> {
let cppNamespace = "::mlir::transform";
Expand Down
204 changes: 103 additions & 101 deletions mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,107 +20,109 @@ def Transform_Dialect : Dialect {

let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
/// Name of the attribute attachable to the symbol table operation
/// containing named sequences. This is used to trigger verification.
constexpr const static ::llvm::StringLiteral
kWithNamedSequenceAttrName = "transform.with_named_sequence";

/// Name of the attribute attachable to an operation so it can be
/// identified as root by the default interpreter pass.
constexpr const static ::llvm::StringLiteral
kTargetTagAttrName = "transform.target_tag";

/// Name of the attribute attachable to an operation, indicating that
/// TrackingListener failures should be silenced.
constexpr const static ::llvm::StringLiteral
kSilenceTrackingFailuresAttrName = "transform.silence_tracking_failures";

/// Names of the attributes indicating whether an argument of an external
/// transform dialect symbol is consumed or only read.
constexpr const static ::llvm::StringLiteral
kArgConsumedAttrName = "transform.consumed";
constexpr const static ::llvm::StringLiteral
kArgReadOnlyAttrName = "transform.readonly";

template <typename DataTy>
const DataTy &getExtraData() const {
return *static_cast<const DataTy *>(extraData.at(::mlir::TypeID::get<DataTy>()).get());
}

/// Parses a type registered by this dialect or one of its extensions.
::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;

/// Prints a type registered by this dialect or one of its extensions.
void printType(::mlir::Type type,
::mlir::DialectAsmPrinter &printer) const override;

/// Parser callback for an individual type registered by this dialect or
/// its extensions.
using ExtensionTypeParsingHook = ::mlir::Type (*)(::mlir::AsmParser &);

/// Printer callback for an individual type registered by this dialect or
/// its extensions.
using ExtensionTypePrintingHook =
std::function<void (::mlir::Type, ::mlir::AsmPrinter &)>;

private:
/// Registers operations specified as template parameters with this
/// dialect. Checks that they implement the required interfaces.
template <typename... OpTys>
void addOperationsChecked() {
(addOperationIfNotRegistered<OpTys>(), ...);
}
template <typename OpTy>
void addOperationIfNotRegistered();

/// Reports a repeated registration error of an op with the given name.
[[noreturn]] void reportDuplicateOpRegistration(StringRef opName);

/// Registers the types specified as template parameters with the
/// Transform dialect. Checks that they meet the requirements for
/// Transform IR types.
template <typename... TypeTys>
void addTypesChecked() {
(addTypeIfNotRegistered<TypeTys>(), ...);
}
template <typename Type>
void addTypeIfNotRegistered();

/// Reports a repeated registration error of a type with the given
/// mnemonic.
[[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);

/// Registers dialect types with the context.
void initializeTypes();

// Give extensions access to injection functions.
template <typename, typename...>
friend class TransformDialectExtension;

/// Gets a mutable reference to extra data of the kind specified as
/// template argument. Allocates the data on the first call.
template <typename DataTy>
DataTy &getOrCreateExtraData();

//===----------------------------------------------------------------===//
// Data fields
//===----------------------------------------------------------------===//

/// Additional data associated with and owned by the dialect. Accessible
/// to extensions.
::llvm::DenseMap<::mlir::TypeID, std::unique_ptr<
::mlir::transform::detail::TransformDialectDataBase>>
extraData;

/// A map from type mnemonic to its parsing function for the remainder of
/// the syntax. The parser has access to the mnemonic, so it is used for
/// further dispatch.
::llvm::StringMap<ExtensionTypeParsingHook> typeParsingHooks;

/// A map from type TypeID to its printing function. No need to do string
/// lookups when the type is fully constructed.
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
typePrintingHooks;
/// Name of the attribute attachable to the symbol table operation
/// containing named sequences. This is used to trigger verification.
constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName =
"transform.with_named_sequence";

/// Name of the attribute attachable to an operation so it can be
/// identified as root by the default interpreter pass.
constexpr const static ::llvm::StringLiteral kTargetTagAttrName =
"transform.target_tag";

/// Name of the attribute attachable to an operation, indicating that
/// TrackingListener failures should be silenced.
constexpr const static ::llvm::StringLiteral
kSilenceTrackingFailuresAttrName =
"transform.silence_tracking_failures";

/// Names of the attributes indicating whether an argument of an external
/// transform dialect symbol is consumed or only read.
constexpr const static ::llvm::StringLiteral kArgConsumedAttrName =
"transform.consumed";
constexpr const static ::llvm::StringLiteral kArgReadOnlyAttrName =
"transform.readonly";

template <typename DataTy>
const DataTy &getExtraData() const {
return *static_cast<const DataTy *>(
extraData.at(::mlir::TypeID::get<DataTy>()).get());
}

/// Parses a type registered by this dialect or one of its extensions.
::mlir::Type parseType(::mlir::DialectAsmParser & parser) const override;

/// Prints a type registered by this dialect or one of its extensions.
void printType(::mlir::Type type, ::mlir::DialectAsmPrinter & printer)
const override;

/// Parser callback for an individual type registered by this dialect or
/// its extensions.
using ExtensionTypeParsingHook = ::mlir::Type (*)(::mlir::AsmParser &);

/// Printer callback for an individual type registered by this dialect or
/// its extensions.
using ExtensionTypePrintingHook =
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;

private:
/// Registers operations specified as template parameters with this
/// dialect. Checks that they implement the required interfaces.
template <typename... OpTys>
void addOperationsChecked() {
(addOperationIfNotRegistered<OpTys>(), ...);
}
template <typename OpTy>
void addOperationIfNotRegistered();

/// Reports a repeated registration error of an op with the given name.
[[noreturn]] void reportDuplicateOpRegistration(StringRef opName);

/// Registers types specified as template parameters with the Transform
/// dialect. Checks that they meet the requirements for Transform IR types.
template <typename... TypeTys>
void addTypesChecked() {
(addTypeIfNotRegistered<TypeTys>(), ...);
}
template <typename Type>
void addTypeIfNotRegistered();

/// Reports a repeated registration error of a type with the given
/// mnemonic.
[[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);

/// Registers dialect types with the context.
void initializeTypes();

// Give extensions access to injection functions.
template <typename, typename...>
friend class TransformDialectExtension;

/// Gets a mutable reference to extra data of the kind specified as
/// template argument. Allocates the data on the first call.
template <typename DataTy>
DataTy &getOrCreateExtraData();

//===----------------------------------------------------------------===//
// Data fields
//===----------------------------------------------------------------===//

/// Additional data associated with and owned by the dialect. Accessible
/// to extensions.
::llvm::DenseMap<
::mlir::TypeID,
std::unique_ptr<::mlir::transform::detail::TransformDialectDataBase>>
extraData;

/// A map from type mnemonic to its parsing function for the remainder of
/// the syntax. The parser has access to the mnemonic, so it is used for
/// further dispatch.
::llvm::StringMap<ExtensionTypeParsingHook> typeParsingHooks;

/// A map from type TypeID to its printing function. No need to do string
/// lookups when the type is fully constructed.
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
typePrintingHooks;
}];
}

Expand Down
Loading

0 comments on commit 98341df

Please sign in to comment.