Skip to content

Commit

Permalink
[Codegen] Add pass to materialize tuning specs
Browse files Browse the repository at this point in the history
... and update 'Materialize User Configs' to pick up those tuning specs.

The overall flow is as follows:
* We pick up any user-specified tuning specs in `materialize tuning
  specs` and link them into a single transform dialect library module.
* We serialize that linked tuning spec as MLIR bytecode.
* We embed this MLIR bytecode as a module attribute. This is so that
  none of the subsequent passes will accidentally `walk` or otherwise
  modify it.
* In `materilize user configs`, we first check if there are any
  transform libraries provided. If not, then we check if the tuning spec
  is present.
* We deserialize the tuning spec attribute into a transform dialect
  library module and execute it.
* We remove the serialized tuning spec from the module, as it's no
  longer needed.

Signed-off-by: Jakub Kuderski <[email protected]>
  • Loading branch information
kuhar committed Nov 30, 2024
1 parent 1684c56 commit d3cd5b2
Show file tree
Hide file tree
Showing 17 changed files with 417 additions and 36 deletions.
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ iree_compiler_cc_library(
"LowerUKernelsToCalls.cpp",
"MaterializeEncodingIntoNop.cpp",
"MaterializeEncodingIntoPackUnPack.cpp",
"MaterializeTuningSpecsPass.cpp",
"MemrefCopyToLinalg.cpp",
"NormalizeLoopBounds.cpp",
"OptimizeTensorInsertExtractSlices.cpp",
Expand Down Expand Up @@ -201,6 +202,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:BufferizationInterfaces",
"@llvm-project//mlir:BufferizationTransforms",
"@llvm-project//mlir:BytecodeWriter",
"@llvm-project//mlir:DestinationStyleOpInterface",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
Expand All @@ -219,6 +221,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:MemRefUtils",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFToControlFlow",
Expand Down Expand Up @@ -284,6 +287,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:PDLDialect",
"@llvm-project//mlir:PDLInterpDialect",
"@llvm-project//mlir:SCFDialect",
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ iree_cc_library(
"LowerUKernelsToCalls.cpp"
"MaterializeEncodingIntoNop.cpp"
"MaterializeEncodingIntoPackUnPack.cpp"
"MaterializeTuningSpecsPass.cpp"
"MemrefCopyToLinalg.cpp"
"NormalizeLoopBounds.cpp"
"OptimizeTensorInsertExtractSlices.cpp"
Expand Down Expand Up @@ -163,6 +164,7 @@ iree_cc_library(
MLIRArithUtils
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRBytecodeWriter
MLIRDestinationStyleOpInterface
MLIRFuncDialect
MLIRFuncTransforms
Expand All @@ -180,6 +182,7 @@ iree_cc_library(
MLIRMemRefDialect
MLIRMemRefTransforms
MLIRMemRefUtils
MLIRParser
MLIRPass
MLIRSCFDialect
MLIRSCFToControlFlow
Expand Down Expand Up @@ -257,6 +260,7 @@ iree_cc_library(
MLIRMemRefTransformOps
MLIRPDLDialect
MLIRPDLInterpDialect
MLIRParser
MLIRPass
MLIRRewrite
MLIRSCFDialect
Expand Down
47 changes: 27 additions & 20 deletions compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ findNestedModulesWithNamedSequences(ModuleOp module) {
static SmallVector<NamedSequenceOp> findTuningSpecs(ModuleOp module) {
Block *body = module.getBody();
return llvm::filter_to_vector(
body->getOps<NamedSequenceOp>(),
[](NamedSequenceOp op) { return op->hasAttr(kTuningSpecAttrName); });
body->getOps<NamedSequenceOp>(), [](NamedSequenceOp op) {
return op->hasAttr(kTuningSpecEntrypointAttrName);
});
}

static LogicalResult validateTuningSpec(NamedSequenceOp op) {
Expand Down Expand Up @@ -85,7 +86,7 @@ emitLinkedTuningSpec(ModuleOp module, ArrayRef<NamedSequenceOp> specsToLink) {
/*res_attrs*/ ArrayAttr{});
newSpec.setArgAttr(0, transform::TransformDialect::kArgReadOnlyAttrName,
builder.getUnitAttr());
newSpec->setAttr(kTuningSpecAttrName, builder.getUnitAttr());
newSpec->setAttr(kTuningSpecEntrypointAttrName, builder.getUnitAttr());

Region &region = newSpec.getRegion();
Block *body = builder.createBlock(&region, region.begin(),
Expand Down Expand Up @@ -122,28 +123,34 @@ struct LinkTuningSpecsPass final
}

void runOnOperation() override {
ModuleOp module = getOperation();
SmallVector<NamedSequenceOp> tuningSpecs;

for (ModuleOp nested : findNestedModulesWithNamedSequences(module)) {
llvm::append_range(tuningSpecs, findTuningSpecs(nested));
if (failed(linkTuningSpecs(getOperation()))) {
signalPassFailure();
}
}
};

for (NamedSequenceOp spec : tuningSpecs) {
LDBG("Found tuning spec: " << spec.getSymName());
if (failed(validateTuningSpec(spec))) {
return signalPassFailure();
}
}
} // namespace

FailureOr<NamedSequenceOp> linkTuningSpecs(ModuleOp module) {
SmallVector<NamedSequenceOp> tuningSpecs;

if (tuningSpecs.empty()) {
LDBG("No tuning specs found, exiting without linking");
return;
for (ModuleOp nested : findNestedModulesWithNamedSequences(module)) {
llvm::append_range(tuningSpecs, findTuningSpecs(nested));
}

for (NamedSequenceOp spec : tuningSpecs) {
LDBG("Found tuning spec: " << spec.getSymName());
if (failed(validateTuningSpec(spec))) {
return failure();
}
}

emitLinkedTuningSpec(module, tuningSpecs);
if (tuningSpecs.empty()) {
LDBG("No tuning specs found, exiting without linking");
return NamedSequenceOp{};
}
};

} // namespace
return emitLinkedTuningSpec(module, tuningSpecs);
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cassert>
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OwningOpRef.h"
#include "mlir/Support/FileUtilities.h"

#define DEBUG_TYPE "iree-codegen-materialize-tuning-specs"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_MATERIALIZETUNINGSPECSPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

llvm::cl::opt<std::string> clCodegenTuningSpecPath(
"iree-codegen-tuning-spec-path",
llvm::cl::desc("File path to a module containing a tuning spec (transform "
"dialect library)."),
llvm::cl::init(""));

llvm::cl::opt<std::string> clCodegenTuningSpecDumpDir(
"iree-codegen-dump-tuning-specs-to",
llvm::cl::desc(
"Dump the final tuning spec modules to the specified directory. When "
"set to '-', prints the tuning spec to stdout."),
llvm::cl::init(""));

using mlir::transform::NamedSequenceOp;

static LogicalResult dumpFinalTuningSpecToDir(ModuleOp tuningSpec,
StringRef dir) {
if (dir == "-") {
tuningSpec->print(llvm::outs());
return success();
}

llvm::sys::fs::create_directories(dir);
llvm::SmallString<64> dumpPath;
auto dumpFileEC = llvm::sys::fs::createUniqueFile(
Twine(dir) + "/iree_tuning_spec_%%.mlir", dumpPath);
if (dumpFileEC) {
return tuningSpec->emitError()
<< "Failed to create a unique file in " << dir << "\n";
}
LDBG("Linked tuning spec file path: " << dumpPath);

std::string error;
auto file = mlir::openOutputFile(dumpPath, &error);
if (!file) {
return tuningSpec->emitError()
<< "Failed to open a tuning spec dump file " << dumpPath << "\n";
}

tuningSpec->print(file->os());
file->keep();
return success();
}

static FailureOr<DenseElementsAttr>
serializeTuningSpecToAttr(ModuleOp tuningSpec) {
std::string buffer;
llvm::raw_string_ostream os(buffer);
if (failed(writeBytecodeToFile(tuningSpec, os))) {
return failure();
}

auto bufferSize = static_cast<int64_t>(buffer.size());
auto bufferShape = VectorType::get(
bufferSize, IntegerType::get(tuningSpec->getContext(), 8));
return DenseElementsAttr::getFromRawBuffer(
bufferShape, ArrayRef(buffer.data(), buffer.data() + bufferSize));
}

struct MaterializeTuningSpecsPass final
: impl::MaterializeTuningSpecsPassBase<MaterializeTuningSpecsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registerTransformDialectTranslationDependentDialects(registry);
}

void runOnOperation() override {
if (clCodegenTuningSpecPath.empty()) {
return;
}

ModuleOp module = getOperation();
MLIRContext *ctx = &getContext();
auto dialect = ctx->getOrLoadDialect<IREE::Codegen::IREECodegenDialect>();
auto maybeTransformLibrary =
dialect->getOrLoadTransformLibraryModule(clCodegenTuningSpecPath);
if (failed(maybeTransformLibrary)) {
module->emitError()
<< "Failed to load tuning spec transform dialect library from "
<< clCodegenTuningSpecPath;
return signalPassFailure();
}

ModuleOp userTuningSpec = *maybeTransformLibrary;
if (!userTuningSpec.getSymName()) {
// Set a module name so that we can refer to its nested symbols.
userTuningSpec.setSymName("iree_user_tuning_spec");
}

Location loc = userTuningSpec.getLoc();

// This module will always be released at the end of the pass.
OwningOpRef<ModuleOp> linkedTuningSpec(
ModuleOp::create(loc, "iree_linked_tuning_spec"));
linkedTuningSpec.get()->setAttr(
transform::TransformDialect::kWithNamedSequenceAttrName,
UnitAttr::get(ctx));
linkedTuningSpec->insert(linkedTuningSpec->begin(), userTuningSpec.clone());

// TODO(https://github.com/iree-org/iree/issues/19214): Add linked tuning
// spec memoization to IREECodegenDialect. We should be able to provide a
// list of input libraries that may have already been linked and ask the
// dialect to return it to us, or invoke a callback that will insert it if
// not found.
FailureOr<transform::NamedSequenceOp> newEntrypoint =
linkTuningSpecs(linkedTuningSpec.get());
if (failed(newEntrypoint)) {
module->emitError("Failed to link tuning specs");
return signalPassFailure();
}

if (!clCodegenTuningSpecDumpDir.empty()) {
if (failed(dumpFinalTuningSpecToDir(linkedTuningSpec.get(),
clCodegenTuningSpecDumpDir))) {
return signalPassFailure();
}
}

FailureOr<DenseElementsAttr> serializedSpec =
serializeTuningSpecToAttr(linkedTuningSpec.get());
if (failed(serializedSpec)) {
module->emitError("Failed to serialize linked tuning specs");
return signalPassFailure();
}
module->setAttr(kSerializedTuningSpecAttrName, *serializedSpec);
}
};

} // namespace
} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cassert>
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/UserConfig.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OwningOpRef.h"
#include "mlir/Parser/Parser.h"

#define DEBUG_TYPE "iree-codegen-materialize-user-configs"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
Expand Down Expand Up @@ -110,6 +112,40 @@ getTransformLibraryFromPath(ModuleOp compiledModule, StringRef path) {
entrySequenceName.str()};
}

// Look up the tuning spec in the given module or any of its parents.
static LogicalResult getModuleTuningSpec(ModuleOp compiledModule,
OwningOpRef<ModuleOp> &tuningSpec) {
IREE::Util::SerializableAttrInterface serializedTuningSpec;
Operation *op = compiledModule;
while (!serializedTuningSpec && op) {
serializedTuningSpec =
op->getAttrOfType<IREE::Util::SerializableAttrInterface>(
kSerializedTuningSpecAttrName);
op = op->getParentOp();
}

if (!serializedTuningSpec) {
return failure();
}

SmallVector<char, 0> bytecode;
if (failed(serializedTuningSpec.serializeToVector(
compiledModule->getLoc(), llvm::endianness::native, bytecode))) {
return compiledModule.emitError()
<< "Failed to read attribute " << kSerializedTuningSpecAttrName;
}

ParserConfig config(compiledModule.getContext());
tuningSpec = parseSourceString<ModuleOp>(
StringRef(bytecode.data(), bytecode.size()), config);
if (!tuningSpec) {
return compiledModule.emitError() << "Failed to parse tuning spec in "
<< kSerializedTuningSpecAttrName;
}
LDBG("--loaded tuning spec");
return success();
}

struct MaterializeUserConfigsPass final
: impl::MaterializeUserConfigsPassBase<MaterializeUserConfigsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
Expand All @@ -119,9 +155,28 @@ struct MaterializeUserConfigsPass final
void runOnOperation() override {
ModuleOp moduleOp = getOperation();

// Try to load the transform library from the user flag first. If none is
// specified, fall back to using the module tuning spec.
FailureOr<TransformLibraryWithEntrypoint> userTransformLibrary =
getTransformLibraryFromPath(moduleOp,
clCodegenTransformDialectLibraryFileName);
OwningOpRef<ModuleOp> tuningSpec;
if (failed(userTransformLibrary)) {
if (succeeded(getModuleTuningSpec(moduleOp, tuningSpec))) {
assert(tuningSpec);
userTransformLibrary = TransformLibraryWithEntrypoint{
tuningSpec.get(), kKernelConfigSpecName.str()};
}
}

// Remove the tuning spec, if any, from the current module. If the tuning
// spec is attached to some other parent op, we conservatively keep it
// as-is, as we are not sure who the producer is and if they want it
// removed.
if (moduleOp->hasAttr(kSerializedTuningSpecAttrName)) {
moduleOp->removeAttr(kSerializedTuningSpecAttrName);
LDBG("--dropped the serialized tuning spec from the module");
}

for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {

Expand Down
Loading

0 comments on commit d3cd5b2

Please sign in to comment.