Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transform] Add a transform.match.operation_empty op to allow s… #68319

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 75 additions & 20 deletions mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,71 @@

#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/STLExtras.h"
#include <optional>
#include <type_traits>

namespace mlir {
namespace transform {
class MatchOpInterface;

namespace detail {
/// Dispatch `matchOperation` based on Operation* or std::optional<Operation*>
/// first operand.
template <typename OpTy>
class SingleOpMatcherOpTrait
: public OpTrait::TraitBase<OpTy, SingleOpMatcherOpTrait> {
DiagnosedSilenceableFailure matchOptionalOperation(OpTy op,
TransformResults &results,
TransformState &state) {
if constexpr (std::is_same_v<
typename llvm::function_traits<
decltype(&OpTy::matchOperation)>::template arg_t<0>,
Operation *>) {
return op.matchOperation(nullptr, results, state);
} else {
return op.matchOperation(std::nullopt, results, state);
}
}
} // namespace detail

template <typename OpTy>
class AtMostOneOpMatcherOpTrait
: public OpTrait::TraitBase<OpTy, AtMostOneOpMatcherOpTrait> {
template <typename T>
using has_get_operand_handle =
decltype(std::declval<T &>().getOperandHandle());
template <typename T>
using has_match_operation = decltype(std::declval<T &>().matchOperation(
using has_match_operation_ptr = decltype(std::declval<T &>().matchOperation(
std::declval<Operation *>(), std::declval<TransformResults &>(),
std::declval<TransformState &>()));
template <typename T>
using has_match_operation_optional =
decltype(std::declval<T &>().matchOperation(
std::declval<std::optional<Operation *>>(),
std::declval<TransformResults &>(),
std::declval<TransformState &>()));

public:
static LogicalResult verifyTrait(Operation *op) {
static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,
"SingleOpMatcherOpTrait expects operation type to have the "
"getOperandHandle() method");
static_assert(llvm::is_detected<has_match_operation, OpTy>::value,
"SingleOpMatcherOpTrait expected operation type to have the "
"matchOperation(Operation *, TransformResults &, "
"TransformState &) method");
"AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects "
"operation type to have the getOperandHandle() method");
static_assert(
llvm::is_detected<has_match_operation_ptr, OpTy>::value ||
llvm::is_detected<has_match_operation_optional, OpTy>::value,
"AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected operation "
"type to have either the matchOperation(Operation *, TransformResults "
"&, TransformState &) or the matchOperation(std::optional<Operation*>, "
"TransformResults &, TransformState &) method");

// This must be a dynamic assert because interface registration is dynamic.
assert(isa<MatchOpInterface>(op) &&
"SingleOpMatchOpTrait is only available on operations with "
"MatchOpInterface");
assert(
isa<MatchOpInterface>(op) &&
"AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on "
"operations with MatchOpInterface");
Value operandHandle = cast<OpTy>(op).getOperandHandle();
if (!isa<TransformHandleTypeInterface>(operandHandle.getType())) {
return op->emitError() << "SingleOpMatchOpTrait requires the op handle "
return op->emitError() << "AtMostOneOpMatcherOpTrait/"
"SingleOpMatchOpTrait requires the op handle "
"to be of TransformHandleTypeInterface";
}

Expand All @@ -55,12 +87,15 @@ class SingleOpMatcherOpTrait
TransformState &state) {
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
auto payload = state.getPayloadOps(operandHandle);
if (!llvm::hasSingleElement(payload)) {
if (!llvm::hasNItemsOrLess(payload, 1)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
<< "SingleOpMatchOpTrait requires the operand handle to point to "
"a single payload op";
<< "AtMostOneOpMatcherOpTrait requires the operand handle to "
"point to at most one payload op";
}
if (payload.empty()) {
return detail::matchOptionalOperation(cast<OpTy>(this->getOperation()),
ftynse marked this conversation as resolved.
Show resolved Hide resolved
results, state);
}

return cast<OpTy>(this->getOperation())
.matchOperation(*payload.begin(), results, state);
}
Expand All @@ -72,12 +107,32 @@ class SingleOpMatcherOpTrait
}
};

template <typename OpTy>
class SingleOpMatcherOpTrait : public AtMostOneOpMatcherOpTrait<OpTy> {

public:
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
TransformResults &results,
TransformState &state) {
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
auto payload = state.getPayloadOps(operandHandle);
if (!llvm::hasSingleElement(payload)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
<< "SingleOpMatchOpTrait requires the operand handle to point to "
"a single payload op";
}
return static_cast<AtMostOneOpMatcherOpTrait<OpTy> *>(this)->apply(
rewriter, results, state);
}
};

template <typename OpTy>
class SingleValueMatcherOpTrait
: public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
public:
static LogicalResult verifyTrait(Operation *op) {
// This must be a dynamic assert because interface registration is dynamic.
// This must be a dynamic assert because interface registration is
// dynamic.
assert(isa<MatchOpInterface>(op) &&
"SingleValueMatchOpTrait is only available on operations with "
"MatchOpInterface");
Expand All @@ -98,8 +153,8 @@ class SingleValueMatcherOpTrait
auto payload = state.getPayloadValues(operandHandle);
if (!llvm::hasSingleElement(payload)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
<< "SingleValueMatchOpTrait requires the value handle to point to "
"a single payload value";
<< "SingleValueMatchOpTrait requires the value handle to point "
"to a single payload value";
}

return cast<OpTy>(this->getOperation())
Expand Down
21 changes: 19 additions & 2 deletions mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,28 @@ def MatchOpInterface
let cppNamespace = "::mlir::transform";
}

// Trait for "matcher" transform operations that apply to an operation handle
// associated with at most one payload operation. Checks that it is indeed
// the case and produces a definite failure when it is not. The matching logic
// is implemented in the `matchOperation` function instead of `apply`. The op
// with this trait must provide a `Value getOperandHandle()` function that
// returns the handle to be used for matching.
def AtMostOneOpMatcher : NativeOpTrait<"AtMostOneOpMatcherOpTrait"> {
let cppNamespace = "::mlir::transform";

string extraDeclaration = [{
::mlir::DiagnosedSilenceableFailure matchOperation(
::std::optional<::mlir::Operation *> maybeCurrent,
::mlir::transform::TransformResults &results,
::mlir::transform::TransformState &state);
}];
}

// Trait for "matcher" transform operations that apply to an operation handle
// associated with exactly one payload operation. Checks that it is indeed
// the case and produces a definite failure when it is not. The matching logic
// is implemented in the `matchOperation` function instead of `apply`. The op
// with this trait must provide a `Value getOperandHandle()` function that
// with this trait must provide a `Value getOperandHandle()` function that
// returns the handle to be used for matching.
def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
let cppNamespace = "::mlir::transform";
Expand All @@ -35,7 +52,7 @@ def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
// associated with exactly one payload value. Checks that it is indeed
// the case and produces a definite failure when it is not. The matching logic
// is implemented in the `matchValue` function instead of `apply`. The op
// with this trait must provide a `Value getOperandHandle()` function that
// with this trait must provide a `Value getOperandHandle()` function that
// returns the handle to be used for matching.
def SingleValueMatcher : NativeOpTrait<"SingleValueMatcherOpTrait"> {
let cppNamespace = "::mlir::transform";
Expand Down
204 changes: 103 additions & 101 deletions mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,107 +20,109 @@ def Transform_Dialect : Dialect {

let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
/// Name of the attribute attachable to the symbol table operation
/// containing named sequences. This is used to trigger verification.
constexpr const static ::llvm::StringLiteral
kWithNamedSequenceAttrName = "transform.with_named_sequence";

/// Name of the attribute attachable to an operation so it can be
/// identified as root by the default interpreter pass.
constexpr const static ::llvm::StringLiteral
kTargetTagAttrName = "transform.target_tag";

/// Name of the attribute attachable to an operation, indicating that
/// TrackingListener failures should be silenced.
constexpr const static ::llvm::StringLiteral
kSilenceTrackingFailuresAttrName = "transform.silence_tracking_failures";

/// Names of the attributes indicating whether an argument of an external
/// transform dialect symbol is consumed or only read.
constexpr const static ::llvm::StringLiteral
kArgConsumedAttrName = "transform.consumed";
constexpr const static ::llvm::StringLiteral
kArgReadOnlyAttrName = "transform.readonly";

template <typename DataTy>
const DataTy &getExtraData() const {
return *static_cast<const DataTy *>(extraData.at(::mlir::TypeID::get<DataTy>()).get());
}

/// Parses a type registered by this dialect or one of its extensions.
::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;

/// Prints a type registered by this dialect or one of its extensions.
void printType(::mlir::Type type,
::mlir::DialectAsmPrinter &printer) const override;

/// Parser callback for an individual type registered by this dialect or
/// its extensions.
using ExtensionTypeParsingHook = ::mlir::Type (*)(::mlir::AsmParser &);

/// Printer callback for an individual type registered by this dialect or
/// its extensions.
using ExtensionTypePrintingHook =
std::function<void (::mlir::Type, ::mlir::AsmPrinter &)>;

private:
/// Registers operations specified as template parameters with this
/// dialect. Checks that they implement the required interfaces.
template <typename... OpTys>
void addOperationsChecked() {
(addOperationIfNotRegistered<OpTys>(), ...);
}
template <typename OpTy>
void addOperationIfNotRegistered();

/// Reports a repeated registration error of an op with the given name.
[[noreturn]] void reportDuplicateOpRegistration(StringRef opName);

/// Registers the types specified as template parameters with the
/// Transform dialect. Checks that they meet the requirements for
/// Transform IR types.
template <typename... TypeTys>
void addTypesChecked() {
(addTypeIfNotRegistered<TypeTys>(), ...);
}
template <typename Type>
void addTypeIfNotRegistered();

/// Reports a repeated registration error of a type with the given
/// mnemonic.
[[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);

/// Registers dialect types with the context.
void initializeTypes();

// Give extensions access to injection functions.
template <typename, typename...>
friend class TransformDialectExtension;

/// Gets a mutable reference to extra data of the kind specified as
/// template argument. Allocates the data on the first call.
template <typename DataTy>
DataTy &getOrCreateExtraData();

//===----------------------------------------------------------------===//
// Data fields
//===----------------------------------------------------------------===//

/// Additional data associated with and owned by the dialect. Accessible
/// to extensions.
::llvm::DenseMap<::mlir::TypeID, std::unique_ptr<
::mlir::transform::detail::TransformDialectDataBase>>
extraData;

/// A map from type mnemonic to its parsing function for the remainder of
/// the syntax. The parser has access to the mnemonic, so it is used for
/// further dispatch.
::llvm::StringMap<ExtensionTypeParsingHook> typeParsingHooks;

/// A map from type TypeID to its printing function. No need to do string
/// lookups when the type is fully constructed.
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
typePrintingHooks;
/// Name of the attribute attachable to the symbol table operation
/// containing named sequences. This is used to trigger verification.
constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName =
"transform.with_named_sequence";

/// Name of the attribute attachable to an operation so it can be
/// identified as root by the default interpreter pass.
constexpr const static ::llvm::StringLiteral kTargetTagAttrName =
"transform.target_tag";

/// Name of the attribute attachable to an operation, indicating that
/// TrackingListener failures should be silenced.
constexpr const static ::llvm::StringLiteral
kSilenceTrackingFailuresAttrName =
"transform.silence_tracking_failures";

/// Names of the attributes indicating whether an argument of an external
/// transform dialect symbol is consumed or only read.
constexpr const static ::llvm::StringLiteral kArgConsumedAttrName =
"transform.consumed";
constexpr const static ::llvm::StringLiteral kArgReadOnlyAttrName =
"transform.readonly";

template <typename DataTy>
const DataTy &getExtraData() const {
return *static_cast<const DataTy *>(
extraData.at(::mlir::TypeID::get<DataTy>()).get());
}

/// Parses a type registered by this dialect or one of its extensions.
::mlir::Type parseType(::mlir::DialectAsmParser & parser) const override;

/// Prints a type registered by this dialect or one of its extensions.
void printType(::mlir::Type type, ::mlir::DialectAsmPrinter & printer)
const override;

/// Parser callback for an individual type registered by this dialect or
/// its extensions.
using ExtensionTypeParsingHook = ::mlir::Type (*)(::mlir::AsmParser &);

/// Printer callback for an individual type registered by this dialect or
/// its extensions.
using ExtensionTypePrintingHook =
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;

private:
/// Registers operations specified as template parameters with this
/// dialect. Checks that they implement the required interfaces.
template <typename... OpTys>
void addOperationsChecked() {
(addOperationIfNotRegistered<OpTys>(), ...);
}
template <typename OpTy>
void addOperationIfNotRegistered();

/// Reports a repeated registration error of an op with the given name.
[[noreturn]] void reportDuplicateOpRegistration(StringRef opName);

/// Registers types specified as template parameters with the Transform
/// dialect. Checks that they meet the requirements for Transform IR types.
template <typename... TypeTys>
void addTypesChecked() {
(addTypeIfNotRegistered<TypeTys>(), ...);
}
template <typename Type>
void addTypeIfNotRegistered();

/// Reports a repeated registration error of a type with the given
/// mnemonic.
[[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);

/// Registers dialect types with the context.
void initializeTypes();

// Give extensions access to injection functions.
template <typename, typename...>
friend class TransformDialectExtension;

/// Gets a mutable reference to extra data of the kind specified as
/// template argument. Allocates the data on the first call.
template <typename DataTy>
DataTy &getOrCreateExtraData();

//===----------------------------------------------------------------===//
// Data fields
//===----------------------------------------------------------------===//

/// Additional data associated with and owned by the dialect. Accessible
/// to extensions.
::llvm::DenseMap<
::mlir::TypeID,
std::unique_ptr<::mlir::transform::detail::TransformDialectDataBase>>
extraData;

/// A map from type mnemonic to its parsing function for the remainder of
/// the syntax. The parser has access to the mnemonic, so it is used for
/// further dispatch.
::llvm::StringMap<ExtensionTypeParsingHook> typeParsingHooks;

/// A map from type TypeID to its printing function. No need to do string
/// lookups when the type is fully constructed.
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
typePrintingHooks;
}];
}

Expand Down
Loading