Skip to content

Commit

Permalink
[mlir][Transform] Provide a minimal set of utils that allow implement…
Browse files Browse the repository at this point in the history
…ing a simple transform dialect interpreter pass
  • Loading branch information
nicolasvasilache committed Oct 6, 2023
1 parent a16f646 commit cdfa540
Show file tree
Hide file tree
Showing 9 changed files with 577 additions and 289 deletions.
29 changes: 27 additions & 2 deletions mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ def Transform_Dialect : Dialect {

let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
/// Symbol name for the default entry point "named sequence".
constexpr const static ::llvm::StringLiteral
kTransformEntryPointSymbolName = "__transform_main";

/// Name of the attribute attachable to the symbol table operation
/// containing named sequences. This is used to trigger verification.
constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName =
"transform.with_named_sequence";
constexpr const static ::llvm::StringLiteral
kWithNamedSequenceAttrName = "transform.with_named_sequence";

/// Name of the attribute attachable to an operation so it can be
/// identified as root by the default interpreter pass.
Expand Down Expand Up @@ -74,6 +78,22 @@ def Transform_Dialect : Dialect {
using ExtensionTypePrintingHook =
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;

/// Appends the given module as a transform symbol library available to
/// all dialect users.
void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
library) {
libraryModules.push_back(std::move(library));
}

/// Returns a range of registered library modules.
auto getLibraryModules() const {
return ::llvm::map_range(
libraryModules,
[](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
return library.get();
});
}

private:
/// Registers operations specified as template parameters with this
/// dialect. Checks that they implement the required interfaces.
Expand Down Expand Up @@ -132,6 +152,11 @@ def Transform_Dialect : Dialect {
/// lookups when the type is fully constructed.
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
typePrintingHooks;

/// Modules containing symbols, e.g. named sequences, that will be
/// resolved by the interpreter when used.
::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
libraryModules;
}];
}

Expand Down
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ class TransformOptions {
LogicalResult
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping = {},
const TransformOptions &options = TransformOptions());
const TransformOptions &options = TransformOptions(),
bool enforceToplevelTransformOp = true);

/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
Expand Down Expand Up @@ -193,7 +194,7 @@ class TransformState {

friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
const RaggedArray<MappedValue> &,
const TransformOptions &);
const TransformOptions &, bool);

friend TransformState
detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//===- TransformInterpreterUtils.h - Transform Utils ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H

#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include <memory>

namespace mlir {
struct LogicalResult;
class MLIRContext;
class ModuleOp;
class Operation;
template <typename>
class OwningOpRef;
class Region;

namespace transform {
namespace detail {
/// Utility to parse and verify the content of a `transformFileName` MLIR file
/// containing a transform dialect specification.
LogicalResult
parseTransformModuleFromFile(MLIRContext *context,
llvm::StringRef transformFileName,
OwningOpRef<ModuleOp> &transformModule);

/// Utility to load a transform interpreter `module` from a module that has
/// already been preloaded in the context.
/// This mode is useful in cases where explicit parsing of a transform library
/// from file is expected to be prohibitively expensive.
/// In such cases, the transform module is expected to be found in the preloaded
/// library modules of the transform dialect.
/// Returns null if the module is not found.
ModuleOp getPreloadedTransformModule(MLIRContext *context);

/// Finds the first TransformOpInterface named `kTransformEntryPointSymbolName`
/// that is either:
/// 1. nested under `root` (takes precedence).
/// 2. nested under `module`, if not found in `root`.
/// Reports errors and returns null if no such operation found.
TransformOpInterface findTransformEntryPoint(
Operation *root, ModuleOp module,
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);

/// Merge all symbols from `other` into `target`. Both ops need to implement the
/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
/// modified by this function and might not verify after the function returns.
/// Upon merging, private symbols may be renamed in order to avoid collisions in
/// the result. Public symbols may not collide, with the exception of
/// instances of `SymbolOpInterface`, where collisions are allowed if at least
/// one of the two is external, in which case the other op preserved (or any one
/// of the two if both are external).
// TODO: Reconsider cloning individual ops rather than forcing users of the
// function to clone (or move) `other` in order to improve efficiency.
// This might primarily make sense if we can also prune the symbols that
// are merged to a subset (such as those that are actually used).
LogicalResult mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other);
} // namespace detail

/// Standalone util to apply the named sequence `entryPoint` to the payload.
/// This is done in 3 steps:
/// 1. lookup the `entryPoint` symbol in `{payload, sharedTransformModule}` by
/// calling detail::findTransformEntryPoint.
/// 2. if the entry point is found and not nested under
/// `sharedTransformModule`, call `detail::defineDeclaredSymbols` to "link" in
/// the `sharedTransformModule`. Note: this may modify the transform IR
/// embedded with the payload IR.
/// 3. apply the transform IR to the payload IR, relaxing the requirement that
/// the transform IR is a top-level transform op. We are applying a named
/// sequence anyway.
LogicalResult applyTransformNamedSequence(
Operation *payload, ModuleOp transformModule,
const TransformOptions &options,
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);

} // namespace transform
} // namespace mlir

#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
26 changes: 13 additions & 13 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2079,20 +2079,20 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
// Entry point.
//===----------------------------------------------------------------------===//

LogicalResult
transform::applyTransforms(Operation *payloadRoot,
TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping,
const TransformOptions &options) {
#ifndef NDEBUG
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
transform->emitError()
<< "expected transform to start at the top-level transform op";
llvm::report_fatal_error("could not run transforms",
/*gen_crash_diag=*/false);
LogicalResult transform::applyTransforms(
Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping,
const TransformOptions &options, bool enforceToplevelTransformOp) {
if (enforceToplevelTransformOp) {
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
return transform->emitError()
<< "expected transform to start at the top-level transform op";
}
} else if (failed(
detail::verifyPossibleTopLevelTransformOpTrait(transform))) {
return failure();
}
#endif // NDEBUG

TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
options);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTransformDialectTransforms
CheckUses.cpp
InferEffects.cpp
TransformInterpreterPassBase.cpp
TransformInterpreterUtils.cpp

DEPENDS
MLIRTransformDialectTransformsIncGen
Expand Down
Loading

0 comments on commit cdfa540

Please sign in to comment.