Skip to content

Commit

Permalink
[Codegen] Load transform library only once in MaterializeUserConfigs (i…
Browse files Browse the repository at this point in the history
…ree-org#19313)

Hoist the library loading logic out of the loop that configures
functions.

This is in preparation for adding tuning spec loading from a new module
attr.

Issue: iree-org#19214
  • Loading branch information
kuhar authored and Groverkss committed Nov 29, 2024
1 parent e140ad7 commit 2b26b54
Showing 1 changed file with 61 additions and 51 deletions.
112 changes: 61 additions & 51 deletions compiler/src/iree/compiler/Codegen/Common/MaterializeUserConfigs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
#include "iree/compiler/Codegen/Common/UserConfig.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"

#define DEBUG_TYPE "iree-codegen-materialize-user-configs"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
Expand Down Expand Up @@ -61,15 +65,64 @@ runTransformConfigurationStrategy(Operation *payloadRoot,
return StrategyRunResult::Success;
}

struct TransformLibraryWithEntrypoint {
ModuleOp transformLibrary;
std::string entrypointName;
};

static FailureOr<TransformLibraryWithEntrypoint>
getTransformLibraryFromPath(ModuleOp compiledModule, StringRef path) {
SmallVector<StringRef, 2> parts;
llvm::SplitString(path, parts, "@");
if (parts.empty()) {
return failure();
}
if (parts.size() > 2) {
return compiledModule.emitError()
<< "Invalid transform library path and sequence name " << path;
}
StringRef libraryFileName = parts[0];
StringRef entrySequenceName = kKernelConfigSpecName;
if (parts.size() == 2) {
entrySequenceName = parts[1];
}

// Validate both the file name and the spec name.
if (libraryFileName.empty()) {
return compiledModule.emitError() << "Cannot specify an empty library path";
}
if (entrySequenceName.empty()) {
return compiledModule.emitError()
<< "Cannot specify an empty sequence name";
}

MLIRContext *ctx = compiledModule->getContext();
auto dialect = ctx->getOrLoadDialect<IREE::Codegen::IREECodegenDialect>();
auto maybeTransformLibrary =
dialect->getOrLoadTransformLibraryModule(libraryFileName.str());
if (failed(maybeTransformLibrary)) {
return compiledModule.emitError()
<< "Failed to load transform library module: " << libraryFileName;
}
LDBG("--found transform library " << libraryFileName << "@"
<< entrySequenceName);
return TransformLibraryWithEntrypoint{*maybeTransformLibrary,
entrySequenceName.str()};
}

struct MaterializeUserConfigsPass final
: impl::MaterializeUserConfigsPassBase<MaterializeUserConfigsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registerTransformDialectTranslationDependentDialects(registry);
}

void runOnOperation() override {
auto moduleOp = getOperation();
MLIRContext *context = &getContext();
ModuleOp moduleOp = getOperation();

FailureOr<TransformLibraryWithEntrypoint> userTransformLibrary =
getTransformLibraryFromPath(moduleOp,
clCodegenTransformDialectLibraryFileName);

for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {

// Parse the file path and kernel config strategy from flags. There are
Expand All @@ -84,54 +137,11 @@ struct MaterializeUserConfigsPass final
// "translation_info" =
// #iree_codegen.translation_info<pipeline = None>
// ```
SmallVector<StringRef, 2> parts;
llvm::SplitString(
llvm::StringRef(clCodegenTransformDialectLibraryFileName), parts,
"@");
if (parts.size() > 2) {
funcOp.emitError()
<< "Invalid transform library path and sequence name "
<< clCodegenTransformDialectLibraryFileName;
return signalPassFailure();
}
bool hasTransformLibrary = !parts.empty();

std::string libraryFileName;
if (hasTransformLibrary) {
if (parts[0].empty()) {
funcOp.emitError() << "Cannot specify an empty library path";
return signalPassFailure();
}
libraryFileName = parts[0];
}

StringRef entrySequenceName = kKernelConfigSpecName;
// Check if the user specified a custom entry point name.
if (parts.size() == 2) {
if (parts[1].empty()) {
funcOp.emitError() << "Cannot specify an empty sequence name";
return signalPassFailure();
}
entrySequenceName = parts[1];
}

LDBG("MaterializeUserConfigsPass on function: " << funcOp);
std::optional<ModuleOp> transformLibrary = std::nullopt;
if (hasTransformLibrary) {
auto dialect =
context->getOrLoadDialect<IREE::Codegen::IREECodegenDialect>();
auto maybeTransformLibrary =
dialect->getOrLoadTransformLibraryModule(libraryFileName);
if (failed(maybeTransformLibrary)) {
funcOp.emitError()
<< "failed to load transform library module: " << libraryFileName;
return signalPassFailure();
}
transformLibrary = *maybeTransformLibrary;
LDBG("--found transform library @" << libraryFileName);

if (succeeded(userTransformLibrary)) {
StringRef entrySequenceName = userTransformLibrary->entrypointName;
auto runResult = runTransformConfigurationStrategy(
funcOp, entrySequenceName, *transformLibrary);
funcOp, entrySequenceName, userTransformLibrary->transformLibrary);
if (runResult == StrategyRunResult::NotFound) {
funcOp.emitError() << "transform kernel config strategy `"
<< entrySequenceName << " not found";
Expand Down Expand Up @@ -186,9 +196,9 @@ struct MaterializeUserConfigsPass final
/// If we have a symbol, verify the existence of the symbol within the
/// transform library.
StringRef entryPoint = strategyName->getLeafReference();
if (!transformLibrary || !(*transformLibrary) ||
!transform::detail::findTransformEntryPoint(funcOp, *transformLibrary,
entryPoint)) {
if (failed(userTransformLibrary) ||
!transform::detail::findTransformEntryPoint(
funcOp, userTransformLibrary->transformLibrary, entryPoint)) {
funcOp.emitOpError("failed to find transform strategy symbol");
}
}
Expand Down

0 comments on commit 2b26b54

Please sign in to comment.