From ddad4ecdffad5eec7b8a5822859a6dcae2a57e21 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 19 Feb 2024 13:22:35 -0800 Subject: [PATCH 01/25] Adding #hal.device.select and related attributes. These allow for device globals to be identified and initialized from available runtime devices. The new InitializeDevicesPass finds globals with the attributes set and builds the appropriate initializers as part of the HAL pipeline. --- .../LLVMCPU/test/smoketest_embedded.mlir | 4 +- .../target/LLVMCPU/test/smoketest_system.mlir | 4 +- .../iree/compiler/Dialect/HAL/IR/HALAttrs.cpp | 450 +++++++++++++++++- .../iree/compiler/Dialect/HAL/IR/HALAttrs.td | 244 +++++++++- .../compiler/Dialect/HAL/IR/HALInterfaces.td | 51 +- .../iree/compiler/Dialect/HAL/IR/HALOps.cpp | 14 +- .../iree/compiler/Dialect/HAL/IR/HALTypes.h | 7 + .../Dialect/HAL/IR/test/attributes.mlir | 42 +- .../HAL/Target/Devices/LocalDevice.cpp | 2 +- .../Dialect/HAL/Target/TargetDevice.cpp | 51 +- .../Dialect/HAL/Target/TargetDevice.h | 15 - .../Dialect/HAL/Transforms/BUILD.bazel | 1 + .../Dialect/HAL/Transforms/CMakeLists.txt | 1 + .../HAL/Transforms/InitializeDevices.cpp | 111 +++++ .../Dialect/HAL/Transforms/Passes.cpp | 16 +- .../compiler/Dialect/HAL/Transforms/Passes.td | 42 +- .../Dialect/HAL/Transforms/test/BUILD.bazel | 1 + .../HAL/Transforms/test/CMakeLists.txt | 1 + .../HAL/Transforms/test/convert_to_hal.mlir | 7 +- .../test/dump_executable_sources.mlir | 69 ++- .../Transforms/test/initialize_devices.mlir | 106 +++++ runtime/src/iree/hal/device.c | 8 - 22 files changed, 1081 insertions(+), 166 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/InitializeDevices.cpp create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/test/initialize_devices.mlir diff --git a/compiler/plugins/target/LLVMCPU/test/smoketest_embedded.mlir b/compiler/plugins/target/LLVMCPU/test/smoketest_embedded.mlir index 493a3c735f93..e772c4da3f86 100644 --- a/compiler/plugins/target/LLVMCPU/test/smoketest_embedded.mlir +++ b/compiler/plugins/target/LLVMCPU/test/smoketest_embedded.mlir @@ -4,7 +4,9 @@ module attributes { hal.device.targets = [ #hal.device.target<"llvm-cpu", [ - #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", { native_vector_size = 16 : index }> + #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", { + native_vector_size = 16 : index + }> ]> ] } { diff --git a/compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir b/compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir index bb5c607e2765..6e7f8d5327fc 100644 --- a/compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir +++ b/compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir @@ -6,7 +6,9 @@ module attributes { hal.device.targets = [ #hal.device.target<"llvm-cpu", [ - #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64",{ native_vector_size = 16 : index } > + #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", { + native_vector_size = 16 : index + }> ]> ] } { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp index 1ff5f24d1f7e..f5ac50deb4d8 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp @@ -12,6 +12,7 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/Path.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Parser/Parser.h" // clang-format off: must be included after all LLVM/MLIR headers. @@ -23,7 +24,7 @@ namespace mlir::iree_compiler::IREE::HAL { //===----------------------------------------------------------------------===// -// Enum utilities +// Utilities //===----------------------------------------------------------------------===// template @@ -348,6 +349,80 @@ DeviceTargetAttr::lookupExecutableTargets(Operation *op) { return resultAttrs; } +void IREE::HAL::DeviceTargetAttr::printStatusDescription( + llvm::raw_ostream &os) const { + cast().print(os, /*elideType=*/true); +} + +// Produces a while-loop that enumerates each device available and tries to +// match it against the target information. SCF is... not very wieldy, but this +// is effectively: +// ``` +// %device_count = hal.devices.count : index +// %result:2 = scf.while(%i = 0, %device = null) { +// %is_null = util.cmp.eq %device, null : !hal.device +// %in_bounds = arith.cmpi slt %i, %device_count : index +// %continue_while = arith.andi %is_null, %in_bounds : i1 +// scf.condition(%continue_while) %i, %device : index, !hal.device +// } do { +// %device_i = hal.devices.get %i : !hal.device +// %is_match = <>(%device_i) +// %try_device = arith.select %is_match, %device_i, null : !hal.device +// %next_i = arith.addi %i, %c1 : index +// scf.yield %next_i, %try_device : index, !hal.device +// } +// ``` +// Upon completion %result#1 contains the device (or null). +Value IREE::HAL::DeviceTargetAttr::buildDeviceEnumeration( + Location loc, const IREE::HAL::TargetRegistry &targetRegistry, + OpBuilder &builder) const { + // Defers to the target backend to build the device match or does a simple + // fallback for unregistered backends (usually for testing, but may be used + // as a way to bypass validation for out-of-tree experiments). + auto buildDeviceMatch = [&](Location loc, Value device, + OpBuilder &builder) -> Value { + // Ask the target backend to build the match expression. It may opt to + // let the default handling take care of things. + Value match; + auto targetDevice = targetRegistry.getTargetDevice(getDeviceID()); + if (targetDevice) + match = targetDevice->buildDeviceTargetMatch(loc, device, *this, builder); + if (match) + return match; + return buildDeviceIDAndExecutableFormatsMatch( + loc, device, getDeviceID(), getExecutableTargets(), builder); + }; + + // Enumerate all devices and match the first one found (if any). + Type indexType = builder.getIndexType(); + Type deviceType = builder.getType(); + Value c0 = builder.create(loc, 0); + Value c1 = builder.create(loc, 1); + Value nullDevice = builder.create(loc, deviceType); + Value deviceCount = builder.create(loc, indexType); + auto whileOp = builder.create( + loc, TypeRange{indexType, deviceType}, ValueRange{c0, nullDevice}, + [&](OpBuilder &beforeBuilder, Location loc, ValueRange operands) { + Value isNull = beforeBuilder.create( + loc, operands[1], nullDevice); + Value inBounds = beforeBuilder.create( + loc, arith::CmpIPredicate::slt, operands[0], deviceCount); + Value continueWhile = + beforeBuilder.create(loc, isNull, inBounds); + beforeBuilder.create(loc, continueWhile, operands); + }, + [&](OpBuilder &afterBuilder, Location loc, ValueRange operands) { + Value device = afterBuilder.create( + loc, deviceType, operands[0]); + Value isMatch = buildDeviceMatch(loc, device, afterBuilder); + Value tryDevice = afterBuilder.create( + loc, isMatch, device, nullDevice); + Value nextI = afterBuilder.create(loc, operands[0], c1); + afterBuilder.create(loc, ValueRange{nextI, tryDevice}); + }); + return whileOp.getResult(1); +} + //===----------------------------------------------------------------------===// // #hal.executable.target<*> //===----------------------------------------------------------------------===// @@ -673,6 +748,379 @@ std::optional ExecutableObjectsAttr::getApplicableObjects( return ArrayAttr::get(specificTargetAttr.getContext(), allObjectAttrs); } +//===----------------------------------------------------------------------===// +// #hal.device.alias<*> +//===----------------------------------------------------------------------===// + +// static +DeviceAliasAttr DeviceAliasAttr::get(MLIRContext *context, StringRef deviceID) { + return get(context, IREE::HAL::DeviceType::get(context), + StringAttr::get(context, deviceID), std::nullopt, + DictionaryAttr::get(context)); +} + +//===----------------------------------------------------------------------===// +// #hal.device.target<*> +//===----------------------------------------------------------------------===// + +// static +DeviceTargetAttr DeviceTargetAttr::get(MLIRContext *context, + StringRef deviceID) { + // TODO(benvanik): query default configuration from the target backend. + return get(context, StringAttr::get(context, deviceID), + DictionaryAttr::get(context), {}); +} + +// static +Attribute DeviceTargetAttr::parse(AsmParser &p, Type type) { + StringAttr deviceIDAttr; + DictionaryAttr configAttr; + SmallVector executableTargetAttrs; + // `<"device-id"` + if (failed(p.parseLess()) || failed(p.parseAttribute(deviceIDAttr))) { + return {}; + } + // `, ` + if (succeeded(p.parseOptionalComma())) { + if (succeeded(p.parseOptionalLSquare())) { + // `[targets, ...]` (optional) + do { + IREE::HAL::ExecutableTargetAttr executableTargetAttr; + if (failed(p.parseAttribute(executableTargetAttr))) + return {}; + executableTargetAttrs.push_back(executableTargetAttr); + } while (succeeded(p.parseOptionalComma())); + if (failed(p.parseRSquare())) + return {}; + } else { + // `{config dict}` (optional) + if (failed(p.parseAttribute(configAttr))) + return {}; + // `, [targets, ...]` (optional) + if (succeeded(p.parseOptionalComma())) { + if (failed(p.parseLSquare())) + return {}; + do { + IREE::HAL::ExecutableTargetAttr executableTargetAttr; + if (failed(p.parseAttribute(executableTargetAttr))) + return {}; + executableTargetAttrs.push_back(executableTargetAttr); + } while (succeeded(p.parseOptionalComma())); + if (failed(p.parseRSquare())) + return {}; + } + } + } + // `>` + if (failed(p.parseGreater())) { + return {}; + } + return get(p.getContext(), deviceIDAttr, configAttr, executableTargetAttrs); +} + +void DeviceTargetAttr::print(AsmPrinter &p) const { + auto &os = p.getStream(); + os << "<"; + p.printAttribute(getDeviceID()); + auto configAttr = getConfiguration(); + if (configAttr && !configAttr.empty()) { + os << ", "; + p.printAttribute(configAttr); + } + auto executableTargetAttrs = getExecutableTargets(); + if (!executableTargetAttrs.empty()) { + os << ", ["; + llvm::interleaveComma(executableTargetAttrs, os, + [&](auto executableTargetAttr) { + p.printAttribute(executableTargetAttr); + }); + os << "]"; + } + os << ">"; +} + +std::string DeviceTargetAttr::getSymbolNameFragment() { + return sanitizeSymbolName(getDeviceID().getValue().lower()); +} + +bool DeviceTargetAttr::hasConfigurationAttr(StringRef name) { + auto configAttr = getConfiguration(); + return configAttr && configAttr.get(name); +} + +void DeviceTargetAttr::getExecutableTargets( + SetVector &resultAttrs) { + for (auto attr : getExecutableTargets()) { + resultAttrs.insert(attr); + } +} + +void IREE::HAL::DeviceTargetAttr::printStatusDescription( + llvm::raw_ostream &os) const { + mlir::cast(this)->print(os, /*elideType=*/true); +} + +// Produces a while-loop that enumerates each device available and tries to +// match it against the target information. SCF is... not very wieldy, but this +// is effectively: +// ``` +// %device_count = hal.devices.count : index +// %result:3 = scf.while(%i = 0, %match_ordinal = 0, %device = null) { +// %is_null = util.cmp.eq %device, null : !hal.device +// %in_bounds = arith.cmpi slt %i, %device_count : index +// %continue_while = arith.andi %is_null, %in_bounds : i1 +// scf.condition(%continue_while) %i, %match_ordinal %device +// : index, index, !hal.device +// } do { +// %device_i = hal.devices.get %i : !hal.device +// %device_match = <>(%device_i) +// %ordinal_match = arith.cmpi eq %match_ordinal, %device_ordinal : index +// %is_match = arith.andi %device_match, %ordinal_match : i1 +// %try_device = arith.select %is_match, %device_i, null : !hal.device +// %next_i = arith.addi %i, %c1 : index +// %match_adv = arith.select %device_match, %c1, %c0 : index +// %next_match_ordinal = arith.addi %match_ordinal, %match_adv : index +// scf.yield %next_i, %next_match_ordinal, %try_device +// : index, index !hal.device +// } +// ``` +// Upon completion %result#1 contains the device (or null). +// If the target had an ordinal specified we skip matches until a match with the +// specified ordinal is reached. +Value IREE::HAL::DeviceTargetAttr::buildDeviceEnumeration( + Location loc, IREE::HAL::BuildDeviceTargetMatchFn buildDeviceTargetMatch, + OpBuilder &builder) const { + // Device configuration can control selection beyond just the match + // expression. + auto configAttr = getConfiguration(); + IntegerAttr deviceOrdinalAttr = + configAttr ? configAttr.getAs("ordinal") : IntegerAttr{}; + + // Defers to the target backend to build the device match or does a simple + // fallback for unregistered backends (usually for testing, but may be used + // as a way to bypass validation for out-of-tree experiments). + auto buildDeviceMatch = [&](Location loc, Value device, + OpBuilder &builder) -> Value { + // Ask the target backend to build the match expression. It may opt to + // let the default handling take care of things. + Value match = buildDeviceTargetMatch(loc, device, *this, builder); + if (match) + return match; + return IREE::HAL::DeviceTargetAttr::buildDeviceIDAndExecutableFormatsMatch( + loc, device, getDeviceID(), getExecutableTargets(), builder); + }; + + // Enumerate all devices and match the first one found (if any). + Type indexType = builder.getIndexType(); + Type deviceType = builder.getType(); + Value c0 = builder.create(loc, 0); + Value c1 = builder.create(loc, 1); + Value nullDevice = builder.create(loc, deviceType); + Value deviceOrdinal = deviceOrdinalAttr + ? builder.create( + loc, deviceOrdinalAttr.getInt()) + : c0; + Value deviceCount = builder.create(loc, indexType); + auto whileOp = builder.create( + loc, + TypeRange{ + /*i=*/indexType, + /*match_ordinal=*/indexType, + /*device=*/deviceType, + }, + ValueRange{ + /*i=*/c0, + /*match_ordinal=*/c0, + /*device=*/nullDevice, + }, + [&](OpBuilder &beforeBuilder, Location loc, ValueRange operands) { + Value isNull = beforeBuilder.create( + loc, operands[/*device=*/2], nullDevice); + Value inBounds = beforeBuilder.create( + loc, arith::CmpIPredicate::slt, operands[/*i=*/0], deviceCount); + Value continueWhile = + beforeBuilder.create(loc, isNull, inBounds); + beforeBuilder.create(loc, continueWhile, operands); + }, + [&](OpBuilder &afterBuilder, Location loc, ValueRange operands) { + // Check whether the device is a match. + Value device = afterBuilder.create( + loc, deviceType, operands[/*i=*/0]); + Value isDeviceMatch = buildDeviceMatch(loc, device, afterBuilder); + + // Check whether whether this matching device ordinal is the requested + // ordinal out of all matching devices. + Value isOrdinalMatch = afterBuilder.create( + loc, arith::CmpIPredicate::eq, operands[/*match_ordinal=*/1], + deviceOrdinal); + Value nextMatchOrdinal = afterBuilder.create( + loc, operands[/*match_ordinal=*/1], + afterBuilder.create(loc, isDeviceMatch, c1, c0)); + + // Break if the device and ordinal match, otherwise continue with null. + Value isMatch = afterBuilder.create(loc, isDeviceMatch, + isOrdinalMatch); + Value tryDevice = afterBuilder.create( + loc, isMatch, device, nullDevice); + + Value nextI = + afterBuilder.create(loc, operands[/*i=*/0], c1); + afterBuilder.create( + loc, ValueRange{ + /*i=*/nextI, + /*match_ordinal=*/nextMatchOrdinal, + /*device=*/tryDevice, + }); + }); + return whileOp.getResult(/*device=*/2); +} + +// static +Value DeviceTargetAttr::buildDeviceIDAndExecutableFormatsMatch( + Location loc, Value device, StringRef deviceIDPattern, + ArrayRef executableTargetAttrs, + OpBuilder &builder) { + // Match first on the device ID, as that's the top-level filter. + Value idMatch = IREE::HAL::DeviceQueryOp::createI1( + loc, device, "hal.device.id", deviceIDPattern, builder); + + // If there are executable formats defined we should check at least one of + // them is supported. + if (executableTargetAttrs.empty()) { + return idMatch; // just device ID + } else { + auto ifOp = builder.create(loc, builder.getI1Type(), idMatch, + true, true); + auto thenBuilder = ifOp.getThenBodyBuilder(); + Value anyFormatMatch = buildExecutableFormatMatch( + loc, device, executableTargetAttrs, thenBuilder); + thenBuilder.create(loc, anyFormatMatch); + auto elseBuilder = ifOp.getElseBodyBuilder(); + Value falseValue = elseBuilder.create(loc, 0, 1); + elseBuilder.create(loc, falseValue); + return ifOp.getResult(0); + } +} + +// static +Value DeviceTargetAttr::buildExecutableFormatMatch( + Location loc, Value device, + ArrayRef executableTargetAttrs, + OpBuilder &builder) { + if (executableTargetAttrs.empty()) + return builder.create(loc, 1, 1); + Value anyFormatMatch; + for (auto executableTargetAttr : executableTargetAttrs) { + Value formatMatch = IREE::HAL::DeviceQueryOp::createI1( + loc, device, "hal.executable.format", + executableTargetAttr.getFormat().getValue(), builder); + if (!anyFormatMatch) { + anyFormatMatch = formatMatch; + } else { + anyFormatMatch = + builder.create(loc, anyFormatMatch, formatMatch); + } + } + return anyFormatMatch; +} + +//===----------------------------------------------------------------------===// +// #hal.device.ordinal<*> +//===----------------------------------------------------------------------===// + +void IREE::HAL::DeviceOrdinalAttr::printStatusDescription( + llvm::raw_ostream &os) const { + mlir::cast(this)->print(os, /*elideType=*/true); +} + +Value IREE::HAL::DeviceOrdinalAttr::buildDeviceEnumeration( + Location loc, IREE::HAL::BuildDeviceTargetMatchFn buildDeviceTargetMatch, + OpBuilder &builder) const { + return builder.create( + loc, getType(), + builder.create(loc, getOrdinal())); +} + +//===----------------------------------------------------------------------===// +// #hal.device.fallback<*> +//===----------------------------------------------------------------------===// + +void IREE::HAL::DeviceFallbackAttr::printStatusDescription( + llvm::raw_ostream &os) const { + mlir::cast(this)->print(os, /*elideType=*/true); +} + +Value IREE::HAL::DeviceFallbackAttr::buildDeviceEnumeration( + Location loc, IREE::HAL::BuildDeviceTargetMatchFn buildDeviceTargetMatch, + OpBuilder &builder) const { + // TODO(benvanik): hal.device.cast if needed - may need to look up the global + // to do it as we don't encode what the device is here in a way that is + // guaranteed to be consistent. + return builder.create(loc, getType(), + getName().getValue()); +} + +//===----------------------------------------------------------------------===// +// #hal.device.select<*> +//===----------------------------------------------------------------------===// + +// static +LogicalResult +DeviceSelectAttr::verify(function_ref emitError, + Type type, ArrayAttr devicesAttr) { + if (devicesAttr.empty()) + return emitError() << "must have at least one device to select"; + for (auto deviceAttr : devicesAttr) { + if (!deviceAttr.isa()) { + return emitError() << "can only select between #hal.device.target, " + "#hal.device.ordinal, #hal.device.fallback, or " + "other device initialization attributes"; + } + } + // TODO(benvanik): when !hal.device is parameterized we should check that the + // type is compatible with the entries. + return success(); +} + +void IREE::HAL::DeviceSelectAttr::printStatusDescription( + llvm::raw_ostream &os) const { + // TODO(benvanik): print something easier to read (newline per device, etc). + mlir::cast(this)->print(os, /*elideType=*/true); +} + +// Builds a recursive nest of try-else blocks for each device specified. +Value IREE::HAL::DeviceSelectAttr::buildDeviceEnumeration( + Location loc, IREE::HAL::BuildDeviceTargetMatchFn buildDeviceTargetMatch, + OpBuilder &builder) const { + Type deviceType = builder.getType(); + Value nullDevice = builder.create(loc, deviceType); + std::function, + OpBuilder &)> + buildTry; + buildTry = + [&](ArrayRef deviceAttrs, + OpBuilder &tryBuilder) -> Value { + auto deviceAttr = deviceAttrs.front(); + Value tryDevice = deviceAttr.buildDeviceEnumeration( + loc, buildDeviceTargetMatch, tryBuilder); + if (deviceAttrs.size() == 1) + return tryDevice; // termination case + Value isNull = + tryBuilder.create(loc, tryDevice, nullDevice); + auto ifOp = + tryBuilder.create(loc, deviceType, isNull, true, true); + auto thenBuilder = ifOp.getThenBodyBuilder(); + Value tryChainDevice = buildTry(deviceAttrs.drop_front(1), thenBuilder); + thenBuilder.create(loc, tryChainDevice); + auto elseBuilder = ifOp.getElseBodyBuilder(); + elseBuilder.create(loc, tryDevice); + return ifOp.getResult(0); + }; + SmallVector deviceAttrs( + getDevices().getAsRange()); + return buildTry(deviceAttrs, builder); +} + //===----------------------------------------------------------------------===// // #hal.affinity.queue<*> //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index 9d85020b2302..69a1f15e633b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td @@ -480,8 +480,9 @@ def HAL_InterfaceBindingArrayAttr : // #hal.device.target<*> //===----------------------------------------------------------------------===// -def HAL_DeviceTargetAttr : - AttrDef { +def HAL_DeviceTargetAttr : AttrDef, +]> { let mnemonic = "device.target"; let summary = [{generic device target specification}]; let description = [{ @@ -491,6 +492,9 @@ def HAL_DeviceTargetAttr : several target executable formats specified with `#hal.executable.target`. An optional configuration dictionary allows for overriding backend defaults. + If used to initialize a device global returns the first device matching or + null if no devices match. + Example: ```mlir #hal.device.target<"llvm-cpu", { @@ -511,6 +515,8 @@ def HAL_DeviceTargetAttr : ]; let extraClassDeclaration = [{ + Type getType() { return IREE::HAL::DeviceType::get(getContext()); } + // Returns a symbol-compatible name that pseudo-uniquely identifies this // target. Callers must perform deduplication when required. std::string getSymbolNameFragment(); @@ -562,6 +568,7 @@ def HAL_DeviceTargetAttr : static SmallVector lookupExecutableTargets(Operation *op); }]; + let hasCustomAssemblyFormat = 1; } @@ -742,6 +749,239 @@ def HAL_ExecutableObjectsAttr : AttrDef { }]; } +//===----------------------------------------------------------------------===// +// #hal.device.alias<*> +//===----------------------------------------------------------------------===// + +def HAL_DeviceAliasAttr : AttrDef { + let mnemonic = "device.alias"; + let summary = [{device target named alias}]; + let description = [{ + Specifies a device target by named alias whose configuration will be + expanded based on compiler configuration and flags. Any configuration + provided will override any defaults provided by the configuration. + + Example: + ```mlir + // Default `vulkan` device: + #hal.device.alias<"vulkan"> : !hal.device + // Default `vulkan` device with configuration overrides: + #hal.device.alias<"vulkan", { + device_config = 123 : index + }> : !hal.device + // The 3rd default `vulkan` device detected at runtime (ordinal = 3): + #hal.device.alias<"vulkan"[3]> : !hal.device + ``` + }]; + + let parameters = (ins + AttributeSelfTypeParameter<"">:$type, + AttrParameter<"StringAttr", "">:$deviceID, + OptionalParameter<"std::optional", "">:$ordinal, + OptionalParameter<"DictionaryAttr", "">:$configuration + ); + + let builders = [ + AttrBuilder<(ins "StringRef":$deviceID)>, + ]; + + let assemblyFormat = [{ + `<` + $deviceID + `` (`[` $ordinal^ `]`)? + (`,` $configuration^)? + `>` + }]; + + let extraClassDeclaration = [{ + Type getType() { return IREE::HAL::DeviceType::get(getContext()); } + }]; +} + +//===----------------------------------------------------------------------===// +// #hal.device.target<*> +//===----------------------------------------------------------------------===// + +def HAL_DeviceTargetAttr : AttrDef, +]> { + let mnemonic = "device.target"; + let summary = [{generic device target specification}]; + let description = [{ + Specifies the properties of a target runtime device. + Target devices are specified with a canonical identifier matching those used + by the runtime (such as `cpu`, `vulkan`, etc). Target devices may support + several target executable formats specified with `#hal.executable.target`. + An optional configuration dictionary allows for overriding backend defaults. + + If used to initialize a device global returns the first device matching the + target requirements or null if no devices match. An optional `ordinal` + index may be provided that selects the N-th matching device and is used to + select between multiple homogeneous devices. + + Example: + ```mlir + #hal.device.target<"llvm-cpu", { + device_configuration = ... + }, [ + #hal.executable.target<"llvm-cpu", "embedded-elf-arm_32">, + #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64">, + ]> : !hal.device + ``` + }]; + let parameters = (ins + AttrParameter<"StringAttr", "">:$deviceID, + AttrParameter<"DictionaryAttr", "">:$configuration, + ArrayRefParameter<"ExecutableTargetAttr", "">:$executable_targets + ); + let builders = [ + AttrBuilder<(ins "StringRef":$deviceID)>, + ]; + + let extraClassDeclaration = [{ + Type getType() { return IREE::HAL::DeviceType::get(getContext()); } + + // Returns a symbol-compatible name that pseudo-uniquely identifies this + // target. Callers must perform deduplication when required. + std::string getSymbolNameFragment(); + + // Returns true if there's an attribute with the given name in the + // configuration dictionary. + bool hasConfigurationAttr(StringRef name); + + // Returns zero or more executable targets that this device supports. + void getExecutableTargets( + SetVector &resultAttrs); + + // Builds an expression that returns an i1 indicating whether the given + // |device| matches the device ID string pattern and executable target + // requirements. + static Value buildDeviceIDAndExecutableFormatsMatch( + Location loc, Value device, StringRef deviceIDPattern, + ArrayRef executableTargetAttrs, + OpBuilder &builder); + + // Builds a match expression that returns an i1 indicating whether the given + // |device| supports any one of the |executableTargetAttrs|. + static Value buildExecutableFormatMatch( + Location loc, Value device, + ArrayRef executableTargetAttrs, + OpBuilder &builder); + }]; + + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// #hal.device.ordinal<*> +//===----------------------------------------------------------------------===// + +def HAL_DeviceOrdinalAttr : AttrDef, +]> { + let mnemonic = "device.ordinal"; + let summary = [{specifies a device by runtime registration ordinal}]; + let description = [{ + Represents the device registered with the runtime in the order it was + registered with ordinal 0 being the first registered. Returns null during + initialization if the device ordinal is out of range. + }]; + + let parameters = (ins + AttributeSelfTypeParameter<"">:$type, + AttrParameter<"int64_t", "">:$ordinal + ); + + let assemblyFormat = [{ + `<` $ordinal `>` + }]; +} + +//===----------------------------------------------------------------------===// +// #hal.device.fallback<*> +//===----------------------------------------------------------------------===// + +def HAL_DeviceFallbackAttr : AttrDef, +]> { + let mnemonic = "device.fallback"; + let summary = [{specifies a reference to another device}]; + let description = [{ + Specifies by symbol a device that has already been initialized. + Returns null during initialization if the device specified as a fallback is + null. + }]; + + let parameters = (ins + AttributeSelfTypeParameter<"">:$type, + AttrParameter<"FlatSymbolRefAttr", "">:$name + ); + + let assemblyFormat = [{ + `<` $name `>` + }]; +} + +//===----------------------------------------------------------------------===// +// #hal.device.select<*> +//===----------------------------------------------------------------------===// + +def HAL_DeviceSelectAttr : AttrDef, +]> { + let mnemonic = "device.select"; + let summary = [{selects a device from one or more options}]; + let description = [{ + Selects a HAL device at runtime by either enumerating and querying for + target support or matching the given existing device by affinity. + Devices are selected in the order listed. Fails during initialization if no + device can be selected. + + Examples: + ```mlir + // Selects a single device matching the given target. + #hal.device.select<[ + #hal.device.target<"..."> : !hal.device + ]> : !hal.device + // Selects a specific device with the given symbol. + #hal.device.select<[ + #hal.device.fallback<@device_0> : !hal.device + ]> : !hal.device + // Selects a specific device by ordinal as registered at runtime. + #hal.device.select<[ + #hal.device.ordinal<0> : !hal.device + ]> : !hal.device + // Selects an optional device if available and otherwise @fallback. + #hal.device.select<[ + #hal.device.target<"some_optional_device"> : !hal.device, + #hal.device.fallback<@fallback> : !hal.device + ]> : !hal.device + ``` + }]; + + let parameters = (ins + AttributeSelfTypeParameter<"">:$type, + AttrParameter<"ArrayAttr", "">:$devices + ); + + let builders = [ + AttrBuilder<(ins + "IREE::HAL::DeviceTargetAttr":$device + )>, + AttrBuilder<(ins + "ArrayRef":$values + )>, + ]; + + let assemblyFormat = [{ + `<` $devices `>` + }]; + + let genVerifyDecl = 1; +} + //===----------------------------------------------------------------------===// // #hal.affinity.queue<*> //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td index b1df6fffa402..fb38415f12d9 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALInterfaces.td @@ -8,24 +8,53 @@ #define IREE_DIALECT_HAL_INTERFACES include "iree/compiler/Dialect/Util/IR/UtilBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" -def HAL_MatchAttrInterface : - AttrInterface<"MatchAttrInterface"> { +//===----------------------------------------------------------------------===// +// IREE::HAL::DeviceInitializationAttrInterface +//===----------------------------------------------------------------------===// + +def HAL_DeviceInitializationAttrInterface : + AttrInterface<"DeviceInitializationAttrInterface", [ + TypedAttrInterface, + ]> { let description = [{ - An attribute that can be used in `hal.*.match.*` expressions. - Each attribute defines some subexpression that can be expanded to one or - more operations that performs the actual query and matching logic. + Interface for attributes controlling device initialization. }]; let methods = [ InterfaceMethod< - [{ - Builds a set of operations that evaluate to a boolean (i1) value - indicating whether the expression tree represented by the match - attribute is true for the given value. + /*desc=*/[{ + prints a string description of the initialization specification for + inclusion in error messages. May include internal newlines but no + newline is expected at the end. + }], + /*retTy=*/"void", + /*methodName=*/"printStatusDescription", + /*args=*/(ins "llvm::raw_ostream &":$os), + /*methodBody=*/[{}] + >, + InterfaceMethod< + /*desc=*/[{ + Builds a `util.initializer` body responsible for initializing a device + global. Returns the device value that should be stored into the global. + The name provided is an informal identifier that can be used to produce + user-level error messages that reference the device. + + The provided `buildDeviceTargetMatch` function will be called with a + `!hal.device` SSA value and a device target specification and should + return an `i1` value indicating whether the given device matches the + specification. If the device always matches (rare!) a null value may + be returned. }], - "Value", "buildConditionExpression", - (ins "Location":$loc, "Value":$device, "OpBuilder":$builder) + /*retTy=*/"Value", + /*methodName=*/"buildDeviceEnumeration", + /*args=*/(ins + "Location":$loc, + "IREE::HAL::BuildDeviceTargetMatchFn":$buildDeviceTargetMatch, + "OpBuilder &":$builder + ), + /*methodBody=*/[{}] >, ]; } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index b8c0bab8c8d0..07ff982c1fbe 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -1480,17 +1480,9 @@ Value ExecutableVariantOp::createConditionOp(OpBuilder &builder) { Value ExecutableVariantOp::buildCondition(Value device, OpBuilder &builder) { // Base case dependent on target information. - // TODO(multi-device): condition on device target ID and other queries that - // may be useful for disambiguating two devices that support the same - // executable targets. Today executable targets are unique per device target - // but that need not always be the case. - auto i1Type = builder.getI1Type(); - Value selected = builder - .create( - getLoc(), i1Type, i1Type, device, - builder.getStringAttr("hal.executable.format"), - getTarget().getFormat(), builder.getZeroAttr(i1Type)) - .getValue(); + Value selected = IREE::HAL::DeviceQueryOp::createI1( + getLoc(), device, "hal.executable.format", + getTarget().getFormat().getValue(), builder); // Factor in variant condition region, if any. auto conditionOp = getConditionOp(); diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h index ef6702416182..c7b7344fd47f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h @@ -33,6 +33,13 @@ namespace mlir::iree_compiler::IREE::HAL { +class DeviceTargetAttr; +class TargetRegistry; + +using BuildDeviceTargetMatchFn = std::function; + #include "iree/compiler/Dialect/HAL/IR/HALAttrInterfaces.h.inc" // IWYU pragma: export #include "iree/compiler/Dialect/HAL/IR/HALOpInterfaces.h.inc" // IWYU pragma: export #include "iree/compiler/Dialect/HAL/IR/HALTypeInterfaces.h.inc" // IWYU pragma: export diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir index 95ea6739df38..47f8468e7770 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir @@ -62,18 +62,44 @@ // CHECK-LABEL: "device.targets" "device.targets"() { - // CHECK-SAME: target_0 = #hal.device.target<"a"> - target_0 = #hal.device.target<"a">, - // CHECK-SAME: target_1 = #hal.device.target<"b", {config}>, - target_1 = #hal.device.target<"b", {config}>, - // CHECK-SAME: target_2 = #hal.device.target<"c", {config}, [#hal.executable.target<"llvm-cpu", "f">]>, - target_2 = #hal.device.target<"c", {config}, [#hal.executable.target<"llvm-cpu", "f">]>, - // CHECK-SAME: target_3 = #hal.device.target<"d", [#hal.executable.target<"llvm-cpu", "f">]> - target_3 = #hal.device.target<"d", [#hal.executable.target<"llvm-cpu", "f">]> + // CHECK-SAME: target_0 = #hal.device.target<"a"> : !hal.device + target_0 = #hal.device.target<"a"> : !hal.device, + // CHECK-SAME: target_1 = #hal.device.target<"b", {config}> : !hal.device, + target_1 = #hal.device.target<"b", {config}> : !hal.device, + // CHECK-SAME: target_2 = #hal.device.target<"c", {config}, [#hal.executable.target<"llvm-cpu", "f">]> : !hal.device, + target_2 = #hal.device.target<"c", {config}, [#hal.executable.target<"llvm-cpu", "f">]> : !hal.device, + // CHECK-SAME: target_3 = #hal.device.target<"d", [#hal.executable.target<"llvm-cpu", "f">]> : !hal.device + target_3 = #hal.device.target<"d", [#hal.executable.target<"llvm-cpu", "f">]> : !hal.device } : () -> () // ----- +// CHECK: util.global private @device_a = #hal.device.target<"a"> : !hal.device +util.global private @device_a = #hal.device.target<"a"> : !hal.device +// CHECK: util.global private @device_0 = #hal.device.ordinal<0> : !hal.device +util.global private @device_0 = #hal.device.ordinal<0> : !hal.device + +// ----- + +// CHECK: util.global private @main = #hal.device.select<[ +// CHECK-SAME: #hal.device.target<"a"> : !hal.device +// CHECK-SAME: ]> : !hal.device +util.global private @main = #hal.device.select<[ + #hal.device.target<"a"> : !hal.device +]> : !hal.device +// CHECK: util.global private @optional = #hal.device.select<[ +// CHECK-SAME: #hal.device.target<"b"> : !hal.device, +// CHECK-SAME: #hal.device.ordinal<1> : !hal.device, +// CHECK-SAME: #hal.device.fallback<@main> : !hal.device +// CHECK-SAME: ]> : !hal.device +util.global private @optional = #hal.device.select<[ + #hal.device.target<"b"> : !hal.device, + #hal.device.ordinal<1> : !hal.device, + #hal.device.fallback<@main> : !hal.device +]> : !hal.device + +// ----- + "affinity.queue"() { // CHECK: any = #hal.affinity.queue<*> any = #hal.affinity.queue<*>, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.cpp index 302ad62d6bf0..e482a4992355 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.cpp @@ -87,7 +87,7 @@ LocalDevice::getHostDeviceTarget(MLIRContext *context, Value LocalDevice::buildDeviceTargetMatch( Location loc, Value device, IREE::HAL::DeviceTargetAttr targetAttr, OpBuilder &builder) const { - return buildDeviceIDAndExecutableFormatsMatch( + return IREE::HAL::DeviceTargetAttr::buildDeviceIDAndExecutableFormatsMatch( loc, device, "local*", targetAttr.getExecutableTargets(), builder); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp index 1695c5061ce1..2e51311c8f88 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp @@ -7,8 +7,6 @@ #include "iree/compiler/Dialect/HAL/Target/TargetDevice.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/SCF/IR/SCF.h" namespace mlir::iree_compiler::IREE::HAL { @@ -16,56 +14,9 @@ namespace mlir::iree_compiler::IREE::HAL { Value TargetDevice::buildDeviceTargetMatch( Location loc, Value device, IREE::HAL::DeviceTargetAttr targetAttr, OpBuilder &builder) const { - return buildDeviceIDAndExecutableFormatsMatch( + return IREE::HAL::DeviceTargetAttr::buildDeviceIDAndExecutableFormatsMatch( loc, device, targetAttr.getDeviceID(), targetAttr.getExecutableTargets(), builder); } -Value buildDeviceIDAndExecutableFormatsMatch( - Location loc, Value device, StringRef deviceIDPattern, - ArrayRef executableTargetAttrs, - OpBuilder &builder) { - // Match first on the device ID, as that's the top-level filter. - Value idMatch = IREE::HAL::DeviceQueryOp::createI1( - loc, device, "hal.device.id", deviceIDPattern, builder); - - // If there are executable formats defined we should check at least one of - // them is supported. - if (executableTargetAttrs.empty()) { - return idMatch; // just device ID - } else { - auto ifOp = builder.create(loc, builder.getI1Type(), idMatch, - true, true); - auto thenBuilder = ifOp.getThenBodyBuilder(); - Value anyFormatMatch = buildExecutableFormatMatch( - loc, device, executableTargetAttrs, thenBuilder); - thenBuilder.create(loc, anyFormatMatch); - auto elseBuilder = ifOp.getElseBodyBuilder(); - Value falseValue = elseBuilder.create(loc, 0, 1); - elseBuilder.create(loc, falseValue); - return ifOp.getResult(0); - } -} - -Value buildExecutableFormatMatch( - Location loc, Value device, - ArrayRef executableTargetAttrs, - OpBuilder &builder) { - if (executableTargetAttrs.empty()) - return builder.create(loc, 1, 1); - Value anyFormatMatch; - for (auto executableTargetAttr : executableTargetAttrs) { - Value formatMatch = IREE::HAL::DeviceQueryOp::createI1( - loc, device, "hal.executable.format", - executableTargetAttr.getFormat().getValue(), builder); - if (!anyFormatMatch) { - anyFormatMatch = formatMatch; - } else { - anyFormatMatch = - builder.create(loc, anyFormatMatch, formatMatch); - } - } - return anyFormatMatch; -} - } // namespace mlir::iree_compiler::IREE::HAL diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h index 2ce53f483abc..f9cb567cd844 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h @@ -48,21 +48,6 @@ class TargetDevice { // various stages. }; -// Builds an expression that returns an i1 indicating whether the given -// |device| matches the device ID string pattern and executable target -// requirements. -Value buildDeviceIDAndExecutableFormatsMatch( - Location loc, Value device, StringRef deviceIDPattern, - ArrayRef executableTargetAttrs, - OpBuilder &builder); - -// Builds a match expression that returns an i1 indicating whether the given -// |device| supports any one of the |executableTargetAttrs|. -Value buildExecutableFormatMatch( - Location loc, Value device, - ArrayRef executableTargetAttrs, - OpBuilder &builder); - } // namespace mlir::iree_compiler::IREE::HAL #endif // IREE_COMPILER_DIALECT_HAL_TARGET_TARGETDEVICE_H_ diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel index 33e9561164fe..63be1d24f00d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel @@ -23,6 +23,7 @@ iree_compiler_cc_library( "DumpExecutableSources.cpp", "ElideRedundantCommands.cpp", "FixupLegacySync.cpp", + "InitializeDevices.cpp", "LinkExecutables.cpp", "MaterializeDispatchInstrumentation.cpp", "MaterializeInterfaces.cpp", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt index 6cce442a87ad..382525e66958 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt @@ -24,6 +24,7 @@ iree_cc_library( "DumpExecutableSources.cpp" "ElideRedundantCommands.cpp" "FixupLegacySync.cpp" + "InitializeDevices.cpp" "LinkExecutables.cpp" "MaterializeDispatchInstrumentation.cpp" "MaterializeInterfaces.cpp" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/InitializeDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/InitializeDevices.cpp new file mode 100644 index 000000000000..fae26c54e0b5 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/InitializeDevices.cpp @@ -0,0 +1,111 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::HAL { + +#define GEN_PASS_DEF_INITIALIZEDEVICESPASS +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc" + +namespace { + +// Converts an initialized device global to one with a util.initializer that +// performs the device initialization. The initializer is added immediately +// following the global in its parent op. +static void initializeDeviceGlobal( + IREE::Util::GlobalOpInterface globalOp, + IREE::HAL::DeviceInitializationAttrInterface initialValue, + const IREE::HAL::TargetRegistry &targetRegistry) { + auto loc = globalOp.getLoc(); + + // Clear the initial value as we'll be initializing from the initializer. + globalOp.setGlobalInitialValue({}); + + // Build a new util.initializer. + OpBuilder moduleBuilder(globalOp); + moduleBuilder.setInsertionPointAfter(globalOp); + auto initializerOp = moduleBuilder.create(loc); + auto *block = moduleBuilder.createBlock(&initializerOp.getBody()); + auto initializerBuilder = OpBuilder::atBlockBegin(block); + + // Get the device from the attribute builder; note that it may be null. + Value enumeratedDevice = initialValue.buildDeviceEnumeration( + loc, + [&](Location loc, Value device, IREE::HAL::DeviceTargetAttr targetAttr, + OpBuilder &builder) { + auto targetDevice = + targetRegistry.getTargetDevice(targetAttr.getDeviceID()); + return targetDevice ? targetDevice->buildDeviceTargetMatch( + loc, device, targetAttr, builder) + : Value{}; + }, + initializerBuilder); + + // Check if the device is null and error out. We could support optional + // devices that are allowed to be null but don't support that anywhere else in + // the compiler today and may never want to. If selecting from multiple + // devices queries can be used to detect what the selected device was and + // those will be memoized. + Value nullDevice = initializerBuilder.create( + loc, enumeratedDevice.getType()); + Value isNull = initializerBuilder.create( + loc, enumeratedDevice, nullDevice); + initializerBuilder.create( + loc, isNull, [&](OpBuilder &thenBuilder, Location thenLoc) { + Value status = thenBuilder.create( + thenLoc, static_cast(IREE::Util::StatusCode::NotFound), + 32); + std::string str; + { + llvm::raw_string_ostream os(str); + os << "HAL device `" << globalOp.getGlobalName().getValue() + << "` not found or unavailable: "; + initialValue.printStatusDescription(os); + } + thenBuilder.create(thenLoc, status, str); + thenBuilder.create(thenLoc); + }); + + // Store the device back to the global to complete initialization. + globalOp.createStoreOp(loc, enumeratedDevice, initializerBuilder); + initializerBuilder.create(loc); +} + +//===----------------------------------------------------------------------===// +// --iree-hal-initialize-devices +//===----------------------------------------------------------------------===// + +struct InitializeDevicesPass + : public IREE::HAL::impl::InitializeDevicesPassBase { + using IREE::HAL::impl::InitializeDevicesPassBase< + InitializeDevicesPass>::InitializeDevicesPassBase; + void runOnOperation() override { + auto moduleOp = getOperation(); + for (auto globalOp : moduleOp.getOps()) { + auto initialValue = + dyn_cast_if_present( + globalOp.getGlobalInitialValue()); + if (initialValue) { + initializeDeviceGlobal(globalOp, initialValue, *targetRegistry.value); + } + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::IREE::HAL diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index 5fe1fd61c745..12f2c7bf3252 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -278,6 +278,10 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, buildHALConfigurationPassPipeline(passManager, targetRegistry, targetOptions, hooks); + // HACK: this should not be here and will be going away. It exists for + // lowering iree_linalg_ext.upper_bound_tile_size ops that exist on the + // host. We should be using stream ops for performing such calculations that + // we can attach affinities to and understand what devices are being used. FunctionLikeNest(passManager).addPass([]() { return createCPUMaterializeUpperBoundTileSizePass(); }); @@ -456,6 +460,14 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, FunctionLikeNest(passManager) .addPass(IREE::HAL::createElideRedundantCommandsPass); + // Initialize device globals now that we've done the analysis that is easier + // with them in their original target specification. + passManager.addPass(IREE::HAL::createInitializeDevicesPass({targetRegistry})); + + // Combine the initializers we emitted during resource cache + // materialization. + passManager.addPass(IREE::Util::createCombineInitializersPass()); + // TODO: Maybe this should be a part of Affine lowering pass. // Remove if it is added there. // https://github.com/llvm/llvm-project/issues/78458 @@ -468,10 +480,6 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, // SimplifyGlobalAccesses are currently broken with scf present. FunctionLikeNest(passManager).addPass(mlir::createConvertSCFToCFPass); - // Combine the initializers we emitted during resource cache - // materialization. - passManager.addPass(IREE::Util::createCombineInitializersPass()); - //---------------------------------------------------------------------------- // Executable serialization //---------------------------------------------------------------------------- diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td index 188340b53659..848acc78a477 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td @@ -55,7 +55,7 @@ def VerifyTargetEnvironmentPass : Option< "targetRegistry", "target-registry", "llvm::cl::TargetRegistryRef", "", - "Target backend registry containing the list of available backends." + "Target registry containing the list of available devices and backends." >, ]; } @@ -70,7 +70,7 @@ def AssignTargetDevicesPass : Option< "targetRegistry", "target-registry", "llvm::cl::TargetRegistryRef", "", - "Target backend registry containing the list of available backends." + "Target registry containing the list of available devices and backends." >, ListOption< "targetBackends", "targetBackends", @@ -213,7 +213,7 @@ def ConfigureExecutablesPass : Option< "targetRegistry", "target-registry", "llvm::cl::TargetRegistryRef", "", - "Target backend registry containing the list of available backends." + "Target registry containing the list of available devices and backends." >, ]; } @@ -229,7 +229,7 @@ def ConfigureTargetExecutableVariantsPass : Option< "targetRegistry", "target-registry", "llvm::cl::TargetRegistryRef", "", - "Target backend registry containing the list of available backends." + "Target registry containing the list of available devices and backends." >, Option< "target", "target", @@ -251,7 +251,7 @@ def TranslateExecutablesPass : Option< "targetRegistry", "target-registry", "llvm::cl::TargetRegistryRef", "", - "Target backend registry containing the list of available backends." + "Target registry containing the list of available devices and backends." >, ]; } @@ -268,7 +268,7 @@ def TranslateTargetExecutableVariantsPass : Option< "targetRegistry", "target-registry", "llvm::cl::TargetRegistryRef", "", - "Target backend registry containing the list of available backends." + "Target registry containing the list of available devices and backends." >, Option< "target", "target", @@ -300,7 +300,7 @@ def LinkExecutablesPass : Option< "targetRegistry", "target-registry", "llvm::cl::TargetRegistryRef", "", - "Target backend registry containing the list of available backends." + "Target registry containing the list of available devices and backends." >, ]; } @@ -318,7 +318,7 @@ def LinkTargetExecutablesPass : Option< "targetRegistry", "target-registry", "llvm::cl::TargetRegistryRef", "", - "Target backend registry containing the list of available backends." + "Target registry containing the list of available devices and backends." >, Option< "target", "target", @@ -354,7 +354,7 @@ def SerializeExecutablesPass : Option< "targetRegistry", "target-registry", "llvm::cl::TargetRegistryRef", "", - "Target backend registry containing the list of available backends." + "Target registry containing the list of available devices and backends." >, Option< "debugLevel", "debug-level", @@ -386,7 +386,7 @@ def SerializeTargetExecutablesPass : Option< "targetRegistry", "target-registry", "llvm::cl::TargetRegistryRef", "", - "Target backend registry containing the list of available backends." + "Target registry containing the list of available devices and backends." >, Option< "target", "target", @@ -439,6 +439,28 @@ def MaterializeDispatchInstrumentationPass : ]; } +def InitializeDevicesPass : + Pass<"iree-hal-initialize-devices", "mlir::ModuleOp"> { + let summary = "Initializes global device handles based on their specification."; + let description = [{ + Initializes each global `!hal.device` based on the specification attribute + by building initializers that enumerate and select the appropriate device. + }]; + let options = [ + Option< + "targetRegistry", "target-registry", + "llvm::cl::TargetRegistryRef", "", + "Target registry containing the list of available devices and backends." + >, + ]; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::scf::SCFDialect", + "IREE::HAL::HALDialect", + "IREE::Util::UtilDialect", + ]; +} + def MaterializeResourceCachesPass : Pass<"iree-hal-materialize-resource-caches", "mlir::ModuleOp"> { let summary = "Materializes cached globals for device resources."; diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel index db9a2d4987e9..1bd2bb03c899 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel @@ -23,6 +23,7 @@ iree_lit_test_suite( "dump_executable_sources.mlir", "elide_redundant_commands.mlir", "fixup_legacy_sync.mlir", + "initialize_devices.mlir", "materialize_dispatch_instrumentation.mlir", "materialize_interfaces.mlir", "materialize_resource_caches.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt index ae4322d53213..d4ea9486ebdc 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt @@ -21,6 +21,7 @@ iree_lit_test_suite( "dump_executable_sources.mlir" "elide_redundant_commands.mlir" "fixup_legacy_sync.mlir" + "initialize_devices.mlir" "materialize_dispatch_instrumentation.mlir" "materialize_interfaces.mlir" "materialize_resource_caches.mlir" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir index d504c5e4c87c..82b6310af13a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir @@ -33,9 +33,7 @@ // CHECK: hal.executable private @ex hal.executable private @ex { hal.executable.variant public @embedded_elf_aarch64 target(#executable_target_embedded_elf_aarch64) { - hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout_0) attributes { - translation_info = #iree_codegen.translation_info - } { + hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout_0) { ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors %c1 = arith.constant 1 : index %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0] @@ -53,8 +51,7 @@ hal.executable private @ex { #hal.interface.binding<0, 4>, #hal.interface.binding<1, 5>, #hal.interface.binding<1, 6> - ], - translation_info = #iree_codegen.translation_info + ] } { ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors %c1 = arith.constant 1 : index diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_sources.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_sources.mlir index 458e841e9f2d..5de1c9c8686c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_sources.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_sources.mlir @@ -4,9 +4,6 @@ // but this is much easier to test with lit. #executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> -#device_target_cpu = #hal.device.target<"llvm-cpu", [ - #executable_target_embedded_elf_x86_64 -]> #pipeline_layout = #hal.pipeline.layout, @@ -15,46 +12,42 @@ ]> ]> -module attributes {hal.device.targets = [#device_target_cpu]} { - - // CHECK: hal.executable public @ex0 - hal.executable private @ex0 { - // We expect local outputs with attributes inlined: - // CHECK-NEXT: hal.executable.variant {{.+}} target(<"llvm-cpu" - hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) { - hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout) attributes { - translation_info = #iree_codegen.translation_info - } { - ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors - %c1 = arith.constant 1 : index - %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0] - hal.return %0, %c1, %c1 : index, index, index - } - builtin.module { - func.func @dispatch0() { - func.return - } +// CHECK: hal.executable public @ex0 +hal.executable private @ex0 { + // We expect local outputs with attributes inlined: + // CHECK-NEXT: hal.executable.variant {{.+}} target(<"llvm-cpu" + hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) { + hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout) attributes { + translation_info = #iree_codegen.translation_info + } { + ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors + %c1 = arith.constant 1 : index + %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0] + hal.return %0, %c1, %c1 : index, index, index + } + builtin.module { + func.func @dispatch0() { + func.return } } } +} - // CHECK: hal.executable private @ex1 - hal.executable private @ex1 { - hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) { - hal.executable.export public @dispatch1 ordinal(0) layout(#pipeline_layout) attributes { - translation_info = #iree_codegen.translation_info - } { - ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors - %c1 = arith.constant 1 : index - %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0] - hal.return %0, %c1, %c1 : index, index, index - } - builtin.module { - func.func @dispatch1() { - func.return - } +// CHECK: hal.executable private @ex1 +hal.executable private @ex1 { + hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) { + hal.executable.export public @dispatch1 ordinal(0) layout(#pipeline_layout) attributes { + translation_info = #iree_codegen.translation_info + } { + ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors + %c1 = arith.constant 1 : index + %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0] + hal.return %0, %c1, %c1 : index, index, index + } + builtin.module { + func.func @dispatch1() { + func.return } } } - } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/initialize_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/initialize_devices.mlir new file mode 100644 index 000000000000..d37c9bcb92e5 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/initialize_devices.mlir @@ -0,0 +1,106 @@ +// RUN: iree-opt --split-input-file --iree-hal-initialize-devices --cse %s | FileCheck %s + +// Tests that #hal.device.ordinal<*> gets the device with the given ordinal. + +// CHECK: util.global private @device_123 : !hal.device +util.global private @device_123 = #hal.device.ordinal<123> : !hal.device + +// CHECK-NEXT: util.initializer +// CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %c123 +// CHECK-DAG: %[[NULL_DEVICE:.+]] = util.null : !hal.device +// CHECK-DAG: %[[IS_NULL:.+]] = util.cmp.eq %[[DEVICE]], %[[NULL_DEVICE]] +// CHECK-NEXT: scf.if %[[IS_NULL]] { +// CHECK: util.status.check_ok %c5_i32, "HAL device `device_123` not found or unavailable: #hal.device.ordinal<123>" +// CHECK: util.global.store %[[DEVICE]], @device_123 + +// ----- + +// Tests that #hal.device.fallback<*> references the specified device global. + +util.global private @device_base : !hal.device + +// CHECK: util.global private @device_fallback : !hal.device +util.global private @device_fallback = #hal.device.fallback<@device_base> : !hal.device + +// CHECK-NEXT: util.initializer +// CHECK-DAG: %[[DEVICE:.+]] = util.global.load @device_base : !hal.device +// CHECK-DAG: %[[IS_NULL:.+]] = util.cmp.eq %[[DEVICE]], %{{.+}} +// CHECK-NEXT: scf.if %[[IS_NULL]] { +// CHECK: util.status.check_ok %c5_i32, "HAL device `device_fallback` not found or unavailable: #hal.device.fallback<@device_base>" +// CHECK: util.global.store %[[DEVICE]], @device_fallback + +// ----- + +// Tests that #hal.device.target<*> enumerates all devices. + +// CHECK: util.global private @device_a : !hal.device +util.global private @device_a = #hal.device.target<"a", [ + #hal.executable.target<"backend0", "format0">, + #hal.executable.target<"backend1", "format1"> +]> : !hal.device + +// CHECK-NEXT: util.initializer +// CHECK-DAG: %[[NULL_DEVICE:.+]] = util.null : !hal.device +// CHECK-DAG: %[[DEVICE_COUNT:.+]] = hal.devices.count +// CHECK: %[[WHILE:.+]]:2 = scf.while (%arg0 = %c0, %arg1 = %[[NULL_DEVICE]]) +// CHECK-DAG: %[[IS_DEVICE_NULL:.+]] = util.cmp.eq %arg1, %[[NULL_DEVICE]] +// CHECK-DAG: %[[IS_END:.+]] = arith.cmpi slt, %arg0, %[[DEVICE_COUNT]] +// CHECK-DAG: %[[CONTINUE:.+]] = arith.andi %[[IS_DEVICE_NULL]], %[[IS_END]] +// CHECK-NEXT: scf.condition(%[[CONTINUE]]) %arg0, %arg1 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg0: index, %arg1: !hal.device) +// CHECK-DAG: %[[DEVICE_N:.+]] = hal.devices.get %arg0 : !hal.device + +// NOTE: this is the fallback path for device matching unregistered targets. +// Real targets can have much more complex logic if they so choose. +// CHECK-DAG: %{{.+}}, %[[ID_MATCH:.+]] = hal.device.query<%[[DEVICE_N]] : !hal.device> key("hal.device.id" :: "a") +// CHECK-NEXT: %[[ANY_FORMAT_MATCH:.+]] = scf.if %[[ID_MATCH]] -> (i1) { +// CHECK-DAG: %{{.+}}, %[[FORMAT0_MATCH:.+]] = hal.device.query<%[[DEVICE_N]] : !hal.device> key("hal.executable.format" :: "format0") +// CHECK-DAG: %{{.+}}, %[[FORMAT1_MATCH:.+]] = hal.device.query<%[[DEVICE_N]] : !hal.device> key("hal.executable.format" :: "format1") +// CHECK-DAG: %[[FORMAT_MATCH_OR:.+]] = arith.ori %[[FORMAT0_MATCH]], %[[FORMAT1_MATCH]] +// CHECK-DAG: scf.yield %[[FORMAT_MATCH_OR]] +// CHECK-NEXT: } else { +// CHECK-DAG: scf.yield %false + +// CHECK-DAG: %[[YIELD_DEVICE:.+]] = arith.select %[[ANY_FORMAT_MATCH]], %[[DEVICE_N]], %[[NULL_DEVICE]] +// CHECK-DAG: %[[NEXT_I:.+]] = arith.addi %arg0, %c1 +// CHECK-NEXT: scf.yield %[[NEXT_I]], %[[YIELD_DEVICE]] +// CHECK-DAG: %[[IS_NULL:.+]] = util.cmp.eq %[[WHILE]]#1, %[[NULL_DEVICE]] +// CHECK-NEXT: scf.if %[[IS_NULL]] { +// CHECK: util.status.check_ok %c5_i32, "HAL device `device_a` not found or unavailable: #hal.device.target<{{.+}}>" +// CHECK: util.global.store %[[WHILE]]#1, @device_a + +// ----- + +// Tests that #hal.device.select<*> expands to a chain of ifs. + +util.global private @fallback : !hal.device + +// CHECK: util.global private @selected : !hal.device +util.global private @selected = #hal.device.select<[ + #hal.device.ordinal<2> : !hal.device, + #hal.device.ordinal<1> : !hal.device, + #hal.device.fallback<@fallback> : !hal.device +]> : !hal.device + +// CHECK-NEXT: util.initializer +// CHECK-DAG: %[[NULL_DEVICE:.+]] = util.null : !hal.device +// CHECK-DAG: %[[DEVICE_2:.+]] = hal.devices.get %c2 +// CHECK-DAG: %[[NOT_DEVICE_2:.+]] = util.cmp.eq %[[DEVICE_2]], %[[NULL_DEVICE]] +// CHECK-NEXT: %[[IF_0:.+]] = scf.if %[[NOT_DEVICE_2]] +// CHECK-DAG: %[[DEVICE_1:.+]] = hal.devices.get %c1 +// CHECK-DAG: %[[NOT_DEVICE_1:.+]] = util.cmp.eq %[[DEVICE_1]], %[[NULL_DEVICE]] +// CHECK-NEXT: %[[IF_1:.+]] = scf.if %[[NOT_DEVICE_1]] +// CHECK-DAG: %[[DEVICE_FALLBACK:.+]] = util.global.load @fallback +// CHECK-NEXT: scf.yield %[[DEVICE_FALLBACK]] +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %[[DEVICE_1]] +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[IF_1]] +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %[[DEVICE_2]] +// CHECK-NEXT: } +// CHECK-DAG: %[[IS_NULL:.+]] = util.cmp.eq %[[IF_0]], %[[NULL_DEVICE]] +// CHECK-NEXT: scf.if %[[IS_NULL]] { +// CHECK: util.status.check_ok %c5_i32, "HAL device `selected` not found or unavailable: #hal.device.select<{{.+}}>" +// CHECK: util.global.store %[[IF_0]], @selected diff --git a/runtime/src/iree/hal/device.c b/runtime/src/iree/hal/device.c index 07bd660ad947..40f261089dad 100644 --- a/runtime/src/iree/hal/device.c +++ b/runtime/src/iree/hal/device.c @@ -65,14 +65,6 @@ IREE_API_EXPORT iree_status_t iree_hal_device_query_i64( iree_string_view_t key, int64_t* out_value) { IREE_ASSERT_ARGUMENT(device); IREE_ASSERT_ARGUMENT(out_value); - - if (iree_string_view_equal(category, - iree_make_cstring_view("hal.device.id"))) { - *out_value = - iree_string_view_match_pattern(iree_hal_device_id(device), key) ? 1 : 0; - return iree_ok_status(); - } - return _VTABLE_DISPATCH(device, query_i64)(device, category, key, out_value); } From 308fb66a0acd13c8bdc186e2acc62adeff4eca13 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 19 Feb 2024 13:42:39 -0800 Subject: [PATCH 02/25] Adding the #hal.device.affinity attr replacing #hal.affinity.queue. The queue affinity attr was added as a placeholder to test things but was never used/useful. --- .../test/convert_region_to_workgroups.mlir | 8 +- .../Transforms/test/outline_constants.mlir | 6 +- .../test/outline_dispatch_regions.mlir | 11 +- .../HAL/Conversion/StreamToHAL/Patterns.cpp | 146 ++++++++++-------- .../StreamToHAL/test/channel_ops.mlir | 6 +- .../Conversion/StreamToHAL/test/cmd_ops.mlir | 36 +++-- .../StreamToHAL/test/context_ops.mlir | 62 ++++++-- .../Conversion/StreamToHAL/test/file_ops.mlir | 18 ++- .../StreamToHAL/test/resource_ops.mlir | 24 ++- .../StreamToHAL/test/timepoint_ops.mlir | 6 +- .../StreamToHAL/test/transfer_ops.mlir | 16 +- .../iree/compiler/Dialect/HAL/IR/HALAttrs.cpp | 81 ++++++---- .../iree/compiler/Dialect/HAL/IR/HALAttrs.td | 33 ++-- .../Dialect/HAL/IR/test/attributes.mlir | 56 ++++++- .../Dialect/HAL/IR/test/executable_ops.mlir | 23 +-- .../HAL/Transforms/test/convert_to_hal.mlir | 9 +- .../test/materialize_resource_caches.mlir | 3 +- .../FlowToStream/test/dispatch_ops.mlir | 15 +- .../FlowToStream/test/tensor_ops.mlir | 14 +- .../compiler/Dialect/Stream/IR/StreamOps.td | 20 +-- .../Dialect/Stream/IR/test/async_ops.mlir | 18 ++- .../Dialect/Stream/IR/test/channel_ops.mlir | 6 +- .../Dialect/Stream/IR/test/context_ops.mlir | 10 +- .../test/fuse_dispatch_bindings.mlir | 12 +- .../test/materialize_copy_on_write.mlir | 6 +- .../Transforms/test/schedule_allocation.mlir | 12 +- .../Transforms/test/schedule_execution.mlir | 31 ++-- .../test/hoist_into_globals.mlir | 8 +- .../Modules/Check/Conversion/BUILD.bazel | 1 + .../Modules/Check/Conversion/CMakeLists.txt | 1 + .../Check/Conversion/ConversionPatterns.cpp | 4 +- 31 files changed, 451 insertions(+), 251 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir index 3caa5b0061c4..92a020823f41 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/convert_region_to_workgroups.mlir @@ -1,5 +1,7 @@ // RUN: iree-opt %s --pass-pipeline="builtin.module(util.func(iree-flow-convert-dispatch-regions-to-workgroups, iree-flow-canonicalize, cse))" -split-input-file | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: util.func public @foo( // CHECK: %[[argA:.*]]: tensor, %[[argB:.*]]: tensor<5x10xf32>, %[[argC:.*]]: tensor<10x11xf32> util.func public @foo(%argA: tensor, %argB: tensor<5x10xf32>, %argC: tensor<10x11xf32>) -> (tensor, tensor<5x11xf32>) { @@ -21,7 +23,7 @@ util.func public @foo(%argA: tensor, %argB: tensor<5x10xf32>, %argC: te flow.return %argA : tensor } // CHECK: %[[r1:.*]] = flow.dispatch.workgroups(%[[argB]], %[[argC]]) : (tensor<5x10xf32>, tensor<10x11xf32>) -> tensor<5x11xf32> - // CHECK-SAME: stream.affinity = #hal.affinity.queue<[0]> + // CHECK-SAME: stream.affinity = #hal.device.affinity<@device> // CHECK-NEXT: (%[[arg3:.*]]: !flow.dispatch.tensor>, %[[arg4:.*]]: !flow.dispatch.tensor>, %[[arg5:.*]]: !flow.dispatch.tensor>) // CHECK-DAG: %[[loadB:.*]] = flow.dispatch.tensor.load %[[arg3]], offsets = [0, 0], sizes = [5, 10], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<5x10xf32> // CHECK-DAG: %[[loadC:.*]] = flow.dispatch.tensor.load %[[arg4]], offsets = [0, 0], sizes = [10, 11], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<10x11xf32> @@ -31,7 +33,9 @@ util.func public @foo(%argA: tensor, %argB: tensor<5x10xf32>, %argC: te // CHECK: flow.dispatch.tensor.store %[[matmul]], %[[arg5]], offsets = [0, 0], sizes = [5, 11], strides = [1, 1] : tensor<5x11xf32> -> !flow.dispatch.tensor> // CHECK: flow.return // CHECK: } - %r1 = flow.dispatch.region -> (tensor<5x11xf32>) attributes {stream.affinity = #hal.affinity.queue<[0]>} { + %r1 = flow.dispatch.region -> (tensor<5x11xf32>) attributes { + stream.affinity = #hal.device.affinity<@device> + } { %zero = arith.constant 0.0 : f32 %0 = tensor.empty() : tensor<5x11xf32> %1 = linalg.fill ins(%zero : f32) outs(%0 : tensor<5x11xf32>) -> tensor<5x11xf32> diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir index e3db1b69795d..57304a122cf2 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_constants.mlir @@ -67,11 +67,13 @@ util.func private @func_1() { // Tests that any hoistable attrs are propagated to the outlined globals. +util.global private @device : !hal.device + // CHECK: util.global private @__constant_tensor_2xi32 -// CHECK-SAME: stream.affinity = #hal.affinity.queue<[0]> +// CHECK-SAME: stream.affinity = #hal.device.affinity<@device, [0]> // CHECK-NEXT: util.func private @set_affinity util.func private @set_affinity() attributes { - stream.affinity = #hal.affinity.queue<[0]> + stream.affinity = #hal.device.affinity<@device, [0]> } { // CHECK-NEXT: = util.global.load immutable @__constant_tensor_2xi32 %cst = arith.constant dense<[0, 1]> : tensor<2xi32> diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir index 0a0f9e5877ff..dd9d65116531 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir @@ -78,6 +78,9 @@ util.func public @dispatchFnMuli(%arg0 : tensor<8x4xf32>) -> tensor<8x4xf32> { // ----- +util.global private @device_a : !hal.device +util.global private @device_b : !hal.device + // CHECK: flow.executable private @dispatchFn1_dispatch_0 // CHECK-LABEL: util.func public @dispatchFn1 @@ -85,9 +88,9 @@ util.func public @dispatchFn1(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { %x = arith.constant 100 : index %y = arith.constant 50 : index // CHECK: flow.dispatch @dispatchFn1_dispatch_0::@dispatchFn1_dispatch_0 - // CHECK-SAME: stream.affinity = #hal.affinity.queue<[0]> + // CHECK-SAME: stream.affinity = #hal.device.affinity<@device_a> %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) attributes { - stream.affinity = #hal.affinity.queue<[0]> + stream.affinity = #hal.device.affinity<@device_a> } = ( %arg: !flow.dispatch.tensor>, %ret: !flow.dispatch.tensor> ) { @@ -103,9 +106,9 @@ util.func public @dispatchFn2(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { %x = arith.constant 100 : index %y = arith.constant 50 : index // CHECK: flow.dispatch @dispatchFn2_dispatch_0::@dispatchFn2_dispatch_0 - // CHECK-SAME: stream.affinity = #hal.affinity.queue<[1]> + // CHECK-SAME: stream.affinity = #hal.device.affinity<@device_b> %0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) attributes { - stream.affinity = #hal.affinity.queue<[1]> + stream.affinity = #hal.device.affinity<@device_b> } = ( %arg: !flow.dispatch.tensor>, %ret: !flow.dispatch.tensor> ) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 74c0fc99afc8..de9bbb4b5f1a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -23,24 +23,6 @@ namespace mlir::iree_compiler { namespace { -// Returns the device queue affinity mask indicating which device queues the -// operations are allowed to execute on. -static Value buildQueueAffinityMask(Location loc, - IREE::Stream::AffinityAttr affinityAttr, - Value device, OpBuilder &builder) { - // Try to find a specified affinity. This may be on the op provided or one of - // its parent regions. - if (auto queueAffinityAttr = - llvm::dyn_cast_if_present( - affinityAttr)) { - return builder.create( - loc, queueAffinityAttr.getMask(), 64); - } - - // No affinity specified; use default (any) affinity. - return builder.create(loc, -1, 64); -} - struct ContextResolveOpPattern : public StreamConversionPattern { using StreamConversionPattern::StreamConversionPattern; @@ -50,9 +32,36 @@ struct ContextResolveOpPattern auto resultTypes = llvm::to_vector(resolveOp.getResultTypes()); assert(!resultTypes.empty() && "must have at least one result"); - // TODO(multi-device): emit get with derived ordinal or lookup with attr. - Value device = - IREE::HAL::DeviceType::resolveAny(resolveOp.getLoc(), rewriter); + // Get the affinity from the op or an ancestor. Note that there may be no + // affinity specified at all. + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(resolveOp); + + // We currently only handle HAL device affinities. + // We could make this an interface to select the device and allow users to + // provide their own affinities to convert to HAL. In the future users may + // also want to provide devices as function arguments post-initialization. + // For now we just have one way to specify device globals. + auto deviceAffinityAttr = + dyn_cast_if_present(affinityAttr); + if (!deviceAffinityAttr) { + resolveOp.emitOpError() << "failed to resolve affinity: only HAL device " + "affinities are supported"; + return rewriter.notifyMatchFailure( + resolveOp, "only HAL device affinities are supported"); + } + + // Get the device handle and queue. + // + // TODO(multi-device): specialized types; may need analysis we don't have + // or at least a symbol lookup. An alternative would be an optional type + // on the affinity in cases where we've evaluated it early but for now + // we assume all device types are unspecialized. + auto deviceType = rewriter.getType(); + Value device = rewriter.create( + resolveOp.getLoc(), deviceType, + deviceAffinityAttr.getDevice().getValue(), + /*is_immutable=*/true); + int64_t queueMask = deviceAffinityAttr.getQueueMask(); SmallVector results; if (isa(resultTypes[0])) { @@ -66,8 +75,8 @@ struct ContextResolveOpPattern } if (resultTypes.size() > 1) { if (isa(resultTypes[1])) { - results.push_back(buildQueueAffinityMask( - resolveOp.getLoc(), resolveOp.getAffinityAttr(), device, rewriter)); + results.push_back(rewriter.create( + resolveOp.getLoc(), queueMask, 64)); } else { return rewriter.notifyMatchFailure( resolveOp, @@ -698,53 +707,66 @@ struct CmdDispatchOpPattern caseExportOps.push_back(std::make_pair(entryPointAttr, exportOp)); }); - // Select the variant index. - Value selectedIndex = buildIfElseTree( - loc, caseExportOps.size(), - [&](Location loc, size_t i, OpBuilder &builder) { - auto exportOp = caseExportOps[i].second; - auto variantOp = - exportOp->getParentOfType(); - return variantOp.buildCondition(device, rewriter); - }, - rewriter); - - // Allow each variant to define how it is dispatched. - auto switchOp = rewriter.replaceOpWithNewOp( - dispatchOp, TypeRange{}, selectedIndex, caseIndices, - caseIndices.size()); - for (size_t i = 0; i < caseExportOps.size(); ++i) { - auto entryPointAttr = caseExportOps[i].first; - auto exportOp = caseExportOps[i].second; - auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock(); - auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock); - + auto recordDispatch = [&](SymbolRefAttr entryPointAttr, + IREE::HAL::ExecutableExportOp exportOp, + OpBuilder &builder) { // Record push constants and buffer bindings. recordParameters(loc, affinityAttr, device, commandBuffer, exportOp, - dispatchOp, adaptor, caseBuilder); + dispatchOp, adaptor, builder); // Dispatch with a target-specific workgroup count. - auto caseWorkgroupCount = exportOp.calculateWorkgroupCount( - loc, device, adaptor.getWorkload(), caseBuilder); - Value executable = caseBuilder.create( - loc, caseBuilder.getType(), device, + auto workgroupCount = exportOp.calculateWorkgroupCount( + loc, device, adaptor.getWorkload(), builder); + Value executable = builder.create( + loc, builder.getType(), device, entryPointAttr.getRootReference().getValue()); - Value ordinal = caseBuilder.create( - loc, caseBuilder.getIndexType(), entryPointAttr); - auto flags = caseBuilder.getAttr( + Value ordinal = builder.create( + loc, builder.getIndexType(), entryPointAttr); + auto flags = builder.getAttr( IREE::HAL::DispatchFlags::None); - caseBuilder.create( - loc, commandBuffer, executable, ordinal, caseWorkgroupCount[0], - caseWorkgroupCount[1], caseWorkgroupCount[2], flags); + return builder.create( + loc, commandBuffer, executable, ordinal, workgroupCount[0], + workgroupCount[1], workgroupCount[2], flags); + }; - caseBuilder.create(loc); - } + // If there is only one variant we can emit that directly without a + // conditional check. The same result should occur later on but it saves + // a lot of IR during generation if we know we can avoid it. + if (caseExportOps.size() == 1) { + auto [entryPointAttr, exportOp] = caseExportOps.front(); + rewriter.replaceOp(dispatchOp, + recordDispatch(entryPointAttr, exportOp, rewriter)); + } else { + // Select the variant index. + Value selectedIndex = buildIfElseTree( + loc, caseExportOps.size(), + [&](Location loc, size_t i, OpBuilder &builder) { + auto exportOp = caseExportOps[i].second; + auto variantOp = + exportOp->getParentOfType(); + return variantOp.buildCondition(device, rewriter); + }, + rewriter); + + // Allow each variant to define how it is dispatched. + auto switchOp = rewriter.create( + loc, TypeRange{}, selectedIndex, caseIndices, caseIndices.size()); + for (size_t i = 0; i < caseExportOps.size(); ++i) { + auto [entryPointAttr, exportOp] = caseExportOps[i]; + auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock(); + auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock); + recordDispatch(entryPointAttr, exportOp, caseBuilder); + caseBuilder.create(loc); + } + + // Fallback for no available variant. Today we just no-op as executable + // loading should have already failed. + auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); + auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock); + defaultBuilder.create(loc); - // Fallback for no available variant. Today we just no-op as executable - // loading should have already failed. - auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); - auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock); - defaultBuilder.create(loc); + rewriter.replaceOp(dispatchOp, switchOp); + } return success(); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir index 3f88bd1ec696..bb2108f1411c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/channel_ops.mlir @@ -1,15 +1,17 @@ // RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: @channel_create // CHECK-SAME: () -> !hal.channel util.func public @channel_create() -> !stream.channel { - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} : !hal.device + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant 3 // CHECK-DAG: %[[ID:.+]] = util.null : !util.buffer // CHECK-DAG: %[[GROUP:.+]] = util.buffer.constant : !util.buffer = "group" // CHECK-DAG: %[[DEFAULT:.+]] = arith.constant -1 // CHECK: %[[CHANNEL:.+]] = hal.channel.create device(%[[DEVICE]] : !hal.device) affinity(%[[AFFINITY]]) flags(0) id(%[[ID]]) group(%[[GROUP]]) rank(%[[DEFAULT]]) count(%[[DEFAULT]]) : !hal.channel - %channel = stream.channel.create on(#hal.affinity.queue<[0, 1]>) group("group") : !stream.channel + %channel = stream.channel.create on(#hal.device.affinity<@device, [0, 1]>) group("group") : !stream.channel // CHECK: util.return %[[CHANNEL]] util.return %channel : !stream.channel } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir index 7cdd9917311f..941c15b417b3 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir @@ -3,12 +3,14 @@ // Today all memory control operations are ignored and we're just left with // the normal sequential execution barriers. +util.global private @device : !hal.device + // CHECK-LABEL: @cmdMemoryControl util.func public @cmdMemoryControl(%arg0: !stream.resource, %arg1: index) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource{%arg1}) { + %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg2: !stream.resource{%arg1}) { // CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]] stream.cmd.flush %arg2[%c0 for %c128] : !stream.resource{%arg1} // CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]] @@ -22,13 +24,15 @@ util.func public @cmdMemoryControl(%arg0: !stream.resource, %arg1: in // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @cmdFill util.func public @cmdFill(%arg0: !stream.resource, %arg1: index) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index %c255_i32 = arith.constant 255 : i32 // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource{%arg1}) { + %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg2: !stream.resource{%arg1}) { // CHECK-NEXT: hal.command_buffer.fill_buffer<%[[CMD]] : !hal.command_buffer> // CHECK-SAME: target(%arg0 : !hal.buffer)[%c0, %c128] // CHECK-SAME: pattern(%c255_i32 : i32) @@ -41,12 +45,14 @@ util.func public @cmdFill(%arg0: !stream.resource, %arg1: index) -> ! // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @cmdCopy util.func public @cmdCopy(%arg0: !stream.resource, %arg1: index, %arg2: !stream.resource, %arg3: index) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource{%arg1}, %arg2 as %arg5: !stream.resource{%arg3}) { + %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg4: !stream.resource{%arg1}, %arg2 as %arg5: !stream.resource{%arg3}) { // CHECK-NEXT: hal.command_buffer.copy_buffer<%[[CMD]] : !hal.command_buffer> // CHECK-SAME: source(%arg0 : !hal.buffer)[%c0] // CHECK-SAME: target(%arg2 : !hal.buffer)[%c0] @@ -60,12 +66,14 @@ util.func public @cmdCopy(%arg0: !stream.resource, %arg1: index, %arg // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @cmdCollective util.func public @cmdCollective(%arg0: !stream.resource, %arg1: index, %arg2: !stream.resource, %arg3: index, %arg4: !stream.channel) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute with(%arg0 as %arg5: !stream.resource{%arg1}, %arg2 as %arg6: !stream.resource{%arg3}) { + %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg5: !stream.resource{%arg1}, %arg2 as %arg6: !stream.resource{%arg3}) { // Out-of-place all-reduce: // CHECK-NEXT: hal.command_buffer.collective @@ -127,12 +135,14 @@ util.func public @cmdCollective(%arg0: !stream.resource, %arg1: index // than we actually need and guard a lot more work than we otherwise would need // to. +util.global private @device : !hal.device + // CHECK-LABEL: @cmdExecute util.func public @cmdExecute(%arg0: !stream.resource, %arg1: index, %arg2: !stream.resource, %arg3: index, %arg4: !stream.timepoint) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute await(%arg4) => with(%arg0 as %arg5: !stream.resource{%arg1}, %arg2 as %arg6: !stream.resource{%arg3}) { + %0 = stream.cmd.execute on(#hal.device.affinity<@device>) await(%arg4) => with(%arg0 as %arg5: !stream.resource{%arg1}, %arg2 as %arg6: !stream.resource{%arg3}) { stream.cmd.concurrent { // CHECK-NEXT: hal.command_buffer.copy_buffer<%[[CMD]] stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource{%arg1} -> !stream.resource{%arg3} @@ -166,10 +176,6 @@ util.func public @cmdExecute(%arg0: !stream.resource, %arg1: index, % #executable_target_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64"> #executable_target_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> -#device_target_cpu = #hal.device.target<"llvm-cpu", [ - #executable_target_aarch64, - #executable_target_x86_64 -]> #pipeline_layout = #hal.pipeline.layout @@ -219,6 +225,8 @@ hal.executable private @ex { } } +util.global private @device : !hal.device + // CHECK-LABEL: @cmdDispatch util.func public @cmdDispatch(%arg0: !stream.resource, %arg1: index, %arg2: !stream.resource, %arg3: index) -> !stream.timepoint { %c0 = arith.constant 0 : index @@ -229,7 +237,7 @@ util.func public @cmdDispatch(%arg0: !stream.resource, %arg1: index, %c5_i32 = arith.constant 5 : i32 %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource{%arg1}, %arg2 as %arg5: !stream.resource{%arg3}) { + %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %arg4: !stream.resource{%arg1}, %arg2 as %arg5: !stream.resource{%arg3}) { // Switch for each executable variant by checking conditions and ranking: // CHECK: %[[DEVICE:.+]] = hal.command_buffer.device<%[[CMD]] : !hal.command_buffer> // CHECK-DAG: %{{.+}}, %[[AARCH64_FORMAT:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-aarch64") @@ -297,6 +305,8 @@ util.func public @cmdDispatch(%arg0: !stream.resource, %arg1: index, // Tests conversion of streamable calls and function declarations. // Expect a command buffer and a buffer + offset + length for each resource. +util.global private @device : !hal.device + // CHECK: util.func private @cmdFunc(%arg0: !hal.command_buffer, %arg1: !hal.buffer, %arg2: index, %arg3: index, %arg4: i32, %arg5: !hal.buffer, %arg6: index, %arg7: index, %arg8: !custom.type, %arg9: !hal.buffer, %arg10: index, %arg11: index) stream.cmd.func private @cmdFunc(%arg0[%arg1 for %arg2]: !stream.resource<*>, %arg3: i32, %arg4[%arg5 for %arg6]: !stream.resource<*>, %arg7: !custom.type, %arg8[%arg9 for %arg10]: !stream.resource<*>) @@ -310,7 +320,7 @@ util.func public @cmdCall(%arg0: !stream.resource, %arg1: i32, %arg2: // CHECK-DAG: %[[SIZE2:.+]] = arith.constant 102 %size2 = arith.constant 102 : index // CHECK: %[[COMMAND_BUFFER:.+]] = hal.command_buffer.create - %timepoint = stream.cmd.execute with(%arg0 as %stream0: !stream.resource{%size0}, %arg2 as %stream1: !stream.resource{%size1}, %arg4 as %stream2: !stream.resource{%size2}) { + %timepoint = stream.cmd.execute on(#hal.device.affinity<@device>) with(%arg0 as %stream0: !stream.resource{%size0}, %arg2 as %stream1: !stream.resource{%size1}, %arg4 as %stream2: !stream.resource{%size2}) { // CHECK: util.call @cmdFunc(%[[COMMAND_BUFFER]], %arg0, %c0, %[[SIZE0]], %arg1, %arg2, %c0, %[[SIZE1]], %arg3, %arg4, %c0, %[[SIZE2]]) : // CHECK-SAME: (!hal.command_buffer, !hal.buffer, index, index, i32, !hal.buffer, index, index, !custom.type, !hal.buffer, index, index) -> () stream.cmd.call @cmdFunc(ro %stream0[%c0 for %size0], %arg1, rw %stream1[%c0 for %size1], %arg3, wo %stream2[%c0 for %size2]) : (!stream.resource{%size0}, i32, !stream.resource{%size1}, !custom.type, !stream.resource{%size2}) -> () @@ -324,12 +334,14 @@ util.func public @cmdCall(%arg0: !stream.resource, %arg1: i32, %arg2: // appropriate queue affinity mask. The final affinity is the result of ORing // the target affinities (0b01 | 0b10 = 0b11 = 3). +util.global private @device : !hal.device + // CHECK-LABEL: @cmdExecuteAffinities util.func public @cmdExecuteAffinities(%arg0: !stream.resource, %arg1: index, %arg2: !stream.resource, %arg3: index, %arg4: !stream.timepoint) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute on(#hal.affinity.queue<[0, 1]>) await(%arg4) => with(%arg0 as %arg5: !stream.resource{%arg1}, %arg2 as %arg6: !stream.resource{%arg3}) { + %0 = stream.cmd.execute on(#hal.device.affinity<@device, [0, 1]>) await(%arg4) => with(%arg0 as %arg5: !stream.resource{%arg1}, %arg2 as %arg6: !stream.resource{%arg3}) { stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource{%arg1} -> !stream.resource{%arg3} } => !stream.timepoint // CHECK: hal.device.queue.execute diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir index 5d7395169c2b..20a9c59a127f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir @@ -1,42 +1,78 @@ // RUN: iree-opt --split-input-file --allow-unregistered-dialect --iree-hal-conversion %s | FileCheck %s -// CHECK-LABEL: @contextResolveAllocator -util.func public @contextResolveAllocator() -> !hal.allocator { - // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} - // CHECK: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator - %allocator = stream.context.resolve : !hal.allocator - // CHECK: util.return %[[ALLOCATOR]] - util.return %allocator : !hal.allocator +util.global private @device : !hal.device + +// CHECK-LABEL: @contextResolveDefaultDevice +util.func public @contextResolveDefaultDevice() -> !hal.device attributes { + stream.affinity = #hal.device.affinity<@device> +} { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + %device = stream.context.resolve : !hal.device + // CHECK: util.return %[[DEVICE]] + util.return %device : !hal.device } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @contextResolveDevice util.func public @contextResolveDevice() -> !hal.device { - // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} - %device = stream.context.resolve : !hal.device + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + %device = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device // CHECK: util.return %[[DEVICE]] util.return %device : !hal.device } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @contextResolveDeviceQueueAffinityAny util.func public @contextResolveDeviceQueueAffinityAny() -> (!hal.device, i64) { - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant -1 : i64 - %device, %queue_affinity_any = stream.context.resolve on(#hal.affinity.queue<*>) : !hal.device, i64 + %device, %queue_affinity_any = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device, i64 // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]] util.return %device, %queue_affinity_any : !hal.device, i64 } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @contextResolveDeviceQueueAffinity45 util.func public @contextResolveDeviceQueueAffinity45() -> (!hal.device, i64) { - // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64 - %device, %queue_affinity_45 = stream.context.resolve on(#hal.affinity.queue<[4, 5]>) : !hal.device, i64 + %device, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64 // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]] util.return %device, %queue_affinity_45 : !hal.device, i64 } + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @contextResolveAllocator +util.func public @contextResolveAllocator() -> !hal.allocator { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator + %allocator = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.allocator + // CHECK: util.return %[[ALLOCATOR]] + util.return %allocator : !hal.allocator +} + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @contextResolveAllocatorQueueAffinity45 +util.func public @contextResolveAllocatorQueueAffinity45() -> (!hal.allocator, i64) { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator + // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64 + %allocator, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.allocator, i64 + // CHECK: util.return %[[ALLOCATOR]], %[[QUEUE_AFFINITY]] + util.return %allocator, %queue_affinity_45 : !hal.allocator, i64 +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir index 1182ee40d5c2..efa925a9ac75 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir @@ -1,44 +1,50 @@ // RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: @file_constant // CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer) util.func public @file_constant(%buffer: !util.buffer) { %c0 = arith.constant 0 : index %c1088 = arith.constant 1088 : index - // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK: = hal.ex.file.from_memory device(%[[DEVICE]] : !hal.device) affinity(%c-1_i64) access(Read) buffer(%[[BUFFER]] : !util.buffer)[%c0 for %c1088] flags(%c0_i32) : !hal.file - %file = stream.file.constant %buffer[%c0 for %c1088] : !util.buffer{%c1088} -> !stream.file + %file = stream.file.constant on(#hal.device.affinity<@device>) %buffer[%c0 for %c1088] : !util.buffer{%c1088} -> !stream.file util.return } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @file_read // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[FILE:.+]]: !hal.file, %[[RESOURCE:.+]]: !hal.buffer) util.func public @file_read(%wait: !stream.timepoint, %file: !stream.file, %resource: !stream.resource) -> !stream.timepoint { %c0 = arith.constant 0 : index %c0_i64 = arith.constant 0 : i64 %c1088 = arith.constant 1088 : index - // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK: %[[SIGNAL:.+]] = hal.fence.create // CHECK: hal.device.queue.read<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[FILE]] : !hal.file)[%c0_i64] target(%[[RESOURCE]] : !hal.buffer)[%c0] length(%c1088) flags(0) - %signal = stream.file.read await(%wait) => %file[%c0_i64], %resource[%c0], %c1088 : !stream.file -> !stream.resource{%c1088} => !stream.timepoint + %signal = stream.file.read on(#hal.device.affinity<@device>) await(%wait) => %file[%c0_i64], %resource[%c0], %c1088 : !stream.file -> !stream.resource{%c1088} => !stream.timepoint // CHECK: util.return %[[SIGNAL]] util.return %signal : !stream.timepoint } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @file_write // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[FILE:.+]]: !hal.file, %[[RESOURCE:.+]]: !hal.buffer) util.func public @file_write(%wait: !stream.timepoint, %file: !stream.file, %resource: !stream.resource) -> !stream.timepoint { %c0 = arith.constant 0 : index %c0_i64 = arith.constant 0 : i64 %c1088 = arith.constant 1088 : index - // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK: %[[SIGNAL:.+]] = hal.fence.create // CHECK: hal.device.queue.write<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[RESOURCE]] : !hal.buffer)[%c0] target(%[[FILE]] : !hal.file)[%c0_i64] length(%c1088) flags(0) - %signal = stream.file.write await(%wait) => %resource[%c0], %file[%c0_i64], %c1088 : !stream.resource{%c1088} -> !stream.file => !stream.timepoint + %signal = stream.file.write on(#hal.device.affinity<@device>) await(%wait) => %resource[%c0], %file[%c0_i64], %c1088 : !stream.resource{%c1088} -> !stream.file => !stream.timepoint // CHECK: util.return %[[SIGNAL]] util.return %signal : !stream.timepoint } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir index 6af93eec387e..09f704636765 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir @@ -1,18 +1,22 @@ // RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: @resourceAlloc util.func public @resourceAlloc(%arg0: index) -> !stream.resource { // CHECK: %[[RET0:.+]] = hal.allocator.allocate // CHECK-SAME: type("DeviceVisible|DeviceLocal") // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") // CHECK-SAME: : !hal.buffer{%arg0} - %0 = stream.resource.alloc uninitialized : !stream.resource{%arg0} + %0 = stream.resource.alloc uninitialized on(#hal.device.affinity<@device>) : !stream.resource{%arg0} // CHECK: util.return %[[RET0]] util.return %0 : !stream.resource } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @resourceAlloca // CHECK-SAME: (%[[SIZE:.+]]: index) util.func public @resourceAlloca(%size: index) -> (!stream.resource, !stream.timepoint) { @@ -26,13 +30,15 @@ util.func public @resourceAlloca(%size: index) -> (!stream.resource, // CHECK-SAME: type("DeviceVisible|DeviceLocal") // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") // CHECK-SAME: : !hal.buffer{%[[SIZE]]} - %0:2 = stream.resource.alloca uninitialized : !stream.resource{%size} => !stream.timepoint + %0:2 = stream.resource.alloca uninitialized on(#hal.device.affinity<@device>) : !stream.resource{%size} => !stream.timepoint // CHECK: util.return %[[RET0]], %[[SIGNAL_FENCE]] util.return %0#0, %0#1 : !stream.resource, !stream.timepoint } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @resourceAllocaAwait // CHECK-SAME: (%[[SIZE:.+]]: index, %[[WAIT_FENCE:.+]]: !hal.fence) util.func public @resourceAllocaAwait(%size: index, %await_timepoint: !stream.timepoint) -> (!stream.resource, !stream.timepoint) { @@ -45,13 +51,15 @@ util.func public @resourceAllocaAwait(%size: index, %await_timepoint: !stream.ti // CHECK-SAME: type("DeviceVisible|DeviceLocal") // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") // CHECK-SAME: : !hal.buffer{%[[SIZE]]} - %0:2 = stream.resource.alloca uninitialized await(%await_timepoint) => !stream.resource{%size} => !stream.timepoint + %0:2 = stream.resource.alloca uninitialized on(#hal.device.affinity<@device>) await(%await_timepoint) => !stream.resource{%size} => !stream.timepoint // CHECK: util.return %[[RET0]], %[[SIGNAL_FENCE]] util.return %0#0, %0#1 : !stream.resource, !stream.timepoint } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @resourceDealloca // CHECK-SAME: (%[[SIZE:.+]]: index, %[[RESOURCE:.+]]: !hal.buffer) util.func public @resourceDealloca(%size: index, %resource: !stream.resource) -> !stream.timepoint { @@ -62,14 +70,14 @@ util.func public @resourceDealloca(%size: index, %resource: !stream.resource{%size} => !stream.timepoint + %0 = stream.resource.dealloca on(#hal.device.affinity<@device>) %resource : !stream.resource{%size} => !stream.timepoint // CHECK: util.return %[[SIGNAL_FENCE]] util.return %0 : !stream.timepoint } // ----- -// TODO(#9572): implement stream ordered allocations. +util.global private @device : !hal.device // CHECK-LABEL: @resourceDeallocaAwait // CHECK-SAME: (%[[SIZE:.+]]: index, %[[RESOURCE:.+]]: !hal.buffer, %[[WAIT_FENCE:.+]]: !hal.fence) @@ -80,7 +88,7 @@ util.func public @resourceDeallocaAwait(%size: index, %resource: !stream.resourc // CHECK-SAME: wait(%[[WAIT_FENCE]]) // CHECK-SAME: signal(%[[SIGNAL_FENCE]]) // CHECK-SAME: buffer(%[[RESOURCE]] : !hal.buffer) - %0 = stream.resource.dealloca await(%await_timepoint) => %resource : !stream.resource{%size} => !stream.timepoint + %0 = stream.resource.dealloca on(#hal.device.affinity<@device>) await(%await_timepoint) => %resource : !stream.resource{%size} => !stream.timepoint // CHECK: util.return %[[SIGNAL_FENCE]] util.return %0 : !stream.timepoint } @@ -97,6 +105,8 @@ util.func public @resourceSize(%arg0: !stream.resource) -> index { // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @resourceTryMap util.func public @resourceTryMap(%arg0: !util.buffer) -> (i1, !stream.resource) { %c0 = arith.constant 0 : index @@ -105,7 +115,7 @@ util.func public @resourceTryMap(%arg0: !util.buffer) -> (i1, !stream.resource i1, !stream.resource{%c128} + %did_map, %mapping = stream.resource.try_map on(#hal.device.affinity<@device>) %arg0[%c0] : !util.buffer -> i1, !stream.resource{%c128} // CHECK: util.return %[[DID_IMPORT]], %[[IMPORTED]] util.return %did_map, %mapping : i1, !stream.resource } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir index 8a7b691bb9ce..007f457d0dfd 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir @@ -42,12 +42,14 @@ util.func public @timepointExportFence(%arg0: !stream.timepoint) -> !hal.fence { // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @timepointChainExternal // CHECK-SAME: (%[[TIMEPOINT:.+]]: !hal.fence, %[[SIGNAL:.+]]: !hal.fence) util.func public @timepointChainExternal(%timepoint: !stream.timepoint, %signal: !hal.fence) { - // CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK: hal.device.queue.execute<%[[DEVICE]] : !hal.device> affinity(%c-1_i64) wait(%[[TIMEPOINT]]) signal(%[[SIGNAL]]) - stream.timepoint.chain_external %timepoint => (%signal : !hal.fence) + stream.timepoint.chain_external on(#hal.device.affinity<@device>) %timepoint => (%signal : !hal.fence) util.return } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/transfer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/transfer_ops.mlir index 1dbcc24785bf..5805f7114e9a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/transfer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/transfer_ops.mlir @@ -1,5 +1,7 @@ // RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: @tensorImportBuffer util.func public @tensorImportBuffer(%arg0: !hal.buffer, %arg1: index) -> !stream.resource { %c20 = arith.constant 20 : index @@ -10,7 +12,7 @@ util.func public @tensorImportBuffer(%arg0: !hal.buffer, %arg1: index) -> !strea // CHECK-SAME: minimum_length(%c20) // CHECK-SAME: type(DeviceVisible) // CHECK-SAME: usage("Transfer{{.+}}Dispatch{{.+}}") - %0 = stream.tensor.import %arg0 : !hal.buffer -> tensor{%arg1} in !stream.resource{%c20} + %0 = stream.tensor.import on(#hal.device.affinity<@device>) %arg0 : !hal.buffer -> tensor{%arg1} in !stream.resource{%c20} // CHECK: util.return %arg0 util.return %0 : !stream.resource } @@ -21,6 +23,8 @@ util.func public @tensorImportBuffer(%arg0: !hal.buffer, %arg1: index) -> !strea // when lowering into the stream dialect; here we only care about the storage // buffer itself. +util.global private @device : !hal.device + // CHECK-LABEL: @tensorImportBufferView util.func public @tensorImportBufferView(%arg0: !hal.buffer_view, %arg1: index) -> !stream.resource { %c20 = arith.constant 20 : index @@ -32,23 +36,27 @@ util.func public @tensorImportBufferView(%arg0: !hal.buffer_view, %arg1: index) // CHECK-SAME: minimum_length(%c20) // CHECK-SAME: type(DeviceVisible) // CHECK-SAME: usage("Transfer{{.+}}Dispatch{{.+}}") - %0 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor{%arg1} in !stream.resource{%c20} + %0 = stream.tensor.import on(#hal.device.affinity<@device>) %arg0 : !hal.buffer_view -> tensor{%arg1} in !stream.resource{%c20} // CHECK: util.return %[[BUFFER]] util.return %0 : !stream.resource } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @tensorExportBuffer util.func public @tensorExportBuffer(%arg0: !stream.resource, %arg1: index) -> !hal.buffer { %c200 = arith.constant 200 : index - %0 = stream.tensor.export %arg0 : tensor{%arg1} in !stream.resource{%c200} -> !hal.buffer + %0 = stream.tensor.export on(#hal.device.affinity<@device>) %arg0 : tensor{%arg1} in !stream.resource{%c200} -> !hal.buffer // CHECK: util.return %arg0 : !hal.buffer util.return %0 : !hal.buffer } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @tensorExportBufferView util.func public @tensorExportBufferView(%arg0: !stream.resource, %arg1: index) -> !hal.buffer_view { %c200 = arith.constant 200 : index @@ -60,7 +68,7 @@ util.func public @tensorExportBufferView(%arg0: !stream.resource, %arg // CHECK-SAME: type(%[[ELEMENT_TYPE]]) // CHECK-SAME: encoding(%[[ENCODING_TYPE]]) // CHECK-SAME: : !hal.buffer_view - %0 = stream.tensor.export %arg0 : tensor{%arg1} in !stream.resource{%c200} -> !hal.buffer_view + %0 = stream.tensor.export on(#hal.device.affinity<@device>) %arg0 : tensor{%arg1} in !stream.resource{%c200} -> !hal.buffer_view // CHECK: util.return %[[VIEW]] util.return %0 : !hal.buffer_view } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp index f5ac50deb4d8..c1d8c227d6c4 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp @@ -1122,25 +1122,24 @@ Value IREE::HAL::DeviceSelectAttr::buildDeviceEnumeration( } //===----------------------------------------------------------------------===// -// #hal.affinity.queue<*> +// #hal.device.affinity<*> //===----------------------------------------------------------------------===// // static -Attribute AffinityQueueAttr::parse(AsmParser &p, Type type) { - int64_t mask = 0; - // `<` - if (failed(p.parseLess())) +Attribute DeviceAffinityAttr::parse(AsmParser &p, Type type) { + // `<@device` + StringAttr deviceName; + int64_t queueMask = -1; + if (failed(p.parseLess()) || failed(p.parseSymbolName(deviceName))) return {}; - // `*` (any) - if (succeeded(p.parseOptionalStar())) { - mask = -1; - } else { + if (succeeded(p.parseOptionalComma())) { // `[`queue_bit[, ...] `]` + queueMask = 0; if (failed(p.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { int64_t i = 0; if (failed(p.parseInteger(i))) return failure(); - mask |= 1ll << i; + queueMask |= 1ll << i; return success(); }))) { return {}; @@ -1149,19 +1148,18 @@ Attribute AffinityQueueAttr::parse(AsmParser &p, Type type) { // `>` if (failed(p.parseGreater())) return {}; - return get(p.getContext(), mask); + return get(p.getContext(), FlatSymbolRefAttr::get(deviceName), queueMask); } -void AffinityQueueAttr::print(AsmPrinter &p) const { +void DeviceAffinityAttr::print(AsmPrinter &p) const { auto &os = p.getStream(); os << "<"; - int64_t mask = getMask(); - if (mask == -1) { - os << "*"; - } else { - os << "["; - for (int i = 0, j = 0; i < sizeof(mask) * 8; ++i) { - if (mask & (1ll << i)) { + os << getDevice(); + int64_t queueMask = getQueueMask(); + if (queueMask != -1) { + os << ", ["; + for (int i = 0, j = 0; i < sizeof(queueMask) * 8; ++i) { + if (queueMask & (1ll << i)) { if (j++ > 0) os << ", "; os << i; @@ -1172,45 +1170,62 @@ void AffinityQueueAttr::print(AsmPrinter &p) const { os << ">"; } -bool AffinityQueueAttr::isExecutableWith( +bool DeviceAffinityAttr::isExecutableWith( IREE::Stream::AffinityAttr other) const { if (!other) return true; - // Only compatible with other queue affinities today. When we extend the - // attributes to specify device targets we'd want to check here. - auto otherQueueAttr = llvm::dyn_cast_if_present(other); - if (!otherQueueAttr) + // Only compatible with the same exact devices today. We could support a + // peering model to allow operations to move across devices in a peered set + // but that may be best done at higher levels and avoided once we get to the + // "are these the same device" stage. + auto otherAffinityAttr = llvm::dyn_cast_if_present(other); + if (!otherAffinityAttr || getDevice() != otherAffinityAttr.getDevice()) return false; // If this affinity is a subset of the target affinity then it can execute // with it. - if ((getMask() & otherQueueAttr.getMask()) == getMask()) + if ((getQueueMask() & otherAffinityAttr.getQueueMask()) == getQueueMask()) return true; // Otherwise not compatible. return false; } IREE::Stream::AffinityAttr -AffinityQueueAttr::joinOR(IREE::Stream::AffinityAttr other) const { +DeviceAffinityAttr::joinOR(IREE::Stream::AffinityAttr other) const { if (!other) return *this; if (!IREE::Stream::AffinityAttr::canExecuteTogether(*this, other)) { return nullptr; } - auto otherQueueAttr = llvm::dyn_cast_if_present(other); - return AffinityQueueAttr::get(getContext(), - getMask() | otherQueueAttr.getMask()); + auto otherAffinityAttr = llvm::dyn_cast_if_present(other); + return DeviceAffinityAttr::get(getContext(), getDevice(), + getQueueMask() | + otherAffinityAttr.getQueueMask()); } IREE::Stream::AffinityAttr -AffinityQueueAttr::joinAND(IREE::Stream::AffinityAttr other) const { +DeviceAffinityAttr::joinAND(IREE::Stream::AffinityAttr other) const { if (!other) return *this; if (!IREE::Stream::AffinityAttr::canExecuteTogether(*this, other)) { return nullptr; } - auto otherQueueAttr = llvm::dyn_cast_if_present(other); - return AffinityQueueAttr::get(getContext(), - getMask() & otherQueueAttr.getMask()); + auto otherAffinityAttr = llvm::dyn_cast_if_present(other); + return DeviceAffinityAttr::get(getContext(), getDevice(), + getQueueMask() & + otherAffinityAttr.getQueueMask()); +} + +bool DeviceAffinityAttr::isLegalToInline(Operation *inlineSite, + Operation *inlinable) const { + // Look up the affinity of the inlining target site and only allow inlining if + // it matches exactly. We could make a decision as to whether we allow + // inlining when queues are subsets (so if the target site allows any queue + // and the inlinable allows queue 2 then allow, etc). In the future we may + // want to allow util.scope restrictions within the inline target to keep + // queue specification tighter but today most queue masks are wildcarded + // anyway. + auto targetAffinityAttr = IREE::Stream::AffinityAttr::lookup(inlineSite); + return *this == targetAffinityAttr; } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index 69a1f15e633b..b77c9a513f58 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td @@ -983,41 +983,40 @@ def HAL_DeviceSelectAttr : AttrDef +// #hal.device.affinity<*> //===----------------------------------------------------------------------===// -def HAL_AffinityQueueAttr : AttrDef, Util_HoistableAttrInterface, + DeclareAttrInterfaceMethods, ]> { - let mnemonic = "affinity.queue"; - let summary = [{specifies a set of allowed queues for an operation}]; + let mnemonic = "device.affinity"; + let summary = [{specifies a named device and optional queue affinity}]; let description = [{ - WIP; see [#10765](https://github.com/iree-org/iree/issues/10765). - This may change in the future to either be a nested attribute on a larger - affinity struct or be defined by an implementation of the affinity attr - interface. For now this allows higher levels of the stack to specify - queues such that the stream dialect can understand them and they can be - lowered into the HAL dialect. - Specifies that an annotated operation or scope is only allowed to execute on - the set of queues (0-64) provided. Operations will not run on other queues. + a specific device and optionally a set of queues (0-64) provided. + Operations will not run on other queues. If the queue mask is omitted then + any queue on the device is allowed to execute the specified operations. Example: ```mlir - // any queue - #hal.affinity.queue<*> - // queues 4 and 5 - #hal.affinity.queue<[4, 5]> + // Any queue on @device_a. + #hal.device.affinity<@device_a> + // Queues 4 and 5 on @device_b. + #hal.device.affinity<@device_b, [4, 5]> ``` }]; let parameters = (ins - AttrParameter<"int64_t", "">:$mask + AttrParameter<"FlatSymbolRefAttr", "">:$device, + AttrParameter<"int64_t", "">:$queue_mask ); let hasCustomAssemblyFormat = 1; diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir index 47f8468e7770..00f39f5e6b87 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir @@ -1,4 +1,5 @@ // RUN: iree-opt --allow-unregistered-dialect --split-input-file --mlir-print-local-scope %s | FileCheck %s +// RUN: iree-opt --inline --allow-unregistered-dialect --split-input-file --mlir-print-local-scope %s | FileCheck %s --check-prefix=CHECK-INLINE // CHECK-LABEL: descriptor_set_layout_binding.basic "descriptor_set_layout_binding.basic"() { @@ -100,11 +101,52 @@ util.global private @optional = #hal.device.select<[ // ----- -"affinity.queue"() { - // CHECK: any = #hal.affinity.queue<*> - any = #hal.affinity.queue<*>, - // CHECK: q0 = #hal.affinity.queue<[0]> - q0 = #hal.affinity.queue<[0]>, - // CHECK: q123 = #hal.affinity.queue<[1, 2, 3]> - q123 = #hal.affinity.queue<[1, 2, 3]> +util.global private @device : !hal.device +"device.affinity"() { + // CHECK: device_any = #hal.device.affinity<@device> + device_any = #hal.device.affinity<@device>, + // CHECK: device_queue_0 = #hal.device.affinity<@device, [0]> + device_queue_0 = #hal.device.affinity<@device, [0]>, + // CHECK: device_queue_123 = #hal.device.affinity<@device, [1, 2, 3]> + device_queue_123 = #hal.device.affinity<@device, [1, 2, 3]> } : () -> () + +// ----- + +// Tests that differing device affinities blocks inlining. +// Here the @inline_target is using the default affinity specified on the +// module and only functions also using the default affinity or a matching +// specified affinity will be inlined. The #hal.device.affinity controls this +// behavior and in the future we could allow inlining of compatible devices, +// the same device on differing queues, etc. + +builtin.module attributes { + stream.affinity = #hal.device.affinity<@device_a> +} { + util.global private @device_a : !hal.device + util.global private @device_b : !hal.device + // CHECK-INLINE: util.func public @inline_target + util.func public @inline_target() -> (i32, i32) { + // CHECK-INLINE-NOT: util.call @compat_inlinable + // CHECK-INLINE: %[[A:.+]] = arith.constant 0 + %a = util.call @compat_inlinable() : () -> i32 + // CHECK-INLINE: %[[B:.+]] = util.call @noncompat_inlinable + %b = util.call @noncompat_inlinable() : () -> i32 + // CHECK-INLINE: util.return %[[A]], %[[B]] + util.return %a, %b : i32, i32 + } + // CHECK-INLINE-NOT: util.func private @compat_inlinable + util.func private @compat_inlinable() -> i32 attributes { + stream.affinity = #hal.device.affinity<@device_a> + } { + %c0 = arith.constant 0 : i32 + util.return %c0 : i32 + } + // CHECK-INLINE: util.func private @noncompat_inlinable + util.func private @noncompat_inlinable() -> i32 attributes { + stream.affinity = #hal.device.affinity<@device_b> + } { + %c1 = arith.constant 1 : i32 + util.return %c1 : i32 + } +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir index 6da48666293c..7408ad9edbef 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir @@ -144,9 +144,10 @@ hal.executable @ex_with_constants { // CHECK-SAME: %[[DEVICE:.+]]: !hal.device, // CHECK-SAME: %[[LAYOUT0:.+]]: !hal.pipeline_layout, // CHECK-SAME: %[[LAYOUT1:.+]]: !hal.pipeline_layout -util.func public @executable_create(%device: !hal.device, - %layout0: !hal.pipeline_layout, - %layout1: !hal.pipeline_layout) { +util.func public @executable_create( + %device: !hal.device, + %layout0: !hal.pipeline_layout, + %layout1: !hal.pipeline_layout) { // CHECK: = hal.executable.create // CHECK-SAME: device(%[[DEVICE]] : !hal.device) // CHECK-SAME: target(@exe::@binary1) @@ -163,16 +164,17 @@ util.func public @executable_create(%device: !hal.device, // CHECK-SAME: %[[DEVICE:.+]]: !hal.device, // CHECK-SAME: %[[LAYOUT0:.+]]: !hal.descriptor_set_layout, // CHECK-SAME: %[[LAYOUT1:.+]]: !hal.descriptor_set_layout -util.func public @pipeline_layout_create(%device: !hal.device, - %layout0: !hal.descriptor_set_layout, - %layout1: !hal.descriptor_set_layout) { +util.func public @pipeline_layout_create( + %device: !hal.device, + %layout0: !hal.descriptor_set_layout, + %layout1: !hal.descriptor_set_layout) { // CHECK: hal.pipeline_layout.create // CHECK-SAME: device(%[[DEVICE]] : !hal.device) // CHECK-SAME: push_constants(1) // CHECK-SAME: layouts([%[[LAYOUT0]], %[[LAYOUT1]]]) : !hal.pipeline_layout %0 = hal.pipeline_layout.create device(%device : !hal.device) - push_constants(1) - layouts([%layout0, %layout1]) : !hal.pipeline_layout + push_constants(1) + layouts([%layout0, %layout1]) : !hal.pipeline_layout util.return } @@ -197,8 +199,9 @@ hal.executable @unresolved_workload_ex { // CHECK-LABEL: @unresolved_workload // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, // CHECK-SAME: %[[WORKLOAD_0:.+]]: index, %[[WORKLOAD_1:.+]]: index) -util.func public @unresolved_workload(%device: !hal.device, - %workload_0: index, %workload_1: index) -> (index, index, index) { +util.func public @unresolved_workload( + %device: !hal.device, + %workload_0: index, %workload_1: index) -> (index, index, index) { // CHECK: %[[WORKGROUP_X:.+]], %[[WORKGROUP_Y:.+]], %[[WORKGROUP_Z:.+]] = // CHECK-SAME: hal.executable.calculate_workgroups // CHECK-SAME: device(%[[DEVICE]] : !hal.device) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir index 82b6310af13a..1de9b90632e8 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir @@ -2,6 +2,8 @@ // Tests an end-to-end simple single-dispatch `dispatch(arg0, arg1) -> result`. +util.global private @device : !hal.device + #executable_target_embedded_elf_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64"> #executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> @@ -66,7 +68,9 @@ hal.executable private @ex { // CHECK: util.func public @simpleDispatch // CHECK-SAME: (%[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view) -> !hal.buffer_view -util.func public @simpleDispatch(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} { +util.func public @simpleDispatch(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes { + stream.affinity = #hal.device.affinity<@device> +} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c16 = arith.constant 16 : index @@ -76,8 +80,7 @@ util.func public @simpleDispatch(%arg0: !hal.buffer_view, %arg1: !hal.buffer_vie // CHECK: %[[ARG0_BUFFER:.+]] = hal.buffer_view.buffer<%[[ARG0]] : !hal.buffer_view> : !hal.buffer - // (annoyingly out of order) - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device : !hal.device // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator // CHECK: hal.buffer.assert<%[[ARG0_BUFFER]] : !hal.buffer> diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir index 287822afe3ae..3cb5f7606716 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir @@ -15,8 +15,7 @@ // CHECK-LABEL: @exeLayoutLookup util.func public @exeLayoutLookup(%device : !hal.device) -> !hal.pipeline_layout { // CHECK: %[[LAYOUT:.+]] = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout - %0 = hal.pipeline_layout.lookup device(%device : !hal.device) - layout(#hal.pipeline.layout, #hal.descriptor_set.binding<1, storage_buffer> diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir index da757047f0c4..410630703f81 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir @@ -46,18 +46,21 @@ util.func public @tiedDispatch(%input0: tensor, %input1: tensor<2x3xi32>) - // ----- +util.global private @device_a : !hal.device +util.global private @device_b : !hal.device + // CHECK-LABEL: @dispatchAffinity // CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM1:.+]]: index, %[[DIM3:.+]]: index) util.func public @dispatchAffinity(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> (tensor, tensor) { - // CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof on(#hal.affinity.queue<[0]>) tensor{%[[DIM1]], %[[DIM3]]} - // CHECK: %[[RESULT0:.+]] = stream.async.dispatch on(#hal.affinity.queue<[0]>) @ex::@entry0(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) + // CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor{%[[DIM1]], %[[DIM3]]} + // CHECK: %[[RESULT0:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@entry0(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) %0 = flow.dispatch @ex::@entry0(%input) { - stream.affinity = #hal.affinity.queue<[0]> + stream.affinity = #hal.device.affinity<@device_a> } : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor{%dim1, %dim3} - // CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof on(#hal.affinity.queue<[1]>) tensor{%[[DIM3]], %[[DIM1]]} - // CHECK: %[[RESULT1:.+]] = stream.async.dispatch on(#hal.affinity.queue<[1]>) @ex::@entry1(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) + // CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor{%[[DIM3]], %[[DIM1]]} + // CHECK: %[[RESULT1:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@entry1(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) %1 = flow.dispatch @ex::@entry1(%input) { - stream.affinity = #hal.affinity.queue<[1]> + stream.affinity = #hal.device.affinity<@device_b> } : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor{%dim3, %dim1} // return %[[RESULT0]], %[[RESULT0_SIZE]], %[[RESULT1]], %[[RESULT1_SIZE]] util.return %0, %1 : tensor, tensor diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir index 9a1272fdf4c6..7633d8c1b849 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir @@ -166,6 +166,8 @@ util.func public @tensorUpdate(%update : tensor<1x1x10xf32>, %target : tensor<5x // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @tensorLoad // CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource<*>, %[[SOURCE_SIZE:.+]]: index) util.func public @tensorLoad(%source : tensor<2x3xi32>) -> i32 { @@ -173,10 +175,10 @@ util.func public @tensorLoad(%source : tensor<2x3xi32>) -> i32 { %c1 = arith.constant 1 : index // CHECK: %[[T0:.+]] = stream.async.transfer // CHECK-SAME: %[[SOURCE]] : !stream.resource<*>{%[[SOURCE_SIZE]]} - // CHECK-SAME: from(#hal.affinity.queue<[0, 1]>) -> !stream.resource{%[[SOURCE_SIZE]]} + // CHECK-SAME: from(#hal.device.affinity<@device>) -> !stream.resource{%[[SOURCE_SIZE]]} // CHECK: %[[T1:.+]] = stream.tensor.load %[[T0]][%c0, %c1] : tensor<2x3xi32> in !stream.resource{%[[SOURCE_SIZE]]} -> i32 %0 = flow.tensor.load %source[%c0, %c1] : tensor<2x3xi32> attributes { - stream.affinity = #hal.affinity.queue<[0, 1]> + stream.affinity = #hal.device.affinity<@device> } // CHECK: util.return %[[T1]] util.return %0 : i32 @@ -184,6 +186,8 @@ util.func public @tensorLoad(%source : tensor<2x3xi32>) -> i32 { // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @tensorStore // CHECK-SAME: (%[[TARGET:.+]]: !stream.resource<*>, %[[TARGET_SIZE:.+]]: index) util.func public @tensorStore(%target : tensor<2x3xi32>) -> tensor<2x3xi32> { @@ -191,13 +195,13 @@ util.func public @tensorStore(%target : tensor<2x3xi32>) -> tensor<2x3xi32> { %c1 = arith.constant 1 : index %c9 = arith.constant 9 : i32 // CHECK: %[[T0:.+]] = stream.async.transfer %[[TARGET]] : !stream.resource<*>{%[[TARGET_SIZE]]} - // CHECK-SAME: from(#hal.affinity.queue<[0, 1]>) -> !stream.resource{%[[TARGET_SIZE]]} + // CHECK-SAME: from(#hal.device.affinity<@device>) -> !stream.resource{%[[TARGET_SIZE]]} // CHECK: %[[T1:.+]] = stream.tensor.store %c9_i32, %[[T0]][%c0, %c1] : // CHECK-SAME: i32 -> tensor<2x3xi32> in %[[T0]] as !stream.resource{%[[TARGET_SIZE]]} // CHECK: %[[T2:.+]] = stream.async.transfer %[[T1]] : !stream.resource{%[[TARGET_SIZE]]} -> - // CHECK-SAME: to(#hal.affinity.queue<[0, 1]>) !stream.resource<*>{%[[TARGET_SIZE]]} + // CHECK-SAME: to(#hal.device.affinity<@device>) !stream.resource<*>{%[[TARGET_SIZE]]} %0 = flow.tensor.store %c9, %target[%c0, %c1] : tensor<2x3xi32> attributes { - stream.affinity = #hal.affinity.queue<[0, 1]> + stream.affinity = #hal.device.affinity<@device> } // CHECK: util.return %[[T2]] util.return %0 : tensor<2x3xi32> diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td index c994e65fb587..cb362716081b 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td @@ -123,8 +123,8 @@ def Stream_ResourceAllocOp : Stream_Op<"resource.alloc", [ ); let assemblyFormat = [{ - (`on` `(` $affinity^ `)`)? (`uninitialized` $uninitialized^)? + (`on` `(` $affinity^ `)`)? attr-dict `:` type($result) `{` $storage_size `}` }]; @@ -1772,15 +1772,15 @@ def Stream_TensorTraceOp : Stream_Op<"tensor.trace", [ } // OpGroupTensorOps //===----------------------------------------------------------------------===// -// Resource transfer ops +// Async (stream.async*) ops //===----------------------------------------------------------------------===// -def OpGroupResourceTransferOps : OpDocGroup { - let summary = "Resource transfer ops"; +def OpGroupAsyncOps : OpDocGroup { + let summary = "Async ops"; let description = ""; } -let opDocGroup = OpGroupResourceTransferOps in { +let opDocGroup = OpGroupAsyncOps in { def Stream_AsyncAllocaOp : Stream_Op<"async.alloca", [ DeclareOpInterfaceMethods, %arg1: index, %arg2: !s // This covers all_gather, all_reduce, and reduce_scatter variants. +util.global private @device : !hal.device + // CHECK-LABEL: @asyncCollectiveAllGather util.func private @asyncCollectiveAllGather( // CHECK-SAME: %[[CHANNEL:.+]]: !stream.channel, @@ -95,8 +97,8 @@ util.func private @asyncCollectiveAllGather( %recv = stream.async.alloca : !stream.resource<*>{%recv_size} // CHECK: = stream.async.collective[%[[COUNT]]] %0 = stream.async.collective[%count] - // CHECK-SAME: on(#hal.affinity.queue<[0]>) channel(%[[CHANNEL]]) - on(#hal.affinity.queue<[0]>) channel(%channel) + // CHECK-SAME: on(#hal.device.affinity<@device>) channel(%[[CHANNEL]]) + on(#hal.device.affinity<@device>) channel(%channel) // CHECK-SAME: %[[SEND]][%c0 to %[[SEND_SIZE]] for %[[SEND_SIZE]]], %send[%c0 to %send_size for %send_size], // CHECK-SAME: %[[RECV]][%c0 to %[[RECV_SIZE]] for %[[RECV_SIZE]]] : @@ -110,6 +112,8 @@ util.func private @asyncCollectiveAllGather( // This covers broadcast and reduce variants. +util.global private @device : !hal.device + // CHECK-LABEL: @asyncCollectiveBroadcast util.func private @asyncCollectiveBroadcast( // CHECK-SAME: %[[CHANNEL:.+]]: !stream.channel, @@ -125,8 +129,8 @@ util.func private @asyncCollectiveBroadcast( %recv = stream.async.alloca : !stream.resource<*>{%recv_size} // CHECK: = stream.async.collective[%[[COUNT]]] %0 = stream.async.collective[%count] - // CHECK-SAME: on(#hal.affinity.queue<[0]>) channel(%[[CHANNEL]]) source(%[[RANK]]) - on(#hal.affinity.queue<[0]>) channel(%channel) source(%rank) + // CHECK-SAME: on(#hal.device.affinity<@device>) channel(%[[CHANNEL]]) source(%[[RANK]]) + on(#hal.device.affinity<@device>) channel(%channel) source(%rank) // CHECK-SAME: %[[SEND]][%c0 to %[[SEND_SIZE]] for %[[SEND_SIZE]]], %send[%c0 to %send_size for %send_size], // CHECK-SAME: %[[RECV]][%c0 to %[[RECV_SIZE]] for %[[RECV_SIZE]]] : @@ -147,10 +151,12 @@ util.func private @asyncTransfer(%arg0: !stream.resource, %arg1: index // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @asyncTransferAffinities util.func private @asyncTransferAffinities(%arg0: !stream.resource, %arg1: index) -> !stream.resource { - // CHECK: = stream.async.transfer %arg0 : !stream.resource{%arg1} from(#hal.affinity.queue<[0]>) -> to(#hal.affinity.queue<[1]>) !stream.resource{%arg1} - %0 = stream.async.transfer %arg0 : !stream.resource{%arg1} from(#hal.affinity.queue<[0]>) -> to(#hal.affinity.queue<[1]>) !stream.resource{%arg1} + // CHECK: = stream.async.transfer %arg0 : !stream.resource{%arg1} from(#hal.device.affinity<@device, [0]>) -> to(#hal.device.affinity<@device, [1]>) !stream.resource{%arg1} + %0 = stream.async.transfer %arg0 : !stream.resource{%arg1} from(#hal.device.affinity<@device, [0]>) -> to(#hal.device.affinity<@device, [1]>) !stream.resource{%arg1} util.return %0 : !stream.resource } diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/channel_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/channel_ops.mlir index 486a03f9c0cb..a465546344a1 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/channel_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/channel_ops.mlir @@ -1,10 +1,12 @@ // RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: @channel_create // CHECK-SAME: (%[[RANK:.+]]: index, %[[COUNT:.+]]: index) util.func private @channel_create(%rank: index, %count: index) { - // CHECK: %channel = stream.channel.create on(#hal.affinity.queue<[0, 1]>) rank(%[[RANK]]) count(%[[COUNT]]) : !stream.channel - %channel = stream.channel.create on(#hal.affinity.queue<[0, 1]>) rank(%rank) count(%count) : !stream.channel + // CHECK: %channel = stream.channel.create on(#hal.device.affinity<@device>) rank(%[[RANK]]) count(%[[COUNT]]) : !stream.channel + %channel = stream.channel.create on(#hal.device.affinity<@device>) rank(%rank) count(%count) : !stream.channel util.return } diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir index ab523ec40817..950643ab1eb0 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/context_ops.mlir @@ -1,12 +1,14 @@ // RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: @context_resolve util.func private @context_resolve() { // CHECK: = stream.context.resolve : !hal.allocator %allocator = stream.context.resolve : !hal.allocator - // CHECK: = stream.context.resolve on(#hal.affinity.queue<*>) : !hal.device, i64 - %device1, %queue_affinity_any = stream.context.resolve on(#hal.affinity.queue<*>) : !hal.device, i64 - // CHECK: = stream.context.resolve on(#hal.affinity.queue<[4, 5]>) : !hal.device, i64 - %device0, %queue_affinity_45 = stream.context.resolve on(#hal.affinity.queue<[4, 5]>) : !hal.device, i64 + // CHECK: = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device, i64 + %device1, %queue_affinity_any = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device, i64 + // CHECK: = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64 + %device0, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64 util.return } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir index 14e8fb26ef7b..ed1f338117dd 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir @@ -16,8 +16,8 @@ stream.executable private @rebaseBindingsEx { stream.executable.export public @dispatch attributes {stream.resources = #aliasConfig} builtin.module { - // CHECK: util.func public @dispatch(%[[BINDING_A:.+]]: !stream.binding, %[[BINDING_B:.+]]: !stream.binding, - // CHECK-SAME: %[[OFFSET_A:.+]]: index, %[[OFFSET_B:.+]]: index, %[[OPERAND:.+]]: index) + // CHECK: util.func public @dispatch(%[[BINDING_A:.+]]: !stream.binding, %[[BINDING_B:.+]]: !stream.binding, + // CHECK-SAME: %[[OFFSET_A:.+]]: index, %[[OFFSET_B:.+]]: index, %[[OPERAND:.+]]: index) util.func public @dispatch(%binding_a: !stream.binding, %binding_b: !stream.binding, %operand: index) { %c0 = arith.constant 0 : index %c20 = arith.constant 20 : index @@ -39,7 +39,7 @@ stream.executable private @rebaseBindingsEx { } } } -// CHECK: util.func public @rebaseBindings(%[[OPERAND:.+]]: index) +// CHECK: util.func public @rebaseBindings(%[[OPERAND:.+]]: index) util.func public @rebaseBindings(%operand: index) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -97,8 +97,8 @@ util.func public @rebaseBindings(%operand: index) { stream.executable private @deduplicateBindingsEx { stream.executable.export public @dispatch attributes {stream.resources = #aliasConfig} builtin.module { - // CHECK: util.func public @dispatch(%[[BINDING_A:.+]]: !stream.binding, %[[BINDING_B:.+]]: !stream.binding, - // CHECK-SAME: %[[OFFSET_A:.+]]: index, %[[OFFSET_C:.+]]: index, %[[OFFSET_B:.+]]: index, %[[OPERAND:.+]]: index) + // CHECK: util.func public @dispatch(%[[BINDING_A:.+]]: !stream.binding, %[[BINDING_B:.+]]: !stream.binding, + // CHECK-SAME: %[[OFFSET_A:.+]]: index, %[[OFFSET_C:.+]]: index, %[[OFFSET_B:.+]]: index, %[[OPERAND:.+]]: index) util.func public @dispatch(%binding_a: !stream.binding, %binding_b: !stream.binding, %binding_c: !stream.binding, %operand: index) { %c0 = arith.constant 0 : index %c20 = arith.constant 20 : index @@ -127,7 +127,7 @@ stream.executable private @deduplicateBindingsEx { } } } -// CHECK: util.func public @deduplicateBindings(%[[OPERAND:.+]]: index) +// CHECK: util.func public @deduplicateBindings(%[[OPERAND:.+]]: index) util.func public @deduplicateBindings(%operand: index) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir index a3b4ef6cfc43..a1509e327a6d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir @@ -110,13 +110,15 @@ util.func public @multiUseTiedOperand(%size: index) -> (!stream.resource<*>, !st // TODO(#11249): support in-place collectives - when supported this will become // a negative test as we'd expect %send_recv to be used for both operands. +util.global private @device : !hal.device + // CHECK-LABEL: @tiedCollectivesTODO // CHECK-SAME: (%[[CHANNEL:.+]]: !stream.channel, %[[SEND_RECV:.+]]: !stream.resource<*>, %[[SEND_SIZE:.+]]: index, %[[RECV_SIZE:.+]]: index, %[[COUNT:.+]]: index) util.func private @tiedCollectivesTODO(%channel: !stream.channel, %send_recv: !stream.resource<*>, %send_size: index, %recv_size: index, %count: index) -> !stream.resource<*> { %c0 = arith.constant 0 : index - // CHECK: %[[RECV_CLONE:.+]] = stream.async.clone on(#hal.affinity.queue<[0]>) %[[SEND_RECV]] + // CHECK: %[[RECV_CLONE:.+]] = stream.async.clone on(#hal.device.affinity<@device>) %[[SEND_RECV]] // CHECK: %[[ALL_GATHER:.+]] = stream.async.collective[%[[COUNT]]] - %0 = stream.async.collective[%count] on(#hal.affinity.queue<[0]>) channel(%channel) + %0 = stream.async.collective[%count] on(#hal.device.affinity<@device>) channel(%channel) // CHECK-SAME: %[[SEND_RECV]][%c0 to %[[SEND_SIZE]] for %[[SEND_SIZE]]], %send_recv[%c0 to %send_size for %send_size], // CHECK-SAME: %[[RECV_CLONE]][%c0 to %[[RECV_SIZE]] for %[[RECV_SIZE]]] : diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir index 00f5c32603a3..8266ca5e2038 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir @@ -223,27 +223,29 @@ util.func public @producedResults(%size0: index, %size1: index) { // execution region. We expect them to be placed into packed slices and // allocated with the async stream-ordered alloca/dealloca ops. +util.global private @device : !hal.device + // CHECK-LABEL: @locals // CHECK-SAME: (%[[SIZE0:.+]]: index, %[[SIZE1:.+]]: index, %[[AWAIT_TIMEPOINT:.+]]: !stream.timepoint) util.func public @locals(%size0: index, %size1: index, %await_timepoint: !stream.timepoint) -> !stream.timepoint { %c254_i32 = arith.constant 254 : i32 %c255_i32 = arith.constant 255 : i32 - // CHECK: %[[SLICES:.+]]:3 = stream.resource.pack on(#hal.affinity.queue<[0]>) slices({ + // CHECK: %[[SLICES:.+]]:3 = stream.resource.pack on(#hal.device.affinity<@device>) slices({ // CHECK-NEXT: [0, 0] = %[[SIZE0]], // CHECK-NEXT: [1, 1] = %[[SIZE1]] // CHECK-NEXT: }) - // CHECK-NEXT: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized on(#hal.affinity.queue<[0]>) await(%[[AWAIT_TIMEPOINT]]) => !stream.resource{%[[SLICES]]#0} => !stream.timepoint + // CHECK-NEXT: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized on(#hal.device.affinity<@device>) await(%[[AWAIT_TIMEPOINT]]) => !stream.resource{%[[SLICES]]#0} => !stream.timepoint // CHECK-NEXT: %[[AWAIT_JOIN:.+]] = stream.timepoint.join max(%[[AWAIT_TIMEPOINT]], %[[ALLOCA_TIMEPOINT]]) - // CHECK: %[[EXEC_TIMEPOINT:.+]] = stream.cmd.execute on(#hal.affinity.queue<[0]>) await(%[[AWAIT_JOIN]]) + // CHECK: %[[EXEC_TIMEPOINT:.+]] = stream.cmd.execute on(#hal.device.affinity<@device>) await(%[[AWAIT_JOIN]]) // CHECK-SAME: with(%[[ALLOCA]] as %[[CAPTURE:.+]]: !stream.resource{%[[SLICES]]#0}) - %result_timepoint = stream.async.execute on(#hal.affinity.queue<[0]>) await(%await_timepoint) => with() { + %result_timepoint = stream.async.execute on(#hal.device.affinity<@device>) await(%await_timepoint) => with() { // CHECK: stream.cmd.fill %c254_i32, %[[CAPTURE]][%[[SLICES]]#1 for %[[SIZE0]]] : i32 -> !stream.resource{%[[SLICES]]#0} %0 = stream.async.splat %c254_i32 : i32 -> !stream.resource{%size0} // CHECK: stream.cmd.fill %c255_i32, %[[CAPTURE]][%[[SLICES]]#2 for %[[SIZE1]]] : i32 -> !stream.resource{%[[SLICES]]#0} %1 = stream.async.splat %c255_i32 : i32 -> !stream.resource{%size1} stream.yield } => !stream.timepoint - // CHECK: %[[DEALLOCA_TIMEPOINT:.+]] = stream.resource.dealloca on(#hal.affinity.queue<[0]>) await(%[[EXEC_TIMEPOINT]]) => %[[ALLOCA]] : !stream.resource{%[[SLICES]]#0} => !stream.timepoint + // CHECK: %[[DEALLOCA_TIMEPOINT:.+]] = stream.resource.dealloca on(#hal.device.affinity<@device>) await(%[[EXEC_TIMEPOINT]]) => %[[ALLOCA]] : !stream.resource{%[[SLICES]]#0} => !stream.timepoint // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[DEALLOCA_TIMEPOINT]], %[[EXEC_TIMEPOINT]]) => !stream.timepoint // CHECK: util.return %[[JOIN]] util.return %result_timepoint : !stream.timepoint diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir index 3ccd78109383..0f33b51cbc0f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir @@ -38,6 +38,9 @@ util.func public @partitioning(%arg0: !stream.resource, %arg1: !stream // Dispatches with the same affinities should be placed into the same execution // regions. +util.global private @device_a : !hal.device +util.global private @device_b : !hal.device + // CHECK-LABEL: @partitioningWithAffinities // CHECK-SAME: (%[[ARG0:.+]]: !stream.resource) util.func public @partitioningWithAffinities(%arg0: !stream.resource) -> !stream.resource { @@ -48,31 +51,31 @@ util.func public @partitioningWithAffinities(%arg0: !stream.resource) %c255_i32 = arith.constant 255 : i32 // CHECK: %[[TRANSIENTS:.+]]:2, %[[TIMEPOINT0:.+]] = stream.async.execute - // CHECK-SAME: on(#hal.affinity.queue<[0]>) + // CHECK-SAME: on(#hal.device.affinity<@device_a>) // CHECK-SAME: with(%[[ARG0]] as %[[ARG0_CAPTURE:.+]]: !stream.resource{%c20}) // CHECK-SAME: -> (!stream.resource{%c1280}, !stream.resource{%c20}) { // CHECK-NEXT: %[[SPLAT:.+]] = stream.async.splat %splat = stream.async.splat %c255_i32 : i32 -> !stream.resource{%c1280} // CHECK-NEXT: %[[DISPATCH0:.+]] = stream.async.dispatch @ex::@dispatch_0[%c1](%[[ARG0_CAPTURE]][{{.+}}], %[[SPLAT]][{{.+}}]) - %dispatch0 = stream.async.dispatch on(#hal.affinity.queue<[0]>) @ex::@dispatch_0[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource{%c20}, !stream.resource{%c20}) -> !stream.resource{%c1280} + %dispatch0 = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@dispatch_0[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource{%c20}, !stream.resource{%c20}) -> !stream.resource{%c1280} // CHECK-NEXT: %[[DISPATCH1:.+]] = stream.async.dispatch @ex::@dispatch_1[%c1](%[[ARG0_CAPTURE]][{{.+}}], %[[SPLAT]][{{.+}}]) - %dispatch1 = stream.async.dispatch on(#hal.affinity.queue<[0]>) @ex::@dispatch_1[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource{%c20}, !stream.resource{%c20}) -> !stream.resource{%c20} + %dispatch1 = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@dispatch_1[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource{%c20}, !stream.resource{%c20}) -> !stream.resource{%c20} // CHECK-NEXT: stream.yield %[[DISPATCH0]], %[[DISPATCH1]] // CHECK-NEXT: } => !stream.timepoint // CHECK: %[[RESULT:.+]], %[[TIMEPOINT1:.+]] = stream.async.execute - // CHECK-SAME: on(#hal.affinity.queue<[1]>) + // CHECK-SAME: on(#hal.device.affinity<@device_b>) // CHECK-SAME: await(%[[TIMEPOINT0]]) // CHECK-SAME: with(%[[TRANSIENTS]]#0 as %[[TRANSIENT0_CAPTURE:.+]]: !stream.resource{%c1280}, // CHECK-SAME: %[[TRANSIENTS]]#1 as %[[TRANSIENT1_CAPTURE:.+]]: !stream.resource{%c20}) // CHECK-SAME: -> !stream.resource{%c20} // CHECK-NEXT: %[[DISPATCH2:.+]] = stream.async.dispatch @ex::@dispatch_2[%c1](%[[TRANSIENT0_CAPTURE]][{{.+}}], %[[TRANSIENT1_CAPTURE]][{{.+}}]) - %dispatch2 = stream.async.dispatch on(#hal.affinity.queue<[1]>) @ex::@dispatch_2[%c1](%dispatch0[%c0 to %c1280 for %c1280], %dispatch1[%c0 to %c20 for %c20]) : (!stream.resource{%c1280}, !stream.resource{%c20}) -> !stream.resource{%c20} + %dispatch2 = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@dispatch_2[%c1](%dispatch0[%c0 to %c1280 for %c1280], %dispatch1[%c0 to %c20 for %c20]) : (!stream.resource{%c1280}, !stream.resource{%c20}) -> !stream.resource{%c20} // CHECK-NEXT: stream.yield %[[DISPATCH2]] // CHECK-NEXT: } => !stream.timepoint // CHECK-NEXT: %[[READY:.+]] = stream.timepoint.await - // CHECK-SAME: on(#hal.affinity.queue<[1]>) + // CHECK-SAME: on(#hal.device.affinity<@device_b>) // CHECK-SAME: %[[TIMEPOINT1]] => %[[RESULT]] : !stream.resource{%c20} // CHECK-NEXT: util.return %[[READY]] util.return %dispatch2 : !stream.resource @@ -84,6 +87,10 @@ util.func public @partitioningWithAffinities(%arg0: !stream.resource) // dependencies. Unrelated dispatches with differing affinities should end up // in concurrently executable regions. +util.global private @device_a : !hal.device +util.global private @device_b : !hal.device +util.global private @device_c : !hal.device + // CHECK-LABEL: @partitioningWithConcurrentAffinities // CHECK-SAME: (%[[ARG0:.+]]: !stream.resource) util.func public @partitioningWithConcurrentAffinities(%arg0: !stream.resource) -> !stream.resource { @@ -94,23 +101,23 @@ util.func public @partitioningWithConcurrentAffinities(%arg0: !stream.resource) + // CHECK-SAME: on(#hal.device.affinity<@device_a>) // CHECK-SAME: with(%[[ARG0]] as %[[ARG0_CAPTURE0:.+]]: !stream.resource{%c20}) // CHECK-SAME: !stream.resource{%c1280} // CHECK-NEXT: %[[SPLAT0:.+]] = stream.async.splat %splat = stream.async.splat %c255_i32 : i32 -> !stream.resource{%c1280} // CHECK-NEXT: %[[DISPATCH0:.+]] = stream.async.dispatch @ex::@dispatch_0[%c1](%[[ARG0_CAPTURE0]][{{.+}}], %[[SPLAT0]][{{.+}}]) - %dispatch0 = stream.async.dispatch on(#hal.affinity.queue<[0]>) @ex::@dispatch_0[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource{%c20}, !stream.resource{%c20}) -> !stream.resource{%c1280} + %dispatch0 = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@dispatch_0[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource{%c20}, !stream.resource{%c20}) -> !stream.resource{%c1280} // CHECK-NEXT: stream.yield %[[DISPATCH0]] // CHECK-NEXT: } => !stream.timepoint // CHECK: %[[TRANSIENT1:.+]], %[[TIMEPOINT1:.+]] = stream.async.execute - // CHECK-SAME: on(#hal.affinity.queue<[1]>) + // CHECK-SAME: on(#hal.device.affinity<@device_b>) // CHECK-SAME: with(%[[ARG0]] as %[[ARG0_CAPTURE1:.+]]: !stream.resource{%c20}) // CHECK-SAME: -> !stream.resource{%c20} { // CHECK-NEXT: %[[SPLAT1:.+]] = stream.async.splat // CHECK-NEXT: %[[DISPATCH1:.+]] = stream.async.dispatch @ex::@dispatch_1[%c1](%[[ARG0_CAPTURE1]][{{.+}}], %[[SPLAT1]][{{.+}}]) - %dispatch1 = stream.async.dispatch on(#hal.affinity.queue<[1]>) @ex::@dispatch_1[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource{%c20}, !stream.resource{%c20}) -> !stream.resource{%c20} + %dispatch1 = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@dispatch_1[%c1](%arg0[%c0 to %c20 for %c20], %splat[%c0 to %c20 for %c20]) : (!stream.resource{%c20}, !stream.resource{%c20}) -> !stream.resource{%c20} // CHECK-NEXT: stream.yield %[[DISPATCH1]] // CHECK-NEXT: } => !stream.timepoint @@ -121,12 +128,12 @@ util.func public @partitioningWithConcurrentAffinities(%arg0: !stream.resource{%c1280}, // CHECK-SAME: %[[TRANSIENT1]] as %[[TRANSIENT1_CAPTURE:.+]]: !stream.resource{%c20}) // CHECK-NEXT: %[[DISPATCH2:.+]] = stream.async.dispatch @ex::@dispatch_2[%c1](%[[TRANSIENT0_CAPTURE]][{{.+}}], %[[TRANSIENT1_CAPTURE]][{{.+}}]) - %dispatch2 = stream.async.dispatch on(#hal.affinity.queue<[2]>) @ex::@dispatch_2[%c1](%dispatch0[%c0 to %c1280 for %c1280], %dispatch1[%c0 to %c20 for %c20]) : (!stream.resource{%c1280}, !stream.resource{%c20}) -> !stream.resource{%c20} + %dispatch2 = stream.async.dispatch on(#hal.device.affinity<@device_c>) @ex::@dispatch_2[%c1](%dispatch0[%c0 to %c1280 for %c1280], %dispatch1[%c0 to %c20 for %c20]) : (!stream.resource{%c1280}, !stream.resource{%c20}) -> !stream.resource{%c20} // CHECK-NEXT: stream.yield %[[DISPATCH2]] // CHECK-NEXT: } => !stream.timepoint // CHECK-NEXT: %[[READY:.+]] = stream.timepoint.await - // CHECK-SAME: on(#hal.affinity.queue<[2]>) + // CHECK-SAME: on(#hal.device.affinity<@device_c>) // CHECK-SAME: %[[TIMEPOINT2]] => %[[RESULT]] : !stream.resource{%c20} // CHECK-NEXT: util.return %[[READY]] util.return %dispatch2 : !stream.resource diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir index 4082bbfa4721..e289f07575a9 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir @@ -142,12 +142,14 @@ module @hoist_inline_parameters { // CHECK-LABEL: @hoist_dialect_attrs module @hoist_dialect_attrs { + // CHECK: util.global private @device + util.global private @device : !hal.device // CHECK: util.global private @[[HOISTED:[a-z0-9_]+]] - // CHECK-SAME: hal.affinity = #hal.affinity.queue<[0, 1]> + // CHECK-SAME: stream.affinity = #hal.device.affinity<@device> // CHECK: util.initializer - // CHECK-SAME: hal.affinity = #hal.affinity.queue<[0, 1]> + // CHECK-SAME: stream.affinity = #hal.device.affinity<@device> util.func public @main() -> tensor attributes { - hal.affinity = #hal.affinity.queue<[0, 1]> + stream.affinity = #hal.device.affinity<@device> } { %0 = arith.constant dense<3> : tensor %1 = "iree_unregistered.const_expr"(%0) : (tensor) -> tensor diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel index 42d582b921cd..a3255ea99a6e 100644 --- a/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel @@ -22,6 +22,7 @@ iree_compiler_cc_library( ], deps = [ "//compiler/src/iree/compiler/Dialect/HAL/Conversion", + "//compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL:Utils", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/VM/Conversion", "//compiler/src/iree/compiler/Modules/Check/IR", diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt index 161a143e8ce6..3c20a5ba1bf1 100644 --- a/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt +++ b/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt @@ -22,6 +22,7 @@ iree_cc_library( MLIRTransformUtils MLIRTransforms iree::compiler::Dialect::HAL::Conversion + iree::compiler::Dialect::HAL::Conversion::StreamToHAL::Utils iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::VM::Conversion iree::compiler::Modules::Check::IR diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp index 3dab905974a4..d9db0b3e7b0a 100644 --- a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp +++ b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Modules/Check/Conversion/ConversionPatterns.h" #include "iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h" +#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h" #include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h" @@ -68,8 +69,7 @@ static LogicalResult applyDefaultCheckBufferRewrite( state.addAttributes(srcOp->getAttrs()); // Add device argument. - // TODO(multi-device): support multiple devices in check tests . - Value device = IREE::HAL::DeviceType::resolveAny(srcOp->getLoc(), rewriter); + Value device = lookupDeviceFor(srcOp, rewriter); state.addOperands({device}); for (auto [srcOperand, dstOperand] : From ba36e425db6eec0a32d7f20888709f7e3d743e33 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 26 Feb 2024 17:57:47 -0800 Subject: [PATCH 03/25] Adding DeviceAnalysis and porting FixupLegacySync. This analysis allows for querying the potential targets of a `!hal.device` SSA value. --- .../HAL/Analysis/Attributes/BUILD.bazel | 35 +++ .../HAL/Analysis/Attributes/CMakeLists.txt | 34 +++ .../Analysis/Attributes/DeviceGlobalPVS.cpp | 117 ++++++++ .../HAL/Analysis/Attributes/DeviceGlobalPVS.h | 60 +++++ .../Analysis/Attributes/DeviceTargetPVS.cpp | 250 ++++++++++++++++++ .../HAL/Analysis/Attributes/DeviceTargetPVS.h | 97 +++++++ .../compiler/Dialect/HAL/Analysis/BUILD.bazel | 9 + .../Dialect/HAL/Analysis/BindingLayout.h | 6 +- .../Dialect/HAL/Analysis/CMakeLists.txt | 9 + .../Dialect/HAL/Analysis/DeviceAnalysis.cpp | 234 ++++++++++++++++ .../Dialect/HAL/Analysis/DeviceAnalysis.h | 104 ++++++++ .../Dialect/HAL/Analysis/DeviceSet.cpp | 139 ++++++++++ .../compiler/Dialect/HAL/Analysis/DeviceSet.h | 57 ++++ .../iree/compiler/Dialect/HAL/IR/HALAttrs.cpp | 163 +----------- .../iree/compiler/Dialect/HAL/IR/HALAttrs.td | 35 +-- .../HAL/Transforms/FixupLegacySync.cpp | 50 +++- .../Transforms/test/fixup_legacy_sync.mlir | 125 +++++++-- .../Dialect/Stream/Analysis/ResourceUsage.cpp | 9 +- .../Dialect/Util/Analysis/Explorer.cpp | 15 +- .../compiler/Dialect/Util/Analysis/Explorer.h | 15 +- 20 files changed, 1325 insertions(+), 238 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/BUILD.bazel create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/CMakeLists.txt create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.cpp create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.h create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.cpp create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.h create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.cpp create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.cpp create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.h diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/BUILD.bazel new file mode 100644 index 000000000000..46e9053e41ce --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/BUILD.bazel @@ -0,0 +1,35 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_compiler_cc_library( + name = "Attributes", + srcs = [ + "DeviceGlobalPVS.cpp", + "DeviceTargetPVS.cpp", + ], + hdrs = [ + "DeviceGlobalPVS.h", + "DeviceTargetPVS.h", + ], + deps = [ + "//compiler/src/iree/compiler/Dialect/HAL/IR", + "//compiler/src/iree/compiler/Dialect/Util/Analysis/DFX", + "//compiler/src/iree/compiler/Dialect/Util/IR", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/CMakeLists.txt new file mode 100644 index 000000000000..8be479f39809 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/CMakeLists.txt @@ -0,0 +1,34 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + Attributes + HDRS + "DeviceGlobalPVS.h" + "DeviceTargetPVS.h" + SRCS + "DeviceGlobalPVS.cpp" + "DeviceTargetPVS.cpp" + DEPS + LLVMSupport + MLIRAnalysis + MLIRIR + MLIRPass + MLIRSupport + iree::compiler::Dialect::HAL::IR + iree::compiler::Dialect::Util::Analysis::DFX + iree::compiler::Dialect::Util::IR + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.cpp new file mode 100644 index 000000000000..8e96e9797d57 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.cpp @@ -0,0 +1,117 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.h" + +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "iree-hal-device-analysis" + +namespace mlir::iree_compiler::IREE::HAL { + +//===----------------------------------------------------------------------===// +// DeviceGlobalValuePVS +//===----------------------------------------------------------------------===// + +const char DeviceGlobalValuePVS::ID = 0; + +void DeviceGlobalValuePVS::initializeValue(Value value, DFX::Solver &solver) { + assert(isa(value.getType()) && + "only initialize on values of type !hal.device"); + + // If the value is a function arg of a public function then we'll never be + // able to know (today). We could look for attributes defining device + // properties but we can't recover a DeviceTargetAttr from them. + if (auto blockArg = dyn_cast(value)) { + if (auto funcOp = + dyn_cast(blockArg.getOwner()->getParentOp())) { + if (funcOp.isPublic()) { + LLVM_DEBUG(llvm::dbgs() + << "DeviceGlobalValuePVS: argument to a public function - " + "treating as undefined\n"); + unionAssumedWithUndef(); + indicatePessimisticFixpoint(); + return; + } + } + } +} + +ChangeStatus DeviceGlobalValuePVS::updateValue(Value value, + DFX::Solver &solver) { + StateType newState; + auto traversalResult = TraversalResult::COMPLETE; + + // Walk into all producers of the SSA value. + // Note that we may end up at multiple global loads of different globals + // by walking up through calls/branches/etc. + traversalResult |= + solver.getExplorer().walkDefiningOps(value, [&](OpResult result) { + updateFromDefiningOp(value, result, newState, solver); + return WalkResult::advance(); + }); + + if (traversalResult == TraversalResult::INCOMPLETE) { + // Incomplete traversal because of external call graph edges or pointers. + newState.unionAssumedWithUndef(); + newState.indicatePessimisticFixpoint(); + } + return DFX::clampStateAndIndicateChange(getState(), newState); +} + +void DeviceGlobalValuePVS::updateFromDefiningOp(Value value, OpResult result, + StateType &newState, + DFX::Solver &solver) { + TypeSwitch(result.getOwner()) + .Case([&](mlir::arith::SelectOp op) { + auto &truePVS = solver.getElementFor( + *this, Position::forValue(op.getTrueValue()), + DFX::Resolution::REQUIRED); + auto &falsePVS = solver.getElementFor( + *this, Position::forValue(op.getFalseValue()), + DFX::Resolution::REQUIRED); + newState ^= truePVS.getState(); + newState ^= falsePVS.getState(); + }) + .Case([&](IREE::Util::OptimizationBarrierOp op) { + auto &sourcePVS = solver.getElementFor( + *this, Position::forValue(op.getOperand(0)), + DFX::Resolution::REQUIRED); + newState ^= sourcePVS.getState(); + }) + .Case([&](IREE::Util::GlobalLoadOpInterface op) { + auto *globalInfo = + solver.getExplorer().queryGlobalInfoFrom(op.getGlobalName(), op); + newState.unionAssumed(globalInfo->op); + }) + .Default([&](Operation *op) {}); +} + +const std::string DeviceGlobalValuePVS::getAsStr(AsmState &asmState) const { + std::string str; + llvm::raw_string_ostream sstream(str); + sstream << "pvs: "; + if (isValidState()) { + sstream << "["; + if (isUndefContained()) { + sstream << "undef, "; + } + llvm::interleaveComma(getAssumedSet(), sstream, + [&](IREE::Util::GlobalOpInterface value) { + value.print(sstream, asmState); + }); + sstream << "]"; + } else { + sstream << "(invalid)"; + } + sstream.flush(); + return str; +} + +} // namespace mlir::iree_compiler::IREE::HAL diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.h new file mode 100644 index 000000000000..10864d654f24 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.h @@ -0,0 +1,60 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_HAL_ANALYSIS_ATTRIBUTES_DEVICEGLOBALPVS_H_ +#define IREE_COMPILER_DIALECT_HAL_ANALYSIS_ATTRIBUTES_DEVICEGLOBALPVS_H_ + +#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/Util/Analysis/DFX/Solver.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" + +namespace mlir::iree_compiler::IREE::HAL { + +//===----------------------------------------------------------------------===// +// DeviceGlobalValuePVS (Potential Values State) +//===----------------------------------------------------------------------===// + +// Set of potential globals that provide a !hal.device SSA value. +// A set size of 1 indicates that the device SSA value is a particular device. +// Multiple entries indicate that multiple code paths may route to the value +// with different devices selected. +class DeviceGlobalValuePVS + : public DFX::StateWrapper< + DFX::PotentialValuesState, + DFX::ValueElement> { +public: + using BaseType = DFX::StateWrapper< + DFX::PotentialValuesState, + DFX::ValueElement>; + using BaseType::BaseType; + + static DeviceGlobalValuePVS &createForPosition(const Position &pos, + DFX::Solver &solver) { + return *(new (solver.getAllocator()) DeviceGlobalValuePVS(pos)); + } + + // Identity definitions. + const std::string getName() const override { return "DeviceGlobalValuePVS"; } + const void *getID() const override { return &ID; } + static bool classof(const DFX::AbstractElement *element) { + return (element->getID() == &ID); + } + static const char ID; + + const std::string getAsStr(AsmState &asmState) const override; + +private: + void initializeValue(Value value, DFX::Solver &solver) override; + ChangeStatus updateValue(Value value, DFX::Solver &solver) override; + void updateFromDefiningOp(Value value, OpResult result, StateType &newState, + DFX::Solver &solver); +}; + +} // namespace mlir::iree_compiler::IREE::HAL + +#endif // IREE_COMPILER_DIALECT_HAL_ANALYSIS_ATTRIBUTES_DEVICEGLOBALPVS_H_ diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.cpp new file mode 100644 index 000000000000..85813f616d75 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.cpp @@ -0,0 +1,250 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.h" + +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "iree-hal-device-analysis" + +namespace mlir::iree_compiler::IREE::HAL { + +//===----------------------------------------------------------------------===// +// DeviceTargetGlobalPVS +//===----------------------------------------------------------------------===// + +const char DeviceTargetGlobalPVS::ID = 0; + +void DeviceTargetGlobalPVS::initializeOperation(IREE::Util::GlobalOp globalOp, + DFX::Solver &solver) { + assert(isa(globalOp.getType()) && + "only initialize on globals of type !hal.device"); + + // We only support immutable initialized device globals. + // We could track usage up through stores to handle the mutable case but + // the compiler does not generate such programs today. + auto *globalInfo = solver.getExplorer().getGlobalInfo(globalOp); + if (!globalInfo || globalInfo->isIndirect || globalOp.isGlobalMutable()) { + LLVM_DEBUG(llvm::dbgs() + << "DeviceTargetGlobalPVS: mutable device globals or those used " + "indirectly are not yet implemented\n"); + unionAssumedWithUndef(); + indicatePessimisticFixpoint(); + return; + } + + // Use the initial value to populate the potential value set. + std::function unionAttr; + unionAttr = [&](Attribute attr) -> bool { + return TypeSwitch(attr) + .Case([&](auto targetAttr) { + LLVM_DEBUG({ + llvm::dbgs() << "DeviceTargetGlobalPVS: unioning with target: "; + attr.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + unionAssumed(targetAttr); + return true; + }) + .Case([&](auto fallbackAttr) { + LLVM_DEBUG({ + llvm::dbgs() << "DeviceTargetGlobalPVS: unioning with fallback: "; + attr.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + auto *fallbackInfo = solver.getExplorer().queryGlobalInfoFrom( + fallbackAttr.getName().getValue(), globalOp); + if (!fallbackInfo) { + LLVM_DEBUG( + llvm::dbgs() + << "DeviceTargetGlobalPVS: !! failed to find fallback global " + << fallbackAttr.getName().getValue() << "\n"); + return false; + } + auto fallbackPVS = + solver.getOrCreateElementFor( + Position::forOperation(fallbackInfo->op)); + if (fallbackPVS.isUndefContained()) { + LLVM_DEBUG(llvm::dbgs() + << "DeviceTargetGlobalPVS: !! fallback is undefined\n"); + return false; + } + unionAssumed(fallbackPVS.getState()); + return true; + }) + .Case([&](auto selectAttr) { + LLVM_DEBUG({ + llvm::dbgs() << "DeviceTargetGlobalPVS: unioning with selected " + "child devices: "; + attr.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + for (auto childAttr : selectAttr.getDevices()) { + if (!unionAttr(childAttr)) { + return false; + } + } + return true; + }) + .Default( + [&](auto attr) { + LLVM_DEBUG( + llvm::dbgs() + << "DeviceTargetGlobalPVS: !! unknown initial value type\n"); + return false; + }); + }; + if (auto initialValueAttr = globalOp.getInitialValueAttr()) { + if (unionAttr(initialValueAttr)) { + indicateOptimisticFixpoint(); + } else { + unionAssumedWithUndef(); + indicatePessimisticFixpoint(); + } + } else { + LLVM_DEBUG(llvm::dbgs() + << "DeviceTargetGlobalPVS: no initial value, dynamically " + "configure devices not yet implemented\n"); + unionAssumedWithUndef(); + indicatePessimisticFixpoint(); + } +} + +ChangeStatus +DeviceTargetGlobalPVS::updateOperation(IREE::Util::GlobalOp globalOp, + DFX::Solver &solver) { + // We only support running on initialized globals today. + // We could support walking store/load or other things, though. + return ChangeStatus::UNCHANGED; +} + +const std::string DeviceTargetGlobalPVS::getAsStr(AsmState &asmState) const { + std::string str; + llvm::raw_string_ostream sstream(str); + sstream << "pvs: "; + if (isValidState()) { + sstream << "["; + if (isUndefContained()) { + sstream << "undef, "; + } + llvm::interleaveComma(getAssumedSet(), sstream, + [&](IREE::HAL::DeviceTargetAttr value) { + cast(value).print(sstream); + }); + sstream << "]"; + } else { + sstream << "(invalid)"; + } + sstream.flush(); + return str; +} + +//===----------------------------------------------------------------------===// +// DeviceTargetValuePVS +//===----------------------------------------------------------------------===// + +const char DeviceTargetValuePVS::ID = 0; + +void DeviceTargetValuePVS::initializeValue(Value value, DFX::Solver &solver) { + assert(isa(value.getType()) && + "only initialize on values of type !hal.device"); + + // If the value is a function arg of a public function then we'll never be + // able to know (today). We could look for attributes defining device + // properties but we can't recover a DeviceTargetAttr from them. + if (auto blockArg = dyn_cast(value)) { + if (auto funcOp = + dyn_cast(blockArg.getOwner()->getParentOp())) { + if (funcOp.isPublic()) { + LLVM_DEBUG(llvm::dbgs() + << "DeviceTargetValuePVS: argument to a public function - " + "treating as undefined\n"); + unionAssumedWithUndef(); + indicatePessimisticFixpoint(); + return; + } + } + } +} + +ChangeStatus DeviceTargetValuePVS::updateValue(Value value, + DFX::Solver &solver) { + StateType newState; + auto traversalResult = TraversalResult::COMPLETE; + + // Walk into all producers of the SSA value. + // Note that we may end up at multiple global loads of different globals + // by walking up through calls/branches/etc. + traversalResult |= + solver.getExplorer().walkDefiningOps(value, [&](OpResult result) { + updateFromDefiningOp(value, result, newState, solver); + return WalkResult::advance(); + }); + + if (traversalResult == TraversalResult::INCOMPLETE) { + // Incomplete traversal because of external call graph edges or pointers. + newState.unionAssumedWithUndef(); + newState.indicatePessimisticFixpoint(); + } + return DFX::clampStateAndIndicateChange(getState(), newState); +} + +void DeviceTargetValuePVS::updateFromDefiningOp(Value value, OpResult result, + StateType &newState, + DFX::Solver &solver) { + TypeSwitch(result.getOwner()) + .Case([&](mlir::arith::SelectOp op) { + auto &truePVS = solver.getElementFor( + *this, Position::forValue(op.getTrueValue()), + DFX::Resolution::REQUIRED); + auto &falsePVS = solver.getElementFor( + *this, Position::forValue(op.getFalseValue()), + DFX::Resolution::REQUIRED); + newState ^= truePVS.getState(); + newState ^= falsePVS.getState(); + }) + .Case([&](IREE::Util::OptimizationBarrierOp op) { + auto &sourcePVS = solver.getElementFor( + *this, Position::forValue(op.getOperand(0)), + DFX::Resolution::REQUIRED); + newState ^= sourcePVS.getState(); + }) + .Case([&](IREE::Util::GlobalLoadOpInterface op) { + auto *globalInfo = + solver.getExplorer().queryGlobalInfoFrom(op.getGlobalName(), op); + auto &globalPVS = solver.getElementFor( + *this, Position::forOperation(globalInfo->op), + DFX::Resolution::REQUIRED); + newState ^= globalPVS.getState(); + }) + .Default([&](Operation *op) {}); +} + +const std::string DeviceTargetValuePVS::getAsStr(AsmState &asmState) const { + std::string str; + llvm::raw_string_ostream sstream(str); + sstream << "pvs: "; + if (isValidState()) { + sstream << "["; + if (isUndefContained()) { + sstream << "undef, "; + } + llvm::interleaveComma(getAssumedSet(), sstream, + [&](IREE::HAL::DeviceTargetAttr value) { + cast(value).print(sstream); + }); + sstream << "]"; + } else { + sstream << "(invalid)"; + } + sstream.flush(); + return str; +} + +} // namespace mlir::iree_compiler::IREE::HAL diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.h new file mode 100644 index 000000000000..f1b220f2c4ff --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.h @@ -0,0 +1,97 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_HAL_ANALYSIS_ATTRIBUTES_DEVICETARGETPVS_H_ +#define IREE_COMPILER_DIALECT_HAL_ANALYSIS_ATTRIBUTES_DEVICETARGETPVS_H_ + +#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/Util/Analysis/DFX/Solver.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" + +namespace mlir::iree_compiler::IREE::HAL { + +//===----------------------------------------------------------------------===// +// DeviceTargetGlobalPVS (Potential Values State) +//===----------------------------------------------------------------------===// + +// Set of potential IREE::HAL::DeviceTargetAttr values for an initialized +// !hal.device global. When defined the device global may take on the traits of +// any of the potential values. +class DeviceTargetGlobalPVS + : public DFX::StateWrapper< + DFX::PotentialValuesState, + DFX::TypedOperationElement> { +public: + using BaseType = + DFX::StateWrapper, + DFX::TypedOperationElement>; + using BaseType::BaseType; + + static DeviceTargetGlobalPVS &createForPosition(const Position &pos, + DFX::Solver &solver) { + return *(new (solver.getAllocator()) DeviceTargetGlobalPVS(pos)); + } + + // Identity definitions. + const std::string getName() const override { return "DeviceTargetGlobalPVS"; } + const void *getID() const override { return &ID; } + static bool classof(const DFX::AbstractElement *element) { + return (element->getID() == &ID); + } + static const char ID; + + const std::string getAsStr(AsmState &asmState) const override; + +private: + void initializeOperation(IREE::Util::GlobalOp globalOp, + DFX::Solver &solver) override; + ChangeStatus updateOperation(IREE::Util::GlobalOp globalOp, + DFX::Solver &solver) override; +}; + +//===----------------------------------------------------------------------===// +// DeviceTargetValuePVS +//===----------------------------------------------------------------------===// + +// Set of potential values for a !hal.device SSA value. +// When defined the value may take on the traits of any of the potential values. +class DeviceTargetValuePVS + : public DFX::StateWrapper< + DFX::PotentialValuesState, + DFX::ValueElement> { +public: + using BaseType = + DFX::StateWrapper, + DFX::ValueElement>; + using BaseType::BaseType; + + static DeviceTargetValuePVS &createForPosition(const Position &pos, + DFX::Solver &solver) { + return *(new (solver.getAllocator()) DeviceTargetValuePVS(pos)); + } + + // Identity definitions. + const std::string getName() const override { return "DeviceTargetValuePVS"; } + const void *getID() const override { return &ID; } + static bool classof(const DFX::AbstractElement *element) { + return (element->getID() == &ID); + } + static const char ID; + + const std::string getAsStr(AsmState &asmState) const override; + +private: + void initializeValue(Value value, DFX::Solver &solver) override; + ChangeStatus updateValue(Value value, DFX::Solver &solver) override; + void updateFromDefiningOp(Value value, OpResult result, StateType &newState, + DFX::Solver &solver); +}; + +} // namespace mlir::iree_compiler::IREE::HAL + +#endif // IREE_COMPILER_DIALECT_HAL_ANALYSIS_ATTRIBUTES_DEVICETARGETPVS_H_ diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BUILD.bazel index f6e18bf20374..0e2aa4ad1f05 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BUILD.bazel @@ -16,18 +16,27 @@ iree_compiler_cc_library( name = "Analysis", srcs = [ "BindingLayout.cpp", + "DeviceAnalysis.cpp", + "DeviceSet.cpp", ], hdrs = [ "BindingLayout.h", + "DeviceAnalysis.h", + "DeviceSet.h", ], deps = [ + "//compiler/src/iree/compiler/Dialect/HAL/Analysis/Attributes", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/Stream/IR", + "//compiler/src/iree/compiler/Dialect/Util/Analysis", + "//compiler/src/iree/compiler/Dialect/Util/Analysis/DFX", "//compiler/src/iree/compiler/Dialect/Util/IR", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", ], ) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h index 1e8704d6b02e..050e18e6a801 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h @@ -4,8 +4,8 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#ifndef IREE_COMPILER_DIALECT_HAL_ANALYSIS_BINDINGLAYOUT_ -#define IREE_COMPILER_DIALECT_HAL_ANALYSIS_BINDINGLAYOUT_ +#ifndef IREE_COMPILER_DIALECT_HAL_ANALYSIS_BINDINGLAYOUT_H_ +#define IREE_COMPILER_DIALECT_HAL_ANALYSIS_BINDINGLAYOUT_H_ #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" @@ -94,4 +94,4 @@ class BindingLayoutAnalysis { } // namespace mlir::iree_compiler::IREE::HAL -#endif // IREE_COMPILER_DIALECT_HAL_ANALYSIS_BINDINGLAYOUT_ +#endif // IREE_COMPILER_DIALECT_HAL_ANALYSIS_BINDINGLAYOUT_H_ diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Analysis/CMakeLists.txt index e25ba3432c06..6e733ac24afe 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/CMakeLists.txt @@ -15,16 +15,25 @@ iree_cc_library( Analysis HDRS "BindingLayout.h" + "DeviceAnalysis.h" + "DeviceSet.h" SRCS "BindingLayout.cpp" + "DeviceAnalysis.cpp" + "DeviceSet.cpp" DEPS LLVMSupport MLIRAnalysis + MLIRFunctionInterfaces MLIRIR MLIRPass + MLIRSCFDialect MLIRSupport + iree::compiler::Dialect::HAL::Analysis::Attributes iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::Stream::IR + iree::compiler::Dialect::Util::Analysis + iree::compiler::Dialect::Util::Analysis::DFX iree::compiler::Dialect::Util::IR PUBLIC ) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.cpp new file mode 100644 index 000000000000..f144e31e3e60 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.cpp @@ -0,0 +1,234 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" + +#include "iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceGlobalPVS.h" +#include "iree/compiler/Dialect/HAL/Analysis/Attributes/DeviceTargetPVS.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/Util/Analysis/DFX/Element.h" +#include "iree/compiler/Dialect/Util/Analysis/DFX/State.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Interfaces/FunctionInterfaces.h" + +namespace mlir::iree_compiler::IREE::HAL { + +//===----------------------------------------------------------------------===// +// DeviceAnalysis +//===----------------------------------------------------------------------===// + +DeviceAnalysis::DeviceAnalysis(Operation *rootOp) + : explorer(rootOp, TraversalAction::SHALLOW), solver(explorer, allocator) { + explorer.setOpInterfaceAction( + TraversalAction::RECURSE); + explorer.setOpAction(TraversalAction::RECURSE); + explorer.setOpAction(TraversalAction::RECURSE); + explorer.setOpAction(TraversalAction::RECURSE); + // Ignore the contents of executables (linalg goo, etc). + explorer.setOpAction(TraversalAction::IGNORE); + explorer.initialize(); +} + +DeviceAnalysis::~DeviceAnalysis() = default; + +LogicalResult DeviceAnalysis::run() { + // TODO(multi-device): remove this fallback path when device globals are fully + // plumbed through. Today we still have inputs with the hal.device.targets + // attribute. + if (auto targetsAttr = explorer.getRootOp()->getAttrOfType( + "hal.device.targets")) { + if (!targetsAttr.empty()) { + defaultDeviceSet = DeviceSet(targetsAttr); + } + } + + // Initialize device globals (in declaration order). + for (auto globalOp : explorer.getRootOp() + ->getRegion(0) + .getOps()) { + auto globalType = globalOp.getGlobalType(); + if (isa(globalType)) { + solver.getOrCreateElementFor( + Position::forOperation(globalOp)); + deviceGlobals.push_back(globalOp); + } + } + + // Initialize all SSA values so we can do just with trivial search. + explorer.walkValuesOfType([&](Value value) { + solver.getOrCreateElementFor( + Position::forValue(value)); + solver.getOrCreateElementFor( + Position::forValue(value)); + return WalkResult::advance(); + }); + + return solver.run(); +} + +std::optional> +DeviceAnalysis::lookupDeviceGlobals(Value deviceValue) { + auto globalPVS = solver.lookupElementFor( + Position::forValue(deviceValue)); + if (!globalPVS || !globalPVS->isValidState() || + globalPVS->isUndefContained()) { + return std::nullopt; + } + SetVector globalOps; + for (auto globalOp : globalPVS->getAssumedSet()) { + globalOps.insert(globalOp); + } + return globalOps; +} + +std::optional +DeviceAnalysis::lookupDeviceTargets(Value deviceValue) { + auto valuePVS = solver.lookupElementFor( + Position::forValue(deviceValue)); + if (!valuePVS || !valuePVS->isValidState() || valuePVS->isUndefContained()) { + return defaultDeviceSet; + } + return DeviceSet(valuePVS->getAssumedSet()); +} + +// Returns a set of target devices that may be active for the given +// operation. This will recursively walk parent operations until one with +// the `hal.device.targets` attribute is found. +// +// This is a legacy mechanism for performing the search. Newer code should use +// affinities or !hal.device analysis instead. +static void gatherLegacyDeviceTargetAttrs( + Operation *op, SetVector &resultSet) { + auto attrId = StringAttr::get(op->getContext(), "hal.device.targets"); + while (op) { + auto targetsAttr = op->getAttrOfType(attrId); + if (targetsAttr) { + for (auto elementAttr : targetsAttr) { + if (auto targetAttr = + dyn_cast(elementAttr)) { + resultSet.insert(targetAttr); + } else { + // HACK: this legacy approach is deprecated and only preserved for + // existing behavior. It's ok to get angry here as users should not be + // trying to use this pass prior to device materialization. + assert(false && + "legacy hal.device.targets only support hal.device.targets"); + } + } + return; + } + op = op->getParentOp(); + } + // No devices found; let caller decide what to do. +} + +// Recursively resolves the referenced device into targets. +void DeviceAnalysis::gatherDeviceTargets( + Attribute rootAttr, Operation *fromOp, + SetVector &resultSet) { + SetVector worklist; + worklist.insert(rootAttr); + do { + auto attr = worklist.pop_back_val(); + if (!TypeSwitch(attr) + .Case([&](auto symRefAttr) { + auto globalOp = + explorer.getSymbolTables() + .lookupNearestSymbolFrom( + fromOp, symRefAttr); + assert(globalOp && "global reference must be valid"); + if (auto initialValueAttr = globalOp.getGlobalInitialValue()) { + // Global with a device initialization value we can analyze. + worklist.insert(initialValueAttr); + return true; + } else { + return false; + } + }) + .Case([&](auto targetAttr) { + resultSet.insert(targetAttr); + return true; + }) + .Case([&](auto fallbackAttr) { + worklist.insert(fallbackAttr.getName()); + return true; + }) + .Case([&](auto selectAttr) { + worklist.insert(selectAttr.getDevices().begin(), + selectAttr.getDevices().end()); + return true; + }) + .Default([](auto attr) { return false; })) { + // No initial value means fall back to defaults. We do that by + // inserting all knowable targets. + gatherLegacyDeviceTargetAttrs(fromOp, resultSet); + return; + } + } while (!worklist.empty()); +} + +void DeviceAnalysis::gatherAllDeviceTargets( + SetVector &resultSet) { + for (auto globalOp : deviceGlobals) { + gatherDeviceTargets(FlatSymbolRefAttr::get(globalOp), explorer.getRootOp(), + resultSet); + } +} + +void DeviceAnalysis::gatherDeviceAffinityTargets( + IREE::Stream::AffinityAttr affinityAttr, Operation *fromOp, + SetVector &resultSet) { + // We currently only know how to handle HAL device affinities. + // We could support other ones via an interface but instead we just fall back + // to default logic if no affinity or an unknown one is found. + auto deviceAffinityAttr = + dyn_cast_if_present(affinityAttr); + if (!deviceAffinityAttr) { + gatherLegacyDeviceTargetAttrs(fromOp, resultSet); + return; + } + + // Recursively resolve the referenced device into targets. + gatherDeviceTargets(deviceAffinityAttr.getDevice(), fromOp, resultSet); +} + +void DeviceAnalysis::gatherAllExecutableTargets( + SetVector &resultSet) { + SetVector deviceTargetSet; + gatherAllDeviceTargets(deviceTargetSet); + for (auto deviceTargetAttr : deviceTargetSet) { + deviceTargetAttr.getExecutableTargets(resultSet); + } +} + +void DeviceAnalysis::gatherRequiredExecutableTargets( + Operation *forOp, SetVector &resultSet) { + // Get the affinity from the op or an ancestor. Note that there may be no + // affinity specified at all. + auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(forOp); + + // Gather the device targets that are referenced by the affinity. + SetVector deviceTargetSet; + gatherDeviceAffinityTargets(affinityAttr, forOp, deviceTargetSet); + + // Add all executable targets on the device targets. + for (auto deviceTargetAttr : deviceTargetSet) { + resultSet.insert(deviceTargetAttr.getExecutableTargets().begin(), + deviceTargetAttr.getExecutableTargets().end()); + } +} + +void DeviceAnalysis::gatherRequiredExecutableTargets( + IREE::Stream::AffinityAttr affinityAttr, Operation *fromOp, + SetVector &resultSet) { + SetVector deviceTargetAttrs; + gatherDeviceAffinityTargets(affinityAttr, fromOp, deviceTargetAttrs); + for (auto deviceTargetAttr : deviceTargetAttrs) { + deviceTargetAttr.getExecutableTargets(resultSet); + } +} + +} // namespace mlir::iree_compiler::IREE::HAL diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h new file mode 100644 index 000000000000..e4f6245ea034 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h @@ -0,0 +1,104 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_HAL_ANALYSIS_DEVICETARGET_H_ +#define IREE_COMPILER_DIALECT_HAL_ANALYSIS_DEVICETARGET_H_ + +#include + +#include "iree/compiler/Dialect/HAL/Analysis/DeviceSet.h" +#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/Util/Analysis/DFX/Solver.h" +#include "iree/compiler/Dialect/Util/Analysis/Explorer.h" +#include "llvm/ADT/DenseSet.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" + +namespace mlir::iree_compiler::IREE::HAL { + +//===----------------------------------------------------------------------===// +// DeviceAnalysis +//===----------------------------------------------------------------------===// + +// Performs whole-program analysis of device traits (limits, configuration, etc) +// and allows for queries against `!hal.device` values for known traits. +// +// Though safe to run at any time this may not provide meaningful results until +// after devices have been materialized and the program has been converted into +// the HAL dialect. +class DeviceAnalysis { +public: + explicit DeviceAnalysis(Operation *rootOp); + ~DeviceAnalysis(); + + Explorer &getExplorer() { return explorer; } + + // Runs analysis and populates the device traits map. + // May fail if analysis cannot be completed due to unsupported or unknown IR. + LogicalResult run(); + + // Returns a set of all !hal.device globals in the analyzed root op in the + // order they are declared in the root op. + ArrayRef getDeviceGlobals() { + return deviceGlobals; + } + + // Returns a set of possible device globals of the given `!hal.device` value, + // if analyzed. + std::optional> + lookupDeviceGlobals(Value deviceValue); + + // Returns a set of possible targets of the given `!hal.device` value, if + // analyzed. + std::optional lookupDeviceTargets(Value deviceValue); + + // Gathers all possible device targets in the root op. + // Ordering is undefined. + void + gatherAllDeviceTargets(SetVector &resultSet); + + // Gathers the set of device targets potentially referenced by the given + // affinity. Targets are ordered by most likely to least likely. + void gatherDeviceAffinityTargets( + IREE::Stream::AffinityAttr affinityAttr, Operation *fromOp, + SetVector &resultSet); + + // Gathers all executable targets from all devices in the root op. + // This should generally be avoided and the scoped + // gatherRequiredExecutableTargets gather should be used instead. + void gatherAllExecutableTargets( + SetVector &resultSet); + + // Gathers all executable targets that may be required by the given host op. + // This should be called on the most narrowly scoped op possible as multiple + // devices may be used within the same function-like op and have different + // requirements. This may return a set with more targets than expected. + void gatherRequiredExecutableTargets( + Operation *forOp, SetVector &resultSet); + + // Gathers all executable targets that may be required for the given affinity. + // This should be called on the most narrowly scoped op possible as multiple + // devices may be used within the same function-like op and have different + // requirements. This may return a set with more targets than expected. + void gatherRequiredExecutableTargets( + IREE::Stream::AffinityAttr affinityAttr, Operation *fromOp, + SetVector &resultSet); + +private: + // Recursively resolves the referenced device into targets. + void gatherDeviceTargets(Attribute rootAttr, Operation *fromOp, + SetVector &resultSet); + + Explorer explorer; + llvm::BumpPtrAllocator allocator; + DFX::Solver solver; + std::optional defaultDeviceSet; + SmallVector deviceGlobals; +}; + +} // namespace mlir::iree_compiler::IREE::HAL + +#endif // IREE_COMPILER_DIALECT_HAL_ANALYSIS_DEVICETARGET_H_ diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.cpp new file mode 100644 index 000000000000..b60e5e7f1119 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.cpp @@ -0,0 +1,139 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/HAL/Analysis/DeviceSet.h" + +namespace mlir::iree_compiler::IREE::HAL { + +//===----------------------------------------------------------------------===// +// DeviceSet +//===----------------------------------------------------------------------===// + +DeviceSet::DeviceSet(ArrayAttr targetsAttr) { + for (auto targetAttr : + targetsAttr.getAsRange()) { + targetAttrs.insert(targetAttr); + } +} + +DeviceSet::DeviceSet(const DenseSet &targetAttrs) + : targetAttrs(targetAttrs) {} + +DeviceSet::~DeviceSet() = default; + +std::optional> +DeviceSet::getExecutableTargets() const { + if (targetAttrs.empty()) { + return std::nullopt; + } + SetVector resultAttrs; + for (auto targetAttr : targetAttrs) { + targetAttr.getExecutableTargets(resultAttrs); + } + return llvm::to_vector(resultAttrs); +} + +template +static std::optional joinConfigAttrs( + const DenseSet &targetAttrs, StringRef name, + std::function + join) { + if (targetAttrs.empty()) { + return std::nullopt; + } + std::optional result; + for (auto targetAttr : targetAttrs) { + auto configAttr = targetAttr.getConfiguration(); + if (!configAttr) { + return std::nullopt; + } + auto valueAttr = configAttr.getAs(name); + if (!valueAttr) { + return std::nullopt; + } else if (!result) { + result = valueAttr.getValue(); + } else { + result = join(result.value(), valueAttr.getValue()); + } + } + return result; +} + +template +static std::optional> +joinConfigStaticRanges(const DenseSet &targetAttrs, + StringRef name, + std::function( + StaticRange, + StaticRange)> + join) { + if (targetAttrs.empty()) { + return std::nullopt; + } + std::optional> result; + for (auto targetAttr : targetAttrs) { + auto configAttr = targetAttr.getConfiguration(); + if (!configAttr) { + return std::nullopt; + } + auto valueAttr = configAttr.getAs(name); + if (!valueAttr) { + return std::nullopt; + } else if (!result) { + result = valueAttr.getValue(); + } else { + result = + join(result.value(), + StaticRange{valueAttr.getValue()}); + } + } + return result; +} + +bool DeviceSet::hasConfigAttrAny(StringRef name) const { + for (auto targetAttr : targetAttrs) { + if (auto configAttr = targetAttr.getConfiguration()) { + if (configAttr.get(name)) { + return true; + } + } + } + return false; +} + +bool DeviceSet::hasConfigAttrAll(StringRef name) const { + for (auto targetAttr : targetAttrs) { + auto configAttr = targetAttr.getConfiguration(); + if (!configAttr || !configAttr.get(name)) { + return false; + } + } + return true; +} + +std::optional DeviceSet::getConfigAttrAnd(StringRef name) const { + return joinConfigAttrs( + targetAttrs, name, [](bool lhs, bool rhs) { return lhs && rhs; }); +} + +std::optional DeviceSet::getConfigAttrOr(StringRef name) const { + return joinConfigAttrs( + targetAttrs, name, [](bool lhs, bool rhs) { return lhs || rhs; }); +} + +std::optional> +DeviceSet::getConfigAttrRange(StringRef name) const { + return joinConfigStaticRanges( + targetAttrs, name, [](StaticRange lhs, StaticRange rhs) { + return StaticRange{ + llvm::APIntOps::smin(lhs.min, rhs.min), + llvm::APIntOps::smax(lhs.max, rhs.max), + }; + }); +} + +} // namespace mlir::iree_compiler::IREE::HAL diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.h new file mode 100644 index 000000000000..18f7d9fb8abf --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.h @@ -0,0 +1,57 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_HAL_ANALYSIS_DEVICESET_H_ +#define IREE_COMPILER_DIALECT_HAL_ANALYSIS_DEVICESET_H_ + +#include + +#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "llvm/ADT/DenseSet.h" +#include "mlir/Support/LLVM.h" + +namespace mlir::iree_compiler::IREE::HAL { + +// Provides configuration queries over a set of devices. +class DeviceSet { +public: + DeviceSet() = default; + explicit DeviceSet(ArrayAttr targetsAttr); + explicit DeviceSet(const DenseSet &targetAttrs); + ~DeviceSet(); + + // Returns zero or more executable targets that may be used by any device. + std::optional> + getExecutableTargets() const; + + // Returns true if there is any UnitAttr with |name| in any device. + bool hasConfigAttrAny(StringRef name) const; + + // Returns true if all device configurations have a UnitAttr with |name|. + bool hasConfigAttrAll(StringRef name) const; + + // Returns the AND of boolean attributes of |name| in all devices. + // Returns nullopt if any config does not have the key defined indicating + // that it's not statically known/runtime dynamic. + std::optional getConfigAttrAnd(StringRef name) const; + + // Returns the OR of boolean attributes of |name| in all devices. + // Returns nullopt if any config does not have the key defined indicating + // that it's not statically known/runtime dynamic. + std::optional getConfigAttrOr(StringRef name) const; + + // Returns the range of integer attributes of |name| in all devices. + // Returns nullopt if any config does not have the key defined indicating + // that it's not statically known/runtime dynamic. + std::optional> getConfigAttrRange(StringRef name) const; + +private: + DenseSet targetAttrs; +}; + +} // namespace mlir::iree_compiler::IREE::HAL + +#endif // IREE_COMPILER_DIALECT_HAL_ANALYSIS_DEVICESET_H_ diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp index c1d8c227d6c4..69e08f0e5377 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp @@ -174,14 +174,22 @@ bool DeviceTargetAttr::hasConfigurationAttr(StringRef name) { return configAttr && configAttr.get(name); } -// static -SmallVector -DeviceTargetAttr::lookup(Operation *op) { +void DeviceTargetAttr::getExecutableTargets( + SetVector &resultAttrs) { + for (auto attr : getExecutableTargets()) { + resultAttrs.insert(attr); + } +} + +// Returns a list of target devices that may be active for the given +// operation. This will recursively walk parent operations until one with +// the `hal.device.targets` attribute is found. +static SmallVector lookupDeviceTargetAttrs(Operation *op) { auto attrId = mlir::StringAttr::get(op->getContext(), "hal.device.targets"); while (op) { auto targetsAttr = op->getAttrOfType(attrId); if (targetsAttr) { - SmallVector result; + SmallVector result; for (auto targetAttr : targetsAttr) { result.push_back(llvm::cast(targetAttr)); } @@ -192,154 +200,11 @@ DeviceTargetAttr::lookup(Operation *op) { return {}; // No devices found; let caller decide what to do. } -// Returns a set of all configuration attributes from all device targets with -// a configuration set. Targets with no configuration set are ignored. -static SmallVector lookupOptionalConfigAttrs(Operation *op) { - auto targetAttrs = IREE::HAL::DeviceTargetAttr::lookup(op); - if (targetAttrs.empty()) - return {}; - SmallVector configAttrs; - for (auto targetAttr : targetAttrs) { - auto configAttr = targetAttr.getConfiguration(); - if (configAttr) - configAttrs.push_back(configAttr); - } - return configAttrs; -} - -void DeviceTargetAttr::getExecutableTargets( - SetVector &resultAttrs) { - for (auto attr : getExecutableTargets()) { - resultAttrs.insert(attr); - } -} - -// Returns a set of all configuration attributes from all device targets. -// Returns nullopt if any target is missing a configuration attribute. -static std::optional> -lookupRequiredConfigAttrs(Operation *op) { - auto targetAttrs = IREE::HAL::DeviceTargetAttr::lookup(op); - if (targetAttrs.empty()) - return std::nullopt; - SmallVector configAttrs; - for (auto targetAttr : targetAttrs) { - auto configAttr = targetAttr.getConfiguration(); - if (!configAttr) - return std::nullopt; - configAttrs.push_back(configAttr); - } - return configAttrs; -} - -template -static std::optional joinConfigAttrs( - ArrayRef configAttrs, StringRef name, - std::function - join) { - if (configAttrs.empty()) - return std::nullopt; - auto firstValue = configAttrs.front().getAs(name); - if (!firstValue) - return std::nullopt; - auto result = firstValue.getValue(); - for (auto configAttr : configAttrs.drop_front(1)) { - auto value = configAttr.getAs(name); - if (!value) - return std::nullopt; - result = join(result, value.getValue()); - } - return result; -} - -template -static std::optional> -joinConfigStaticRanges(ArrayRef configAttrs, StringRef name, - std::function( - StaticRange, - StaticRange)> - join) { - if (configAttrs.empty()) - return std::nullopt; - auto firstValue = configAttrs.front().getAs(name); - if (!firstValue) - return std::nullopt; - StaticRange result{firstValue.getValue()}; - for (auto configAttr : configAttrs.drop_front(1)) { - auto value = configAttr.getAs(name); - if (!value) - return std::nullopt; - result = - join(result, StaticRange{value.getValue()}); - } - return result; -} - -// static -bool DeviceTargetAttr::lookupConfigAttrAny(Operation *op, StringRef name) { - auto configAttrs = lookupOptionalConfigAttrs(op); - if (configAttrs.empty()) - return false; - for (auto configAttr : configAttrs) { - if (configAttr.get(name)) - return true; - } - return false; -} - -// static -bool DeviceTargetAttr::lookupConfigAttrAll(Operation *op, StringRef name) { - auto configAttrs = lookupRequiredConfigAttrs(op); - if (!configAttrs) - return false; - for (auto configAttr : *configAttrs) { - if (!configAttr.get(name)) - return false; - } - return true; -} - -// static -std::optional DeviceTargetAttr::lookupConfigAttrAnd(Operation *op, - StringRef name) { - auto configAttrs = lookupRequiredConfigAttrs(op); - if (!configAttrs) - return std::nullopt; - return joinConfigAttrs( - configAttrs.value(), name, [](bool lhs, bool rhs) { return lhs && rhs; }); -} - -// static -std::optional DeviceTargetAttr::lookupConfigAttrOr(Operation *op, - StringRef name) { - auto configAttrs = lookupRequiredConfigAttrs(op); - if (!configAttrs) - return std::nullopt; - return joinConfigAttrs( - configAttrs.value(), name, [](bool lhs, bool rhs) { return lhs || rhs; }); -} - -// static -std::optional> -DeviceTargetAttr::lookupConfigAttrRange(Operation *op, StringRef name) { - auto configAttrs = lookupRequiredConfigAttrs(op); - if (!configAttrs) - return std::nullopt; - return joinConfigStaticRanges( - configAttrs.value(), name, - [](StaticRange lhs, StaticRange rhs) { - return StaticRange{ - llvm::APIntOps::smin(lhs.min, rhs.min), - llvm::APIntOps::smax(lhs.max, rhs.max), - }; - }); -} - // static SmallVector DeviceTargetAttr::lookupExecutableTargets(Operation *op) { - SmallVector resultAttrs; - for (auto deviceTargetAttr : lookup(op)) { + SmallVector resultAttrs; + for (auto deviceTargetAttr : lookupDeviceTargetAttrs(op)) { for (auto executableTargetAttr : deviceTargetAttr.getExecutableTargets()) { if (!llvm::is_contained(resultAttrs, executableTargetAttr)) { resultAttrs.push_back(executableTargetAttr); diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index b77c9a513f58..b43580c01090 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td @@ -525,44 +525,11 @@ def HAL_DeviceTargetAttr : AttrDef lookup(Operation *op); - - // Returns true if there is any UnitAttr with |name| in any device - // configuration for the given |op|. - static bool lookupConfigAttrAny(Operation *op, StringRef name); - - // Returns true if all device configurations found for the given |op| have - // a UnitAttr with |name|. - static bool lookupConfigAttrAll(Operation *op, StringRef name); - - // Returns the AND of boolean attributes of |name| in all device - // configurations found for the given |op|. - // Returns nullopt if any config does not have the key defined indicating - // that it's not statically known/runtime dynamic. - static std::optional - lookupConfigAttrAnd(Operation *op, StringRef name); - - // Returns the OR of boolean attributes of |name| in all device - // configurations found for the given |op|. - // Returns nullopt if any config does not have the key defined indicating - // that it's not statically known/runtime dynamic. - static std::optional - lookupConfigAttrOr(Operation *op, StringRef name); - - // Returns the range of integer attributes of |name| in all device - // configurations found for the given |op|. - // Returns nullopt if any config does not have the key defined indicating - // that it's not statically known/runtime dynamic. - static std::optional> - lookupConfigAttrRange(Operation *op, StringRef name); - // Returns zero or more executable targets that this device supports. void getExecutableTargets( SetVector &resultAttrs); + // DEPRECATED: analysis is required in order to query this information. // Returns a list of all target executable configurations that may be // required for the given operation. static SmallVector diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/FixupLegacySync.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/FixupLegacySync.cpp index 45bf830078a7..7f6a18e93453 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/FixupLegacySync.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/FixupLegacySync.cpp @@ -4,6 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" @@ -149,12 +150,15 @@ struct FixupLegacySyncPass void runOnOperation() override { auto moduleOp = getOperation(); - // See if any devices are marked as requiring the legacy_sync behavior. - // If any single device does we must uniformly apply the fixups. - if (!IREE::HAL::DeviceTargetAttr::lookupConfigAttrAny(moduleOp, - "legacy_sync")) { - return; - } + // Analyze the module to determine which devices need the behavior. + DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) + return signalPassFailure(); + auto isLegacySync = [&](Value deviceValue) { + auto deviceSet = deviceAnalysis.lookupDeviceTargets(deviceValue); + return deviceSet.has_value() ? deviceSet->hasConfigAttrAny("legacy_sync") + : false; + }; // This could use an interface but it'd be better to remove the need for // this pass instead. @@ -162,19 +166,39 @@ struct FixupLegacySyncPass funcOp.walk([&](Operation *op) { TypeSwitch(op) .Case([&](IREE::HAL::CommandBufferCreateOp op) { - makeAllowInlineExecution(op); + if (isLegacySync(op.getDevice())) { + makeAllowInlineExecution(op); + } }) .Case([&](IREE::HAL::DeviceQueueAllocaOp op) { - insertWaitIfNeeded(op, op.getWaitFenceMutable(), - op.getSignalFence()); + if (isLegacySync(op.getDevice())) { + insertWaitIfNeeded(op, op.getWaitFenceMutable(), + op.getSignalFence()); + } }) .Case([&](IREE::HAL::DeviceQueueDeallocaOp op) { - insertWaitIfNeeded(op, op.getWaitFenceMutable(), - op.getSignalFence()); + if (isLegacySync(op.getDevice())) { + insertWaitIfNeeded(op, op.getWaitFenceMutable(), + op.getSignalFence()); + } + }) + .Case([&](IREE::HAL::DeviceQueueReadOp op) { + if (isLegacySync(op.getDevice())) { + insertWaitIfNeeded(op, op.getWaitFenceMutable(), + op.getSignalFence()); + } + }) + .Case([&](IREE::HAL::DeviceQueueWriteOp op) { + if (isLegacySync(op.getDevice())) { + insertWaitIfNeeded(op, op.getWaitFenceMutable(), + op.getSignalFence()); + } }) .Case([&](IREE::HAL::DeviceQueueExecuteOp op) { - insertWaitIfNeeded(op, op.getWaitFenceMutable(), - op.getSignalFence()); + if (isLegacySync(op.getDevice())) { + insertWaitIfNeeded(op, op.getWaitFenceMutable(), + op.getSignalFence()); + } }); }); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir index 29de091b2df4..d217b4784924 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir @@ -1,56 +1,95 @@ // RUN: iree-opt --split-input-file --iree-hal-fixup-legacy-sync %s | FileCheck %s +// TODO(multi-device): remove once device globals are used. This is a fallback +// path during the transition. +module attributes { + hal.device.targets = [ + #hal.device.target<"vulkan", {legacy_sync}> : !hal.device + ] +} { +// CHECK-LABEL: @default_device_targets +// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64) +util.func public @default_device_targets(%device: !hal.device, %affinity: i64) { + // CHECK: hal.command_buffer.create device(%[[DEVICE]] : !hal.device) mode("None") + %cmd = hal.command_buffer.create device(%device : !hal.device) mode("None") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer + util.return +} +} // module + +// ----- + +// Tests that unknown devices (here passed as an arg on a public function) +// don't trigger the pass, as we default to non-legacy behavior. + +// CHECK-LABEL: @unknown_device +// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64) +util.func public @unknown_device(%device: !hal.device, %affinity: i64) { + // CHECK: hal.command_buffer.create device(%[[DEVICE]] : !hal.device) mode("None") + %cmd = hal.command_buffer.create device(%device : !hal.device) mode("None") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer + util.return +} + +// ----- + // Tests that command buffers that are reusable don't execute inline. // Reusable + inline is not a valid combination. -module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} { +util.global private @device = #hal.device.target<"vulkan", {legacy_sync}> : !hal.device + // CHECK-LABEL: @command_buffer_reusable -util.func public @command_buffer_reusable(%device: !hal.device, %affinity: i64) { - // CHECK: hal.command_buffer.create device(%{{.+}} : !hal.device) mode("None") +util.func public @command_buffer_reusable(%affinity: i64) { + // CHECK: %[[DEVICE:.+]] = util.global.load @device + %device = util.global.load @device : !hal.device + // CHECK: hal.command_buffer.create device(%[[DEVICE]] : !hal.device) mode("None") %cmd = hal.command_buffer.create device(%device : !hal.device) mode("None") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer util.return } -} // module // ----- // Tests that one-shot command buffers are allowed to execute inline. -module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} { +util.global private @device = #hal.device.target<"vulkan", {legacy_sync}> : !hal.device + // CHECK-LABEL: @command_buffer_oneshot -util.func public @command_buffer_oneshot(%device: !hal.device, %affinity: i64) { - // CHECK: hal.command_buffer.create device(%{{.+}} : !hal.device) mode("OneShot|AllowInlineExecution") +util.func public @command_buffer_oneshot(%affinity: i64) { + // CHECK: %[[DEVICE:.+]] = util.global.load @device + %device = util.global.load @device : !hal.device + // CHECK: hal.command_buffer.create device(%[[DEVICE]] : !hal.device) mode("OneShot|AllowInlineExecution") %cmd = hal.command_buffer.create device(%device : !hal.device) mode(OneShot) categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer util.return } -} // module // ----- // Tests for a no-op if there are no devices requiring legacy mode. -module attributes {hal.device.targets = [ +util.global private @device = #hal.device.select<[ #hal.device.target<"local", {}>, #hal.device.target<"vulkan", {}> -]} { +]> : !hal.device + // CHECK-LABEL: @legacy_mode_not_required -util.func public @legacy_mode_not_required(%device: !hal.device, %affinity: i64) { - // CHECK: hal.command_buffer.create device(%{{.+}} : !hal.device) mode(OneShot) +util.func public @legacy_mode_not_required(%affinity: i64) { + // CHECK: %[[DEVICE:.+]] = util.global.load @device + %device = util.global.load @device : !hal.device + // CHECK: hal.command_buffer.create device(%[[DEVICE]] : !hal.device) mode(OneShot) %cmd = hal.command_buffer.create device(%device : !hal.device) mode(OneShot) categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer util.return } -} // module // ----- -// Tests that any device requiring legacy_sync will trigger the pass. +// Tests that any device requiring legacy_sync in a set will trigger the pass. -module attributes {hal.device.targets = [ +util.global private @device = #hal.device.select<[ #hal.device.target<"local", {}>, #hal.device.target<"vulkan", {legacy_sync}> -]} { +]> : !hal.device + // CHECK-LABEL: @mixed_legacy_mode_required -util.func public @mixed_legacy_mode_required(%device: !hal.device, %wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) { +util.func public @mixed_legacy_mode_required(%wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) { + %device = util.global.load @device : !hal.device %affinity = arith.constant 1 : i64 // CHECK: hal.fence.await // CHECK: hal.device.queue.execute @@ -61,17 +100,50 @@ util.func public @mixed_legacy_mode_required(%device: !hal.device, %wait: !hal.f commands([%cmd]) util.return } -} // module + +// ----- + +// Tests that only devices with legacy_sync trigger the pass. + +util.global private @device_async = #hal.device.target<"local", {}> : !hal.device +util.global private @device_sync = #hal.device.target<"vulkan", {legacy_sync}> : !hal.device + +// CHECK-LABEL: @mixed_legacy_mode_scoped +util.func public @mixed_legacy_mode_scoped(%wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) { + // CHECK-DAG: %[[DEVICE_ASYNC:.+]] = util.global.load @device_async + %device_async = util.global.load @device_async : !hal.device + // CHECK-DAG: %[[DEVICE_SYNC:.+]] = util.global.load @device_sync + %device_sync = util.global.load @device_sync : !hal.device + %affinity = arith.constant 1 : i64 + // CHECK-NOT: hal.fence.await + // CHECK: hal.device.queue.execute<%[[DEVICE_ASYNC]] + // CHECK-NOT: hal.fence.await + hal.device.queue.execute<%device_async : !hal.device> + affinity(%affinity) + wait(%wait) signal(%signal) + commands([%cmd]) + // CHECK: hal.fence.await + // CHECK: hal.device.queue.execute<%[[DEVICE_SYNC]] + // CHECK: hal.fence.await + hal.device.queue.execute<%device_sync : !hal.device> + affinity(%affinity) + wait(%wait) signal(%signal) + commands([%cmd]) + util.return +} // ----- // Tests that queued operations get the appropriate waits before/after. -module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} { +util.global private @device = #hal.device.target<"vulkan", {legacy_sync}> : !hal.device + // CHECK-LABEL: @blocking_execute -// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[WAIT:.+]]: !hal.fence, %[[CMD:.+]]: !hal.command_buffer, %[[SIGNAL:.+]]: !hal.fence) -util.func public @blocking_execute(%device: !hal.device, %wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) { +// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[CMD:.+]]: !hal.command_buffer, %[[SIGNAL:.+]]: !hal.fence) +util.func public @blocking_execute(%wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) { %affinity = arith.constant 1 : i64 + // CHECK: %[[DEVICE:.+]] = util.global.load @device + %device = util.global.load @device : !hal.device // CHECK-DAG: %[[NULL:.+]] = util.null : !hal.fence // CHECK-DAG: hal.fence.await until([%[[WAIT]]]) // CHECK-NEXT: hal.device.queue.execute<%[[DEVICE]] : !hal.device> @@ -84,16 +156,18 @@ util.func public @blocking_execute(%device: !hal.device, %wait: !hal.fence, %cmd commands([%cmd]) util.return } -} // module // ----- // Tests that waits are not inserted if they already exist. -module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} { +util.global private @device = #hal.device.target<"vulkan", {legacy_sync}> : !hal.device + // CHECK-LABEL: @blocking_execute -// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[WAIT:.+]]: !hal.fence, %[[CMD:.+]]: !hal.command_buffer, %[[SIGNAL:.+]]: !hal.fence) -util.func public @blocking_execute(%device: !hal.device, %wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) { +// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[CMD:.+]]: !hal.command_buffer, %[[SIGNAL:.+]]: !hal.fence) +util.func public @blocking_execute(%wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) { + // CHECK: %[[DEVICE:.+]] = util.global.load @device + %device = util.global.load @device : !hal.device // CHECK-NEXT: %[[TIMEOUT:.+]] = arith.constant 100 %timeout = arith.constant 100 : i32 // CHECK-NEXT: hal.fence.await until([%[[WAIT]]]) timeout_millis(%[[TIMEOUT]]) @@ -114,4 +188,3 @@ util.func public @blocking_execute(%device: !hal.device, %wait: !hal.fence, %cmd // CHECK-NEXT: util.return util.return } -} // module diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp index 7728ce8ff22d..1708782f21be 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp @@ -11,9 +11,7 @@ #include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" #include "iree/compiler/Dialect/Stream/IR/StreamOps.h" #include "iree/compiler/Dialect/Util/Analysis/DFX/Element.h" -#include "iree/compiler/Dialect/Util/Analysis/DFX/Solver.h" #include "iree/compiler/Dialect/Util/Analysis/DFX/State.h" -#include "iree/compiler/Dialect/Util/Analysis/Explorer.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "llvm/ADT/TypeSwitch.h" @@ -867,11 +865,8 @@ LogicalResult ResourceUsageAnalysis::run() { // }); // Initialize all SSA values we can do just with trivial search. - explorer.walkValues([&](Value value) { - if (llvm::isa(value.getType())) { - solver.getOrCreateElementFor( - Position::forValue(value)); - } + explorer.walkValuesOfType([&](Value value) { + solver.getOrCreateElementFor(Position::forValue(value)); return WalkResult::advance(); }); diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp index c25458cf9aad..e014588b055f 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp @@ -320,7 +320,8 @@ WalkResult Explorer::recursiveWalk(Operation *parentOp, return WalkResult::advance(); } -TraversalResult Explorer::walkValues(ValueWalkFn fn) { +TraversalResult Explorer::walkAllValues(ValueWalkFn fn, + std::optional typeID) { LLVM_DEBUG(llvm::dbgs() << "[[ Explorer::walkValues ]]\n"); TraversalResult result = TraversalResult::COMPLETE; @@ -357,7 +358,8 @@ TraversalResult Explorer::walkValues(ValueWalkFn fn) { LLVM_DEBUG(llvm::dbgs() << " + entering callable region @" << getRegionName(callableRegion) << "\n"); - auto emitResult = recursiveWalkValues(callableOp, visitedValues, fn); + auto emitResult = + recursiveWalkValues(callableOp, visitedValues, fn, typeID); if (emitResult.wasInterrupted()) break; if (emitResult.wasSkipped()) @@ -384,7 +386,8 @@ TraversalResult Explorer::walkValues(Operation *op, ValueWalkFn fn) { WalkResult Explorer::recursiveWalkValues(Operation *parentOp, DenseSet &visitedValues, - const ValueWalkFn &fn) { + const ValueWalkFn &fn, + std::optional typeID) { auto parentAction = getTraversalAction(parentOp); if (parentAction == TraversalAction::IGNORE) { LLVM_DEBUG(llvm::dbgs() @@ -396,6 +399,8 @@ WalkResult Explorer::recursiveWalkValues(Operation *parentOp, LLVM_DEBUG(llvm::dbgs() << " + processing op results " << getOpName(parentOp) << "\n"); for (auto result : parentOp->getResults()) { + if (typeID.has_value() && result.getType().getTypeID() != *typeID) + continue; if (visitedValues.insert(result).second) { LLVM_DEBUG({ llvm::dbgs() << " == emitting value "; @@ -425,6 +430,8 @@ WalkResult Explorer::recursiveWalkValues(Operation *parentOp, llvm::dbgs() << " arguments\n"; }); for (auto arg : block.getArguments()) { + if (typeID.has_value() && arg.getType().getTypeID() != *typeID) + continue; if (visitedValues.insert(arg).second) { LLVM_DEBUG({ llvm::dbgs() << " == emitting block arg "; @@ -437,7 +444,7 @@ WalkResult Explorer::recursiveWalkValues(Operation *parentOp, } } for (auto &op : block) { - auto opResult = recursiveWalkValues(&op, visitedValues, fn); + auto opResult = recursiveWalkValues(&op, visitedValues, fn, typeID); if (opResult.wasInterrupted()) return WalkResult::interrupt(); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h index 000aa2fa3be2..1e975be96937 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h @@ -229,7 +229,15 @@ class Explorer { TraversalResult walk(OperationWalkFn fn); // Walks all unique SSA values nested within the root op. - TraversalResult walkValues(ValueWalkFn fn); + TraversalResult walkValues(ValueWalkFn fn) { + return walkAllValues(fn, std::nullopt); + } + // Walks all unique SSA values nested within the root op that have the given + // type. + template + TraversalResult walkValuesOfType(ValueWalkFn fn) { + return walkAllValues(fn, OpT::getTypeID()); + } // Walks all unique SSA values used/defined by |op| and all nested regions. TraversalResult walkValues(Operation *op, ValueWalkFn fn); @@ -341,10 +349,13 @@ class Explorer { void initializeGlobalInfos(); void initializeInverseCallGraph(); + TraversalResult walkAllValues(ValueWalkFn fn, std::optional typeID); + WalkResult recursiveWalk(Operation *parentOp, const OperationWalkFn &fn); WalkResult recursiveWalkValues(Operation *parentOp, DenseSet &visitedValues, - const ValueWalkFn &fn); + const ValueWalkFn &fn, + std::optional typeID = std::nullopt); Operation *rootOp = nullptr; AsmState asmState; From 50c851276008bcb2f2bb39dd98596d5e6aa7e245 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 29 Feb 2024 13:21:37 -0800 Subject: [PATCH 04/25] Making MemoizeDeviceQueries support multiple devices. This will fail on cases where a query can't be tracked to a single device but it's possible in the future to hoist/propagate across CFG edges before running this pass such that it doesn't happen. Today we inline most things and don't deduplicate functions so it'll be rare that we end up being unable to memoize. Hopefully. --- .../HAL/Transforms/MemoizeDeviceQueries.cpp | 179 +++++++++++++----- .../test/memoize_device_queries.mlir | 103 ++++++---- 2 files changed, 198 insertions(+), 84 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp index 9857a3356f93..2068a8777b52 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp @@ -6,10 +6,12 @@ #include +#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Utils/StringUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -27,6 +29,34 @@ namespace { // --iree-hal-memoize-device-queries //===----------------------------------------------------------------------===// +// All queries for a particular !hal.device global. +struct DeviceQueries { + // Global !hal.device. + IREE::Util::GlobalOpInterface deviceOp; + // [category, key, default] used for lookup/indexing. + SmallVector queryKeys; + // Ops performing queries against the device by [category, key, default]. + DenseMap> queryOps; +}; + +// A query being replaced by global lookups. +struct Query { + Query(Location loc) : loc(loc) {} + Location loc; + IREE::Util::GlobalOp okGlobalOp; + IREE::Util::GlobalOp valueGlobalOp; + StringAttr categoryAttr; + StringAttr keyAttr; + TypedAttr defaultValueAttr; +}; + +static std::string getDeviceNamePrefix(IREE::Util::GlobalOpInterface deviceOp) { + StringRef deviceName = deviceOp.getGlobalName().getValue(); + if (deviceName.starts_with("__")) + return deviceName.str(); + return ("__" + deviceName).str(); +} + // NOTE: this implementation is just for a single active device. As we start to // support multiple devices we'll need to change this to be per-device. struct MemoizeDeviceQueriesPass @@ -35,28 +65,50 @@ struct MemoizeDeviceQueriesPass void runOnOperation() override { auto moduleOp = getOperation(); + // Analyze the module to determine which devices are used where. + DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) { + return signalPassFailure(); + } + + // Prepare device table indexed by symbol name. + DenseMap allDeviceQueries; + for (auto deviceOp : deviceAnalysis.getDeviceGlobals()) { + allDeviceQueries[deviceOp.getGlobalName()].deviceOp = deviceOp; + } + // Find all query ops we want to memoize and group them together. // This lets us easily replace all usages of a match with a single variable. - SmallVector deviceQueryKeys; - DenseMap> deviceQueryOps; for (auto callableOp : moduleOp.getOps()) { callableOp.walk([&](IREE::HAL::DeviceQueryOp queryOp) { + // Try to find the device this query is made on. If analysis failed then + // we can't memoize the query today. + auto deviceGlobals = + deviceAnalysis.lookupDeviceGlobals(queryOp.getDevice()); + if (!deviceGlobals || deviceGlobals->size() != 1) + return WalkResult::advance(); + IREE::Util::GlobalOpInterface deviceGlobalOp = deviceGlobals->front(); + + // Construct key used to dedupe/lookup the query. auto fullKey = ArrayAttr::get( moduleOp.getContext(), { - // TODO(multi-device): add attr key on device resolve source. StringAttr::get(moduleOp.getContext(), queryOp.getCategory() + queryOp.getKey()), queryOp.getDefaultValue().has_value() ? queryOp.getDefaultValueAttr() : Attribute{}, }); - auto lookup = deviceQueryOps.try_emplace( - fullKey, std::vector{}); + + // Track the query on the device. + auto &deviceQueries = allDeviceQueries[deviceGlobalOp.getGlobalName()]; + auto lookup = deviceQueries.queryOps.try_emplace( + fullKey, SmallVector{}); if (lookup.second) { - deviceQueryKeys.push_back(std::move(fullKey)); + deviceQueries.queryKeys.push_back(std::move(fullKey)); } lookup.first->second.push_back(queryOp); + return WalkResult::advance(); }); } @@ -64,54 +116,83 @@ struct MemoizeDeviceQueriesPass // Create each query variable and replace the uses with loads. SymbolTable symbolTable(moduleOp); auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); - for (auto queryKey : llvm::enumerate(deviceQueryKeys)) { - auto queryOps = deviceQueryOps[queryKey.value()]; - auto anyQueryOp = queryOps.front(); - auto queryType = anyQueryOp.getValue().getType(); - - // Merge all the locs as we are deduping the original query ops. - auto fusedLoc = moduleBuilder.getFusedLoc(llvm::map_to_vector( - queryOps, [&](Operation *op) { return op->getLoc(); })); - - // The initializer will perform the query once and store it in the - // variable. - std::string variableName = - "_device_query_" + std::to_string(queryKey.index()); - auto valueGlobalOp = moduleBuilder.create( - fusedLoc, variableName, - /*isMutable=*/false, queryType); - symbolTable.insert(valueGlobalOp); - valueGlobalOp.setPrivate(); - auto okGlobalOp = moduleBuilder.create( - fusedLoc, variableName + "_ok", - /*isMutable=*/false, moduleBuilder.getI1Type()); - symbolTable.insert(okGlobalOp); - okGlobalOp.setPrivate(); + for (auto deviceOp : deviceAnalysis.getDeviceGlobals()) { + auto &deviceQueries = allDeviceQueries[deviceOp.getGlobalName()]; + if (deviceQueries.queryKeys.empty()) { + // No queries against this device. + continue; + } + + // Create one global per unique query key against the device. + SmallVector queries; + moduleBuilder.setInsertionPointAfter(deviceOp); + for (auto [i, queryKey] : llvm::enumerate(deviceQueries.queryKeys)) { + auto &queryOps = deviceQueries.queryOps[queryKey]; + auto queryLoc = moduleBuilder.getFusedLoc(llvm::map_to_vector( + queryOps, [&](auto queryOp) { return queryOp.getLoc(); })); + // Create a global for the ok flag and the queried value. + // TODO(benvanik): create a better name based on the key. + auto anyQueryOp = queryOps.front(); + auto queryType = anyQueryOp.getValue().getType(); + std::string variableName = + getDeviceNamePrefix(deviceOp) + "_query_" + std::to_string(i) + + "_" + sanitizeSymbolName(anyQueryOp.getCategory()) + "_" + + sanitizeSymbolName(anyQueryOp.getKey()); + auto okGlobalOp = moduleBuilder.create( + queryLoc, variableName + "_ok", + /*isMutable=*/false, moduleBuilder.getI1Type()); + symbolTable.insert(okGlobalOp); + okGlobalOp.setPrivate(); + auto valueGlobalOp = moduleBuilder.create( + queryLoc, variableName, + /*isMutable=*/false, queryType); + symbolTable.insert(valueGlobalOp); + valueGlobalOp.setPrivate(); + + // Stash the globals for initialization. + Query query(queryLoc); + query.okGlobalOp = okGlobalOp; + query.valueGlobalOp = valueGlobalOp; + query.categoryAttr = anyQueryOp.getCategoryAttr(); + query.keyAttr = anyQueryOp.getKeyAttr(); + query.defaultValueAttr = anyQueryOp.getDefaultValueAttr(); + queries.push_back(query); + + // Replace all queries with loads of the global values. + for (auto queryOp : queryOps) { + OpBuilder replaceBuilder(queryOp); + auto okLoadOp = + okGlobalOp.createLoadOp(queryOp.getLoc(), replaceBuilder); + auto resultLoadOp = + valueGlobalOp.createLoadOp(queryOp.getLoc(), replaceBuilder); + queryOp.replaceAllUsesWith(ValueRange{ + okLoadOp.getLoadedGlobalValue(), + resultLoadOp.getLoadedGlobalValue(), + }); + queryOp.erase(); + } + } + + // Create an initializer for the device where we will perform all queries. + auto fusedLoc = moduleBuilder.getFusedLoc( + llvm::map_to_vector(queries, [&](auto &query) { return query.loc; })); auto initializerOp = moduleBuilder.create(fusedLoc); auto funcBuilder = OpBuilder::atBlockBegin(initializerOp.addEntryBlock()); - // TODO(multi-device): pass in resolve info to the call and reuse. - Value device = IREE::HAL::DeviceType::resolveAny(fusedLoc, funcBuilder); - auto queryOp = funcBuilder.create( - fusedLoc, funcBuilder.getI1Type(), queryType, device, - anyQueryOp.getCategoryAttr(), anyQueryOp.getKeyAttr(), - anyQueryOp.getDefaultValueAttr()); - okGlobalOp.createStoreOp(fusedLoc, queryOp.getOk(), funcBuilder); - valueGlobalOp.createStoreOp(fusedLoc, queryOp.getValue(), funcBuilder); - funcBuilder.create(fusedLoc); - - for (auto queryOp : queryOps) { - OpBuilder replaceBuilder(queryOp); - auto okLoadOp = okGlobalOp.createLoadOp(fusedLoc, replaceBuilder); - auto resultLoadOp = - valueGlobalOp.createLoadOp(fusedLoc, replaceBuilder); - queryOp.replaceAllUsesWith(ValueRange{ - okLoadOp.getLoadedGlobalValue(), - resultLoadOp.getLoadedGlobalValue(), - }); - queryOp.erase(); + Value device = + deviceOp.createLoadOp(fusedLoc, funcBuilder).getLoadedGlobalValue(); + for (auto [i, queryKey] : llvm::enumerate(deviceQueries.queryKeys)) { + auto &query = queries[i]; + auto queryOp = funcBuilder.create( + fusedLoc, funcBuilder.getI1Type(), + query.valueGlobalOp.getGlobalType(), device, query.categoryAttr, + query.keyAttr, query.defaultValueAttr); + query.okGlobalOp.createStoreOp(fusedLoc, queryOp.getOk(), funcBuilder); + query.valueGlobalOp.createStoreOp(fusedLoc, queryOp.getValue(), + funcBuilder); } + funcBuilder.create(fusedLoc); } } }; diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir index 5211bd99dd5b..16dfd5c0ff58 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir @@ -1,38 +1,71 @@ // RUN: iree-opt --split-input-file --iree-hal-memoize-device-queries --canonicalize %s | FileCheck %s -// CHECK: util.global private @_device_query_0 : i1 -// CHECK-NEXT: util.global private @_device_query_0_ok : i1 -// CHECK-NEXT: util.initializer { -// CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} -// CHECK-NEXT: %[[OK0:.+]], %[[VALUE0:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "id0*") : i1, i1 = false -// CHECK-NEXT: util.global.store %[[OK0]], @_device_query_0_ok : i1 -// CHECK-NEXT: util.global.store %[[VALUE0]], @_device_query_0 : i1 - -// CHECK: util.global private @_device_query_1 : i1 -// CHECK-NEXT: util.global private @_device_query_1_ok : i1 -// CHECK-NEXT: util.initializer { -// CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} -// CHECK-NEXT: %[[OK1:.+]], %[[VALUE1:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "id1") : i1, i1 = false -// CHECK-NEXT: util.global.store %[[OK1]], @_device_query_1_ok : i1 -// CHECK-NEXT: util.global.store %[[VALUE1]], @_device_query_1 : i1 - -// CHECK: util.global private @_device_query_2 - -// CHECK-LABEL: util.func public @device_matchers -util.func public @device_matchers(%device : !hal.device) -> (i1, i1, i1, i1, i1, i1) { - // Same queries (same variables): - // CHECK-NEXT: = util.global.load @_device_query_0_ok : i1 - // CHECK-NEXT: = util.global.load @_device_query_0 : i1 - %id0_a_ok, %id0_a = hal.device.query<%device : !hal.device> key("hal.device.id" :: "id0*") : i1, i1 = false - // CHECK-NEXT: = util.global.load @_device_query_0_ok : i1 - // CHECK-NEXT: = util.global.load @_device_query_0 : i1 - %id0_b_ok, %id0_b = hal.device.query<%device : !hal.device> key("hal.device.id" :: "id0*") : i1, i1 = false - - // Same query but with different defaults (different variables): - // CHECK-NEXT: = util.global.load @_device_query_1 : i1 - %id1_a_ok, %id1_a = hal.device.query<%device : !hal.device> key("hal.device.id" :: "id1") : i1, i1 = false - // CHECK-NEXT: = util.global.load @_device_query_2 : i1 - %id1_b_ok, %id1_b = hal.device.query<%device : !hal.device> key("hal.device.id" :: "id1") : i1, i1 = true - - util.return %id0_a_ok, %id0_a, %id0_b_ok, %id0_b, %id1_a, %id1_b : i1, i1, i1, i1, i1, i1 +// Tests that unknown devices (here passed as an arg on a public function) don't +// get memoized. + +// CHECK-LABEL: @unknown_device +// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device) +util.func public @unknown_device(%device: !hal.device) -> i1 { + // CHECK-NEXT: hal.device.query<%[[DEVICE]] + %id0_ok, %id0 = hal.device.query<%device : !hal.device> key("hal.device.id" :: "id0") : i1, i1 = false + util.return %id0 : i1 +} + +// ----- + +// Tests that multiple possible devices disable memoization. +// TODO(multi-device): enable propagation of queried values across the program. +// We should be able to track back to each global, memoize there, then pass +// through the value as a normal SSA value. + +util.global private @device_a : !hal.device +util.global private @device_b : !hal.device + +// CHECK-LABEL: @multi_device_not_memoized +util.func public @multi_device_not_memoized(%cond: i1) -> i1 { + // CHECK-DAG: %[[DEVICE_A:.+]] = util.global.load @device_a + %device_a = util.global.load @device_a : !hal.device + // CHECK-DAG: %[[DEVICE_B:.+]] = util.global.load @device_b + %device_b = util.global.load @device_b : !hal.device + // CHECK: %[[DEVICE_AB:.+]] = arith.select %{{.+}}, %[[DEVICE_A]], %[[DEVICE_B]] + %device_ab = arith.select %cond, %device_a, %device_b : !hal.device + // CHECK-NEXT: hal.device.query<%[[DEVICE_AB]] + %id0_ok, %id0 = hal.device.query<%device_ab : !hal.device> key("hal.device.id" :: "id0") : i1, i1 = false + util.return %id0 : i1 +} + +// ----- + +// Tests basic hoisting of device queries up to an initializer per device. + +// CHECK: util.global private @device +util.global private @device : !hal.device +// CHECK-NEXT: util.global private @__device_query_0_hal_device_id_id0_ok : i1 +// CHECK-NEXT: util.global private @__device_query_0_hal_device_id_id0 : i1 +// CHECK-NEXT: util.global private @__device_query_1_hal_device_id_id1_ok : i1 +// CHECK-NEXT: util.global private @__device_query_1_hal_device_id_id1 : i1 +// CHECK-NEXT: util.initializer +// CHECK: %[[DEVICE:.+]] = util.global.load @device : !hal.device +// CHECK: %[[OK0:.+]], %[[VALUE0:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "id0") : i1, i1 = false +// CHECK: util.global.store %[[OK0]], @__device_query_0_hal_device_id_id0_ok : i1 +// CHECK: util.global.store %[[VALUE0]], @__device_query_0_hal_device_id_id0 : i1 +// CHECK: %[[OK1:.+]], %[[VALUE1:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.device.id" :: "id1") : i1, i1 = false +// CHECK: util.global.store %[[OK1]], @__device_query_1_hal_device_id_id1_ok : i1 +// CHECK: util.global.store %[[VALUE1]], @__device_query_1_hal_device_id_id1 : i1 + +// CHECK: @single_device_memoized_0 +util.func public @single_device_memoized_0() -> (i1, i1) { + %device = util.global.load @device : !hal.device + // CHECK-NEXT: = util.global.load @__device_query_0_hal_device_id_id0_ok : i1 + // CHECK-NEXT: = util.global.load @__device_query_0_hal_device_id_id0 : i1 + %id0_ok, %id0 = hal.device.query<%device : !hal.device> key("hal.device.id" :: "id0") : i1, i1 = false + util.return %id0_ok, %id0 : i1, i1 +} +// CHECK: @single_device_memoized_1 +util.func public @single_device_memoized_1() -> (i1, i1) { + %device = util.global.load @device : !hal.device + // CHECK-NEXT: = util.global.load @__device_query_1_hal_device_id_id1_ok : i1 + // CHECK-NEXT: = util.global.load @__device_query_1_hal_device_id_id1 : i1 + %id1_ok, %id1 = hal.device.query<%device : !hal.device> key("hal.device.id" :: "id1") : i1, i1 = false + util.return %id1_ok, %id1 : i1, i1 } From 8f6166c80e6d349c02b72b46bea2dc274aed2ed2 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 29 Feb 2024 18:35:15 -0800 Subject: [PATCH 05/25] Making MaterializeResourceCaches support multiple devices. This will fail on cases where a query can't be tracked to a single device but it's possible in the future to hoist/propagate across CFG edges before running this pass such that it doesn't happen. Today we inline most things and don't deduplicate functions so it'll be rare that we end up being unable to memoize. Hopefully. --- .../Transforms/MaterializeResourceCaches.cpp | 913 ++++++++++++------ .../HAL/Transforms/MemoizeDeviceQueries.cpp | 6 +- .../test/materialize_resource_caches.mlir | 264 ++++- .../VM/Transforms/GlobalInitialization.cpp | 14 +- 4 files changed, 875 insertions(+), 322 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp index 9761580ac94a..de22093e4e29 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp @@ -7,11 +7,13 @@ #include #include +#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Attributes.h" @@ -20,6 +22,9 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/Pass/Pass.h" +#define DEBUG_TYPE "iree-hal-materialize-resource-caches" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") + namespace mlir::iree_compiler::IREE::HAL { #define GEN_PASS_DEF_MATERIALIZERESOURCECACHESPASS @@ -27,315 +32,683 @@ namespace mlir::iree_compiler::IREE::HAL { namespace { -// TODO(multi-device): rewrite this to shard resources per device. -struct MaterializeResourceCachesPass - : public IREE::HAL::impl::MaterializeResourceCachesPassBase< - MaterializeResourceCachesPass> { - void runOnOperation() override { - auto moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) - return; - moduleBuilder = OpBuilder(&moduleOp.getBody()->front()); - - // Find all relevant ops. If we don't find any we skip the pass as it's - // likely it's already been run. We could fix the pass to better support - // partial materialization but there's no use cases for that today. - auto executableOps = llvm::to_vector<8>(moduleOp.getOps()); - SmallVector pipelineLayoutLookupOps; - SmallVector executableLookupOps; - for (auto funcOp : moduleOp.getOps()) { - for (auto &block : funcOp.getFunctionBody()) { - block.walk([&](Operation *op) { - if (auto lookupOp = dyn_cast(op)) { - pipelineLayoutLookupOps.push_back(lookupOp); - } else if (auto lookupOp = dyn_cast(op)) { - executableLookupOps.push_back(lookupOp); - } - }); - } - } - if (pipelineLayoutLookupOps.empty() && executableLookupOps.empty()) { - return; - } +//===----------------------------------------------------------------------===// +// --iree-hal-materialize-resource-caches +//===----------------------------------------------------------------------===// - // Declare all layouts used by the executables. This will ensure that the - // initialization order is correct as any pipeline layout needed (and its - // dependencies) will be created prior to the executable cache below. The - // other nice thing is that we get ordering similar to the executable - // variables above. - for (auto executableOp : executableOps) { - for (auto variantOp : - executableOp.getOps()) { - for (auto exportOp : variantOp.getExportOps()) { - definePipelineLayoutOp(exportOp.getLoc(), exportOp.getLayout()); - } - } - } +struct DescriptorSetLayout { + // All locations that use the layout. + SetVector locs; + // Value within the initializer once materialized. + Value initializerValue; +}; +using DescriptorSetLayoutKey = + std::pair; + +struct PipelineLayout { + // All locations that use the layout. + SetVector locs; + // Lookup ops for this layout. + SmallVector lookupOps; + // Global once materialized. + IREE::Util::GlobalOpInterface globalOp; + // Value within the initializer once materialized. + Value initializerValue; +}; - // Declare executable variables so that we can reference them during lookup - // replacement. - for (auto executableOp : executableOps) { - defineExecutableOp(executableOp); - } +struct Executable { + // All locations that use the executable. + SetVector locs; + // Executable representing the program to load. + IREE::HAL::ExecutableOp executableOp; + // Lookup ops for this executable. + SmallVector lookupOps; + // Global once materialized. + IREE::Util::GlobalOpInterface globalOp; +}; - // Generate cached resource singletons and replace lookup ops with direct - // loads from variables. - for (auto lookupOp : pipelineLayoutLookupOps) { - replacePipelineLayoutLookupOp(lookupOp); - } - for (auto lookupOp : executableLookupOps) { - replaceExecutableLookupOp(lookupOp); - } - } +struct DeviceResources { + DeviceResources() = default; + explicit DeviceResources(IREE::Util::GlobalOpInterface deviceOp) + : deviceOp(deviceOp) {} + + // Global !hal.device. + IREE::Util::GlobalOpInterface deviceOp; + + // Fallback devices that should be checked for resources. + // These are derived from the transitive set of #hal.device.fallback attrs. + SetVector fallbackDeviceResources; + + // Descriptor set layouts used on the device, keyed by [bindingAttrs, flags]. + llvm::MapVector + descriptorSetLayouts; + // Pipeline layouts used on the device, keyed by layout attr. + llvm::MapVector + pipelineLayouts; + // Executables used on the device, keyed by name. + llvm::MapVector executables; +}; -private: - IREE::Util::GlobalOp - defineDescriptorSetLayoutOp(Location loc, ArrayAttr bindingAttrs, - IREE::HAL::DescriptorSetLayoutFlags flags) { - std::pair key = { - bindingAttrs, flags}; - auto existingIt = descriptorSetLayoutCache_.find(key); - if (existingIt != descriptorSetLayoutCache_.end()) { - return existingIt->second; - } +static std::string getDeviceNamePrefix(IREE::Util::GlobalOpInterface deviceOp) { + StringRef deviceName = deviceOp.getGlobalName().getValue(); + if (deviceName.starts_with("__")) { + // Already prefixed. + return deviceName.str(); + } + auto prefixedName = "__" + deviceName; + return prefixedName.str(); +} + +static void declareDevicePipelineLayout(IREE::Util::GlobalOpInterface deviceOp, + PipelineLayout &pipelineLayout, + size_t pipelineLayoutIndex, + OpBuilder &moduleBuilder) { + // Create global in the module. + auto symbolName = getDeviceNamePrefix(deviceOp) + "_pipeline_layout_" + + std::to_string(pipelineLayoutIndex); + LLVM_DEBUG(DBGS() << "+ creating device `" + << deviceOp.getGlobalName().getValue() + << "` pipeline global `" << symbolName << "`\n"); + auto layoutType = moduleBuilder.getType(); + auto globalOp = moduleBuilder.create( + moduleBuilder.getFusedLoc(llvm::to_vector(pipelineLayout.locs)), + symbolName, + /*isMutable=*/false, layoutType); + globalOp.setPrivate(); + pipelineLayout.globalOp = globalOp; + + // Replace lookups with the global. + for (auto lookupOp : pipelineLayout.lookupOps) { + LLVM_DEBUG({ + DBGS() << " - replacing lookup: "; + lookupOp.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + OpBuilder lookupBuilder(lookupOp); + auto loadedValue = + pipelineLayout.globalOp.createLoadOp(lookupOp.getLoc(), lookupBuilder) + .getLoadedGlobalValue(); + lookupOp.replaceAllUsesWith(loadedValue); + lookupOp.erase(); + } + pipelineLayout.lookupOps.clear(); +} + +static void declareDeviceExecutable(IREE::Util::GlobalOpInterface deviceOp, + Executable &executable, + size_t executableIndex, + OpBuilder &moduleBuilder) { + // Create global in the module. + auto symbolName = (getDeviceNamePrefix(deviceOp) + "_executable_" + + std::to_string(executableIndex) + "_" + + executable.executableOp.getName()) + .str(); + LLVM_DEBUG(DBGS() << "+ creating device `" + << deviceOp.getGlobalName().getValue() + << "` executable global `" << symbolName << "`\n"); + auto executableType = moduleBuilder.getType(); + auto globalOp = moduleBuilder.create( + moduleBuilder.getFusedLoc(llvm::to_vector(executable.locs)), symbolName, + /*isMutable=*/false, executableType); + globalOp.setPrivate(); + executable.globalOp = globalOp; + + // Replace lookups with the global. + for (auto lookupOp : executable.lookupOps) { + LLVM_DEBUG({ + DBGS() << " - replacing lookup: "; + lookupOp.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + OpBuilder lookupBuilder(lookupOp); + auto loadedValue = + executable.globalOp.createLoadOp(lookupOp.getLoc(), lookupBuilder) + .getLoadedGlobalValue(); + lookupOp.replaceAllUsesWith(loadedValue); + lookupOp.erase(); + } + executable.lookupOps.clear(); +} + +static DescriptorSetLayoutKey +getDescriptorSetLayoutKey(IREE::HAL::DescriptorSetLayoutAttr setLayoutAttr) { + auto bindingAttrs = + llvm::to_vector_of(setLayoutAttr.getBindings()); + return DescriptorSetLayoutKey{ + ArrayAttr::get(setLayoutAttr.getContext(), bindingAttrs), + setLayoutAttr.getFlags().value_or( + IREE::HAL::DescriptorSetLayoutFlags::None), + }; +} + +// Inlines a constant block as a function in |moduleBuilder| and then inserts +// a call to it in |callerBuilder|. +static SmallVector inlineConstantBlockOp( + StringRef funcName, IREE::HAL::ExecutableConstantBlockOp blockOp, + OpBuilder &moduleBuilder, OpBuilder &callerBuilder, Value callerDevice) { + LLVM_DEBUG(DBGS() << "- inlining constant block `" << funcName << "`\n"); + + // Create the function with the region contents of the constant block. + auto funcOp = moduleBuilder.create( + blockOp.getLoc(), funcName, blockOp.getFunctionType()); + funcOp.setPrivate(); + IRMapping mapping; + blockOp.getRegion().cloneInto(&funcOp.getRegion(), mapping); + + // Replace the hal.return with a func.return. + for (auto returnOp : + llvm::make_early_inc_range(funcOp.getOps())) { + OpBuilder(returnOp).create(returnOp.getLoc(), + returnOp.getOperands()); + returnOp.erase(); + } - auto symbolName = (StringRef("_descriptor_set_layout_") + - std::to_string(nextUniqueDescriptorSetLayoutId++)) - .str(); - - auto layoutType = DescriptorSetLayoutType::get(loc.getContext()); - auto globalOp = moduleBuilder.create( - loc, symbolName, - /*isMutable=*/false, layoutType); - globalOp.setPrivate(); - descriptorSetLayoutCache_.try_emplace(key, globalOp); - - auto initializerOp = moduleBuilder.create(loc); - OpBuilder blockBuilder = - OpBuilder::atBlockEnd(initializerOp.addEntryBlock()); - // TODO(multi-device): pass in resolve info to the call and reuse. - Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder); - Value layout = blockBuilder.createOrFold( - loc, layoutType, device, flags, bindingAttrs); - globalOp.createStoreOp(loc, layout, blockBuilder); - blockBuilder.create(loc); - - return globalOp; + // Create the call passing in the device if needed. + SmallVector callOperands; + if (funcOp.getNumArguments() > 0) { + callOperands.push_back(callerDevice); + } + auto callOp = callerBuilder.create(blockOp.getLoc(), + funcOp, callOperands); + return llvm::to_vector_of(callOp.getResults()); +} + +static Value initializeExecutable(DeviceResources &deviceResources, + Executable &executable, + OpBuilder &moduleBuilder, + Value initializerDevice, + OpBuilder &initializerBuilder) { + auto loc = executable.globalOp.getLoc(); + auto executableType = moduleBuilder.getType(); + + // Create a switch statement with a case for each variant. + // Each case should then cache only executables which contain a matching + // ExecutableVariantOp. + // Afterwards, canonicalization will take care of de-duping/etc. + SmallVector caseIndices; + SmallVector caseVariantOps; + for (auto variantOp : + executable.executableOp.getOps()) { + caseIndices.push_back(caseIndices.size()); + caseVariantOps.push_back(variantOp); } - IREE::Util::GlobalOp - definePipelineLayoutOp(Location loc, - IREE::HAL::PipelineLayoutAttr layoutAttr) { - auto existingIt = pipelineLayoutCache_.find(layoutAttr); - if (existingIt != pipelineLayoutCache_.end()) { - return existingIt->second; + // Select the variant index. + Value selectedIndex = buildIfElseTree( + loc, caseVariantOps.size(), + [&](Location loc, size_t i, OpBuilder &builder) { + return caseVariantOps[i].buildCondition(initializerDevice, builder); + }, + initializerBuilder); + + // Allow each variant to define how it is loaded and what pipeline it has. + auto switchOp = initializerBuilder.create( + loc, executableType, selectedIndex, caseIndices, caseIndices.size()); + for (auto [i, variantOp] : llvm::enumerate(caseVariantOps)) { + auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock(); + auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock); + + // Gather each of the pipeline layouts needed for each entry point in + // the executable. + SmallVector pipelineLayoutValues; + for (auto exportOp : variantOp.getExportOps()) { + auto &pipelineLayout = + deviceResources.pipelineLayouts[exportOp.getLayoutAttr()]; + pipelineLayoutValues.push_back(pipelineLayout.initializerValue); } - // First lookup (or create) all the required descriptor sets. This ensures - // they end up in the proper initialization order. - SmallVector setLayoutGlobalOps; - for (auto setLayoutAttr : layoutAttr.getSetLayouts()) { - SmallVector bindingAttrs; - for (auto bindingAttr : setLayoutAttr.getBindings()) { - bindingAttrs.push_back(bindingAttr); - } - setLayoutGlobalOps.push_back(defineDescriptorSetLayoutOp( - loc, ArrayAttr::get(loc.getContext(), bindingAttrs), - setLayoutAttr.getFlags().value_or( - IREE::HAL::DescriptorSetLayoutFlags::None))); + // Inline constant initializer from the variant. + // We want these to all happen inside of this device switch case; they'll + // get deduplicated/hoisted if possible in future canonicalization passes. + SmallVector constantValues; + for (auto [blockIndex, blockOp] : + llvm::enumerate(variantOp.getConstantBlockOps())) { + auto blockName = (executable.globalOp.getGlobalName().getValue() + + "_constant_block_" + std::to_string(blockIndex)) + .str(); + constantValues.append(inlineConstantBlockOp( + blockName, blockOp, moduleBuilder, caseBuilder, initializerDevice)); } - auto symbolName = (StringRef("_pipeline_layout_") + - std::to_string(nextUniquePipelineLayoutId++)) - .str(); + Value executableValue = + caseBuilder.createOrFold( + loc, executableType, initializerDevice, + SymbolRefAttr::get( + executable.executableOp.getSymNameAttr(), + {SymbolRefAttr::get(variantOp.getSymNameAttr())}), + pipelineLayoutValues, constantValues); + + caseBuilder.create(loc, executableValue); + } - auto layoutType = PipelineLayoutType::get(loc.getContext()); - auto globalOp = moduleBuilder.create( - loc, symbolName, /*isMutable=*/false, layoutType); - globalOp.setPrivate(); - pipelineLayoutCache_.try_emplace(layoutAttr, globalOp); + // Fallback for no available variant. + auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); + auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock); + Value status = defaultBuilder.create( + loc, static_cast(IREE::Util::StatusCode::Unavailable), 32); + { + std::string errorStr; + llvm::raw_string_ostream errorStream(errorStr); + errorStream << "HAL device `" + << deviceResources.deviceOp.getGlobalName().getValue() + << "` does not support any variant of executable `" + << executable.executableOp.getName() + << "`; available formats: ["; + llvm::interleaveComma(caseVariantOps, errorStream, [&](auto variantOp) { + errorStream << variantOp.getTargetAttr().getFormat().getValue(); + }); + errorStream << "]"; + defaultBuilder.create(loc, status, errorStr); + } + auto nullValue = + defaultBuilder.createOrFold(loc, executableType); + defaultBuilder.create(loc, nullValue); + + return switchOp.getResult(0); +} + +static void initializeDeviceResources(DeviceResources &deviceResources, + OpBuilder &moduleBuilder, + Value initializerDevice, + OpBuilder &initializerBuilder) { + // Initialize all descriptor set layouts for use by the pipeline layouts. + auto setLayoutType = initializerBuilder.getType(); + for (auto [i, it] : llvm::enumerate(deviceResources.descriptorSetLayouts)) { + auto [bindingAttrs, flags] = it.first; + auto &descriptorSetLayout = it.second; + descriptorSetLayout.initializerValue = + initializerBuilder.createOrFold( + initializerBuilder.getFusedLoc( + llvm::to_vector(descriptorSetLayout.locs)), + setLayoutType, initializerDevice, flags, bindingAttrs); + } - auto initializerOp = moduleBuilder.create(loc); - OpBuilder blockBuilder = - OpBuilder::atBlockEnd(initializerOp.addEntryBlock()); + // Initialize all pipeline layouts required for executable creation. + auto pipelineLayoutType = initializerBuilder.getType(); + for (auto [i, it] : llvm::enumerate(deviceResources.pipelineLayouts)) { + auto &[layoutAttr, pipelineLayout] = it; SmallVector setLayoutValues; - for (auto setLayoutGlobalOp : setLayoutGlobalOps) { + for (auto setLayoutAttr : layoutAttr.getSetLayouts()) { + auto key = getDescriptorSetLayoutKey(setLayoutAttr); setLayoutValues.push_back( - setLayoutGlobalOp.createLoadOp(loc, blockBuilder) - .getLoadedGlobalValue()); + deviceResources.descriptorSetLayouts[key].initializerValue); } - // TODO(multi-device): pass in resolve info to the call and reuse. - Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder); - Value layout = blockBuilder.createOrFold( - loc, layoutType, device, - blockBuilder.getIndexAttr(layoutAttr.getPushConstants()), - setLayoutValues); - globalOp.createStoreOp(loc, layout, blockBuilder); - blockBuilder.create(loc); - - return globalOp; + pipelineLayout.initializerValue = + initializerBuilder.createOrFold( + pipelineLayout.globalOp.getLoc(), pipelineLayoutType, + initializerDevice, + initializerBuilder.getIndexAttr(layoutAttr.getPushConstants()), + setLayoutValues); + pipelineLayout.globalOp.createStoreOp(pipelineLayout.globalOp.getLoc(), + pipelineLayout.initializerValue, + initializerBuilder); } - void defineExecutableOp(ExecutableOp executableOp) { - auto loc = executableOp.getLoc(); - auto symbolName = - (StringRef("_executable_") + executableOp.getSymName()).str(); - - auto executableType = ExecutableType::get(executableOp.getContext()); - auto globalOp = moduleBuilder.create( - loc, symbolName, /*isMutable=*/false, executableType); - globalOp.setPrivate(); - executableCache_.try_emplace(executableOp.getSymName(), globalOp); - - auto initializerOp = moduleBuilder.create(loc); - OpBuilder blockBuilder = - OpBuilder::atBlockEnd(initializerOp.addEntryBlock()); - // TODO(multi-device): pass in resolve info to the call and reuse. - Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder); - - // Create a switch statement with a case for each variant. - // Each case should then cache only executables which contain a matching - // ExecutableVariantOp. - // Afterwards, canonicalization will take care of de-duping/etc. - SmallVector caseIndices; - SmallVector caseVariantOps; - for (auto variantOp : - executableOp.getOps()) { - caseIndices.push_back(caseIndices.size()); - caseVariantOps.push_back(variantOp); - } + // Initialize all executables. + for (auto [i, it] : llvm::enumerate(deviceResources.executables)) { + auto &[executableName, executable] = it; + executable.globalOp.createStoreOp( + executable.globalOp.getLoc(), + initializeExecutable(deviceResources, executable, moduleBuilder, + initializerDevice, initializerBuilder), + initializerBuilder); + } +} + +static void reuseFallbackDeviceResources(DeviceResources &deviceResources, + DeviceResources &fallbackResources, + Value initializerDevice, + OpBuilder &initializerBuilder) { + // Load fallback pipeline layouts for all required by this device. + for (auto &[layoutAttr, pipelineLayout] : deviceResources.pipelineLayouts) { + auto fallbackGlobalOp = + fallbackResources.pipelineLayouts[layoutAttr].globalOp; + assert(fallbackGlobalOp && "should have created global"); + Value fallbackPipelineLayout = + fallbackGlobalOp + .createLoadOp(pipelineLayout.globalOp.getLoc(), initializerBuilder) + .getLoadedGlobalValue(); + pipelineLayout.globalOp.createStoreOp(pipelineLayout.globalOp.getLoc(), + fallbackPipelineLayout, + initializerBuilder); + } - // Select the variant index. + // Load fallback executables for all required by this device. + for (auto &[executableName, executable] : deviceResources.executables) { + auto fallbackGlobalOp = + fallbackResources.executables[executable.executableOp.getNameAttr()] + .globalOp; + assert(fallbackGlobalOp && "should have created global"); + Value fallbackExecutable = + fallbackGlobalOp + .createLoadOp(executable.globalOp.getLoc(), initializerBuilder) + .getLoadedGlobalValue(); + executable.globalOp.createStoreOp(executable.globalOp.getLoc(), + fallbackExecutable, initializerBuilder); + } +} + +static void buildDeviceResourceInitializer(DeviceResources &deviceResources, + OpBuilder &moduleBuilder) { + auto loc = deviceResources.deviceOp.getLoc(); + auto initializerOp = moduleBuilder.create(loc); + OpBuilder initializerBuilder = + OpBuilder::atBlockEnd(initializerOp.addEntryBlock()); + Value initializerDevice = + deviceResources.deviceOp.createLoadOp(loc, initializerBuilder) + .getLoadedGlobalValue(); + + // If there are any fallbacks then we need to handle referencing their + // resources and otherwise will initialize our own. + if (deviceResources.fallbackDeviceResources.empty()) { + initializeDeviceResources(deviceResources, moduleBuilder, initializerDevice, + initializerBuilder); + } else { + SmallVector caseIndices; Value selectedIndex = buildIfElseTree( - loc, caseVariantOps.size(), - [&](Location loc, size_t i, OpBuilder &builder) { - return caseVariantOps[i].buildCondition(device, builder); + loc, deviceResources.fallbackDeviceResources.size(), + [&](Location loc, size_t i, OpBuilder &caseBuilder) { + caseIndices.push_back(caseIndices.size()); + auto *fallbackResources = deviceResources.fallbackDeviceResources[i]; + Value fallbackDevice = + fallbackResources->deviceOp.createLoadOp(loc, caseBuilder) + .getLoadedGlobalValue(); + return caseBuilder.create(loc, initializerDevice, + fallbackDevice); }, - blockBuilder); - - // Allow each variant to define how it is loaded and what pipeline it has. - auto switchOp = blockBuilder.create( - loc, executableType, selectedIndex, caseIndices, caseIndices.size()); - for (auto [i, variantOp] : llvm::enumerate(caseVariantOps)) { - auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock(); + initializerBuilder); + auto switchOp = initializerBuilder.create( + loc, TypeRange{}, selectedIndex, caseIndices, caseIndices.size()); + for (auto [fallbackResources, caseRegion] : + llvm::zip_equal(deviceResources.fallbackDeviceResources, + switchOp.getCaseRegions())) { + auto &caseBlock = caseRegion.emplaceBlock(); auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock); + reuseFallbackDeviceResources(deviceResources, *fallbackResources, + initializerDevice, caseBuilder); + caseBuilder.create(loc); + } + auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); + auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock); + initializeDeviceResources(deviceResources, moduleBuilder, initializerDevice, + defaultBuilder); + defaultBuilder.create(loc); + } - // Gather each of the pipeline layouts needed for each entry point in - // the executable. - SmallVector pipelineLayoutValues; - for (auto exportOp : variantOp.getExportOps()) { - auto pipelineLayoutGlobalOp = - definePipelineLayoutOp(executableOp.getLoc(), exportOp.getLayout()); - pipelineLayoutValues.push_back( - pipelineLayoutGlobalOp.createLoadOp(loc, caseBuilder) - .getLoadedGlobalValue()); + initializerBuilder.create(loc); +} + +// Returns zero or more devices globals that may act as fallbacks for the +// given device, if analyzed. The result is in selection order. +static std::optional> +getDeviceFallbackGlobals(IREE::Util::GlobalOpInterface deviceGlobal, + SymbolTable &symbolTable) { + SetVector resultSet; + auto processAttr = [&](Attribute attr) { + if (!attr) + return true; // ignore uninitialized devices + return TypeSwitch(attr) + .Case([](auto attr) { return true; }) + .Case([](auto attr) { return true; }) + .Case([&](auto fallbackAttr) { + resultSet.insert(symbolTable.lookup( + fallbackAttr.getName().getValue())); + return true; + }) + .Default([](auto attr) { return false; }); + }; + auto initialValue = deviceGlobal.getGlobalInitialValue(); + if (auto selectAttr = + dyn_cast_if_present(initialValue)) { + for (auto deviceAttr : selectAttr.getDevices()) { + if (!processAttr(deviceAttr)) { + // Fails if unsupported/unhandled device attribute type. + return std::nullopt; } + } + } else { + if (!processAttr(initialValue)) { + // Fails if unsupported/unhandled device attribute type. + return std::nullopt; + } + } + return resultSet; +} + +static LogicalResult gatherDeviceResources( + ModuleOp &moduleOp, SymbolTable &symbolTable, + DeviceAnalysis &deviceAnalysis, + llvm::MapVector &allDeviceResources) { + // Allocate storage for the resource sets. + for (auto deviceOp : deviceAnalysis.getDeviceGlobals()) { + LLVM_DEBUG(DBGS() << "Gathering device `" + << deviceOp.getGlobalName().getValue() + << "` resources...\n"); + allDeviceResources.try_emplace(deviceOp.getGlobalName(), + DeviceResources(deviceOp)); + } - // Inline constant initializer from the variant. - // We want these to all happen inside of this device switch case; they'll - // get deduplicated/hoisted if possible in future canonicalization passes. - SmallVector constantValues; - for (auto blockOp : - llvm::make_early_inc_range(variantOp.getConstantBlockOps())) { - constantValues.append( - inlineConstantBlockOp(blockOp, moduleBuilder, caseBuilder, device)); - blockOp.erase(); - } + // Link fallbacks between the resources. + for (auto deviceOp : deviceAnalysis.getDeviceGlobals()) { + auto fallbackOps = getDeviceFallbackGlobals(deviceOp, symbolTable); + if (!fallbackOps) { + return deviceOp->emitOpError() + << "analysis failed on device; currently analysis must succeed"; + } + auto &deviceResources = allDeviceResources[deviceOp.getGlobalName()]; + for (auto fallbackOp : *fallbackOps) { + LLVM_DEBUG(DBGS() << "* linking to fallback `" + << fallbackOp.getGlobalName().getValue() << "`\n"); + deviceResources.fallbackDeviceResources.insert( + &allDeviceResources[fallbackOp.getGlobalName()]); + } + } - Value executable = caseBuilder.createOrFold( - loc, executableType, device, - SymbolRefAttr::get(executableOp.getSymNameAttr(), - {SymbolRefAttr::get(variantOp.getSymNameAttr())}), - pipelineLayoutValues, constantValues); + // Find all relevant ops. If we don't find any we skip the pass as it's + // likely it's already been run. We could fix the pass to better support + // partial materialization but there's no use cases for that today. + auto tryGetDeviceResources = [&](Operation *op, + Value device) -> DeviceResources * { + auto deviceGlobals = deviceAnalysis.lookupDeviceGlobals(device); + if (!deviceGlobals || deviceGlobals->size() != 1) { + op->emitOpError() << "analysis failed on device; currently analysis " + "must succeed with a single device"; + return nullptr; + } + auto deviceOp = deviceGlobals->front(); + return &allDeviceResources.find(deviceOp.getGlobalName())->second; + }; + for (auto funcOp : moduleOp.getOps()) { + for (auto &block : funcOp.getFunctionBody()) { + if (block + .walk([&](Operation *op) -> WalkResult { + if (auto lookupOp = dyn_cast(op)) { + auto *deviceResources = + tryGetDeviceResources(lookupOp, lookupOp.getDevice()); + if (!deviceResources) { + return WalkResult::interrupt(); + } + auto layoutAttr = lookupOp.getLayoutAttr(); + LLVM_DEBUG(DBGS() + << "+ requiring pipeline layout from lookup: `" + << layoutAttr << "`\n"); + auto &pipelineLayout = + deviceResources->pipelineLayouts[layoutAttr]; + pipelineLayout.locs.insert(lookupOp.getLoc()); + pipelineLayout.lookupOps.push_back(lookupOp); + for (auto setLayoutAttr : layoutAttr.getSetLayouts()) { + LLVM_DEBUG( + DBGS() + << "+ requiring descriptor set layout from lookup: `" + << setLayoutAttr << "`\n"); + auto key = getDescriptorSetLayoutKey(setLayoutAttr); + auto &setLayout = + deviceResources->descriptorSetLayouts[key]; + setLayout.locs.insert(lookupOp.getLoc()); + } + } else if (auto lookupOp = dyn_cast(op)) { + auto *deviceResources = + tryGetDeviceResources(lookupOp, lookupOp.getDevice()); + if (!deviceResources) { + return WalkResult::interrupt(); + } + auto executableAttr = lookupOp.getExecutableAttr().getAttr(); + LLVM_DEBUG(DBGS() << "+ requiring executable from lookup: `" + << executableAttr.getValue() << "`\n"); + auto &executable = + deviceResources->executables[executableAttr]; + executable.locs.insert(lookupOp.getLoc()); + executable.lookupOps.push_back(lookupOp); + } + return WalkResult::advance(); + }) + .wasInterrupted()) { + return failure(); + } + } + } - caseBuilder.create(loc, executable); + // Gather the executables referenced by all lookup ops. + for (auto &[deviceName, deviceResources] : allDeviceResources) { + for (auto &[executableName, executable] : deviceResources.executables) { + executable.executableOp = + symbolTable.lookup(executableName); + for (auto variantOp : + executable.executableOp.getOps()) { + for (auto exportOp : variantOp.getExportOps()) { + auto layoutAttr = exportOp.getLayoutAttr(); + LLVM_DEBUG(DBGS() << "+ requiring pipeline layout from export: `" + << layoutAttr << "`\n"); + auto &pipelineLayout = deviceResources.pipelineLayouts[layoutAttr]; + pipelineLayout.locs.insert(exportOp.getLoc()); + for (auto setLayoutAttr : layoutAttr.getSetLayouts()) { + LLVM_DEBUG(DBGS() + << "+ requiring descriptor set layout from export: `" + << setLayoutAttr << "`\n"); + auto key = getDescriptorSetLayoutKey(setLayoutAttr); + auto &setLayout = deviceResources.descriptorSetLayouts[key]; + setLayout.locs.insert(exportOp.getLoc()); + } + } + } } + } - // Fallback for no available variant. - auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); - auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock); - Value status = defaultBuilder.create( - loc, static_cast(IREE::Util::StatusCode::Unavailable), 32); - defaultBuilder.create( - loc, status, - "none of the executable binaries in the module are supported by the " - "runtime"); - auto nullValue = - defaultBuilder.createOrFold(loc, executableType); - defaultBuilder.create(loc, nullValue); - - auto executableValue = switchOp.getResult(0); - globalOp.createStoreOp(loc, executableValue, blockBuilder); - blockBuilder.create(loc); + // Merge all resources that may be used by way of fallbacks into each fallback + // device. We could make this optional to improve startup performance by + // adding these as optional and create them on demand but that's more complex. + // For now we just always ensure the resources are available even if they end + // up unused. + for (auto &[deviceName, deviceResources] : + llvm::reverse(allDeviceResources)) { + for (auto *fallbackResources : deviceResources.fallbackDeviceResources) { + LLVM_DEBUG( + DBGS() << "-> requiring fallback resources from device `" + << fallbackResources->deviceOp.getGlobalName().getValue() + << "`\n"); + for (auto [setKey, setLayout] : deviceResources.descriptorSetLayouts) { + auto &fallbackSetLayout = + fallbackResources->descriptorSetLayouts[setKey]; + fallbackSetLayout.locs.insert(setLayout.locs.begin(), + setLayout.locs.end()); + } + for (auto [layoutAttr, pipelineLayout] : + deviceResources.pipelineLayouts) { + auto &fallbackPipelineLayout = + fallbackResources->pipelineLayouts[layoutAttr]; + fallbackPipelineLayout.locs.insert(pipelineLayout.locs.begin(), + pipelineLayout.locs.end()); + } + for (auto [executableName, executable] : deviceResources.executables) { + auto &fallbackExecutable = + fallbackResources->executables[executableName]; + fallbackExecutable.locs.insert(executable.locs.begin(), + executable.locs.end()); + fallbackExecutable.executableOp = executable.executableOp; + } + } } - // Inlines a constant block as a function in |moduleBuilder| and then inserts - // a call to it in |callerBuilder|. - SmallVector inlineConstantBlockOp(ExecutableConstantBlockOp blockOp, - OpBuilder &moduleBuilder, - OpBuilder &callerBuilder, - Value device) { - // Create the function with the region contents of the constant block. - auto funcName = (StringRef("__constant_block_") + - std::to_string(nextUniqueConstantBlockId++)) - .str(); - auto funcOp = moduleBuilder.create( - blockOp.getLoc(), funcName, blockOp.getFunctionType()); - funcOp.setPrivate(); - funcOp.getRegion().takeBody(blockOp.getRegion()); - - // Replace the hal.return with a func.return. - for (auto returnOp : - llvm::make_early_inc_range(funcOp.getOps())) { - OpBuilder(returnOp).create(returnOp.getLoc(), - returnOp.getOperands()); - returnOp.erase(); + return success(); +} + +struct MaterializeResourceCachesPass + : public IREE::HAL::impl::MaterializeResourceCachesPassBase< + MaterializeResourceCachesPass> { + void runOnOperation() override { + auto moduleOp = getOperation(); + SymbolTable symbolTable(moduleOp); + + // Analyze the module to determine which devices are used where. + LLVM_DEBUG(DBGS() << "Running device analysis...\n"); + DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) { + return signalPassFailure(); } - // Create the call passing in the device if needed. - SmallVector callOperands; - if (funcOp.getNumArguments() > 0) { - callOperands.push_back(device); + // Build a table of all resources used by all devices in the program. + LLVM_DEBUG(DBGS() << "Gathering device resources...\n"); + llvm::MapVector allDeviceResources; + if (failed(gatherDeviceResources(moduleOp, symbolTable, deviceAnalysis, + allDeviceResources))) { + return signalPassFailure(); } - auto callOp = callerBuilder.create( - blockOp.getLoc(), funcOp, callOperands); - return llvm::map_to_vector(callOp.getResults(), - [](OpResult result) -> Value { return result; }); - } + // Materialize resources for each device (if any) and replace lookups. + for (auto &[nameAttr, deviceResources] : allDeviceResources) { + LLVM_DEBUG(DBGS() << "Materializing device `" + << deviceResources.deviceOp.getGlobalName().getValue() + << "` resources...\n"); + // Skip devices with no resources. + if (deviceResources.pipelineLayouts.empty() && + deviceResources.executables.empty()) { + LLVM_DEBUG(DBGS() << "~ skipping device with no resources\n"); + continue; + } - void replacePipelineLayoutLookupOp(PipelineLayoutLookupOp &lookupOp) { - OpBuilder builder(lookupOp); - auto globalOp = - definePipelineLayoutOp(lookupOp.getLoc(), lookupOp.getLayout()); - auto loadedValue = globalOp.createLoadOp(lookupOp.getLoc(), builder) - .getLoadedGlobalValue(); - lookupOp.replaceAllUsesWith(loadedValue); - lookupOp.erase(); - } + // TODO(benvanik): proper insertion order if devices are initialized via + // an initializer. Today this assumes the device hasn't been materialized + // yet if there are any lookups to them. + if (!deviceResources.deviceOp.getGlobalInitialValue()) { + deviceResources.deviceOp.emitOpError() + << "is expected to be initialized with an attribute and not yet " + "via a util.initializer"; + return signalPassFailure(); + } - void replaceExecutableLookupOp(ExecutableLookupOp &lookupOp) { - OpBuilder builder(lookupOp); - auto executableIt = executableCache_.find(lookupOp.getExecutable()); - assert(executableIt != executableCache_.end() && - "executable must have been cached"); - auto globalOp = executableIt->second; - auto loadedValue = globalOp.createLoadOp(lookupOp.getLoc(), builder) - .getLoadedGlobalValue(); - lookupOp.replaceAllUsesWith(loadedValue); - lookupOp.erase(); - } + // Declare globals for each pipeline layout and executable and replace all + // lookup ops to reference them. + OpBuilder moduleBuilder(moduleOp); + moduleBuilder.setInsertionPointAfter(deviceResources.deviceOp); + for (auto [i, it] : llvm::enumerate(deviceResources.pipelineLayouts)) { + auto &[layoutAttr, pipelineLayout] = it; + declareDevicePipelineLayout(deviceResources.deviceOp, pipelineLayout, i, + moduleBuilder); + } + for (auto [i, it] : llvm::enumerate(deviceResources.executables)) { + auto &[executableName, executable] = it; + declareDeviceExecutable(deviceResources.deviceOp, executable, i, + moduleBuilder); + } - OpBuilder moduleBuilder{static_cast(nullptr)}; - DenseMap, - IREE::Util::GlobalOp> - descriptorSetLayoutCache_; - DenseMap pipelineLayoutCache_; - DenseMap executableCache_; + // Create an initializer after the declared globals. + buildDeviceResourceInitializer(deviceResources, moduleBuilder); + } - int nextUniqueConstantBlockId = 0; - int nextUniquePipelineLayoutId = 0; - int nextUniqueDescriptorSetLayoutId = 0; + // Remove ops that are no longer required after materialization. + for (auto executableOp : moduleOp.getOps()) { + for (auto variantOp : + executableOp.getOps()) { + if (auto conditionOp = variantOp.getConditionOp()) { + conditionOp.erase(); + } + for (auto blockOp : + llvm::make_early_inc_range(variantOp.getConstantBlockOps())) { + blockOp.erase(); + } + } + } + } }; } // namespace diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp index 2068a8777b52..096b7bf643c4 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp @@ -52,9 +52,11 @@ struct Query { static std::string getDeviceNamePrefix(IREE::Util::GlobalOpInterface deviceOp) { StringRef deviceName = deviceOp.getGlobalName().getValue(); - if (deviceName.starts_with("__")) + if (deviceName.starts_with("__")) { return deviceName.str(); - return ("__" + deviceName).str(); + } + auto prefixedName = "__" + deviceName; + return prefixedName.str(); } // NOTE: this implementation is just for a single active device. As we start to diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir index 3cb5f7606716..4e562f63f72c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir @@ -1,27 +1,34 @@ // RUN: iree-opt --split-input-file --iree-hal-materialize-resource-caches %s | FileCheck %s -// CHECK: util.global private @_descriptor_set_layout_0 : !hal.descriptor_set_layout - -// CHECK: util.global private @_pipeline_layout_0 : !hal.pipeline_layout +// CHECK: util.global private @device = #hal.device.ordinal<0> +util.global private @device = #hal.device.ordinal<0> : !hal.device +// CHECK: util.global private @__device_pipeline_layout_0 : !hal.pipeline_layout // CHECK-NEXT: util.initializer { -// CHECK-DAG: %[[SET0:.+]] = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout -// CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} -// CHECK-NEXT: %[[LAYOUT:.+]] = hal.pipeline_layout.create -// CHECK-SAME: device(%[[DEVICE]] : !hal.device) -// CHECK-SAME: push_constants(1) -// CHECK-SAME: layouts([%[[SET0]]]) : !hal.pipeline_layout -// CHECK-NEXT: util.global.store %[[LAYOUT]], @_pipeline_layout_0 : !hal.pipeline_layout +// CHECK-DAG: %[[DEVICE:.+]] = util.global.load @device +// CHECK-DAG: %[[SET_LAYOUT_0:.+]] = hal.descriptor_set_layout.create +// CHECK-SAME: device(%[[DEVICE]] : !hal.device) +// CHECK-SAME: flags("None") +// CHECK-SAME: bindings([ +// CHECK-SAME: #hal.descriptor_set.binding<0, storage_buffer>, +// CHECK-SAME: #hal.descriptor_set.binding<1, storage_buffer> +// CHECK-SAME: ]) : !hal.descriptor_set_layout +// CHECK-NEXT: %[[PIPELINE_LAYOUT:.+]] = hal.pipeline_layout.create +// CHECK-SAME: device(%[[DEVICE]] : !hal.device) +// CHECK-SAME: push_constants(1) +// CHECK-SAME: layouts([%[[SET_LAYOUT_0]]]) : !hal.pipeline_layout +// CHECK-NEXT: util.global.store %[[PIPELINE_LAYOUT]], @__device_pipeline_layout_0 : !hal.pipeline_layout // CHECK-LABEL: @exeLayoutLookup -util.func public @exeLayoutLookup(%device : !hal.device) -> !hal.pipeline_layout { - // CHECK: %[[LAYOUT:.+]] = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout +util.func public @exeLayoutLookup() -> !hal.pipeline_layout { + %device = util.global.load @device : !hal.device + // CHECK: %[[LOADED_LAYOUT:.+]] = util.global.load @__device_pipeline_layout_0 : !hal.pipeline_layout %0 = hal.pipeline_layout.lookup device(%device : !hal.device) layout(#hal.pipeline.layout, #hal.descriptor_set.binding<1, storage_buffer> ]> ]>) : !hal.pipeline_layout - // CHECK-NEXT: util.return %[[LAYOUT]] + // CHECK-NEXT: util.return %[[LOADED_LAYOUT]] util.return %0 : !hal.pipeline_layout } @@ -41,28 +48,25 @@ util.func public @exeLayoutLookup(%device : !hal.device) -> !hal.pipeline_layout ]> ]> -// TODO(scotttodd): Test without depending on a specific HAL target? Or move to HAL/Target/*/test/? -// - If there is no matching hal.executable.variant then the executable will not be cached -hal.executable @exe { +// CHECK: hal.executable private @exe +hal.executable private @exe { + // CHECK: hal.executable.variant public @vmvx hal.executable.variant @vmvx target(<"vmvx", "vmvx-bytecode-fb">) { + // CHECK-NOT: hal.executable.condition hal.executable.condition(%device: !hal.device) -> i1 { %ok, %selected = hal.device.query<%device : !hal.device> key("some" :: "feature") : i1, i1 hal.return %selected : i1 } - hal.executable.export @entry0 ordinal(0) layout(#pipeline_layout_0) attributes { - workgroup_size = [32 : index, 1 : index, 1 : index] - } - hal.executable.export @entry0_alias ordinal(0) layout(#pipeline_layout_0) attributes { - workgroup_size = [32 : index, 1 : index, 1 : index] - } - hal.executable.export @entry1 ordinal(1) layout(#pipeline_layout_1) attributes { - workgroup_size = [32 : index, 1 : index, 1 : index] - } + hal.executable.export @entry0 ordinal(0) layout(#pipeline_layout_0) + hal.executable.export @entry0_alias ordinal(0) layout(#pipeline_layout_0) + hal.executable.export @entry1 ordinal(1) layout(#pipeline_layout_1) + // CHECK-NOT: hal.executable.constant.block hal.executable.constant.block() -> (i32, i32) as ("foo", "bar") { %c123 = arith.constant 123 : i32 %c456 = arith.constant 456 : i32 hal.return %c123, %c456 : i32, i32 } + // CHECK-NOT: hal.executable.constant.block hal.executable.constant.block(%device: !hal.device) -> i32 as "baz" { %ok, %query = hal.device.query<%device : !hal.device> key("sys" :: "baz") : i1, i32 cf.cond_br %ok, ^bb_ok, ^bb_fail @@ -75,16 +79,27 @@ hal.executable @exe { } } -// CHECK-DAG: util.global private @_descriptor_set_layout_0 -// CHECK-DAG: util.global private @_pipeline_layout_0 -// CHECK-DAG: util.global private @_descriptor_set_layout_1 -// CHECK-DAG: util.global private @_pipeline_layout_1 +// CHECK: util.global private @device = #hal.device.ordinal<0> +util.global private @device = #hal.device.ordinal<0> : !hal.device -// CHECK: util.global private @_executable_exe : !hal.executable -// CHECK-NEXT: util.initializer { +// Cached resources for the device. +// CHECK: util.global private @__device_pipeline_layout_0 : !hal.pipeline_layout +// CHECK: util.global private @__device_pipeline_layout_1 : !hal.pipeline_layout +// CHECK: util.global private @__device_executable_0_exe : !hal.executable + +// Device initializer for all resources used with the device: +// CHECK: util.initializer +// CHECK: %[[DEVICE:.+]] = util.global.load @device + +// Create pipeline layouts (and required descriptor set layouts): +// CHECK: %[[SET_LAYOUT_0:.+]] = hal.descriptor_set_layout.create device(%[[DEVICE]] : !hal.device) +// CHECK: %[[SET_LAYOUT_1:.+]] = hal.descriptor_set_layout.create device(%[[DEVICE]] : !hal.device) +// CHECK: %[[PIPELINE_LAYOUT_0:.+]] = hal.pipeline_layout.create device(%[[DEVICE]] : !hal.device) push_constants(0) layouts([%[[SET_LAYOUT_0]]]) : !hal.pipeline_layout +// CHECK: util.global.store %[[PIPELINE_LAYOUT_0]], @__device_pipeline_layout_0 +// CHECK: %[[PIPELINE_LAYOUT_1:.+]] = hal.pipeline_layout.create device(%device : !hal.device) push_constants(0) layouts([%[[SET_LAYOUT_1]]]) : !hal.pipeline_layout +// CHECK: util.global.store %[[PIPELINE_LAYOUT_1]], @__device_pipeline_layout_1 // Switch on the supported formats: -// CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} // CHECK: %{{.+}}, %[[FORMAT_VMVX:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "vmvx-bytecode-fb") // CHECK: %[[VMVX_CONDITION:.+]] = scf.execute_region -> i1 { // CHECK: %{{.+}}, %[[FEATURE:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("some" :: "feature") @@ -97,20 +112,15 @@ hal.executable @exe { // CHECK: %[[RET:.+]] = scf.index_switch %[[VARIANT_INDEX]] -> !hal.executable // CHECK: case 0 { -// Dependent layouts: -// CHECK: %[[LAYOUT0:.+]] = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout -// CHECK: %[[LAYOUT0_2:.+]] = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout -// CHECK: %[[LAYOUT1:.+]] = util.global.load @_pipeline_layout_1 : !hal.pipeline_layout - // Constant block initializers: -// CHECK: %[[CONST_01:.+]]:2 = util.call @__constant_block_0() -// CHECK: %[[CONST_2:.+]] = util.call @__constant_block_1(%[[DEVICE]]) +// CHECK: %[[CONST_01:.+]]:2 = util.call @__device_executable_0_exe_constant_block_0() +// CHECK: %[[CONST_2:.+]] = util.call @__device_executable_0_exe_constant_block_1(%[[DEVICE]]) // Executable creation: // CHECK: %[[EXE:.+]] = hal.executable.create // CHECK-SAME: device(%[[DEVICE]] : !hal.device) // CHECK-SAME: target(@exe::@vmvx) -// CHECK-SAME: layouts([%[[LAYOUT0]], %[[LAYOUT0_2]], %[[LAYOUT1]]]) +// CHECK-SAME: layouts([%[[PIPELINE_LAYOUT_0]], %[[PIPELINE_LAYOUT_0]], %[[PIPELINE_LAYOUT_1]]]) // CHECK-SAME: constants([%[[CONST_01]]#0, %[[CONST_01]]#1, %[[CONST_2]]]) // CHECK-SAME: : !hal.executable @@ -118,18 +128,18 @@ hal.executable @exe { // CHECK: } // CHECK: default { // CHECK: %[[C14:.+]] = arith.constant 14 : i32 -// CHECK: util.status.check_ok %[[C14]], "none of the executable binaries in the module are supported by the runtime" +// CHECK: util.status.check_ok %[[C14]], "HAL device `device` does not support any variant of executable `exe`; available formats: [vmvx-bytecode-fb]" // CHECK: %[[NULL:.+]] = util.null : !hal.executable // CHECK: scf.yield %[[NULL]] : !hal.executable // CHECK: } -// CHECK: util.global.store %[[RET]], @_executable_exe : !hal.executable +// CHECK: util.global.store %[[RET]], @__device_executable_0_exe : !hal.executable -// Inlined constant block functions (here we ensure all blocks are cloned): -// CHECK: util.func private @__constant_block_0() -> (i32, i32) +// Constant block functions (here we ensure all blocks are cloned): +// CHECK: util.func private @__device_executable_0_exe_constant_block_0() -> (i32, i32) // CHECK-DAG: %[[C0:.+]] = arith.constant 123 // CHECK-DAG: %[[C1:.+]] = arith.constant 456 // CHECK: util.return %[[C0]], %[[C1]] -// CHECK: util.func private @__constant_block_1(%[[BLOCK_DEVICE:.+]]: !hal.device) -> i32 +// CHECK: util.func private @__device_executable_0_exe_constant_block_1(%[[BLOCK_DEVICE:.+]]: !hal.device) -> i32 // CHECK: %[[OK:.+]], %[[VALUE:.+]] = hal.device.query<%[[BLOCK_DEVICE]] : !hal.device> key("sys" :: "baz") // CHECK: cf.cond_br %[[OK]], ^bb1, ^bb2 // CHECK: ^bb1: @@ -139,16 +149,172 @@ hal.executable @exe { // CHECK: util.return %[[DUMMY]] // CHECK-LABEL: @exeLookup -util.func public @exeLookup(%device : !hal.device) -> !hal.executable { - // CHECK: %[[EXE:.+]] = util.global.load @_executable_exe : !hal.executable +util.func public @exeLookup() -> !hal.executable { + %device = util.global.load @device : !hal.device + // CHECK: %[[EXE:.+]] = util.global.load @__device_executable_0_exe : !hal.executable %0 = hal.executable.lookup device(%device : !hal.device) - executable(@exe) : !hal.executable + executable(@exe) : !hal.executable // CHECK-NEXT: util.return %[[EXE]] util.return %0 : !hal.executable } // ----- +// Tests that fallback resources are reused instead of being created again +// when a device selects a fallback. + +// CHECK: hal.executable private @exe +hal.executable private @exe { + // CHECK: hal.executable.variant public @vmvx + hal.executable.variant @vmvx target(<"vmvx", "vmvx-bytecode-fb">) { + // CHECK-NOT: hal.executable.condition + hal.executable.condition(%device: !hal.device) -> i1 { + %ok, %selected = hal.device.query<%device : !hal.device> key("some" :: "feature") : i1, i1 + hal.return %selected : i1 + } + hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer> + ]> + ]>) + // CHECK-NOT: hal.executable.constant.block + hal.executable.constant.block() -> (i32, i32) as ("foo", "bar") { + %c123 = arith.constant 123 : i32 + %c456 = arith.constant 456 : i32 + hal.return %c123, %c456 : i32, i32 + } + } +} + +// CHECK: util.global private @primary_device +util.global private @primary_device = #hal.device.ordinal<0> : !hal.device +// CHECK-NEXT: util.global private @__primary_device_pipeline_layout_0 +// CHECK-NEXT: util.global private @__primary_device_executable_0_exe +// CHECK-NEXT: util.initializer +// CHECK: util.global.load @primary_device +// CHECK: hal.descriptor_set_layout.create +// CHECK: hal.pipeline_layout.create +// CHECK: util.global.store {{.+}}, @__primary_device_pipeline_layout_0 +// CHECK: hal.executable.create +// CHECK: util.global.store {{.+}}, @__primary_device_executable_0_exe +// CHECK: util.func private @__primary_device_executable_0_exe_constant_block_0 + +// CHECK: util.global private @optional_device +util.global private @optional_device = #hal.device.select<[ + #hal.device.ordinal<1> : !hal.device, + #hal.device.fallback<@primary_device> : !hal.device +]> : !hal.device +// CHECK-NEXT: util.global private @__optional_device_pipeline_layout_0 +// CHECK-NEXT: util.global private @__optional_device_executable_0_exe +// CHECK-NEXT: util.initializer +// CHECK-DAG: %[[OPTIONAL_DEVICE:.+]] = util.global.load @optional_device +// CHECK-DAG: %[[PRIMARY_DEVICE:.+]] = util.global.load @primary_device +// CHECK-DAG: %[[DEVICE_EQ:.+]] = util.cmp.eq %[[OPTIONAL_DEVICE]], %[[PRIMARY_DEVICE]] +// CHECK-DAG: %[[INDEX:.+]] = arith.select %[[DEVICE_EQ]] +// CHECK-DAG: scf.index_switch %[[INDEX]] +// CHECK: case 0 +// CHECK: %[[PRIMARY_LAYOUT:.+]] = util.global.load @__primary_device_pipeline_layout_0 +// CHECK: util.global.store %[[PRIMARY_LAYOUT]], @__optional_device_pipeline_layout_0 +// CHECK: %[[PRIMARY_EXE:.+]] = util.global.load @__primary_device_executable_0_exe +// CHECK: util.global.store %[[PRIMARY_EXE]], @__optional_device_executable_0_exe +// CHECK: default +// CHECK: hal.descriptor_set_layout.create +// CHECK: hal.pipeline_layout.create +// CHECK: util.global.store {{.+}}, @__optional_device_pipeline_layout_0 +// CHECK: hal.executable.create +// CHECK: util.global.store {{.+}}, @__optional_device_executable_0_exe +// CHECK: util.func private @__optional_device_executable_0_exe_constant_block_0 + +// CHECK-LABEL: @fallbackLookup +util.func public @fallbackLookup() -> (!hal.executable, !hal.executable) { + %primary_device = util.global.load @primary_device : !hal.device + // CHECK: %[[PRIMARY_EXE_LOOKUP:.+]] = util.global.load @__primary_device_executable_0_exe + %0 = hal.executable.lookup device(%primary_device : !hal.device) + executable(@exe) : !hal.executable + %optional_device = util.global.load @optional_device : !hal.device + // CHECK: %[[OPTIONAL_EXE_LOOKUP:.+]] = util.global.load @__optional_device_executable_0_exe + %1 = hal.executable.lookup device(%optional_device : !hal.device) + executable(@exe) : !hal.executable + util.return %0, %1 : !hal.executable, !hal.executable +} + +// ----- + +// Tests that resources only used by optional devices force the resources to +// be created on fallbacks. This isn't optimal as we should really only be +// creating them if the fallback is selected but that's more complex than it's +// worth today given the limited usage of fallbacks. + +hal.executable private @exe { + hal.executable.variant @vmvx target(<"vmvx", "vmvx-bytecode-fb">) { + hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout + ]> + ]>) + } +} + +// CHECK-LABEL: util.global private @primary_device +util.global private @primary_device = #hal.device.ordinal<0> : !hal.device +// CHECK-NEXT: util.global private @__primary_device_pipeline_layout_0 +// CHECK-NEXT: util.global private @__primary_device_executable_0_exe +// CHECK-NEXT: util.initializer +// CHECK: util.global.load @primary_device +// CHECK: hal.descriptor_set_layout.create +// CHECK: hal.pipeline_layout.create +// CHECK: util.global.store {{.+}}, @__primary_device_pipeline_layout_0 +// CHECK: hal.executable.create +// CHECK: util.global.store {{.+}}, @__primary_device_executable_0_exe + +// CHECK-LABEL: util.global private @optional_device_0 +util.global private @optional_device_0 = #hal.device.select<[ + #hal.device.ordinal<1> : !hal.device, + #hal.device.fallback<@primary_device> : !hal.device +]> : !hal.device +// CHECK-NEXT: util.global private @__optional_device_0_pipeline_layout_0 +// CHECK-NEXT: util.global private @__optional_device_0_executable_0_exe +// CHECK-NEXT: util.initializer +// CHECK-DAG: %[[OPTIONAL_DEVICE_0:.+]] = util.global.load @optional_device_0 +// CHECK-DAG: %[[PRIMARY_DEVICE:.+]] = util.global.load @primary_device +// CHECK-DAG: %[[DEVICE_EQ:.+]] = util.cmp.eq %[[OPTIONAL_DEVICE_0]], %[[PRIMARY_DEVICE]] +// CHECK-DAG: %[[INDEX:.+]] = arith.select %[[DEVICE_EQ]] +// CHECK-DAG: scf.index_switch %[[INDEX]] +// CHECK: util.global.load @__primary_device_pipeline_layout_0 +// CHECK: util.global.store {{.+}}, @__optional_device_0_pipeline_layout_0 +// CHECK: util.global.load @__primary_device_executable_0_exe +// CHECK: util.global.store {{.+}}, @__optional_device_0_executable_0_exe + +// CHECK-LABEL: util.global private @optional_device_1 +util.global private @optional_device_1 = #hal.device.select<[ + #hal.device.ordinal<2> : !hal.device, + #hal.device.fallback<@optional_device_0> : !hal.device +]> : !hal.device +// CHECK-NEXT: util.global private @__optional_device_1_pipeline_layout_0 +// CHECK-NEXT: util.global private @__optional_device_1_executable_0_exe +// CHECK-NEXT: util.initializer +// CHECK-DAG: %[[OPTIONAL_DEVICE_1:.+]] = util.global.load @optional_device_1 +// CHECK-DAG: %[[OPTIONAL_DEVICE_0:.+]] = util.global.load @optional_device_0 +// CHECK-DAG: %[[DEVICE_EQ:.+]] = util.cmp.eq %[[OPTIONAL_DEVICE_1]], %[[OPTIONAL_DEVICE_0]] +// CHECK-DAG: %[[INDEX:.+]] = arith.select %[[DEVICE_EQ]] +// CHECK-DAG: scf.index_switch %[[INDEX]] +// CHECK: util.global.load @__optional_device_0_pipeline_layout_0 +// CHECK: util.global.store {{.+}}, @__optional_device_1_pipeline_layout_0 +// CHECK: util.global.load @__optional_device_0_executable_0_exe +// CHECK: util.global.store {{.+}}, @__optional_device_1_executable_0_exe + +// CHECK-LABEL: @fallbackOnlyLookup +util.func public @fallbackOnlyLookup() -> !hal.executable { + %optional_device_1 = util.global.load @optional_device_1 : !hal.device + // CHECK: util.global.load @__optional_device_1_executable_0_exe + %0 = hal.executable.lookup device(%optional_device_1 : !hal.device) + executable(@exe) : !hal.executable + util.return %0 : !hal.executable +} + +// ----- + // Tests that materialization no-ops when resource caches have already been // materialized. Today this is rather simplistic and just bails if the names // match with the expectation being that users are mostly just running through @@ -163,6 +329,8 @@ util.func public @exeLookup(%device : !hal.device) -> !hal.executable { ]> ]> +util.global private @device : !hal.device + util.global private @_descriptor_set_layout_0 : !hal.descriptor_set_layout util.initializer { %c0 = arith.constant 0 : index diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp index 65ad8c9c2032..871087cbe74f 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp @@ -97,8 +97,18 @@ static void fixupGlobalMutability(Operation *moduleOp, // No uses - erase the global entirely. deadOps.push_back(globalInfo->op); } else { - // If there are stores mark the global as mutable. - globalInfo->op.setGlobalMutable(!globalInfo->getStores().empty()); + // TODO(benvanik): verify we want this behavior - we likely want to change + // this to be mutable only if stores exist outside of initializers. + // + // If there are stores mark the global as mutable. We need to update all + // of the loads if this changes anything. + bool hasStores = !globalInfo->getStores().empty(); + bool didChange = globalInfo->op.isGlobalMutable() != hasStores; + globalInfo->op.setGlobalMutable(hasStores); + if (didChange) { + for (auto loadOp : globalInfo->getLoads()) + loadOp.setGlobalImmutable(!hasStores); + } } for (auto loadOp : globalInfo->getLoads()) loadOp.setGlobalImmutable(!globalInfo->op.isGlobalMutable()); From 3bc43db250e6095edcbb660914967c5b08f17058 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Fri, 1 Mar 2024 09:29:44 -0800 Subject: [PATCH 06/25] Making MaterializeInterfaces support multiple devices. --- .../HAL/Transforms/MaterializeInterfaces.cpp | 196 +++++---- .../test/materialize_interfaces.mlir | 379 +++++++++--------- 2 files changed, 314 insertions(+), 261 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp index 1f5eb152be06..e1437b2cf77d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp @@ -9,6 +9,7 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Dialect/HAL/Analysis/BindingLayout.h" +#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/HAL/Target/TargetBackend.h" @@ -43,64 +44,98 @@ using ExportExpansions = DenseMap< Attribute, SmallVector>>; +// Map of operations (executables, dispatches, etc) to the executable targets +// required by those operations based on usage. If missing or empty the default +// set should be used. +using RequiredExecutableTargets = + DenseMap>; + //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// +static SymbolRefAttr +makeExportSymbolRefAttr(IREE::HAL::ExecutableOp executableOp, + IREE::HAL::ExecutableVariantOp variantOp, + IREE::HAL::ExecutableExportOp exportOp) { + return SymbolRefAttr::get(executableOp.getNameAttr(), + { + FlatSymbolRefAttr::get(variantOp.getNameAttr()), + FlatSymbolRefAttr::get(exportOp.getNameAttr()), + }); +} + static void setApplicableObjects(Operation *sourceOp, IREE::HAL::ExecutableVariantOp targetOp) { auto objectsAttr = sourceOp->getAttrOfType( "hal.executable.objects"); - if (!objectsAttr) + if (!objectsAttr) { return; + } auto objects = objectsAttr.getApplicableObjects(targetOp.getTarget()); - if (!objects) + if (!objects) { return; + } targetOp.setObjectsAttr(*objects); } -// Returns a set of executable targets required by any dispatch to the given -// executable. Not all exports may be dispatched on the targets. -// If the |executableOp| is public then targets specified on the module will be -// used in addition to any from the dispatches. -template -static SmallVector -gatherExecutableTargetAttrs(SymbolOpInterface executableOp, - llvm::iterator_range exportOps, - const BindingLayoutAnalysis &layoutAnalysis) { - llvm::SetVector> - targetAttrsSet; - if (executableOp.isPublic()) { - for (auto targetAttr : - IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(executableOp)) { - targetAttrsSet.insert(targetAttr); +template +static void +buildRequiredExecutableTypeTargetsMap(ModuleOp moduleOp, + DeviceAnalysis &deviceAnalysis, + BindingLayoutAnalysis &layoutAnalysis, + RequiredExecutableTargets &resultMap) { + // NOTE: we build the map before we process it so that the addresses are + // stable. + for (auto executableOp : moduleOp.template getOps()) { + (void)resultMap[executableOp]; + for (auto exportOp : executableOp.template getOps()) { + for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) { + (void)resultMap[dispatchOp]; + } } } - for (auto exportOp : exportOps) { - for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) { - for (auto targetAttr : - IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(dispatchOp)) { - targetAttrsSet.insert(targetAttr); + for (auto executableOp : moduleOp.template getOps()) { + auto &executableTargetAttrs = resultMap[executableOp]; + for (auto exportOp : executableOp.template getOps()) { + for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) { + auto &dispatchTargetAttrs = resultMap[dispatchOp]; + deviceAnalysis.gatherRequiredExecutableTargets(dispatchOp, + dispatchTargetAttrs); + executableTargetAttrs.insert(dispatchTargetAttrs.begin(), + dispatchTargetAttrs.end()); } } + if (executableOp.isPublic()) { + // Public executables need all possible targets. + deviceAnalysis.gatherAllExecutableTargets(executableTargetAttrs); + } } - auto targetAttrs = targetAttrsSet.takeVector(); - llvm::stable_sort(targetAttrs, [](auto lhs, auto rhs) { - return lhs.getSymbolNameFragment() < rhs.getSymbolNameFragment(); - }); - return targetAttrs; +} + +// Builds a map of executable and dispatch ops to the executable targets that +// may be required. +static RequiredExecutableTargets +buildRequiredExecutableTargetsMap(ModuleOp moduleOp, + DeviceAnalysis &deviceAnalysis, + BindingLayoutAnalysis &layoutAnalysis) { + RequiredExecutableTargets resultMap; + buildRequiredExecutableTypeTargetsMap( + moduleOp, deviceAnalysis, layoutAnalysis, resultMap); + buildRequiredExecutableTypeTargetsMap( + moduleOp, deviceAnalysis, layoutAnalysis, resultMap); + return resultMap; } // Updates the target entry point symbols of |dispatchOp| to the expanded set of // variant exports in |exportExpansions|. -static void updateDispatchTargets(IREE::Stream::CmdDispatchOp dispatchOp, - const ExportExpansions &exportExpansions) { - DenseSet requiredTargetAttrs; - for (auto targetAttr : - IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(dispatchOp)) { - requiredTargetAttrs.insert(targetAttr); - } +static void +updateDispatchTargets(IREE::Stream::CmdDispatchOp dispatchOp, + const ExportExpansions &exportExpansions, + RequiredExecutableTargets &requiredExecutableTargets) { + auto &requiredTargetAttrs = requiredExecutableTargets[dispatchOp]; SmallVector newAttrs; for (auto oldAttr : dispatchOp.getEntryPointRefs()) { auto it = exportExpansions.find(oldAttr); @@ -109,9 +144,13 @@ static void updateDispatchTargets(IREE::Stream::CmdDispatchOp dispatchOp, continue; } for (auto [newAttr, targetAttr] : it->second) { - // Filter the new expansions to only those used by the dispatch. - if (requiredTargetAttrs.contains(targetAttr)) + // Filter the new expansions to only those used by the dispatch (if we + // have a valid filter). + if (requiredTargetAttrs.empty()) { newAttrs.push_back(newAttr); + } else if (requiredTargetAttrs.contains(targetAttr)) { + newAttrs.push_back(newAttr); + } } } dispatchOp.setEntryPointsAttr( @@ -122,26 +161,20 @@ static void updateDispatchTargets(IREE::Stream::CmdDispatchOp dispatchOp, // hal.executable.source materialization //===----------------------------------------------------------------------===// -SymbolRefAttr makeExportSymbolRefAttr(IREE::HAL::ExecutableOp executableOp, - IREE::HAL::ExecutableVariantOp variantOp, - IREE::HAL::ExecutableExportOp exportOp) { - return SymbolRefAttr::get(executableOp.getNameAttr(), - { - FlatSymbolRefAttr::get(variantOp.getNameAttr()), - FlatSymbolRefAttr::get(exportOp.getNameAttr()), - }); -} - -static void -materializeExecutableFromSourceOp(IREE::HAL::ExecutableSourceOp sourceOp, - BindingLayoutAnalysis &layoutAnalysis) { +static void materializeExecutableFromSourceOp( + IREE::HAL::ExecutableSourceOp sourceOp, + BindingLayoutAnalysis &layoutAnalysis, + RequiredExecutableTargets &requiredExecutableTargets) { // Gather the required executable targets based on the dispatches to exports // in the source op. - auto targetAttrs = gatherExecutableTargetAttrs( - sourceOp, sourceOp.getOps(), - layoutAnalysis); - if (targetAttrs.empty()) + SmallVector targetAttrs( + requiredExecutableTargets[sourceOp].getArrayRef()); + if (targetAttrs.empty()) { return; + } + llvm::stable_sort(targetAttrs, [](auto lhs, auto rhs) { + return lhs.getSymbolNameFragment() < rhs.getSymbolNameFragment(); + }); // Create the op that will contain the translated executable. OpBuilder moduleBuilder(sourceOp); @@ -180,8 +213,9 @@ materializeExecutableFromSourceOp(IREE::HAL::ExecutableSourceOp sourceOp, // Clone any target-specific object files specified. if (auto objectsAttr = sourceOp.getObjectsAttr()) { auto objects = objectsAttr.getApplicableObjects(targetAttr); - if (objects) + if (objects) { targetVariantOp.setObjectsAttr(*objects); + } } // Clone inner module contents. @@ -194,7 +228,8 @@ materializeExecutableFromSourceOp(IREE::HAL::ExecutableSourceOp sourceOp, // Update all dispatch sites to reference the new expanded variants. for (auto exportOp : sourceExportOps) { for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) { - updateDispatchTargets(dispatchOp, exportExpansions); + updateDispatchTargets(dispatchOp, exportExpansions, + requiredExecutableTargets); } } @@ -312,8 +347,9 @@ cloneFuncWithInterface(mlir::func::FuncOp sourceFuncOp, } unsigned resourceIdx = 0; for (auto arg : entryBlock->getArguments()) { - if (!llvm::isa(arg.getType())) - continue; + if (!llvm::isa(arg.getType())) { + continue; // unhandled arg type (primitive/etc) + } auto setBinding = resourceMap[resourceIdx++]; auto setLayoutAttr = layoutAttr.getSetLayouts()[setBinding.first]; auto bindingAttr = setLayoutAttr.getBindings()[setBinding.second]; @@ -332,7 +368,8 @@ cloneFuncWithInterface(mlir::func::FuncOp sourceFuncOp, static LogicalResult declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp, IREE::HAL::ExecutableOp targetExecutableOp, - const BindingLayoutAnalysis &layoutAnalysis) { + const BindingLayoutAnalysis &layoutAnalysis, + RequiredExecutableTargets &requiredExecutableTargets) { auto variantOps = targetExecutableOp.getBlock().getOps(); OpBuilder executableBuilder(&targetExecutableOp.getBlock().front()); @@ -348,8 +385,9 @@ declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp, if (auto sourceModuleOp = sourceExecutableOp.getInnerModule()) { sourceFuncOp = sourceModuleOp.lookupSymbol( exportOp.getFunctionRef()); - if (failed(verifyEntryPointTypes(sourceFuncOp))) + if (failed(verifyEntryPointTypes(sourceFuncOp))) { return failure(); + } } // Lookup to see if a layout was specified already. If not we'll perform @@ -441,7 +479,8 @@ declareEntryPointOps(IREE::Stream::ExecutableOp sourceExecutableOp, // Update all dispatch sites to reference the new expanded variants. for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) { - updateDispatchTargets(dispatchOp, exportExpansions); + updateDispatchTargets(dispatchOp, exportExpansions, + requiredExecutableTargets); } } @@ -513,8 +552,9 @@ struct InlineConstantWorkgroupSizePattern assert(exportOp && "must have an entry point corresponding to the parent func"); auto workgroupSizeAttr = exportOp.getWorkgroupSizeAttr(); - if (!workgroupSizeAttr) + if (!workgroupSizeAttr) { return failure(); + } uint64_t dimIdx = sizeOp.getDimension().getZExtValue(); auto dimAttr = workgroupSizeAttr[dimIdx]; @@ -554,11 +594,22 @@ struct MaterializeInterfacesPass SymbolTable symbolTable(moduleOp); BindingLayoutAnalysis layoutAnalysis(moduleOp, symbolTable); + // Run required analysis passes. + DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) { + return signalPassFailure(); + } + + // Gather the required executable targets per executable and dispatch site. + auto requiredExecutableTargets = buildRequiredExecutableTargetsMap( + moduleOp, deviceAnalysis, layoutAnalysis); + // Handle any hand-authored executables; these only need variant expansion // and no layout analysis as the user specified the layout themselves. for (auto sourceOp : llvm::make_early_inc_range( moduleOp.getOps())) { - materializeExecutableFromSourceOp(sourceOp, layoutAnalysis); + materializeExecutableFromSourceOp(sourceOp, layoutAnalysis, + requiredExecutableTargets); } // Processes all executables within the input module and produce the @@ -568,17 +619,22 @@ struct MaterializeInterfacesPass for (auto sourceOp : llvm::make_early_inc_range( moduleOp.getOps())) { auto exportOps = sourceOp.getOps(); - if (exportOps.empty()) + if (exportOps.empty()) { continue; + } // Gather a list of all #hal.executable.targets that we should produce // variants for based on the dispatches performed. Not all exports may be // used on any particular target but we let future DCE/pruning passes // remove them instead of modifying the inner modules here. - auto targetAttrs = - gatherExecutableTargetAttrs(sourceOp, exportOps, layoutAnalysis); - if (targetAttrs.empty()) - continue; + SmallVector targetAttrs( + requiredExecutableTargets[sourceOp].getArrayRef()); + if (targetAttrs.empty()) { + return; + } + llvm::stable_sort(targetAttrs, [](auto lhs, auto rhs) { + return lhs.getSymbolNameFragment() < rhs.getSymbolNameFragment(); + }); // Create the op that will contain the translated executable. OpBuilder builder = OpBuilder::atBlockEnd(moduleOp.getBody()); @@ -605,8 +661,8 @@ struct MaterializeInterfacesPass } // Define interfaces for each exported function based on analysis. - if (failed( - declareEntryPointOps(sourceOp, executableOp, layoutAnalysis))) { + if (failed(declareEntryPointOps(sourceOp, executableOp, layoutAnalysis, + requiredExecutableTargets))) { return signalPassFailure(); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir index a688bddf687d..d350e0e038c0 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir @@ -2,138 +2,139 @@ // Tests an executable with a workgroup count region specified. -module attributes {hal.device.targets = [ - #hal.device.target<"llvm-cpu", [ - #hal.executable.target<"llvm-cpu", "arm_64">, - #hal.executable.target<"llvm-cpu", "x86_64"> - ]> -]} { - // CHECK: #pipeline_layout = #hal.pipeline.layout< - // CHECK-SAME: push_constants = 1 - // CHECK-SAME: sets = [ - // CHECK-SAME: <0, bindings = [ - // CHECK-SAME: <0, storage_buffer, ReadOnly> - // CHECK-SAME: <1, storage_buffer, ReadOnly> - // CHECK-SAME: <2, storage_buffer> +// The default device when none is specified. +// Functions and scopes can override the target device. +util.global private @default_device = #hal.device.target<"cpu", [ + #hal.executable.target<"llvm-cpu", "arm_64">, + #hal.executable.target<"llvm-cpu", "x86_64"> +]> : !hal.device - // CHECK: hal.executable private @ex - // CHECK: hal.executable.variant public @arm_64 target(#executable_target_arm_64 - // CHECK: hal.executable.export public @entry ordinal(0) layout(#pipeline_layout) - // CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>] - // CHECK-NEXT: ^bb0(%[[DEVICE:.+]]: !hal.device, %[[ARG0:.+]]: index, %[[ARG1:.+]]: index): - // CHECK-NEXT: hal.return %[[ARG0]], %[[ARG1]], %[[ARG0]] : index, index, index - // CHECK-NEXT: } - // CHECK: builtin.module - // CHECK-NEXT: func.func private @extern_func() - // CHECK-NEXT: func.func @entry - // CHECK: hal.executable.variant public @x86_64 target(#executable_target_x86_64 - // CHECK: hal.executable.export public @entry ordinal(0) layout(#pipeline_layout) - // CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>] - // CHECK-NEXT: ^bb0(%[[DEVICE:.+]]: !hal.device, %[[ARG0:.+]]: index, %[[ARG1:.+]]: index): - // CHECK-NEXT: hal.return %[[ARG0]], %[[ARG1]], %[[ARG0]] : index, index, index - // CHECK-NEXT: } - // CHECK: builtin.module - // CHECK-NEXT: func.func private @extern_func() +// CHECK: #pipeline_layout = #hal.pipeline.layout< +// CHECK-SAME: push_constants = 1 +// CHECK-SAME: sets = [ +// CHECK-SAME: <0, bindings = [ +// CHECK-SAME: <0, storage_buffer, ReadOnly> +// CHECK-SAME: <1, storage_buffer, ReadOnly> +// CHECK-SAME: <2, storage_buffer> - // CHECK-NEXT: func.func @entry - stream.executable private @ex { - stream.executable.export public @entry workgroups(%arg0: index, %arg1: index) -> (index, index, index) { - stream.return %arg0, %arg1, %arg0 : index, index, index - } - builtin.module { - func.func private @extern_func() - func.func @entry(%operand: i32, %arg0: !stream.binding {stream.alignment = 64 : index}, %arg1: !stream.binding {stream.alignment = 64 : index}, %arg2: !stream.binding {stream.alignment = 64 : index}) { - return - } - } +// CHECK: hal.executable private @ex +// CHECK: hal.executable.variant public @arm_64 target(#executable_target_arm_64 +// CHECK: hal.executable.export public @entry ordinal(0) layout(#pipeline_layout) +// CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>] +// CHECK-NEXT: ^bb0(%[[DEVICE:.+]]: !hal.device, %[[ARG0:.+]]: index, %[[ARG1:.+]]: index): +// CHECK-NEXT: hal.return %[[ARG0]], %[[ARG1]], %[[ARG0]] : index, index, index +// CHECK-NEXT: } +// CHECK: builtin.module +// CHECK-NEXT: func.func private @extern_func() +// CHECK-NEXT: func.func @entry +// CHECK: hal.executable.variant public @x86_64 target(#executable_target_x86_64 +// CHECK: hal.executable.export public @entry ordinal(0) layout(#pipeline_layout) +// CHECK-SAME: hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>] +// CHECK-NEXT: ^bb0(%[[DEVICE:.+]]: !hal.device, %[[ARG0:.+]]: index, %[[ARG1:.+]]: index): +// CHECK-NEXT: hal.return %[[ARG0]], %[[ARG1]], %[[ARG0]] : index, index, index +// CHECK-NEXT: } +// CHECK: builtin.module +// CHECK-NEXT: func.func private @extern_func() + +// CHECK-NEXT: func.func @entry +stream.executable private @ex { + stream.executable.export public @entry workgroups(%arg0: index, %arg1: index) -> (index, index, index) { + stream.return %arg0, %arg1, %arg0 : index, index, index } - util.func public @main(%arg0: !stream.resource, %arg1: !stream.resource, %arg2: index, %arg3: i32) -> !stream.resource { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %0 = stream.resource.alloc uninitialized : !stream.resource{%arg2} - %1 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource{%arg2}, %arg1 as %arg5: !stream.resource{%arg2}, %0 as %arg6: !stream.resource{%arg2}) { - // CHECK: stream.cmd.dispatch - // CHECK-SAME: @ex::@arm_64::@entry - // CHECK-SAME: @ex::@x86_64::@entry - stream.cmd.dispatch @ex::@entry[%c1, %c2](%arg3 : i32) { - ro %arg4[%c0 for %arg2] : !stream.resource{%arg2}, - ro %arg5[%c0 for %arg2] : !stream.resource{%arg2}, - wo %arg6[%c0 for %arg2] : !stream.resource{%arg2} - } - } => !stream.timepoint - %2 = stream.timepoint.await %1 => %0 : !stream.resource{%arg2} - util.return %2 : !stream.resource + builtin.module { + func.func private @extern_func() + func.func @entry(%operand: i32, %arg0: !stream.binding {stream.alignment = 64 : index}, %arg1: !stream.binding {stream.alignment = 64 : index}, %arg2: !stream.binding {stream.alignment = 64 : index}) { + return + } } } +util.func public @main(%arg0: !stream.resource, %arg1: !stream.resource, %arg2: index, %arg3: i32) -> !stream.resource attributes { + stream.affinity = #hal.device.affinity<@default_device> +} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = stream.resource.alloc uninitialized : !stream.resource{%arg2} + %1 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource{%arg2}, %arg1 as %arg5: !stream.resource{%arg2}, %0 as %arg6: !stream.resource{%arg2}) { + // CHECK: stream.cmd.dispatch + // CHECK-SAME: @ex::@arm_64::@entry + // CHECK-SAME: @ex::@x86_64::@entry + stream.cmd.dispatch @ex::@entry[%c1, %c2](%arg3 : i32) { + ro %arg4[%c0 for %arg2] : !stream.resource{%arg2}, + ro %arg5[%c0 for %arg2] : !stream.resource{%arg2}, + wo %arg6[%c0 for %arg2] : !stream.resource{%arg2} + } + } => !stream.timepoint + %2 = stream.timepoint.await %1 => %0 : !stream.resource{%arg2} + util.return %2 : !stream.resource +} // ----- // Tests that executable variants are expanded based on what devices they are // dispatched on. -module attributes { - // The default device when none is specified. - // Functions and scopes can override the target device. - hal.device.targets = [ - #hal.device.target<"cpu", [ - #hal.executable.target<"llvm-cpu", "arm_64">, - #hal.executable.target<"llvm-cpu", "x86_64"> - ]> - ] +// The default device when none is specified. +// Functions and scopes can override the target device. +util.global private @default_device = #hal.device.target<"cpu", [ + #hal.executable.target<"llvm-cpu", "arm_64">, + #hal.executable.target<"llvm-cpu", "x86_64"> +]> : !hal.device +util.global private @riscv_device = #hal.device.target<"cpu", [ + #hal.executable.target<"llvm-cpu", "riscv_32"> +]> : !hal.device + +// CHECK: hal.executable private @ex +// CHECK: hal.executable.variant public @arm_64 +// CHECK: hal.executable.variant public @riscv_32 +// CHECK: hal.executable.variant public @x86_64 +stream.executable private @ex { + stream.executable.export public @entry workgroups() -> (index, index, index) { + %c1 = arith.constant 1 : index + stream.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + func.func @entry(%arg0: !stream.binding {stream.alignment = 64 : index}) { + return + } + } +} + +// This function uses the default HAL device targeting arm_64 and x86_64. +// CHECK-LABEL: @using_default +util.func public @using_default(%arg0: !stream.resource, %arg1: index) -> !stream.timepoint attributes { + stream.affinity = #hal.device.affinity<@default_device> } { - // CHECK: hal.executable private @ex - // CHECK: hal.executable.variant public @arm_64 - // CHECK: hal.executable.variant public @riscv_32 - // CHECK: hal.executable.variant public @x86_64 - stream.executable private @ex { - stream.executable.export public @entry workgroups() -> (index, index, index) { - %c1 = arith.constant 1 : index - stream.return %c1, %c1, %c1 : index, index, index + %c0 = arith.constant 0 : index + %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource{%arg1}) { + // CHECK: stream.cmd.dispatch + // CHECK-SAME: @ex::@arm_64::@entry + // CHECK-NOT: @ex::@riscv_32::@entry + // CHECK-SAME: @ex::@x86_64::@entry + stream.cmd.dispatch @ex::@entry { + rw %arg2[%c0 for %arg1] : !stream.resource{%arg1} } - builtin.module { - func.func @entry(%arg0: !stream.binding {stream.alignment = 64 : index}) { - return - } + } => !stream.timepoint + util.return %0 : !stream.timepoint +} + +// This function is specialized to only run on only riscv_32 and should +// not get assigned the arm_64/x86_64 variant entry points. +// CHECK-LABEL: @using_specialized +util.func public @using_specialized(%arg0: !stream.resource, %arg1: index) -> !stream.timepoint attributes { + stream.affinity = #hal.device.affinity<@riscv_device> +} { + %c0 = arith.constant 0 : index + %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource{%arg1}) { + // CHECK: stream.cmd.dispatch + // CHECK-NOT: @ex::@arm_64::@entry + // CHECK-SAME: @ex::@riscv_32::@entry + // CHECK-NOT: @ex::@x86_64::@entry + stream.cmd.dispatch @ex::@entry { + rw %arg2[%c0 for %arg1] : !stream.resource{%arg1} } - } - // This function uses the default HAL device targeting arm_64 and x86_64. - // CHECK-LABEL: @using_default - util.func public @using_default(%arg0: !stream.resource, %arg1: index) -> !stream.timepoint { - %c0 = arith.constant 0 : index - %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource{%arg1}) { - // CHECK: stream.cmd.dispatch - // CHECK-SAME: @ex::@arm_64::@entry - // CHECK-NOT: @ex::@riscv_32::@entry - // CHECK-SAME: @ex::@x86_64::@entry - stream.cmd.dispatch @ex::@entry { - rw %arg2[%c0 for %arg1] : !stream.resource{%arg1} - } - } => !stream.timepoint - util.return %0 : !stream.timepoint - } - // This function is specialized to only run on only riscv_32 and should - // not get assigned the arm_64/x86_64 variant entry points. - // CHECK-LABEL: @using_specialized - util.func public @using_specialized(%arg0: !stream.resource, %arg1: index) -> !stream.timepoint attributes { - hal.device.targets = [ - #hal.device.target<"cpu", [ - #hal.executable.target<"llvm-cpu", "riscv_32"> - ]> - ] - } { - %c0 = arith.constant 0 : index - %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource{%arg1}) { - // CHECK: stream.cmd.dispatch - // CHECK-NOT: @ex::@arm_64::@entry - // CHECK-SAME: @ex::@riscv_32::@entry - // CHECK-NOT: @ex::@x86_64::@entry - stream.cmd.dispatch @ex::@entry { - rw %arg2[%c0 for %arg1] : !stream.resource{%arg1} - } - } => !stream.timepoint - util.return %0 : !stream.timepoint - } + } => !stream.timepoint + util.return %0 : !stream.timepoint } // ----- @@ -143,69 +144,68 @@ module attributes { // hand-authored code or other dialects that perform interface assignment // themselves. -module attributes { - // The default device when none is specified. - // Functions and scopes can override the target device. - hal.device.targets = [ - #hal.device.target<"cpu", [ - #hal.executable.target<"llvm-cpu", "arm_64">, - #hal.executable.target<"llvm-cpu", "x86_64"> +// The default device when none is specified. +// Functions and scopes can override the target device. +util.global private @default_device = #hal.device.target<"cpu", [ + #hal.executable.target<"llvm-cpu", "arm_64">, + #hal.executable.target<"llvm-cpu", "x86_64"> +]> : !hal.device +util.global private @riscv_device = #hal.device.target<"cpu", [ + #hal.executable.target<"llvm-cpu", "riscv_32"> +]> : !hal.device + +// CHECK: hal.executable private @ex +// CHECK: hal.executable.variant public @arm_64 +// CHECK: hal.executable.variant public @riscv_32 +// CHECK: hal.executable.variant public @x86_64 +hal.executable.source private @ex { + hal.executable.export public @entry layout(#hal.pipeline.layout ]> - ] -} { - // CHECK: hal.executable private @ex - // CHECK: hal.executable.variant public @arm_64 - // CHECK: hal.executable.variant public @riscv_32 - // CHECK: hal.executable.variant public @x86_64 - hal.executable.source private @ex { - hal.executable.export public @entry layout(#hal.pipeline.layout - ]> - ]>) - builtin.module { - func.func @entry() { - return - } + ]>) + builtin.module { + func.func @entry() { + return } } - // This function uses the default HAL device targeting arm_64 and x86_64. - // CHECK-LABEL: @using_default - util.func public @using_default(%arg0: !stream.resource, %arg1: index) -> !stream.timepoint { - %c0 = arith.constant 0 : index - %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource{%arg1}) { - // CHECK: stream.cmd.dispatch - // CHECK-SAME: @ex::@arm_64::@entry - // CHECK-NOT: @ex::@riscv_32::@entry - // CHECK-SAME: @ex::@x86_64::@entry - stream.cmd.dispatch @ex::@entry { - rw %arg2[%c0 for %arg1] : !stream.resource{%arg1} - } - } => !stream.timepoint - util.return %0 : !stream.timepoint - } - // This function is specialized to only run on only riscv_32 and should - // not get assigned the arm_64/x86_64 variant entry points. - // CHECK-LABEL: @using_specialized - util.func public @using_specialized(%arg0: !stream.resource, %arg1: index) -> !stream.timepoint attributes { - hal.device.targets = [ - #hal.device.target<"cpu", [ - #hal.executable.target<"llvm-cpu", "riscv_32"> - ]> - ] - } { - %c0 = arith.constant 0 : index - %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource{%arg1}) { - // CHECK: stream.cmd.dispatch - // CHECK-NOT: @ex::@arm_64::@entry - // CHECK-SAME: @ex::@riscv_32::@entry - // CHECK-NOT: @ex::@x86_64::@entry - stream.cmd.dispatch @ex::@entry { - rw %arg2[%c0 for %arg1] : !stream.resource{%arg1} - } - } => !stream.timepoint - util.return %0 : !stream.timepoint - } +} + +// This function uses the default HAL device targeting arm_64 and x86_64. +// CHECK-LABEL: @using_default +util.func public @using_default(%arg0: !stream.resource, %arg1: index) -> !stream.timepoint attributes { + stream.affinity = #hal.device.affinity<@default_device> +} { + %c0 = arith.constant 0 : index + %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource{%arg1}) { + // CHECK: stream.cmd.dispatch + // CHECK-SAME: @ex::@arm_64::@entry + // CHECK-NOT: @ex::@riscv_32::@entry + // CHECK-SAME: @ex::@x86_64::@entry + stream.cmd.dispatch @ex::@entry { + rw %arg2[%c0 for %arg1] : !stream.resource{%arg1} + } + } => !stream.timepoint + util.return %0 : !stream.timepoint +} + +// This function is specialized to only run on only riscv_32 and should +// not get assigned the arm_64/x86_64 variant entry points. +// CHECK-LABEL: @using_specialized +util.func public @using_specialized(%arg0: !stream.resource, %arg1: index) -> !stream.timepoint attributes { + stream.affinity = #hal.device.affinity<@riscv_device> +} { + %c0 = arith.constant 0 : index + %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource{%arg1}) { + // CHECK: stream.cmd.dispatch + // CHECK-NOT: @ex::@arm_64::@entry + // CHECK-SAME: @ex::@riscv_32::@entry + // CHECK-NOT: @ex::@x86_64::@entry + stream.cmd.dispatch @ex::@entry { + rw %arg2[%c0 for %arg1] : !stream.resource{%arg1} + } + } => !stream.timepoint + util.return %0 : !stream.timepoint } // ----- @@ -213,14 +213,15 @@ module attributes { // Tests that a hal.executable.source op gets expanded to all default targets // when it's public in addition to any ones from dispatch sites. -module attributes { - hal.device.targets = [ - #hal.device.target<"cpu", [ - #hal.executable.target<"llvm-cpu", "arm_64">, - #hal.executable.target<"llvm-cpu", "x86_64"> - ]> - ] -} { +module { + util.global private @primary_device = #hal.device.target<"cpu", [ + #hal.executable.target<"llvm-cpu", "arm_64">, + #hal.executable.target<"llvm-cpu", "x86_64"> + ]> : !hal.device + util.global private @riscv_device = #hal.device.target<"cpu", [ + #hal.executable.target<"llvm-cpu", "riscv_32"> + ]> : !hal.device + // CHECK: hal.executable public @ex // CHECK: hal.executable.variant public @arm_64 // CHECK: hal.executable.variant public @riscv_32 @@ -239,11 +240,7 @@ module attributes { } // CHECK-LABEL: @using_specialized util.func public @using_specialized(%arg0: !stream.resource, %arg1: index) -> !stream.timepoint attributes { - hal.device.targets = [ - #hal.device.target<"cpu", [ - #hal.executable.target<"llvm-cpu", "riscv_32"> - ]> - ] + stream.affinity = #hal.device.affinity<@riscv_device> } { %c0 = arith.constant 0 : index %0 = stream.cmd.execute with(%arg0 as %arg2: !stream.resource{%arg1}) { From 7e35749f124e67e7da92796f62ca2f5de05b5d7d Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 6 Mar 2024 17:50:20 -0800 Subject: [PATCH 07/25] Wiring up AssignTargetDevices and associated passes. This materializes device globals early on and sets the affinity so that all following passes can assume the affinity exists. --- .../Torch/InputConversion/FuncConversion.cpp | 1 + .../Native/Transforms/WrapEntryPoints.cpp | 2 + .../TFLite/Transforms/WrapEntryPoints.cpp | 2 + .../iree/compiler/Dialect/HAL/IR/HALTypes.h | 3 +- .../HAL/Transforms/AssignTargetDevices.cpp | 67 ++++---- .../Dialect/HAL/Transforms/BUILD.bazel | 3 +- .../Dialect/HAL/Transforms/CMakeLists.txt | 3 +- .../Transforms/MaterializeTargetDevices.cpp | 103 ++++++++++++ .../Dialect/HAL/Transforms/Passes.cpp | 47 ++++-- .../compiler/Dialect/HAL/Transforms/Passes.h | 12 +- .../compiler/Dialect/HAL/Transforms/Passes.td | 47 ++++-- .../Dialect/HAL/Transforms/VerifyDevices.cpp | 154 ++++++++++++++++++ .../Transforms/VerifyTargetEnvironment.cpp | 120 -------------- .../Dialect/HAL/Transforms/test/BUILD.bazel | 3 +- .../HAL/Transforms/test/CMakeLists.txt | 3 +- .../test/assign_target_devices.mlir | 11 ++ .../test/materialize_target_devices.mlir | 58 +++++++ .../HAL/Transforms/test/verify_devices.mlir | 58 +++++++ .../test/verify_target_environment.mlir | 54 ------ .../Common/IREEImportPublic.cpp | 6 +- .../Modules/HAL/Inline/Transforms/Passes.cpp | 2 + .../Modules/HAL/Loader/Transforms/Passes.cpp | 2 + .../src/iree/compiler/Pipelines/Pipelines.cpp | 6 +- 23 files changed, 511 insertions(+), 256 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp delete mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyTargetEnvironment.cpp create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir delete mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_target_environment.mlir diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp index d9541a723276..4bdd5c468729 100644 --- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp +++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp @@ -439,6 +439,7 @@ void retainFunctionAttributes(Operation *srcOp, IREE::Util::FuncOp destOp) { // Allowlist of function attributes to retain when importing funcs. constexpr const char *kRetainedAttributes[] = { "iree.reflection", + "stream.affinity", }; auto retainedAttributes = ArrayRef( kRetainedAttributes, diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp index a5fd2fe77e1b..bd2702f5f87e 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp @@ -494,6 +494,8 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, // Populate the reflection attrs based on the original types. populateReflectionAttrs(invocationModel, exportOp, wrapperOp); exportOp->removeAttr("iree.reflection"); + if (auto affinityAttr = exportOp->getAttr("stream.affinity")) + wrapperOp->setAttr("stream.affinity", affinityAttr); auto *entryBlock = wrapperOp.addEntryBlock(); auto entryBuilder = OpBuilder::atBlockBegin(entryBlock); diff --git a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp index 1ec24e84d8a6..3e8c0fdb8d71 100644 --- a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp @@ -499,6 +499,8 @@ class WrapEntryPointsPass wrapperFuncOp.setAllResultAttrs(resultAttrDict); populateReflectionAttrs(entryFuncOp, wrapperFuncOp); + if (auto affinityAttr = entryFuncOp->getAttr("stream.affinity")) + wrapperFuncOp->setAttr("stream.affinity", affinityAttr); // Call the entryFuncOp and return the results. // If we wanted to perform additional work here to invalidate cached shapes diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h index c7b7344fd47f..ef6417dac598 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h @@ -120,7 +120,8 @@ struct DescriptorSetLayoutType struct DeviceType : public Type::TypeBase { + mlir::OpTrait::IREE::Util::ImplicitlyCaptured, + IREE::Util::ReferenceTypeInterface::Trait> { using Base::Base; static constexpr StringLiteral name = "hal.device"; diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp index 7e0b5d4cf161..ddcc2ef70aa6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp @@ -36,66 +36,61 @@ struct AssignTargetDevicesPass AssignTargetDevicesPass> { using IREE::HAL::impl::AssignTargetDevicesPassBase< AssignTargetDevicesPass>::AssignTargetDevicesPassBase; - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - for (auto &targetBackend : targetRegistry->getTargetBackends( - targetRegistry->getRegisteredTargetBackends())) { - targetBackend->getDependentDialects(registry); - } - } void runOnOperation() override { auto moduleOp = getOperation(); - // Check to see if targets are already specified. - auto existingTargetsAttr = - moduleOp->getAttrOfType("hal.device.targets"); - if (existingTargetsAttr) { - // Targets already exist on the module; no-op the pass so that we don't - // mess with whatever the user intended. + // If no targets are specified we can't do anything - another pass earlier + // in the pipeline will have had to add the targets. + if (targetDevices.empty()) { return; } - // If no targets are specified we can't do anything - another pass earlier - // in the pipeline will have had to add the targets. - if (targetBackends.empty()) { - emitRemark(moduleOp.getLoc()) - << "no target HAL target backends specified during assignment"; + // Check to see if targets are already specified and if so then no-op the + // pass so that we don't mess with whatever the user intended. + if (moduleOp->hasAttr("hal.device.targets")) { return; } + // If there are any device globals declared then bail as it means the user + // has already materialized the devices they want. + for (auto globalOp : moduleOp.getOps()) { + if (isa(globalOp.getGlobalType())) + return; + } + llvm::SmallDenseSet targetAttrSet; SmallVector targetAttrs; for (const auto &targetBackendName : targetBackends) { auto targetBackend = targetRegistry->getTargetBackend(targetBackendName); if (!targetBackend) { - std::string backends; - llvm::raw_string_ostream os(backends); - llvm::interleaveComma(targetRegistry->getRegisteredTargetBackends(), os, - [&os](const std::string &name) { os << name; }); - emitError(moduleOp.getLoc()) - << "target backend '" << targetBackendName - << "' not registered; registered backends: " << os.str(); - signalPassFailure(); - return; + auto diagnostic = emitError(moduleOp.getLoc()) + << "target backend '" << targetBackendName + << "' not registered; registered backends: ["; + llvm::interleaveComma(targetRegistry->getRegisteredTargetBackends(), + diagnostic); + diagnostic << "]"; + return signalPassFailure(); } auto targetDeviceName = targetBackend->getLegacyDefaultDeviceID(); auto targetDevice = targetRegistry->getTargetDevice(targetDeviceName); if (!targetDevice) { - std::string devices; - llvm::raw_string_ostream os(devices); - llvm::interleaveComma(targetRegistry->getRegisteredTargetDevices(), os, - [&os](const std::string &name) { os << name; }); - emitError(moduleOp.getLoc()) - << "target device '" << targetDeviceName - << "' not registered; registered devices: " << os.str(); - signalPassFailure(); - return; + auto diagnostic = emitError(moduleOp.getLoc()) + << "target device '" << targetDeviceName + << "' not registered; registered devices: ["; + llvm::interleaveComma(targetRegistry->getRegisteredTargetDevices(), + diagnostic); + diagnostic << "]"; + return signalPassFailure(); } // Ask the target backend for its default device specification attribute. auto targetAttr = targetDevice->getDefaultDeviceTarget( moduleOp.getContext(), *targetRegistry.value); + if (!targetAttr) { + emitError(moduleOp.getLoc()) << "no default device targets available"; + return signalPassFailure(); + } if (!targetAttrSet.contains(targetAttr)) { targetAttrSet.insert(targetAttr); targetAttrs.push_back(targetAttr); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel index 63be1d24f00d..3063a8d5a970 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel @@ -28,6 +28,7 @@ iree_compiler_cc_library( "MaterializeDispatchInstrumentation.cpp", "MaterializeInterfaces.cpp", "MaterializeResourceCaches.cpp", + "MaterializeTargetDevices.cpp", "MemoizeDeviceQueries.cpp", "Passes.cpp", "Passes.h.inc", @@ -39,7 +40,7 @@ iree_compiler_cc_library( "StripExecutableContents.cpp", "SubstituteExecutables.cpp", "TranslateExecutables.cpp", - "VerifyTargetEnvironment.cpp", + "VerifyDevices.cpp", ], hdrs = [ "Passes.h", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt index 382525e66958..50ce275ad91b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt @@ -29,6 +29,7 @@ iree_cc_library( "MaterializeDispatchInstrumentation.cpp" "MaterializeInterfaces.cpp" "MaterializeResourceCaches.cpp" + "MaterializeTargetDevices.cpp" "MemoizeDeviceQueries.cpp" "Passes.cpp" "Passes.h.inc" @@ -40,7 +41,7 @@ iree_cc_library( "StripExecutableContents.cpp" "SubstituteExecutables.cpp" "TranslateExecutables.cpp" - "VerifyTargetEnvironment.cpp" + "VerifyDevices.cpp" DEPS ::PassesIncGen LLVMSupport diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp new file mode 100644 index 000000000000..904e20aa95af --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp @@ -0,0 +1,103 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::HAL { + +#define GEN_PASS_DEF_MATERIALIZETARGETDEVICESPASS +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// --iree-hal-materialize-target-devices +//===----------------------------------------------------------------------===// + +struct MaterializeTargetDevicesPass + : public IREE::HAL::impl::MaterializeTargetDevicesPassBase< + MaterializeTargetDevicesPass> { + using IREE::HAL::impl::MaterializeTargetDevicesPassBase< + MaterializeTargetDevicesPass>::MaterializeTargetDevicesPassBase; + + void runOnOperation() override { + auto moduleOp = getOperation(); + + // Only run if there's a module-level attribute specified. + auto deviceTargetAttrs = + moduleOp->getAttrOfType("hal.device.targets"); + if (!deviceTargetAttrs || deviceTargetAttrs.empty()) + return; + moduleOp->removeAttr("hal.device.targets"); + + // Create the default device global. + auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); + auto deviceType = moduleBuilder.getType(); + auto globalOp = moduleBuilder.create( + moduleOp.getLoc(), "__device.0", /*isMutable=*/false, deviceType); + globalOp.setPrivate(); + if (deviceTargetAttrs.size() == 1) { + auto typedAttr = + dyn_cast(deviceTargetAttrs.getValue().front()); + if (typedAttr && isa(typedAttr.getType())) { + globalOp.setInitialValueAttr(typedAttr); + } else { + moduleOp.emitOpError() + << "has invalid device targets specified; " + "expect hal.device.targets to be an " + "ArrayAttr of !hal.device initialization attributes"; + return signalPassFailure(); + } + } else { + globalOp.setInitialValueAttr( + moduleBuilder.getAttr( + deviceType, deviceTargetAttrs)); + } + + // Assign affinities to all top level ops that don't already have one set. + auto affinityName = StringAttr::get(&getContext(), "stream.affinity"); + auto affinityAttr = moduleBuilder.getAttr( + FlatSymbolRefAttr::get(globalOp), /*queue_mask=*/-1ll); + auto isAnnotatableType = [](Type type) { + return isa(type) || isa(type); + }; + for (auto &op : moduleOp.getOps()) { + bool shouldAnnotate = true; + if (auto globalOp = dyn_cast(op)) { + if (!isAnnotatableType(globalOp.getGlobalType())) + shouldAnnotate = false; + } else if (op.hasTrait()) { + // Symbol table ops can't reference parent symbols properly. + shouldAnnotate = false; + } + if (!shouldAnnotate) + continue; + if (auto affinityOp = dyn_cast(op)) { + if (!affinityOp.getAffinity()) + affinityOp.setAffinity(affinityAttr); + } else { + if (!op.hasAttr(affinityName)) + op.setAttr(affinityName, affinityAttr); + } + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::IREE::HAL diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index 12f2c7bf3252..f2752c953097 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -180,6 +180,28 @@ static void addExecutableSubstitutionPasses(OpPassManager &passManager, } } +//===----------------------------------------------------------------------===// +// --iree-hal-device-assignment-pipeline +//===----------------------------------------------------------------------===// + +void buildHALDeviceAssignmentPassPipeline(OpPassManager &passManager, + const TargetRegistry &targetRegistry, + const TargetOptions &targetOptions) { + // The HAL must know its targets early on in the process. This pass discovers/ + // derives/specifies the target devices and annotates the module with that + // information. This allows subsequent passes to lookup which devices they are + // targeting. + if (!targetOptions.targets.empty()) { + // Today we just assign devices from parameters but we should instead be + // performing analysis at the flow level and then doing magic device + // database lookups here. + passManager.addPass(IREE::HAL::createAssignTargetDevicesPass( + {&targetRegistry, targetOptions.targets})); + } + passManager.addPass(IREE::HAL::createMaterializeTargetDevicesPass()); + passManager.addPass(IREE::HAL::createVerifyDevicesPass({&targetRegistry})); +} + //===----------------------------------------------------------------------===// // --iree-hal-configuration-pipeline //===----------------------------------------------------------------------===// @@ -197,23 +219,9 @@ void buildHALConfigurationPassPipeline(OpPassManager &passManager, addCleanupPatterns(passManager); //---------------------------------------------------------------------------- - // Device assignment and interface materialization + // Device-specific interface materialization //---------------------------------------------------------------------------- - // The HAL must know its targets early on in the process. This pass discovers/ - // derives/specifies the target devices and annotates the module with that - // information. This allows subsequent passes to lookup which devices they are - // targeting. - if (!targetOptions.targets.empty()) { - // Today we just assign devices from parameters but we should instead be - // performing analysis at the flow level and then doing magic device - // database lookups here. - passManager.addPass(IREE::HAL::createAssignTargetDevicesPass( - {&targetRegistry, targetOptions.targets})); - } - passManager.addPass( - IREE::HAL::createVerifyTargetEnvironmentPass({targetRegistry})); - // Add dispatch instrumentation prior to materializing interfaces so we can // more easily mutate the stream dispatch ops and exports. if (auto bufferSize = clInstrumentDispatchBufferSize.getValue()) { @@ -275,6 +283,8 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, hooks.beforePhase(PipelinePhase::ExecutableSources, passManager); if (compileFrom < PipelinePhase::ExecutableSources) { + buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry, + targetOptions); buildHALConfigurationPassPipeline(passManager, targetRegistry, targetOptions, hooks); @@ -566,6 +576,13 @@ void registerHALPasses() { registerPasses(); // Pipelines. + PassPipelineRegistration<>("iree-hal-device-assignment-pipeline", + "Runs HAL target device assignment pipeline.", + [](OpPassManager &passManager) { + buildHALDeviceAssignmentPassPipeline( + passManager, TargetRegistry::getGlobal(), + TargetOptions::FromFlags::get()); + }); PassPipelineRegistration<>("iree-hal-configuration-pipeline", "Runs HAL target configuration pipeline.", [](OpPassManager &passManager) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h index ec5c9c56dc69..d09080bf5a2c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h @@ -46,9 +46,17 @@ struct PipelineHooks { std::function afterPhase; }; +// Assigns devices from flags and coarse module-level specification. +// Frontends are encouraged to create and assign devices themselves in order to +// support more complex configurations (multiple devices, fallbacks, etc). +void buildHALDeviceAssignmentPassPipeline(OpPassManager &passManager, + const TargetRegistry &targetRegistry, + const TargetOptions &targetOptions); + // Adds a set of passes to the given pass manager that run the head of the HAL -// pipeline to assign devices, materialize interfaces, and translate -// executables. The host portion of the program is annotated but not modified. +// pipeline to materialize interfaces, import externally specified executables, +// and translate executables. The host portion of the program is annotated but +// not modified. void buildHALConfigurationPassPipeline(OpPassManager &passManager, const TargetRegistry &targetRegistry, const TargetOptions &targetOptions, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td index 848acc78a477..cdd8d28f5884 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td @@ -42,14 +42,11 @@ def ConvertToHALPass : // Device management //===----------------------------------------------------------------------===// -def VerifyTargetEnvironmentPass : - Pass<"iree-hal-verify-target-environment", "mlir::ModuleOp"> { - let summary = "Verifies that the target execution environment is valid."; +def AssignTargetDevicesPass : + Pass<"iree-hal-assign-target-devices", "mlir::ModuleOp"> { + let summary = "Assigns the HAL devices the module will target to the given list of targets."; let description = [{ - Verifies that the target execution environment is valid. - `#hal.device.target` and `#hal.executable.target` attribute placement and - definition will be checked that they reference the available target backends - and that they are structurally valid. + Assigns target HAL devices to the module based on the given list. }]; let options = [ Option< @@ -57,14 +54,37 @@ def VerifyTargetEnvironmentPass : "llvm::cl::TargetRegistryRef", "", "Target registry containing the list of available devices and backends." >, + ListOption< + "targetBackends", "targetBackends", + "std::string", + "List of target backends to assign as device targets." + >, + ]; + let dependentDialects = [ + "IREE::HAL::HALDialect", ]; } -def AssignTargetDevicesPass : - Pass<"iree-hal-assign-target-devices", "mlir::ModuleOp"> { - let summary = "Assigns the HAL devices the module will target to the given list of targets."; +def MaterializeTargetDevicesPass : + Pass<"iree-hal-materialize-target-devices", "mlir::ModuleOp"> { + let summary = "Materializes global device handles based on a `hal.device.targets` spec."; let description = [{ - Assigns target HAL devices to the module based on the given list. + Materializes a global `!hal.device` for the devices specified by the + `hal.device.targets` attribute on the module. It's preferred that frontends + provide IR with the globals assigned as this only supports a single device. + }]; + let dependentDialects = [ + "IREE::HAL::HALDialect", + "IREE::Util::UtilDialect", + ]; +} + +def VerifyDevicesPass : + Pass<"iree-hal-verify-devices", "mlir::ModuleOp"> { + let summary = "Verifies that all devices can be targeted with the available compiler plugins."; + let description = [{ + Verifies that `#hal.device.target` and `#hal.executable.target` attributes + reference targets that are registered with the compiler. }]; let options = [ Option< @@ -72,11 +92,6 @@ def AssignTargetDevicesPass : "llvm::cl::TargetRegistryRef", "", "Target registry containing the list of available devices and backends." >, - ListOption< - "targetBackends", "targetBackends", - "std::string", - "List of target backends to assign as device targets." - >, ]; } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp new file mode 100644 index 000000000000..21f37543f36a --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp @@ -0,0 +1,154 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::HAL { + +#define GEN_PASS_DEF_VERIFYDEVICESPASS +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// --iree-hal-verify-devices +//===----------------------------------------------------------------------===// + +static void printAvailable(InFlightDiagnostic &diagnostic, + const TargetRegistry &targetRegistry) { + diagnostic << "available devices: ["; + llvm::interleaveComma(targetRegistry.getRegisteredTargetDevices(), + diagnostic); + diagnostic << "], available backends = ["; + llvm::interleaveComma(targetRegistry.getRegisteredTargetBackends(), + diagnostic); + diagnostic << "]"; +} + +static LogicalResult +verifyDeviceTargetAttr(Operation *deviceOp, + IREE::HAL::DeviceTargetAttr deviceTargetAttr, + const TargetRegistry &targetRegistry) { + auto targetDevice = + targetRegistry.getTargetDevice(deviceTargetAttr.getDeviceID().getValue()); + if (!targetDevice) { + auto diagnostic = deviceOp->emitError(); + diagnostic << "unregistered target device " + << deviceTargetAttr.getDeviceID() + << "; ensure it is linked in to the compiler (available = [ "; + for (const auto &targetName : targetRegistry.getRegisteredTargetDevices()) { + diagnostic << "'" << targetName << "' "; + } + diagnostic << "])"; + return diagnostic; + } + + for (auto executableTargetAttr : deviceTargetAttr.getExecutableTargets()) { + auto targetBackend = targetRegistry.getTargetBackend( + executableTargetAttr.getBackend().getValue()); + if (!targetBackend) { + auto diagnostic = deviceOp->emitError(); + diagnostic << "unregistered target backend " + << executableTargetAttr.getBackend() + << "; ensure it is linked in to the compiler (available = [ "; + for (const auto &targetName : + targetRegistry.getRegisteredTargetBackends()) { + diagnostic << "'" << targetName << "' "; + } + diagnostic << "])"; + return diagnostic; + } + } + + return success(); +} + +static LogicalResult verifyAttr(Operation *deviceOp, Attribute attr, + const TargetRegistry &targetRegistry) { + return TypeSwitch(attr) + .Case([&](auto deviceTargetAttr) { + return verifyDeviceTargetAttr(deviceOp, deviceTargetAttr, + targetRegistry); + }) + .Case([&](auto deviceSelectAttr) { + for (auto attr : deviceSelectAttr.getDevices().getValue()) { + if (failed(verifyAttr(deviceOp, attr, targetRegistry))) { + return failure(); + } + } + return success(); + }) + .Default([&](auto attr) { + return success(); // probably fallback/ordinal/etc - can't verify + }); +} + +struct VerifyDevicesPass + : public IREE::HAL::impl::VerifyDevicesPassBase { + using IREE::HAL::impl::VerifyDevicesPassBase< + VerifyDevicesPass>::VerifyDevicesPassBase; + void runOnOperation() override { + auto moduleOp = getOperation(); + + // Devices are required if we need to convert host code or executables. + // If we only have hal.executables as input then we can bypass this. + // We could extend this check to be a bit smarter at the risk of false + // negatives - today this is just handling the standalone hal.executable + // compilation workflow. + bool anyNonExecutableOps = false; + for (auto &op : moduleOp.getOps()) { + if (!isa(op)) { + anyNonExecutableOps = true; + break; + } + } + if (!anyNonExecutableOps) { + return; + } + + // Analyze the module to find all devices. + DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) { + return signalPassFailure(); + } + + // Must have at least one device specified. + if (deviceAnalysis.getDeviceGlobals().empty()) { + auto diagnostic = moduleOp.emitError(); + diagnostic + << "no HAL devices defined in the module; use the module-level " + "hal.device.targets attribute, the --iree-hal-target-device= " + "flags, or provide inputs with global !hal.devices defined; "; + printAvailable(diagnostic, *targetRegistry.value); + return signalPassFailure(); + } + + // Walk all devices and verify them. + for (auto deviceOp : deviceAnalysis.getDeviceGlobals()) { + if (auto initialValue = deviceOp.getGlobalInitialValue()) { + if (failed(verifyAttr(deviceOp, initialValue, *targetRegistry.value))) { + return signalPassFailure(); + } + } + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::IREE::HAL diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyTargetEnvironment.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyTargetEnvironment.cpp deleted file mode 100644 index 7362b92570af..000000000000 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyTargetEnvironment.cpp +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include -#include - -#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" -#include "iree/compiler/Dialect/HAL/IR/HALOps.h" -#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h" -#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" -#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/Pass/Pass.h" - -namespace mlir::iree_compiler::IREE::HAL { - -#define GEN_PASS_DEF_VERIFYTARGETENVIRONMENTPASS -#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc" - -namespace { - -//===----------------------------------------------------------------------===// -// --iree-hal-verify-target-environment -//===----------------------------------------------------------------------===// - -struct VerifyTargetEnvironmentPass - : public IREE::HAL::impl::VerifyTargetEnvironmentPassBase< - VerifyTargetEnvironmentPass> { - using IREE::HAL::impl::VerifyTargetEnvironmentPassBase< - VerifyTargetEnvironmentPass>::VerifyTargetEnvironmentPassBase; - void runOnOperation() override { - auto moduleOp = getOperation(); - - // Targets are required if we need to convert host code or executables. - // If we only have hal.executables as input then we can bypass this. - // We could extend this check to be a bit smarter at the risk of false - // negatives - today this is just handling the standalone hal.executable - // compilation workflow. - bool anyNonExecutableOps = false; - for (auto &op : moduleOp.getOps()) { - if (!isa(op)) { - anyNonExecutableOps = true; - break; - } - } - if (!anyNonExecutableOps) - return; - - // Must have targets specified. - auto targetsAttr = moduleOp->getAttrOfType("hal.device.targets"); - if (!targetsAttr || targetsAttr.empty()) { - auto diagnostic = moduleOp.emitError(); - diagnostic - << "no HAL target devices specified on the module (available = [ "; - for (const auto &targetName : - targetRegistry->getRegisteredTargetBackends()) { - diagnostic << "'" << targetName << "' "; - } - diagnostic << "])"; - signalPassFailure(); - return; - } - - // Verify each target is registered. - for (auto attr : targetsAttr) { - auto deviceTargetAttr = llvm::dyn_cast(attr); - if (!deviceTargetAttr) { - moduleOp.emitError() << "invalid target attr type: " << attr; - signalPassFailure(); - return; - } - - auto targetDevice = targetRegistry->getTargetDevice( - deviceTargetAttr.getDeviceID().getValue()); - if (!targetDevice) { - auto diagnostic = moduleOp.emitError(); - diagnostic - << "unregistered target device " << deviceTargetAttr.getDeviceID() - << "; ensure it is linked in to the compiler (available = [ "; - for (const auto &targetName : - targetRegistry->getRegisteredTargetDevices()) { - diagnostic << "'" << targetName << "' "; - } - diagnostic << "])"; - signalPassFailure(); - return; - } - - for (auto executableTargetAttr : - deviceTargetAttr.getExecutableTargets()) { - auto targetBackend = targetRegistry->getTargetBackend( - executableTargetAttr.getBackend().getValue()); - if (!targetBackend) { - auto diagnostic = moduleOp.emitError(); - diagnostic - << "unregistered target backend " - << executableTargetAttr.getBackend() - << "; ensure it is linked in to the compiler (available = [ "; - for (const auto &targetName : - targetRegistry->getRegisteredTargetBackends()) { - diagnostic << "'" << targetName << "' "; - } - diagnostic << "])"; - signalPassFailure(); - return; - } - } - } - } -}; - -} // namespace - -} // namespace mlir::iree_compiler::IREE::HAL diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel index 1bd2bb03c899..5e84ae3fa741 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel @@ -27,6 +27,7 @@ iree_lit_test_suite( "materialize_dispatch_instrumentation.mlir", "materialize_interfaces.mlir", "materialize_resource_caches.mlir", + "materialize_target_devices.mlir", "memoize_device_queries.mlir", "preprocess_executables.mlir", "prune_executables.mlir", @@ -34,7 +35,7 @@ iree_lit_test_suite( "resolve_export_ordinals.mlir", "strip_executable_contents.mlir", "substitute_executables.mlir", - "verify_target_environment.mlir", + "verify_devices.mlir", ], include = ["*.mlir"], exclude = ["substitute_executables_replacement.mlir"], diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt index d4ea9486ebdc..8079246613d0 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt @@ -25,6 +25,7 @@ iree_lit_test_suite( "materialize_dispatch_instrumentation.mlir" "materialize_interfaces.mlir" "materialize_resource_caches.mlir" + "materialize_target_devices.mlir" "memoize_device_queries.mlir" "preprocess_executables.mlir" "prune_executables.mlir" @@ -32,7 +33,7 @@ iree_lit_test_suite( "resolve_export_ordinals.mlir" "strip_executable_contents.mlir" "substitute_executables.mlir" - "verify_target_environment.mlir" + "verify_devices.mlir" TOOLS FileCheck iree-opt diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir index 46c0a5e683c5..e10b6494c00d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir @@ -29,3 +29,14 @@ module @module {} module @module attributes { hal.device.targets = [#hal.device.target<"foo">] } {} + +// ----- + +// The pass does nothing when one or more devices has already been defined. + +// CHECK: module @module +// CHECK-NOT: hal.device.targets +module @module { + // CHECK: @existing_device + util.global private @existing_device : !hal.device +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir new file mode 100644 index 000000000000..d5d011c649b9 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir @@ -0,0 +1,58 @@ +// RUN: iree-opt --split-input-file --iree-hal-materialize-target-devices %s --verify-diagnostics | FileCheck %s + +// expected-error@+1 {{invalid device targets specified}} +module @module attributes { + hal.device.targets = [ + "wrong_type" + ] +} { + util.func private @func() -> () +} + +// ----- + +// Valid input with proper attributes. + +// CHECK: #device_target_llvm_cpu = #hal.device.target<"llvm-cpu"> +#device_target_llvm_cpu = #hal.device.target<"llvm-cpu"> +// CHECK: #device_target_vmvx = #hal.device.target<"vmvx"> +#device_target_vmvx = #hal.device.target<"vmvx"> + +// CHECK: module @module +// CHECK-NOT: hal.device.targets +module @module attributes { + hal.device.targets = [ + #device_target_llvm_cpu, + #device_target_vmvx + ] +} { + // CHECK: util.global private @__device.0 = #hal.device.select<[ + // CHECK-SAME: #device_target_llvm_cpu, + // CHECK-SAME: #device_target_vmvx + // CHECK-SAME: ]> : !hal.device + + // CHECK: util.global private @tensor_global + // CHECK-SAME: stream.affinity = #hal.device.affinity<@__device_0> + util.global private @tensor_global : tensor<4xf32> + + // CHECK: util.global private @primitive_global + // CHECK-NOT: stream.affinity + util.global private @primitive_global : i32 + + // CHECK: util.func private @func + // CHECK-SAME: stream.affinity = #hal.device.affinity<@__device_0> + util.func private @func() -> () +} + +// ----- + +// Modules without anything that needs an environment are OK. + +// CHECK: module @module +module @module { + // CHECK-NEXT: hal.executable private @exe + hal.executable private @exe { + // CHECK-NEXT: hal.executable.variant public @embedded_elf_arm_64 + hal.executable.variant public @embedded_elf_arm_64 target(#hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {}>) {} + } +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir new file mode 100644 index 000000000000..4511a0becdcb --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir @@ -0,0 +1,58 @@ +// RUN: iree-opt --split-input-file --iree-hal-verify-devices %s --mlir-print-local-scope --verify-diagnostics | FileCheck %s + +// expected-error@+1 {{no HAL devices defined in the module}} +module @module { + util.func private @func() -> () +} + +// ----- + +module @module { + // expected-error@+1 {{unregistered target device "__unregistered__"}} + util.global private @device = #hal.device.target<"__unregistered__"> : !hal.device + util.func private @func() -> () attributes { + stream.affinity = #hal.device.affinity<@device> + } +} + +// ----- + +module @module { + // expected-error@+1 {{unregistered target device "__unregistered__"}} + util.global private @device = #hal.device.select<[ + #hal.device.target<"vmvx"> : !hal.device, + #hal.device.target<"__unregistered__"> : !hal.device + ]> : !hal.device + util.func private @func() -> () attributes { + stream.affinity = #hal.device.affinity<@device> + } +} + +// ----- + +// Valid input with proper attributes. + +// CHECK: module @module +module @module { + util.global private @device = #hal.device.target<"vmvx"> : !hal.device + util.global private @optional = #hal.device.fallback<@device> : !hal.device + util.global private @ordinal = #hal.device.ordinal<0> : !hal.device + util.global private @selected = #hal.device.select<[ + #hal.device.target<"llvm-cpu"> : !hal.device, + #hal.device.target<"vmvx"> : !hal.device + ]> : !hal.device + util.func private @func() -> () attributes { + stream.affinity = #hal.device.affinity<@device> + } +} + +// ----- + +// Modules without anything that needs an environment are OK. + +// CHECK: module @module +module @module { + hal.executable private @exe { + hal.executable.variant public @embedded_elf_arm_64 target(#hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {}>) {} + } +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_target_environment.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_target_environment.mlir deleted file mode 100644 index 81f1e1ce813d..000000000000 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_target_environment.mlir +++ /dev/null @@ -1,54 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-hal-verify-target-environment %s --verify-diagnostics | FileCheck %s - -// expected-error@+1 {{no HAL target devices specified}} -module @module { - util.func private @func() -> () -} - -// ----- - -// expected-error@+1 {{no HAL target devices specified}} -module @module attributes {hal.device.targets = []} { - util.func private @func() -> () -} - -// ----- - -// expected-error@+1 {{invalid target attr type}} -module @module attributes {hal.device.targets = ["wrong_type"]} { - util.func private @func() -> () -} - -// ----- - -// expected-error@+1 {{unregistered target device "foo"}} -module @module attributes {hal.device.targets = [#hal.device.target<"foo">]} { - util.func private @func() -> () -} - -// ----- - -// Valid input with proper attributes. - -// CHECK: #device_target_vmvx = #hal.device.target<"vmvx"> -#device_target_vmvx = #hal.device.target<"vmvx"> - -// CHECK: module @module attributes {hal.device.targets = [#device_target_vmvx]} -module @module attributes {hal.device.targets = [#device_target_vmvx]} { - util.func private @func() -> () -} - -// ----- - -// Modules without anything that needs an environment are OK. - -#executable_target = #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {}> - -// CHECK: module @module -module @module { - // CHECK-NEXT: hal.executable private @exe - hal.executable private @exe { - // CHECK-NEXT: hal.executable.variant public @embedded_elf_arm_64 - hal.executable.variant public @embedded_elf_arm_64 target(#executable_target) {} - } -} diff --git a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp index f6162eb283d5..3200829635bf 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp @@ -349,10 +349,8 @@ class FuncFuncOpPattern : public OpConversionPattern { // Allowlist of function attributes to retain when importing funcs. constexpr const char *kRetainedAttributes[] = { - "iree.reflection", - "vm.fallback", - "vm.signature", - "vm.version", + "iree.reflection", "stream.affinity", "vm.fallback", + "vm.signature", "vm.version", }; auto retainedAttributes = ArrayRef( kRetainedAttributes, diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp index 2c11a333d6a6..f7a8276e9da9 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp @@ -53,6 +53,8 @@ void buildHALInlineStaticTransformPassPipeline( // Device assignment and interface materialization //---------------------------------------------------------------------------- + IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry, + targetOptions); IREE::HAL::buildHALConfigurationPassPipeline(passManager, targetRegistry, targetOptions); diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp index 47f1bcdf6b73..370d0abedc16 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp @@ -53,6 +53,8 @@ void buildHALInlineDynamicTransformPassPipeline( // Device assignment and interface materialization //---------------------------------------------------------------------------- + IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry, + targetOptions); IREE::HAL::buildHALConfigurationPassPipeline(passManager, targetRegistry, targetOptions); diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp index d7a155b90807..0052ed868bce 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp @@ -84,10 +84,8 @@ void buildIREEPrecompileTransformPassPipeline( // IR so that they are available for all passes that may want to use this // information. If trying to compile in a generic mode the user should omit // specifying targets. - if (!executableOptions.targets.empty()) { - passManager.addPass(IREE::HAL::createAssignTargetDevicesPass( - {&targetRegistry, executableOptions.targets})); - } + IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry, + executableOptions); // Input pipelines can result in changes to the exported functions and types // and must run before generating bindings. From 3af1211b127f6c34bd8bd9a85b8595d9a23b4161 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 7 Mar 2024 08:40:54 -0800 Subject: [PATCH 08/25] Adding #hal.device.promise and a resolution pass. This allows for devices to be referenced prior to materialization. --- .../iree/compiler/Dialect/HAL/IR/HALAttrs.cpp | 107 ++++++++++++ .../iree/compiler/Dialect/HAL/IR/HALAttrs.td | 46 ++++++ .../Dialect/HAL/IR/test/attributes.mlir | 11 ++ .../Dialect/HAL/Transforms/BUILD.bazel | 1 + .../Dialect/HAL/Transforms/CMakeLists.txt | 1 + .../Dialect/HAL/Transforms/Passes.cpp | 1 + .../compiler/Dialect/HAL/Transforms/Passes.td | 12 ++ .../HAL/Transforms/ResolveDevicePromises.cpp | 155 ++++++++++++++++++ .../Dialect/HAL/Transforms/test/BUILD.bazel | 1 + .../HAL/Transforms/test/CMakeLists.txt | 1 + .../test/resolve_device_promises.mlir | 43 +++++ 11 files changed, 379 insertions(+) create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/ResolveDevicePromises.cpp create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_device_promises.mlir diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp index 69e08f0e5377..8460588b4c31 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp @@ -1093,6 +1093,113 @@ bool DeviceAffinityAttr::isLegalToInline(Operation *inlineSite, return *this == targetAffinityAttr; } +//===----------------------------------------------------------------------===// +// #hal.device.promise<*> +//===----------------------------------------------------------------------===// + +// static +Attribute DevicePromiseAttr::parse(AsmParser &p, Type type) { + // `<@device` + StringAttr deviceName; + int64_t queueMask = -1; + if (failed(p.parseLess()) || failed(p.parseSymbolName(deviceName))) + return {}; + if (succeeded(p.parseOptionalComma())) { + // `[`queue_bit[, ...] `]` + queueMask = 0; + if (failed(p.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { + int64_t i = 0; + if (failed(p.parseInteger(i))) + return failure(); + queueMask |= 1ll << i; + return success(); + }))) { + return {}; + } + } + // `>` + if (failed(p.parseGreater())) + return {}; + return get(p.getContext(), deviceName, queueMask); +} + +void DevicePromiseAttr::print(AsmPrinter &p) const { + auto &os = p.getStream(); + os << "<@"; + os << getDevice().getValue(); + int64_t queueMask = getQueueMask(); + if (queueMask != -1) { + os << ", ["; + for (int i = 0, j = 0; i < sizeof(queueMask) * 8; ++i) { + if (queueMask & (1ll << i)) { + if (j++ > 0) + os << ", "; + os << i; + } + } + os << "]"; + } + os << ">"; +} + +bool DevicePromiseAttr::isExecutableWith( + IREE::Stream::AffinityAttr other) const { + if (!other) + return true; + // Only compatible with the same exact devices today. We could support a + // peering model to allow operations to move across devices in a peered set + // but that may be best done at higher levels and avoided once we get to the + // "are these the same device" stage. + auto otherPromiseAttr = llvm::dyn_cast_if_present(other); + if (!otherPromiseAttr || getDevice() != otherPromiseAttr.getDevice()) + return false; + // If this affinity is a subset of the target affinity then it can execute + // with it. + if ((getQueueMask() & otherPromiseAttr.getQueueMask()) == getQueueMask()) + return true; + // Otherwise not compatible. + return false; +} + +IREE::Stream::AffinityAttr +DevicePromiseAttr::joinOR(IREE::Stream::AffinityAttr other) const { + if (!other) + return *this; + if (!IREE::Stream::AffinityAttr::canExecuteTogether(*this, other)) { + return nullptr; + } + auto otherPromiseAttr = llvm::dyn_cast_if_present(other); + return DevicePromiseAttr::get(getContext(), getDevice(), + getQueueMask() | + otherPromiseAttr.getQueueMask()); +} + +IREE::Stream::AffinityAttr +DevicePromiseAttr::joinAND(IREE::Stream::AffinityAttr other) const { + if (!other) + return *this; + if (!IREE::Stream::AffinityAttr::canExecuteTogether(*this, other)) { + return nullptr; + } + auto otherPromiseAttr = llvm::dyn_cast_if_present(other); + return DevicePromiseAttr::get(getContext(), getDevice(), + getQueueMask() & + otherPromiseAttr.getQueueMask()); +} + +bool DevicePromiseAttr::isLegalToInline(Operation *inlineSite, + Operation *inlinable) const { + // Look up the affinity of the inlining target site and only allow inlining if + // it matches exactly. We could make a decision as to whether we allow + // inlining when queues are subsets (so if the target site allows any queue + // and the inlinable allows queue 2 then allow, etc). In the future we may + // want to allow util.scope restrictions within the inline target to keep + // queue specification tighter but today most queue masks are wildcarded + // anyway. + auto targetAffinityAttr = IREE::Stream::AffinityAttr::lookup(inlineSite); + return *this == targetAffinityAttr; +} + //===----------------------------------------------------------------------===// // IREE::HAL::HALDialect //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index b43580c01090..c28c7d1a7bcf 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td @@ -989,4 +989,50 @@ def HAL_DeviceAffinityAttr : AttrDef +//===----------------------------------------------------------------------===// + +def HAL_DevicePromiseAttr : AttrDef, + DeclareAttrInterfaceMethods, +]> { + let mnemonic = "device.promise"; + let summary = [{promises a named device and optional queue affinity}]; + let description = [{ + Specifies that an annotated operation or scope is only allowed to execute on + a specific device that has not yet been declared and optionally a set of + queues (0-64) provided. Operations will not run on other queues. If the + queue mask is omitted then any queue on the device is allowed to execute the + specified operations. + + This is used in input programs to assign operations to particular devices + prior to the devices being declared. This allows device categories to be + referenced in the program as produced from the frontend and for those + device specifications to be provided later on during compilation. + Verification is performed as part of the ResolveDevicePromisesPass. + + Example: + ```mlir + // Any queue on whatever @device_a will be after declaration. + #hal.device.promise<@device_a> + // Queues 4 and 5 on whatever @device_b will be after declaration. + #hal.device.promise<@device_b, [4, 5]> + ``` + }]; + + let parameters = (ins + AttrParameter<"StringAttr", "">:$device, + AttrParameter<"int64_t", "">:$queue_mask + ); + + let hasCustomAssemblyFormat = 1; +} + #endif // IREE_DIALECT_HAL_IR_HAL_ATTRS diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir index 00f39f5e6b87..3a05d9cf8ecf 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir @@ -113,6 +113,17 @@ util.global private @device : !hal.device // ----- +"device.promise"() { + // CHECK: device_any = #hal.device.promise<@device> + device_any = #hal.device.promise<@device>, + // CHECK: device_queue_0 = #hal.device.promise<@device, [0]> + device_queue_0 = #hal.device.promise<@device, [0]>, + // CHECK: device_queue_123 = #hal.device.promise<@device, [1, 2, 3]> + device_queue_123 = #hal.device.promise<@device, [1, 2, 3]> +} : () -> () + +// ----- + // Tests that differing device affinities blocks inlining. // Here the @inline_target is using the default affinity specified on the // module and only functions also using the default affinity or a matching diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel index 3063a8d5a970..e5be69d00346 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel @@ -35,6 +35,7 @@ iree_compiler_cc_library( "PreprocessExecutables.cpp", "PruneExecutables.cpp", "RepeatDispatches.cpp", + "ResolveDevicePromises.cpp", "ResolveExportOrdinals.cpp", "SerializeExecutables.cpp", "StripExecutableContents.cpp", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt index 50ce275ad91b..c7b1207f5bf7 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt @@ -36,6 +36,7 @@ iree_cc_library( "PreprocessExecutables.cpp" "PruneExecutables.cpp" "RepeatDispatches.cpp" + "ResolveDevicePromises.cpp" "ResolveExportOrdinals.cpp" "SerializeExecutables.cpp" "StripExecutableContents.cpp" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index f2752c953097..e60d6fe5209a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -199,6 +199,7 @@ void buildHALDeviceAssignmentPassPipeline(OpPassManager &passManager, {&targetRegistry, targetOptions.targets})); } passManager.addPass(IREE::HAL::createMaterializeTargetDevicesPass()); + passManager.addPass(IREE::HAL::createResolveDevicePromisesPass()); passManager.addPass(IREE::HAL::createVerifyDevicesPass({&targetRegistry})); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td index cdd8d28f5884..8ff1cf7e5ee4 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td @@ -79,6 +79,18 @@ def MaterializeTargetDevicesPass : ]; } +def ResolveDevicePromisesPass : + Pass<"iree-hal-resolve-device-promises", "mlir::ModuleOp"> { + let summary = "Resolves `#hal.device.promise` attributes to their devices."; + let description = [{ + Resolves promised device affinities to the materialized device globals that + were promised. Verifies that all promises are resolved. + }]; + let dependentDialects = [ + "IREE::HAL::HALDialect", + ]; +} + def VerifyDevicesPass : Pass<"iree-hal-verify-devices", "mlir::ModuleOp"> { let summary = "Verifies that all devices can be targeted with the available compiler plugins."; diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ResolveDevicePromises.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ResolveDevicePromises.cpp new file mode 100644 index 000000000000..413c1073d160 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ResolveDevicePromises.cpp @@ -0,0 +1,155 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::HAL { + +#define GEN_PASS_DEF_RESOLVEDEVICEPROMISESPASS +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// --iree-hal-resolve-device-promises +//===----------------------------------------------------------------------===// + +struct ResolveDevicePromisesPass + : public IREE::HAL::impl::ResolveDevicePromisesPassBase< + ResolveDevicePromisesPass> { + using IREE::HAL::impl::ResolveDevicePromisesPassBase< + ResolveDevicePromisesPass>::ResolveDevicePromisesPassBase; + + void runOnOperation() override { + auto moduleOp = getOperation(); + + // Resolves a #hal.device.promise attr to a #hal.device.affinity. Fails if + // the referenced device is not found. + SymbolTable symbolTable(moduleOp); + auto resolvePromise = [&](Operation *fromOp, + IREE::HAL::DevicePromiseAttr promiseAttr) + -> FailureOr { + auto deviceOp = + symbolTable.lookupNearestSymbolFrom( + fromOp, promiseAttr.getDevice()); + if (!deviceOp) { + return fromOp->emitOpError() + << "references a promised device that was not declared: " + << promiseAttr; + } + return cast( + IREE::HAL::DeviceAffinityAttr::get(&getContext(), + FlatSymbolRefAttr::get(deviceOp), + promiseAttr.getQueueMask())); + }; + + // Resolves any #hal.device.promise attr on the op. + auto resolvePromiseAttrs = [&](Operation *op, DictionaryAttr attrDict) + -> std::optional> { + bool didReplaceAny = false; + auto newDict = dyn_cast_if_present(attrDict.replace( + [&](Attribute attr) + -> std::optional> { + if (auto promiseAttr = + dyn_cast_if_present(attr)) { + auto resolvedAttrOr = resolvePromise(op, promiseAttr); + if (failed(resolvedAttrOr)) { + return std::make_pair(attr, WalkResult::interrupt()); + } + didReplaceAny = true; + return std::make_pair(resolvedAttrOr.value(), + WalkResult::advance()); + } + return std::nullopt; + })); + if (newDict) { + return std::make_pair(newDict, didReplaceAny ? WalkResult::advance() + : WalkResult::skip()); + } else { + return std::make_pair(attrDict, WalkResult::interrupt()); + } + }; + auto resolveAllPromiseAttrs = + [&](Operation *op, + MutableArrayRef attrDicts) -> WalkResult { + bool didReplaceAny = false; + for (auto &attrDict : attrDicts) { + auto resolveState = resolvePromiseAttrs(op, attrDict); + if (!resolveState) { + // Failed to resolve while recursively replacing. + return WalkResult::interrupt(); + } else if (!resolveState->second.wasSkipped()) { + // Performed a replacement. + attrDict = resolveState->first; + didReplaceAny = true; + } + } + return didReplaceAny ? WalkResult::advance() : WalkResult::skip(); + }; + auto resolvePromisesOnOp = [&](Operation *op) -> WalkResult { + auto opAttrs = op->getAttrDictionary(); + if (opAttrs) { + auto resolveState = resolvePromiseAttrs(op, opAttrs); + if (!resolveState) { + // Failed to resolve while recursively replacing. + return WalkResult::interrupt(); + } else if (!resolveState->second.wasSkipped()) { + // Performed a replacement. + op->setAttrs(resolveState->first); + } + } + if (auto funcOp = dyn_cast(op)) { + SmallVector argAttrs; + funcOp.getAllArgAttrs(argAttrs); + auto argStatus = resolveAllPromiseAttrs(op, argAttrs); + if (argStatus.wasInterrupted()) { + return argStatus; + } else if (!argStatus.wasSkipped()) { + funcOp.setAllArgAttrs(argAttrs); + } + SmallVector resultAttrs; + funcOp.getAllResultAttrs(resultAttrs); + auto resultStatus = resolveAllPromiseAttrs(op, resultAttrs); + if (resultStatus.wasInterrupted()) { + return resultStatus; + } else if (!resultStatus.wasSkipped()) { + funcOp.setAllResultAttrs(resultAttrs); + } + } + return WalkResult::advance(); + }; + + // Walk the entire module and replace promises. + // We skip any symbol table op as all devices are top-level only. + if (resolvePromisesOnOp(moduleOp).wasInterrupted()) { + return signalPassFailure(); + } + if (moduleOp + .walk([&](Operation *op) { + if (op->hasTrait()) { + return WalkResult::skip(); // ignore isolated ops + } + return resolvePromisesOnOp(op); + }) + .wasInterrupted()) { + return signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::IREE::HAL diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel index 5e84ae3fa741..3d1d096fc06d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel @@ -32,6 +32,7 @@ iree_lit_test_suite( "preprocess_executables.mlir", "prune_executables.mlir", "repeat_dispatches.mlir", + "resolve_device_promises.mlir", "resolve_export_ordinals.mlir", "strip_executable_contents.mlir", "substitute_executables.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt index 8079246613d0..972947ed406e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt @@ -30,6 +30,7 @@ iree_lit_test_suite( "preprocess_executables.mlir" "prune_executables.mlir" "repeat_dispatches.mlir" + "resolve_device_promises.mlir" "resolve_export_ordinals.mlir" "strip_executable_contents.mlir" "substitute_executables.mlir" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_device_promises.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_device_promises.mlir new file mode 100644 index 000000000000..2b5985fe65ef --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_device_promises.mlir @@ -0,0 +1,43 @@ +// RUN: iree-opt --split-input-file --iree-hal-resolve-device-promises %s --mlir-print-local-scope --verify-diagnostics | FileCheck %s + +// Resolves device promises. + +// CHECK: module @module +module @module attributes { + // CHECK-SAME: stream.affinity = #hal.device.affinity<@device0, [1, 2, 3]> + stream.affinity = #hal.device.promise<@device0, [1, 2, 3]> +} { + util.global private @device0 = #hal.device.target<"vmvx"> : !hal.device + util.global private @device1 = #hal.device.target<"vmvx"> : !hal.device + // CHECK: util.func private @func + util.func private @func(%arg0: tensor { + // CHECK-SAME: arg.affinity = #hal.device.affinity<@device1> + arg.affinity = #hal.device.promise<@device1> + }) -> (tensor { + // CHECK-SAME: result.affinity = #hal.device.affinity<@device1> + result.affinity = #hal.device.promise<@device1> + }) attributes { + // CHECK-SAME: func.affinity = #hal.device.affinity<@device1> + func.affinity = #hal.device.promise<@device1> + } { + // CHECK: util.return + util.return { + // CHECK-SAME: some.affinities = [#hal.device.affinity<@device0>, #hal.device.affinity<@device1>] + some.affinities = [#hal.device.promise<@device0>, #hal.device.promise<@device1>] + } %arg0 : tensor + } +} + +// ----- + +// Verifies that promised devices exist. + +module @module { + util.global private @device = #hal.device.target<"vmvx"> : !hal.device + // expected-error@+1 {{op references a promised device that was not declared}} + util.func private @func() -> () attributes { + stream.affinity = #hal.device.promise<@unknown_device> + } { + util.return + } +} From 4a03eea3260dfe3946a5b667f4666f39ac4fad65 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 7 Mar 2024 10:08:01 -0800 Subject: [PATCH 09/25] Working around data tiling limitations a bit. This changes the passes to be module-level and lookup their targets based on their function context. The passes are not long for this world in their current form and the spaghettification that happened with the VMVX and LLVM-CPU paths makes it near impossible to factor properly without a rewrite. --- .../compiler/Codegen/Common/CPU/BUILD.bazel | 3 +- .../Codegen/Common/CPU/CMakeLists.txt | 3 +- ...ngPass.cpp => CPUMaterializeEncodings.cpp} | 242 +++++++++++------- .../compiler/Codegen/Common/CPU/PassDetail.h | 1 + .../iree/compiler/Codegen/Common/CPU/Passes.h | 10 +- .../compiler/Codegen/Common/CPU/Passes.td | 16 +- .../test/llvmcpu_materialize_encoding.mlir | 2 +- .../CPU/test/vmvx_materialize_encoding.mlir | 10 +- .../MaterializeEncodingIntoPackUnPack.cpp | 60 ++--- .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 2 +- .../iree/compiler/Dialect/HAL/IR/HALAttrs.cpp | 33 --- .../iree/compiler/Dialect/HAL/IR/HALAttrs.td | 6 - .../Dialect/HAL/Transforms/Passes.cpp | 4 +- .../Dialect/VMVX/Transforms/Passes.cpp | 12 +- .../compiler/GlobalOptimization/BUILD.bazel | 1 + .../GlobalOptimization/CMakeLists.txt | 1 + .../MaterializeHomogeneousEncodings.cpp | 18 +- .../compiler/Preprocessing/Common/BUILD.bazel | 1 + .../Preprocessing/Common/CMakeLists.txt | 1 + .../Preprocessing/Common/PadToIntrinsics.cpp | 39 ++- .../docs/community/blog/posts/microkernels.md | 4 +- 21 files changed, 255 insertions(+), 214 deletions(-) rename compiler/src/iree/compiler/Codegen/Common/CPU/{CPUMaterializeEncodingPass.cpp => CPUMaterializeEncodings.cpp} (78%) diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel index 0efe1f14e983..98da9d99c543 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel @@ -46,7 +46,7 @@ iree_compiler_cc_library( name = "CommonCPUPasses", srcs = [ "CPULowerToUKernels.cpp", - "CPUMaterializeEncodingPass.cpp", + "CPUMaterializeEncodings.cpp", "CPUPrepareUkernels.cpp", "Passes.cpp", ], @@ -62,6 +62,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Transforms", "//compiler/src/iree/compiler/Codegen/Utils", "//compiler/src/iree/compiler/Dialect/Encoding/IR", + "//compiler/src/iree/compiler/Dialect/HAL/Analysis", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//runtime/src/iree/builtins/ukernel:exported_bits", "@llvm-project//llvm:Support", diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt index fbae2accee65..72361213574c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt @@ -43,7 +43,7 @@ iree_cc_library( "Passes.h" SRCS "CPULowerToUKernels.cpp" - "CPUMaterializeEncodingPass.cpp" + "CPUMaterializeEncodings.cpp" "CPUPrepareUkernels.cpp" "Passes.cpp" DEPS @@ -84,6 +84,7 @@ iree_cc_library( iree::compiler::Codegen::Transforms iree::compiler::Codegen::Utils iree::compiler::Dialect::Encoding::IR + iree::compiler::Dialect::HAL::Analysis iree::compiler::Dialect::HAL::IR PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp similarity index 78% rename from compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp rename to compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp index e81ed85cf22e..c107bc7eec0a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp @@ -11,6 +11,7 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" +#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -29,15 +30,12 @@ namespace mlir::iree_compiler { -using namespace IREE::Encoding; -using IREE::HAL::ExecutableTargetAttr; - // Enumerate tile sizes to choose from when no specific architecture is // targeted. For narrow-{M,N} cases, this only enumerates on narrow M. The // narrow-N cases are handled by transposition in chooseMatmulTile. static SmallVector enumerateMatmulTilesVMVX(linalg::ContractionDimensions cDims, - ExecutableTargetAttr target) { + IREE::HAL::ExecutableTargetAttr target) { // TODO(hanchung): The ukernel path does not support 3d // codegen.query_tile_sizes op, so we disable dynamic tile shapes for // batch_matmul. @@ -59,7 +57,7 @@ enumerateMatmulTilesVMVX(linalg::ContractionDimensions cDims, // For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases // are handled by transposition in chooseMatmulTile. static SmallVector -enumerateMatmulTileRiscv32(ExecutableTargetAttr target) { +enumerateMatmulTileRiscv32(IREE::HAL::ExecutableTargetAttr target) { if (hasUkernel(target)) { return { TileMxNxK{8, 8, 4}, // Some reasonable tile shape. @@ -76,7 +74,8 @@ enumerateMatmulTileRiscv32(ExecutableTargetAttr target) { // For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases // are handled by transposition in chooseMatmulTile. static SmallVector -enumerateMatmulTileArm64(TypeRange elementTypes, ExecutableTargetAttr target) { +enumerateMatmulTileArm64(TypeRange elementTypes, + IREE::HAL::ExecutableTargetAttr target) { // Data-tiling for SVE is not implemented yet. if (hasFeature(target, "+sve") || hasFeature(target, "+sve2")) { return {}; @@ -166,7 +165,8 @@ enumerateMatmulTileArm64(TypeRange elementTypes, ExecutableTargetAttr target) { // For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases // are handled by transposition in chooseMatmulTile. static SmallVector -enumerateMatmulTileX86_64(TypeRange elementTypes, ExecutableTargetAttr target) { +enumerateMatmulTileX86_64(TypeRange elementTypes, + IREE::HAL::ExecutableTargetAttr target) { assert(elementTypes.size() == 3); Type lhs = elementTypes[0]; Type rhs = elementTypes[1]; @@ -376,9 +376,10 @@ chooseMatmulTile(ArrayRef enumeratedTiles, int64_t matmulNarrowM, return bestRatedTile; } -SmallVector +static SmallVector enumerateMatmulTileMxNxK(linalg::ContractionDimensions cDims, - TypeRange elementTypes, ExecutableTargetAttr target) { + TypeRange elementTypes, + IREE::HAL::ExecutableTargetAttr target) { if (isVMVXBackend(target)) { return enumerateMatmulTilesVMVX(cDims, target); } @@ -394,41 +395,10 @@ enumerateMatmulTileMxNxK(linalg::ContractionDimensions cDims, return {}; } -struct CPUMaterializeEncodingPass - : public CPUMaterializeEncodingBase { - CPUMaterializeEncodingPass() : targetAttr(nullptr) {} - explicit CPUMaterializeEncodingPass(IREE::HAL::ExecutableTargetAttr attr) - : targetAttr(attr) {} - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() override; - -private: - IREE::HAL::ExecutableTargetAttr targetAttr; -}; - -struct CPUMaterializeUpperBoundTileSizePass - : public CPUMaterializeUpperBoundTileSizeBase< - CPUMaterializeUpperBoundTileSizePass> { - CPUMaterializeUpperBoundTileSizePass() = default; - explicit CPUMaterializeUpperBoundTileSizePass( - ArrayRef attrs) - : targetAttrs(attrs) {} - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() override; - -private: - SmallVector targetAttrs; -}; - -FailureOr +static FailureOr materializeEncodingForTarget(RankedTensorType tensorType, - ExecutableTargetAttr targetAttr) { - IREE::Encoding::EncodingAttr encoding = + IREE::HAL::ExecutableTargetAttr targetAttr) { + auto encoding = dyn_cast_or_null(tensorType.getEncoding()); if (!encoding) { return failure(); @@ -464,7 +434,7 @@ materializeEncodingForTarget(RankedTensorType tensorType, } static MaterializeEncodingFn -getMaterializeEncodingFn(ExecutableTargetAttr targetAttr) { +getMaterializeEncodingFn(IREE::HAL::ExecutableTargetAttr targetAttr) { return [targetAttr]( RankedTensorType tensorType) -> FailureOr { @@ -481,8 +451,8 @@ getMaterializeEncodingFn(ExecutableTargetAttr targetAttr) { // executable variant. There, the padding amounts only control the size of // allocated buffers, so it's OK to over-estimate (only wasting some memory) // but not under-estimate (would cause buffer overruns) padding amounts. -static MaterializeEncodingFn -getUpperBoundMaterializeEncodingFn(ArrayRef targetAttrs) { +static MaterializeEncodingFn getUpperBoundMaterializeEncodingFn( + ArrayRef targetAttrs) { return [targetAttrs]( RankedTensorType tensorType) -> FailureOr { @@ -540,73 +510,165 @@ getMaterializeEncodingValueFn(IREE::HAL::ExecutableTargetAttr targetAttr) { return {}; } -void CPUMaterializeEncodingPass::runOnOperation() { - MLIRContext *context = &getContext(); - auto operation = getOperation(); - RewritePatternSet materializeEncodingPattern(context); - if (!targetAttr) - targetAttr = ExecutableTargetAttr::lookup(operation); - auto materializeEncodingFn = getMaterializeEncodingFn(targetAttr); +static LogicalResult materializeFuncOpEncodings( + FunctionOpInterface funcOp, + IREE::HAL::ExecutableTargetAttr executableTargetAttr) { + RewritePatternSet materializeEncodingPattern(funcOp.getContext()); + auto materializeEncodingFn = getMaterializeEncodingFn(executableTargetAttr); if (!materializeEncodingFn) { - return signalPassFailure(); + return failure(); } MaterializeEncodingTypeConverter typeConverter(materializeEncodingFn); - MaterializeEncodingConversionTarget target(*context); - auto materializeEncodingValueFn = getMaterializeEncodingValueFn(targetAttr); + MaterializeEncodingConversionTarget target(*funcOp.getContext()); + auto materializeEncodingValueFn = + getMaterializeEncodingValueFn(executableTargetAttr); populateMaterializeEncodingIntoPackUnPackPatterns(materializeEncodingPattern, target, typeConverter, materializeEncodingValueFn); - if (failed(applyPartialConversion(operation, target, + if (failed(applyPartialConversion(funcOp, target, std::move(materializeEncodingPattern)))) { - operation.emitOpError("materialization failed"); - return signalPassFailure(); + funcOp.emitOpError("materialization failed"); + return failure(); } - // Add patterns to fold pack/unpack ops with pad/extract_slice ops and resolve - // dims ops. + // Add patterns to fold pack/unpack ops with pad/extract_slice ops and + // resolve dims ops. { - RewritePatternSet patterns(context); + RewritePatternSet patterns(funcOp.getContext()); tensor::populateFoldIntoPackAndUnpackPatterns(patterns); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(operation, std::move(patterns)))) { - operation.emitOpError("folding patterns failed"); - return signalPassFailure(); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + funcOp.emitOpError("folding patterns failed"); + return failure(); } } + + return success(); } -void CPUMaterializeUpperBoundTileSizePass::runOnOperation() { - MLIRContext *context = &getContext(); - auto operation = getOperation(); - if (targetAttrs.empty()) { - targetAttrs = - IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(operation); - } - RewritePatternSet patterns(context); - MaterializeEncodingFn materializeEncodingFn = - getUpperBoundMaterializeEncodingFn(targetAttrs); - if (!materializeEncodingFn) { - return signalPassFailure(); +struct CPUMaterializeHostEncodingPass + : public CPUMaterializeHostEncodingBase { + CPUMaterializeHostEncodingPass() = default; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); } - populateMaterializeUpperBoundTileSizePatterns(patterns, - materializeEncodingFn); - if (failed(applyPatternsAndFoldGreedily(operation, std::move(patterns)))) { - operation.emitOpError( - "encoding padding sizes materialization pattern failed"); - return signalPassFailure(); + + void runOnOperation() override { + auto moduleOp = getOperation(); + + // Run required analysis passes. + IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) + return signalPassFailure(); + + for (auto funcOp : moduleOp.getOps()) { + // Gather the required executable targets for the function. Note that it's + // possible there are more required for ops nested within the function but + // this pass is a hack and can't handle that :shrug:. + SetVector executableTargets; + deviceAnalysis.gatherRequiredExecutableTargets(funcOp, executableTargets); + + // HACK: this pass is run on the host _but shouldn't be_. Because it's + // run on the host and IREE is a compiler capable of multi-targeting there + // may be multiple executable targets at any point in the host program. + // This pass can't handle that and assumes it's been checked earlier by + // spooky action at a distance. This needs to be fixed. + if (executableTargets.size() != 1) { + funcOp.emitOpError() << "has multiple executable targets and CPU data " + "tiling isn't built to support that"; + return signalPassFailure(); + } + + // Materialize encodings within the function. + if (failed( + materializeFuncOpEncodings(funcOp, executableTargets.front()))) { + return signalPassFailure(); + } + } } +}; + +std::unique_ptr createCPUMaterializeHostEncodingPass() { + return std::make_unique(); } -std::unique_ptr> -createCPUMaterializeEncodingPass(IREE::HAL::ExecutableTargetAttr targetAttr) { - return std::make_unique(targetAttr); +// NOTE: this runs on host modules and executables and has two paths to handle +// that. It should _not_ be running on both - target-specific codegen passes +// are not allowed on host programs and it's a big violation of layering that +// this exists. +struct CPUMaterializeDeviceEncodingPass + : public CPUMaterializeDeviceEncodingBase< + CPUMaterializeDeviceEncodingPass> { + CPUMaterializeDeviceEncodingPass() = default; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + auto funcOp = getOperation(); + auto executableTargetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp); + if (failed(materializeFuncOpEncodings(funcOp, executableTargetAttr))) { + return signalPassFailure(); + } + } +}; + +std::unique_ptr createCPUMaterializeDeviceEncodingPass() { + return std::make_unique(); } -std::unique_ptr> -createCPUMaterializeUpperBoundTileSizePass( - ArrayRef targetAttrs) { - return std::make_unique(targetAttrs); +// NOTE: this runs on host modules. +struct CPUMaterializeUpperBoundTileSizePass + : public CPUMaterializeUpperBoundTileSizeBase< + CPUMaterializeUpperBoundTileSizePass> { + CPUMaterializeUpperBoundTileSizePass() = default; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + // Run required analysis passes. + IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) + return signalPassFailure(); + + for (auto funcOp : moduleOp.getOps()) { + // Gather the required executable targets for the function. Note that it's + // possible there are more required for ops nested within the function but + // this pass is a hack and can't handle that :shrug:. + SetVector executableTargets; + deviceAnalysis.gatherRequiredExecutableTargets(funcOp, executableTargets); + + // Get patterns specialized for the executable targets used by the + // function. + RewritePatternSet patterns(&getContext()); + MaterializeEncodingFn materializeEncodingFn = + getUpperBoundMaterializeEncodingFn(executableTargets.getArrayRef()); + if (!materializeEncodingFn) + return signalPassFailure(); + populateMaterializeUpperBoundTileSizePatterns(patterns, + materializeEncodingFn); + + // Run patterns on the function. + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + funcOp.emitOpError( + "encoding padding sizes materialization pattern failed"); + return signalPassFailure(); + } + } + } +}; + +std::unique_ptr createCPUMaterializeUpperBoundTileSizePass() { + return std::make_unique(); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/PassDetail.h b/compiler/src/iree/compiler/Codegen/Common/CPU/PassDetail.h index 3a782ba2c265..25360a73c7e6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/PassDetail.h +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/PassDetail.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_CODEGEN_LLVMCPU_PASS_DETAIL_H_ #define IREE_COMPILER_CODEGEN_LLVMCPU_PASS_DETAIL_H_ +#include "mlir/IR/BuiltinOps.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.h b/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.h index c1fe3bfad537..f5f9a31d5b80 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.h @@ -13,6 +13,7 @@ #define IREE_COMPILER_CODEGEN_COMMON_CPU_PASSES_H_ #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" @@ -22,9 +23,8 @@ namespace mlir::iree_compiler { /// encoding.set_encoding -> tensor.pack /// encoding.unset_encoding -> tensor.unpack /// linalg.matmul -> linalg.mmt4d -std::unique_ptr> -createCPUMaterializeEncodingPass( - IREE::HAL::ExecutableTargetAttr targetAttr = nullptr); +std::unique_ptr createCPUMaterializeHostEncodingPass(); +std::unique_ptr createCPUMaterializeDeviceEncodingPass(); /// Like createLLVMCPUMaterializeEncodingPass, but specifically for /// encoding.upper_bound_tile_size, converting it to constants. @@ -41,9 +41,7 @@ createCPUMaterializeEncodingPass( /// converts upper_bound_tile_size to some specific constant size (currently 16) /// that is the largest tile size that we can use in VMVX, and can be adjusted // as needed. -std::unique_ptr> -createCPUMaterializeUpperBoundTileSizePass( - ArrayRef targetAttrs = {}); +std::unique_ptr createCPUMaterializeUpperBoundTileSizePass(); /// Adds CPU bufferization passes to the pipeline. void addCPUBufferizePasses(OpPassManager &funcPassManager); diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.td index 8e3004938b40..6329c532cc0d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/Passes.td @@ -13,14 +13,20 @@ include "mlir/Pass/PassBase.td" // Common Passes used for CPU-like backends (keep alphabetical) //===---------------------------------------------------------------------===// -def CPUMaterializeEncoding : - InterfacePass<"iree-codegen-cpu-materialize-encoding", "mlir::FunctionOpInterface"> { - let summary = "Materialize the encoding for tensor as specified by the backend"; - let constructor = "mlir::iree_compiler::createCPUMaterializeEncodingPass()"; +def CPUMaterializeHostEncoding : + Pass<"iree-codegen-cpu-materialize-host-encoding", "mlir::ModuleOp"> { + let summary = "Materialize the encoding for tensor as specified by the backend."; + let constructor = "mlir::iree_compiler::createCPUMaterializeHostEncodingPass()"; +} + +def CPUMaterializeDeviceEncoding : + InterfacePass<"iree-codegen-cpu-materialize-device-encoding", "mlir::FunctionOpInterface"> { + let summary = "Materialize the encoding for tensor as specified by the backend."; + let constructor = "mlir::iree_compiler::createCPUMaterializeDeviceEncodingPass()"; } def CPUMaterializeUpperBoundTileSize : - InterfacePass<"iree-codegen-cpu-materialize-upper-bound-tile-size", "mlir::FunctionOpInterface"> { + Pass<"iree-codegen-cpu-materialize-upper-bound-tile-size", "mlir::ModuleOp"> { let summary = "Materialize upper_bound_tile_size to constants."; let constructor = "mlir::iree_compiler::createCPUMaterializeUpperBoundTileSizePass()"; } diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir index 466bf1fa1da5..2a96d39970c8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-cpu-materialize-encoding),canonicalize,cse)" --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-cpu-materialize-device-encoding),canonicalize,cse)" --split-input-file %s | FileCheck %s func.func @set_encoding_with_padding_semantics_bf16_x86_64_avx512f() attributes { diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir index c1bb2794d62a..4c0fd3265fbc 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/test/vmvx_materialize_encoding.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-cpu-materialize-encoding),canonicalize,cse)" --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-cpu-materialize-device-encoding),canonicalize,cse)" --split-input-file %s | FileCheck %s #map = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> @@ -77,7 +77,7 @@ func.func @matmul_lowering_i8i8i32_vmvx_ukernel() attributes { #map3 = affine_map<(d0, d1, d2) -> (d2, d1)> #map4 = affine_map<(d0, d1, d2) -> (d0, d1)> func.func @fill_matmul(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) attributes { - hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb"> + hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb"> } { %c32_i64 = arith.constant 32 : i64 %cst = arith.constant 0.000000e+00 : f32 @@ -123,7 +123,7 @@ func.func @fill_matmul(%arg0: index, %arg1: index, %arg2: index, %arg3: index, % #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> func.func @set_encoding_dynamic() attributes { - hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb"> + hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb"> } { %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 @@ -177,7 +177,7 @@ func.func @set_encoding_dynamic() attributes { #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> func.func @unset_encoding_dynamic() attributes { - hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb"> + hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb"> } { %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 @@ -225,7 +225,7 @@ func.func @unset_encoding_dynamic() attributes { #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> func.func @matmul_lowering_f32f32f32_generic() attributes { - hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb"> + hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb"> } { %c0 = arith.constant 0 : index %M = hal.interface.constant.load[0] : index diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp index c585614ba79d..281c39849f4b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp @@ -28,10 +28,6 @@ namespace mlir::iree_compiler { -using namespace IREE::Encoding; -using IREE::Encoding::getEncodingAttr; -using IREE::HAL::ExecutableTargetAttr; - //===---------------------------------------------------------------------===// // Utility methods //===---------------------------------------------------------------------===// @@ -214,11 +210,11 @@ static std::optional getPaddingValue(Value &source) { /// For now this takes a `paddingValue` as input. The source is also taken /// as input so that these could be used with `OpConversionPatterns`. static FailureOr lowerSetEncodingOpToPackOp( - RewriterBase &rewriter, SetEncodingOp encodingOp, Value source, - MaterializeEncodingFn materializeEncodingFn, + RewriterBase &rewriter, IREE::Encoding::SetEncodingOp encodingOp, + Value source, MaterializeEncodingFn materializeEncodingFn, MaterializeEncodingValueFn materializeEncodingValueFn) { RankedTensorType resultType = encodingOp.getResultType(); - auto encoding = getEncodingAttr(resultType); + auto encoding = IREE::Encoding::getEncodingAttr(resultType); if (!encoding) { return failure(); } @@ -239,9 +235,6 @@ static FailureOr lowerSetEncodingOpToPackOp( return rewriter.notifyMatchFailure( encodingOp, "failed to generate runtime tile size query"); } - if (!encoding) { - return failure(); - } std::optional paddingValue; if (encoding.getRoundDimsToArray().empty()) { paddingValue = getPaddingValue(source); @@ -266,8 +259,8 @@ static FailureOr lowerSetEncodingOpToPackOp( /// The source is taken as input so that these could be used with /// `OpConversionPatterns`. static FailureOr lowerUnsetEncodingToUnpackOp( - RewriterBase &rewriter, UnsetEncodingOp encodingOp, Value packedValue, - MaterializeEncodingFn materializeEncodingFn, + RewriterBase &rewriter, IREE::Encoding::UnsetEncodingOp encodingOp, + Value packedValue, MaterializeEncodingFn materializeEncodingFn, MaterializeEncodingValueFn materializeEncodingValueFn) { RankedTensorType sourceType = encodingOp.getSourceType(); FailureOr materializeEncodingInfo = @@ -275,7 +268,7 @@ static FailureOr lowerUnsetEncodingToUnpackOp( if (failed(materializeEncodingInfo)) { return rewriter.notifyMatchFailure(encodingOp, "unhandled source encoding"); } - if (isNarrowNResult(getEncodingAttr(sourceType))) { + if (isNarrowNResult(IREE::Encoding::getEncodingAttr(sourceType))) { transposeInPlace(*materializeEncodingInfo); } // Create an `tensor.empty` for the result of the unpack operation. @@ -297,7 +290,8 @@ static FailureOr lowerUnsetEncodingToUnpackOp( } static FailureOr> lowerUpperBoundTileSizeOpToConstants( - RewriterBase &rewriter, UpperBoundTileSizeOp upperBoundTileSizeOp, + RewriterBase &rewriter, + IREE::Encoding::UpperBoundTileSizeOp upperBoundTileSizeOp, MaterializeEncodingFn materializeEncodingFn) { Location loc = upperBoundTileSizeOp.getLoc(); RankedTensorType tensorType = upperBoundTileSizeOp.getTensorType(); @@ -340,16 +334,17 @@ lowerContractionOpWithEncoding(RewriterBase &rewriter, auto lhsType = cast(inputs[0]->get().getType()); auto rhsType = cast(inputs[1]->get().getType()); auto resultType = cast(outputs[0].getType()); - auto lhsEncoding = getEncodingAttr(lhsType); - auto rhsEncoding = getEncodingAttr(rhsType); - auto resultEncoding = getEncodingAttr(resultType); + auto lhsEncoding = IREE::Encoding::getEncodingAttr(lhsType); + auto rhsEncoding = IREE::Encoding::getEncodingAttr(rhsType); + auto resultEncoding = IREE::Encoding::getEncodingAttr(resultType); if (!lhsEncoding || !rhsEncoding || !resultEncoding) { return failure(); } - if (lhsEncoding.getOperandIndex().getValue() != MATMUL_LHS || - rhsEncoding.getOperandIndex().getValue() != MATMUL_RHS || - resultEncoding.getOperandIndex().getValue() != MATMUL_RESULT) { + if (lhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_LHS || + rhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RHS || + resultEncoding.getOperandIndex().getValue() != + IREE::Encoding::MATMUL_RESULT) { return failure(); } @@ -415,7 +410,7 @@ lowerOpWithEncoding(RewriterBase &rewriter, tensor::EmptyOp emptyOp, loc, emptyOp.getMixedSizes(), resultType.getElementType()); return newEmptyOp; } - if (isNarrowNResult(getEncodingAttr(emptyType))) { + if (isNarrowNResult(IREE::Encoding::getEncodingAttr(emptyType))) { transposeInPlace(*materializeEncodingInfo); } FailureOr> innerTileSizesOfr = @@ -524,7 +519,7 @@ static FailureOr> getPackedDimsForDispatchTensor( if (failed(encodingInfo)) { return failure(); } - if (isNarrowNResult(getEncodingAttr(boundTensorType))) { + if (isNarrowNResult(IREE::Encoding::getEncodingAttr(boundTensorType))) { transposeInPlace(*encodingInfo); } @@ -731,12 +726,12 @@ struct MaterializeFlowDispatchTensorStoreOp /// Convert `set_encoding` op to `pack` op. struct SetEncodingOpToPackOpConversion - : public OpMaterializeEncodingPattern { + : public OpMaterializeEncodingPattern { using OpMaterializeEncodingPattern< - SetEncodingOp>::OpMaterializeEncodingPattern; + IREE::Encoding::SetEncodingOp>::OpMaterializeEncodingPattern; LogicalResult - matchAndRewrite(SetEncodingOp encodingOp, OpAdaptor adaptor, + matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto converter = static_cast( getTypeConverter()); @@ -763,12 +758,12 @@ struct SetEncodingOpToPackOpConversion /// Convert `unset_encoding` op to `unpack` op. struct UnsetEncodingOpToUnPackOpConversion - : public OpMaterializeEncodingPattern { + : public OpMaterializeEncodingPattern { using OpMaterializeEncodingPattern< - UnsetEncodingOp>::OpMaterializeEncodingPattern; + IREE::Encoding::UnsetEncodingOp>::OpMaterializeEncodingPattern; LogicalResult - matchAndRewrite(UnsetEncodingOp encodingOp, OpAdaptor adaptor, + matchAndRewrite(IREE::Encoding::UnsetEncodingOp encodingOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto converter = static_cast( this->getTypeConverter()); @@ -797,14 +792,15 @@ struct UnsetEncodingOpToUnPackOpConversion /// `materializeEncodingFn` returns a failure, the pattern will materialize it /// to the same shape. struct UpperBoundTileSizeToConstantOpConversion - : public OpRewritePattern { + : public OpRewritePattern { UpperBoundTileSizeToConstantOpConversion( MLIRContext *context, MaterializeEncodingFn materializeEncodingFn) - : OpRewritePattern(context), + : OpRewritePattern(context), materializeEncodingFn(materializeEncodingFn) {} - LogicalResult matchAndRewrite(UpperBoundTileSizeOp upperBoundTileSizeOp, - PatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(IREE::Encoding::UpperBoundTileSizeOp upperBoundTileSizeOp, + PatternRewriter &rewriter) const override { auto constants = lowerUpperBoundTileSizeOpToConstants( rewriter, upperBoundTileSizeOp, materializeEncodingFn); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index f7db82676104..6a48eb7b7710 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -750,7 +750,7 @@ void buildLLVMCPUCodegenConfigurationPassPipelineImpl( // TODO(#13888): This(createExpandF16OpToF32Pass()) pass is being added // way to late and should insted be be done during lowering to LLVM. .addPass(createExpandF16OpToF32Pass) - .addPass([&]() { return createCPUMaterializeEncodingPass(); }) + .addPass(createCPUMaterializeDeviceEncodingPass) // TODO: Remove the following pass the plumb support for // #hal.descriptor_type memory space through the stack. .addPass(createEraseHALDescriptorTypeFromMemRefPass); diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp index 8460588b4c31..f477925c1636 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp @@ -181,39 +181,6 @@ void DeviceTargetAttr::getExecutableTargets( } } -// Returns a list of target devices that may be active for the given -// operation. This will recursively walk parent operations until one with -// the `hal.device.targets` attribute is found. -static SmallVector lookupDeviceTargetAttrs(Operation *op) { - auto attrId = mlir::StringAttr::get(op->getContext(), "hal.device.targets"); - while (op) { - auto targetsAttr = op->getAttrOfType(attrId); - if (targetsAttr) { - SmallVector result; - for (auto targetAttr : targetsAttr) { - result.push_back(llvm::cast(targetAttr)); - } - return result; - } - op = op->getParentOp(); - } - return {}; // No devices found; let caller decide what to do. -} - -// static -SmallVector -DeviceTargetAttr::lookupExecutableTargets(Operation *op) { - SmallVector resultAttrs; - for (auto deviceTargetAttr : lookupDeviceTargetAttrs(op)) { - for (auto executableTargetAttr : deviceTargetAttr.getExecutableTargets()) { - if (!llvm::is_contained(resultAttrs, executableTargetAttr)) { - resultAttrs.push_back(executableTargetAttr); - } - } - } - return resultAttrs; -} - void IREE::HAL::DeviceTargetAttr::printStatusDescription( llvm::raw_ostream &os) const { cast().print(os, /*elideType=*/true); diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index c28c7d1a7bcf..92ef8128c841 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td @@ -528,12 +528,6 @@ def HAL_DeviceTargetAttr : AttrDef &resultAttrs); - - // DEPRECATED: analysis is required in order to query this information. - // Returns a list of all target executable configurations that may be - // required for the given operation. - static SmallVector - lookupExecutableTargets(Operation *op); }]; let hasCustomAssemblyFormat = 1; diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index e60d6fe5209a..a58a51fa5ca6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -293,9 +293,7 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, // lowering iree_linalg_ext.upper_bound_tile_size ops that exist on the // host. We should be using stream ops for performing such calculations that // we can attach affinities to and understand what devices are being used. - FunctionLikeNest(passManager).addPass([]() { - return createCPUMaterializeUpperBoundTileSizePass(); - }); + passManager.addPass(createCPUMaterializeUpperBoundTileSizePass()); // Preprocess executables using an external tool. The tool may mutate one or // more variants and even insert or remove variants. diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp index 7d23aaac5ff3..fc25a707e962 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp @@ -32,8 +32,8 @@ namespace mlir::iree_compiler::IREE::VMVX { // Variant configuration // --------------------------------------------------------------------------- -static void -buildVMVXConfigurationPassPipelineImpl(OpPassManager &modulePassManager) { +void buildVMVXConfigurationPassPipeline(OpPassManager &variantPassManager) { + OpPassManager &modulePassManager = variantPassManager.nest(); { FunctionLikeNest funcPassManager(modulePassManager); // --------------------------------------------------------------------------- @@ -43,25 +43,19 @@ buildVMVXConfigurationPassPipelineImpl(OpPassManager &modulePassManager) { } modulePassManager.addPass(createMaterializeUserConfigsPass()); FunctionLikeNest(modulePassManager) - .addPass([&]() { return createCPUMaterializeEncodingPass(); }) + .addPass(createCPUMaterializeDeviceEncodingPass) // TODO: Remove the following pass the plumb support for // #hal.descriptor_type memory space through the stack. .addPass(createEraseHALDescriptorTypeFromMemRefPass); modulePassManager.addPass(createVMVXSelectLoweringStrategyPass()); } -void buildVMVXConfigurationPassPipeline(OpPassManager &variantPassManager) { - OpPassManager &modulePassManager = variantPassManager.nest(); - buildVMVXConfigurationPassPipelineImpl(modulePassManager); -} - // --------------------------------------------------------------------------- // Variant Translation // --------------------------------------------------------------------------- static void buildVectorVMVXTransformPassPipeline(OpPassManager &variantPassManager) { - OpPassManager &modulePassManager = variantPassManager.nest(); // --------------------------------------------------------------------------- // Tensor-level optimization, kernel dispatch and lower to buffers. diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel index 8b4d450d2962..46c4988e4fed 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel +++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel @@ -83,6 +83,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow", "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/Flow/Transforms", + "//compiler/src/iree/compiler/Dialect/HAL/Analysis", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt index 4e29821594db..9410270d374a 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt +++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt @@ -98,6 +98,7 @@ iree_cc_library( iree::compiler::Dialect::Flow::Conversion::TensorToFlow iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::Flow::Transforms + iree::compiler::Dialect::HAL::Analysis iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::LinalgExt::IR diff --git a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp index bf6286954811..143b9694967f 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Codegen/Common/CPU/Passes.h" #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/GlobalOptimization/PassDetail.h" @@ -46,11 +47,16 @@ class MaterializeHomogeneousEncodingsPass void runOnOperation() override { auto moduleOp = getOperation(); - auto executableTargets = - IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(moduleOp); + IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) + return signalPassFailure(); + + SetVector executableTargets; + deviceAnalysis.gatherAllExecutableTargets(executableTargets); if (executableTargets.size() != 1) { return runNopPipeline(moduleOp); } + // TODO: vmvx has its own logic about supporting dynamic tile // sizes. It is not fully integrated into the pipeline, so we remain the // materialization to the end. @@ -65,12 +71,8 @@ class MaterializeHomogeneousEncodingsPass } OpPassManager passManager(moduleOp.getOperationName()); - FunctionLikeNest(passManager).addPass([&]() { - return createCPUMaterializeUpperBoundTileSizePass(executableTargets); - }); - FunctionLikeNest(passManager).addPass([&]() { - return createCPUMaterializeEncodingPass(executableTarget); - }); + passManager.addPass(createCPUMaterializeUpperBoundTileSizePass()); + passManager.addPass(createCPUMaterializeHostEncodingPass()); if (failed(runPipeline(passManager, moduleOp))) { return signalPassFailure(); } diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index a48188885592..81e204e18fdb 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -54,6 +54,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Utils", "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/Flow/Transforms", + "//compiler/src/iree/compiler/Dialect/HAL/Analysis", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "//compiler/src/iree/compiler/Dialect/Stream/IR", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index 1bc4c1e85972..723dacc1595b 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -65,6 +65,7 @@ iree_cc_library( iree::compiler::Codegen::Utils iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::Flow::Transforms + iree::compiler::Dialect::HAL::Analysis iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::Stream::IR diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp index 721522afaef5..a5e1a86f89c7 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp @@ -5,9 +5,13 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include +#include + #include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" +#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" #include "iree/compiler/Preprocessing/Common/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -141,10 +145,8 @@ expandMapsAndIterators(SmallVector &expandedMaps, } static SmallVector -getIntrinsics(linalg::LinalgOp linalgOp) { - SmallVector executableTargets = - IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(linalgOp); - +getIntrinsics(linalg::LinalgOp linalgOp, + ArrayRef executableTargets) { IREE::GPU::TargetAttr target; if (executableTargets.size() == 1) { auto targetAttr = executableTargets.front(); @@ -165,7 +167,9 @@ getIntrinsics(linalg::LinalgOp linalgOp) { }); } -static void padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp) { +static void +padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, + ArrayRef executableTargets) { if (!isa(*linalgOp)) { return; } @@ -174,7 +178,8 @@ static void padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp) { return; // Early exit if cannot find intrinsics or if multiple executable targets. - SmallVector intrinsics = getIntrinsics(linalgOp); + SmallVector intrinsics = + getIntrinsics(linalgOp, executableTargets); if (intrinsics.empty()) return; @@ -304,8 +309,9 @@ static void padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp) { rewriter.replaceOp(linalgOp, extracted); } -static void padContractionLikeOp(RewriterBase &rewriter, - linalg::LinalgOp linalgOp) { +static void padContractionLikeOp( + RewriterBase &rewriter, linalg::LinalgOp linalgOp, + ArrayRef executableTargets) { FailureOr contractionDims = mlir::linalg::inferContractionDims(linalgOp); @@ -319,7 +325,8 @@ static void padContractionLikeOp(RewriterBase &rewriter, } // Early exit if cannot find intrinsics or if multiple executable targets. - SmallVector intrinsics = getIntrinsics(linalgOp); + SmallVector intrinsics = + getIntrinsics(linalgOp, executableTargets); if (intrinsics.empty()) return; @@ -536,7 +543,12 @@ struct PadToIntrinsicsPass void PadToIntrinsicsPass::runOnOperation() { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); + auto funcOp = getOperation(); + IREE::HAL::DeviceAnalysis deviceAnalysis(funcOp->getParentOp()); + if (failed(deviceAnalysis.run())) + return signalPassFailure(); + bool padConvOps = padTargetType == PadTargetType::ConvOp || padTargetType == PadTargetType::All; bool padContractionOps = padTargetType == PadTargetType::ContractionOp || @@ -564,11 +576,16 @@ void PadToIntrinsicsPass::runOnOperation() { IRRewriter rewriter(context); for (auto convOp : targetConvOps) { rewriter.setInsertionPoint(convOp); - padConvOp(rewriter, convOp); + SetVector executableTargets; + deviceAnalysis.gatherRequiredExecutableTargets(convOp, executableTargets); + padConvOp(rewriter, convOp, executableTargets.getArrayRef()); } for (auto contractOp : targetContractOps) { rewriter.setInsertionPoint(contractOp); - padContractionLikeOp(rewriter, contractOp); + SetVector executableTargets; + deviceAnalysis.gatherRequiredExecutableTargets(contractOp, + executableTargets); + padContractionLikeOp(rewriter, contractOp, executableTargets.getArrayRef()); } } diff --git a/docs/website/docs/community/blog/posts/microkernels.md b/docs/website/docs/community/blog/posts/microkernels.md index f56b71303db9..2cb59280b811 100644 --- a/docs/website/docs/community/blog/posts/microkernels.md +++ b/docs/website/docs/community/blog/posts/microkernels.md @@ -357,10 +357,10 @@ module attributes {hal.device.targets = [#device_target_llvm_cpu]} { } ``` -### IR Dump After CPUMaterializeEncoding +### IR Dump After CPUMaterializeHostEncoding ```mlir -// -----// IR Dump After CPUMaterializeEncoding (iree-codegen-cpu-materialize-encoding) //----- // +// -----// IR Dump After CPUMaterializeHostEncoding (iree-codegen-cpu-materialize-host-encoding) //----- // [...] // -----// IR Dump After Canonicalizer (canonicalize) //----- // [...] From 2a202714b2848b2abb5508dcd3783fda81a3c95a Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Fri, 8 Mar 2024 09:26:07 -0800 Subject: [PATCH 10/25] Disabling DumpExecutableBenchmarks when multiple devices are present. I think we can generate one benchmark per device and only include dispatches used on that device but for now that is left as follow-on work. --- .../Transforms/DumpExecutableBenchmarks.cpp | 49 ++- .../test/dump_executable_benchmarks.mlir | 389 ++++++++++++------ 2 files changed, 295 insertions(+), 143 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp index c2487b84ca2b..6e5f110b68f3 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp @@ -7,6 +7,7 @@ #include #include +#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" @@ -74,7 +75,14 @@ static DispatchParamsMap gatherDispatchParams(mlir::ModuleOp moduleOp, for (auto funcOp : moduleOp.getOps()) { funcOp.walk([&](IREE::Stream::CmdDispatchOp dispatchOp) { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(dispatchOp); + auto affinityAttr = dyn_cast_if_present( + IREE::Stream::AffinityAttr::lookup(dispatchOp)); + if (!affinityAttr) { + LLVM_DEBUG( + llvm::dbgs() + << "skipping dispatch because it has no affinity specified\n"); + return; + } auto workloadValues = dispatchOp.getWorkload(); SmallVector workload; @@ -84,7 +92,7 @@ static DispatchParamsMap gatherDispatchParams(mlir::ModuleOp moduleOp, if (!matchPattern(workloadValue, m_ConstantInt(&workloadConstValue))) { LLVM_DEBUG({ auto firstEntryPoint = *dispatchOp.getEntryPointRefs().begin(); - llvm::dbgs() << "Skipping dispatch of entry point `" + llvm::dbgs() << "skipping dispatch of entry point `" << firstEntryPoint << "` (non-constant workload)\n"; }); return; @@ -123,7 +131,7 @@ static DispatchParamsMap gatherDispatchParams(mlir::ModuleOp moduleOp, APInt resourceLengthInt; if (!matchPattern(resourceLength, m_ConstantInt(&resourceLengthInt))) { - LLVM_DEBUG(llvm::dbgs() << "Skipping dispatch of entry point `" + LLVM_DEBUG(llvm::dbgs() << "skipping dispatch of entry point `" << entryPointAttr << "` (non-constant resource length)\n";); return; @@ -402,19 +410,21 @@ static void appendDispatchBenchmark(IREE::Stream::AffinityAttr affinityAttr, static mlir::OwningOpRef buildBenchmarkModule(IREE::HAL::ExecutableOp sourceExecutableOp, IREE::HAL::ExecutableVariantOp sourceVariantOp, - const DispatchParamsMap &dispatchParamsMap) { + const DispatchParamsMap &dispatchParamsMap, + DeviceAnalysis &deviceAnalysis) { // Empty module with default name. // We could use the original module name here to make tracking nicer. mlir::OwningOpRef moduleOp = mlir::ModuleOp::create(sourceExecutableOp.getLoc()); auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp->getBody()); - // Copy over the device targets from the original module. - // TODO(benvanik): filter this by the target of the variant. - moduleOp->getOperation()->setAttr( - "hal.device.targets", - sourceExecutableOp->getParentOfType()->getAttr( - "hal.device.targets")); + // Copy over the devices from the original module. Note that not all of the + // devices may be used and we should prune them, but even better than that + // would be to generate one module per device dispatches are made on such + // that users can isolate to individual devices. For now we just deal with + // it. + for (auto globalOp : deviceAnalysis.getDeviceGlobals()) + moduleBuilder.clone(*globalOp.getOperation()); // Clone the executable variant into the new module. auto executableOp = moduleBuilder.create( @@ -489,6 +499,21 @@ struct DumpExecutableBenchmarksPass auto moduleName = moduleOp.getName().value_or("module"); SymbolTable symbolTable(moduleOp); + DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) + return signalPassFailure(); + if (deviceAnalysis.getDeviceGlobals().empty()) { + mlir::emitRemark(moduleOp.getLoc()) + << "Executable benchmarks were requested but no devices were " + "declared in the module.\n"; + return; + } else if (deviceAnalysis.getDeviceGlobals().size() != 1) { + mlir::emitWarning(moduleOp.getLoc()) + << "Executable benchmarks were requested but there are multiple " + "devices in the module and the pass does not support that yet.\n"; + return; + } + // Analyze the module to find dispatch parameters. // This is a full walk of all stream.cmd.dispatch ops and will handle // filtering out dispatches that have dynamic parameters we don't @@ -511,8 +536,8 @@ struct DumpExecutableBenchmarksPass for (auto executableOp : moduleOp.getOps()) { for (auto variantOp : executableOp.getOps()) { - auto benchmarkModuleOp = - buildBenchmarkModule(executableOp, variantOp, dispatchParamsMap); + auto benchmarkModuleOp = buildBenchmarkModule( + executableOp, variantOp, dispatchParamsMap, deviceAnalysis); if (!benchmarkModuleOp) continue; auto fileName = (moduleName + "_" + executableOp.getName() + "_" + diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir index b42b38f5df01..bd6d630a4293 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir @@ -1,12 +1,15 @@ -// RUN: iree-opt --split-input-file --iree-hal-dump-executable-benchmarks %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-hal-dump-executable-benchmarks %s --verify-diagnostics | FileCheck %s // Tests dumping executable benchmarks to stdout - it's more common to use files // but this is much easier to test with lit. +// Ensure devices are copied and made available: #executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> -#device_target_cpu = #hal.device.target<"llvm-cpu", [ +// CHECK: util.global private @device +util.global private @device = #hal.device.target<"llvm-cpu", [ #executable_target_embedded_elf_x86_64 -]> +]> : !hal.device + #pipeline_layout_0 = #hal.pipeline.layout, @@ -21,144 +24,268 @@ ]> ]> -module attributes {hal.device.targets = [#device_target_cpu]} { - - // Executable should be dumped: - // CHECK: hal.executable private @ex0 - hal.executable private @ex0 { - hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) { - hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout_0) attributes { - translation_info = #iree_codegen.translation_info - } { - ^bb0(%device: !hal.device, %arg0: index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @dispatch0() { - func.return - } +// Executable should be dumped: +// CHECK: hal.executable private @ex0 +hal.executable private @ex0 { + hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) { + hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout_0) attributes { + translation_info = #iree_codegen.translation_info + } { + ^bb0(%device: !hal.device, %arg0: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @dispatch0() { + func.return } + } - hal.executable.export public @dispatch1 ordinal(1) layout(#pipeline_layout_1) attributes { - translation_info = #iree_codegen.translation_info - } { - ^bb0(%device: !hal.device, %arg0: index, %arg1: index): - %c1 = arith.constant 1 : index - %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0] - %1 = arith.addi %0, %arg1 : index - hal.return %1, %c1, %c1 : index, index, index - } - builtin.module { - func.func @dispatch1() { - func.return - } + hal.executable.export public @dispatch1 ordinal(1) layout(#pipeline_layout_1) attributes { + translation_info = #iree_codegen.translation_info + } { + ^bb0(%device: !hal.device, %arg0: index, %arg1: index): + %c1 = arith.constant 1 : index + %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0] + %1 = arith.addi %0, %arg1 : index + hal.return %1, %c1, %c1 : index, index, index + } + builtin.module { + func.func @dispatch1() { + func.return } } } +} + +// =========================================================================== +// @dispatch0 benchmark logic: +// =========================================================================== + +// CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch0_512_buffer : !hal.buffer +// CHECK-NEXT: util.initializer { +// CHECK: %[[BUFFER:.+]] = hal.allocator.allocate<%{{.+}} : !hal.allocator> affinity(%{{.+}}) type("DeviceVisible|DeviceLocal") usage("{{.+}}Dispatch{{.+}}") : !hal.buffer{%c768} +// CHECK-NEXT: util.global.store %[[BUFFER]], @ex0_embedded_elf_x86_64_dispatch0_512_buffer : !hal.buffer + +// CHECK: util.func public @ex0_embedded_elf_x86_64_dispatch0_512(%arg0: i32) +// CHECK-SAME: attributes {iree.abi.stub, iree.reflection = {iree.benchmark = "dispatch"}} { +// CHECK: %[[BATCH_SIZE:.+]] = arith.index_cast %arg0 : i32 to index + +// Create command buffer: +// CHECK: %[[CMD:.+]] = hal.command_buffer.create + +// Setup dispatch constants and bindings: +// CHECK: hal.command_buffer.push_constants<%[[CMD]] : !hal.command_buffer> layout(%{{.+}} : !hal.pipeline_layout) offset(0) values([%c100_i32, %c200_i32]) : i32, i32 +// CHECK: %[[BUFFER:.+]] = util.global.load @ex0_embedded_elf_x86_64_dispatch0_512_buffer +// CHECK: hal.command_buffer.push_descriptor_set<%[[CMD]] : !hal.command_buffer> layout(%{{.+}} : !hal.pipeline_layout)[%c0] bindings([ +// CHECK-NEXT: %c0 = (%[[BUFFER]] : !hal.buffer)[%c0, %c32], +// CHECK-NEXT: %c1 = (%[[BUFFER]] : !hal.buffer)[%c256, %c32], +// CHECK-NEXT: %c2 = (%[[BUFFER]] : !hal.buffer)[%c512, %c32] +// CHECK-NEXT: ]) + +// Calculate the workgroup count, which we leave symbolic until after +// translation: +// CHECK: %[[WORKGROUP_X:.+]], %[[WORKGROUP_Y:.+]], %[[WORKGROUP_Z:.+]] = +// CHECK-SAME: hal.executable.calculate_workgroups +// CHECK-SAME: target(@ex0::@embedded_elf_x86_64::@dispatch0) +// CHECK-SAME: workload([%c512]) + +// Get executable and target ordinal (outside of the loop). +// CHECK-DAG: %[[EXECUTABLE:.+]] = hal.executable.lookup device({{.+}}) executable(@ex0) : !hal.executable +// CHECK-DAG: %[[ORDINAL_0:.+]] = hal.executable.export.ordinal target(@ex0::@embedded_elf_x86_64::@dispatch0) : index + +// Dispatch up to batch size dispatches: +// CHECK: scf.for %{{.+}} = %c0 to %[[BATCH_SIZE]] step %c1 { +// CHECK-NEXT: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> target(%[[EXECUTABLE:.+]] : !hal.executable)[%[[ORDINAL_0]]] workgroups([%[[WORKGROUP_X]], %[[WORKGROUP_Y]], %[[WORKGROUP_Z]]]) +// CHECK-NEXT: hal.command_buffer.execution_barrier +// CHECK-NEXT: } + +// Submit and wait for dispatches to complete: +// CHECK: hal.command_buffer.finalize<%[[CMD]] : !hal.command_buffer> +// CHECK: hal.fence.await + +// =========================================================================== +// @dispatch1 benchmark logic (note two deduplicated dispatches): +// =========================================================================== + +// CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch1_512x1_buffer : !hal.buffer +// CHECK: util.func public @ex0_embedded_elf_x86_64_dispatch1_512x1(%arg0: i32) +// CHECK: %[[ORDINAL_1A:.+]] = hal.executable.export.ordinal target(@ex0::@embedded_elf_x86_64::@dispatch1) : index +// CHECK: hal.command_buffer.dispatch<%{{.+}} : !hal.command_buffer> target({{.+}})[%[[ORDINAL_1A]]] + +// CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch1_128x32_buffer : !hal.buffer +// CHECK: util.func public @ex0_embedded_elf_x86_64_dispatch1_128x32(%arg0: i32) +// CHECK: %[[ORDINAL_1B:.+]] = hal.executable.export.ordinal target(@ex0::@embedded_elf_x86_64::@dispatch1) : index +// CHECK: hal.command_buffer.dispatch<%{{.+}} : !hal.command_buffer> target({{.+}})[%[[ORDINAL_1B]]] + +util.func public @main(%dynamic_arg: i32) -> !stream.timepoint attributes { + stream.affinity = #hal.device.affinity<@device> +} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %c100_i32 = arith.constant 100 : i32 + %c200_i32 = arith.constant 200 : i32 + %c300_i32 = arith.constant 300 : i32 + %result, %result_timepoint = stream.resource.alloca uninitialized : !stream.resource{%c128} => !stream.timepoint + %6 = stream.cmd.execute await(%result_timepoint) => with(%result as %result_capture: !stream.resource{%c128}) { + // Dispatches with static and dynamic args. + stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch0[%c512](%c100_i32, %c200_i32 : i32, i32) { + ro %result_capture[%c0 for %c32] : !stream.resource{%c128}, + rw %result_capture[%c32 for %c32] : !stream.resource{%c128}, + rw %result_capture[%c64 for %c32] : !stream.resource{%c128} + } + // NOTE: today the dynamic args will prevent us from generating + // benchmarks. We could handle this better by tracking alignment and such. + stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch0[%c512](%c300_i32, %dynamic_arg : i32, i32) { + ro %result_capture[%c0 for %c32] : !stream.resource{%c128}, + rw %result_capture[%c32 for %c32] : !stream.resource{%c128}, + rw %result_capture[%c64 for %c32] : !stream.resource{%c128} + } + + // Multiple dispatches to a single entry point. + // Dispatches are deduplicated and the two 128x32x1 should combine. + stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c512, %c1] { + ro %result_capture[%c0 for %c64] : !stream.resource{%c128}, + rw %result_capture[%c64 for %c32] : !stream.resource{%c128} + } + stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c128, %c32] { + ro %result_capture[%c0 for %c64] : !stream.resource{%c128}, + rw %result_capture[%c64 for %c32] : !stream.resource{%c128} + } + stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c128, %c32] { + ro %result_capture[%c0 for %c64] : !stream.resource{%c128}, + rw %result_capture[%c64 for %c32] : !stream.resource{%c128} + } + } => !stream.timepoint + %39 = stream.resource.dealloca await(%6) => %result : !stream.resource{%c128} => !stream.timepoint + util.return %39 : !stream.timepoint +} + +// ----- +// expected-warning@-2 {{multiple devices in the module}} + +// Tests that multiple devices fail today. +// We should be creating one benchmark per executable with only the dispatches +// used by that executable. - // =========================================================================== - // @dispatch0 benchmark logic: - // =========================================================================== - - // CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch0_512_buffer : !hal.buffer - // CHECK-NEXT: util.initializer { - // CHECK: %[[BUFFER:.+]] = hal.allocator.allocate<%{{.+}} : !hal.allocator> affinity(%{{.+}}) type("DeviceVisible|DeviceLocal") usage("{{.+}}Dispatch{{.+}}") : !hal.buffer{%c768} - // CHECK-NEXT: util.global.store %[[BUFFER]], @ex0_embedded_elf_x86_64_dispatch0_512_buffer : !hal.buffer - - // CHECK: util.func public @ex0_embedded_elf_x86_64_dispatch0_512(%arg0: i32) - // CHECK-SAME: attributes {iree.abi.stub, iree.reflection = {iree.benchmark = "dispatch"}} { - // CHECK: %[[BATCH_SIZE:.+]] = arith.index_cast %arg0 : i32 to index - - // Create command buffer: - // CHECK: %[[CMD:.+]] = hal.command_buffer.create - - // Setup dispatch constants and bindings: - // CHECK: hal.command_buffer.push_constants<%[[CMD]] : !hal.command_buffer> layout(%{{.+}} : !hal.pipeline_layout) offset(0) values([%c100_i32, %c200_i32]) : i32, i32 - // CHECK: %[[BUFFER:.+]] = util.global.load @ex0_embedded_elf_x86_64_dispatch0_512_buffer - // CHECK: hal.command_buffer.push_descriptor_set<%[[CMD]] : !hal.command_buffer> layout(%{{.+}} : !hal.pipeline_layout)[%c0] bindings([ - // CHECK-NEXT: %c0 = (%[[BUFFER]] : !hal.buffer)[%c0, %c32], - // CHECK-NEXT: %c1 = (%[[BUFFER]] : !hal.buffer)[%c256, %c32], - // CHECK-NEXT: %c2 = (%[[BUFFER]] : !hal.buffer)[%c512, %c32] - // CHECK-NEXT: ]) - - // Calculate the workgroup count, which we leave symbolic until after - // translation: - // CHECK: %[[WORKGROUP_X:.+]], %[[WORKGROUP_Y:.+]], %[[WORKGROUP_Z:.+]] = - // CHECK-SAME: hal.executable.calculate_workgroups - // CHECK-SAME: target(@ex0::@embedded_elf_x86_64::@dispatch0) - // CHECK-SAME: workload([%c512]) - - // Get executable and target ordinal (outside of the loop). - // CHECK-DAG: %[[EXECUTABLE:.+]] = hal.executable.lookup device({{.+}}) executable(@ex0) : !hal.executable - // CHECK-DAG: %[[ORDINAL_0:.+]] = hal.executable.export.ordinal target(@ex0::@embedded_elf_x86_64::@dispatch0) : index - - // Dispatch up to batch size dispatches: - // CHECK: scf.for %{{.+}} = %c0 to %[[BATCH_SIZE]] step %c1 { - // CHECK-NEXT: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> target(%[[EXECUTABLE:.+]] : !hal.executable)[%[[ORDINAL_0]]] workgroups([%[[WORKGROUP_X]], %[[WORKGROUP_Y]], %[[WORKGROUP_Z]]]) - // CHECK-NEXT: hal.command_buffer.execution_barrier - // CHECK-NEXT: } - - // Submit and wait for dispatches to complete: - // CHECK: hal.command_buffer.finalize<%[[CMD]] : !hal.command_buffer> - // CHECK: hal.fence.await - - // =========================================================================== - // @dispatch1 benchmark logic (note two deduplicated dispatches): - // =========================================================================== - - // CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch1_512x1_buffer : !hal.buffer - // CHECK: util.func public @ex0_embedded_elf_x86_64_dispatch1_512x1(%arg0: i32) - // CHECK: %[[ORDINAL_1A:.+]] = hal.executable.export.ordinal target(@ex0::@embedded_elf_x86_64::@dispatch1) : index - // CHECK: hal.command_buffer.dispatch<%{{.+}} : !hal.command_buffer> target({{.+}})[%[[ORDINAL_1A]]] - - // CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch1_128x32_buffer : !hal.buffer - // CHECK: util.func public @ex0_embedded_elf_x86_64_dispatch1_128x32(%arg0: i32) - // CHECK: %[[ORDINAL_1B:.+]] = hal.executable.export.ordinal target(@ex0::@embedded_elf_x86_64::@dispatch1) : index - // CHECK: hal.command_buffer.dispatch<%{{.+}} : !hal.command_buffer> target({{.+}})[%[[ORDINAL_1B]]] - - util.func public @main(%dynamic_arg: i32) -> !stream.timepoint { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %c100_i32 = arith.constant 100 : i32 - %c200_i32 = arith.constant 200 : i32 - %c300_i32 = arith.constant 300 : i32 - %result, %result_timepoint = stream.resource.alloca uninitialized : !stream.resource{%c128} => !stream.timepoint - %6 = stream.cmd.execute await(%result_timepoint) => with(%result as %result_capture: !stream.resource{%c128}) { - // Dispatches with static and dynamic args. - stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch0[%c512](%c100_i32, %c200_i32 : i32, i32) { - ro %result_capture[%c0 for %c32] : !stream.resource{%c128}, - rw %result_capture[%c32 for %c32] : !stream.resource{%c128}, - rw %result_capture[%c64 for %c32] : !stream.resource{%c128} +#executable_target_embedded_elf_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64"> +#executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> +util.global private @device_a = #hal.device.target<"llvm-cpu", [ + #executable_target_embedded_elf_aarch64 +]> : !hal.device +util.global private @device_b = #hal.device.target<"llvm-cpu", [ + #executable_target_embedded_elf_x86_64 +]> : !hal.device + +#pipeline_layout = #hal.pipeline.layout + ]> +]> + +hal.executable private @ex_0 { + hal.executable.variant public @variant_a target(#executable_target_embedded_elf_aarch64) { + hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout) attributes { + translation_info = #iree_codegen.translation_info + } { + ^bb0(%device: !hal.device, %arg0: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @dispatch0() { + func.return } - // NOTE: today the dynamic args will prevent us from generating - // benchmarks. We could handle this better by tracking alignment and such. - stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch0[%c512](%c300_i32, %dynamic_arg : i32, i32) { - ro %result_capture[%c0 for %c32] : !stream.resource{%c128}, - rw %result_capture[%c32 for %c32] : !stream.resource{%c128}, - rw %result_capture[%c64 for %c32] : !stream.resource{%c128} + } + hal.executable.export public @dispatch1 ordinal(1) layout(#pipeline_layout) attributes { + translation_info = #iree_codegen.translation_info + } { + ^bb0(%device: !hal.device, %arg0: index, %arg1: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @dispatch1() { + func.return } - - // Multiple dispatches to a single entry point. - // Dispatches are deduplicated and the two 128x32x1 should combine. - stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c512, %c1] { - ro %result_capture[%c0 for %c64] : !stream.resource{%c128}, - rw %result_capture[%c64 for %c32] : !stream.resource{%c128} + } + } + hal.executable.variant public @variant_b target(#executable_target_embedded_elf_x86_64) { + hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout) attributes { + translation_info = #iree_codegen.translation_info + } { + ^bb0(%device: !hal.device, %arg0: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @dispatch0() { + func.return } - stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c128, %c32] { - ro %result_capture[%c0 for %c64] : !stream.resource{%c128}, - rw %result_capture[%c64 for %c32] : !stream.resource{%c128} + } + hal.executable.export public @dispatch1 ordinal(1) layout(#pipeline_layout) attributes { + translation_info = #iree_codegen.translation_info + } { + ^bb0(%device: !hal.device, %arg0: index, %arg1: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @dispatch1() { + func.return } - stream.cmd.dispatch @ex0::@embedded_elf_x86_64::@dispatch1[%c128, %c32] { - ro %result_capture[%c0 for %c64] : !stream.resource{%c128}, - rw %result_capture[%c64 for %c32] : !stream.resource{%c128} + } + } +} +hal.executable private @ex_1 { + hal.executable.variant public @variant_b target(#executable_target_embedded_elf_x86_64) { + hal.executable.export public @dispatch0 ordinal(0) layout(#pipeline_layout) attributes { + translation_info = #iree_codegen.translation_info + } { + ^bb0(%device: !hal.device, %arg0: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @dispatch0() { + func.return } - } => !stream.timepoint - %39 = stream.resource.dealloca await(%6) => %result : !stream.resource{%c128} => !stream.timepoint - util.return %39 : !stream.timepoint + } } } + +util.func public @main(%resource_a_arg: !stream.resource, %resource_b_arg: !stream.resource) -> (!stream.timepoint, !stream.timepoint) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %tp_a = stream.cmd.execute on(#hal.device.affinity<@device_a>) with(%resource_a_arg as %resource_a: !stream.resource{%c128}) { + stream.cmd.dispatch @ex_0::@variant_a::@dispatch0[%c512] { + rw %resource_a[%c0 for %c32] : !stream.resource{%c128} + } + stream.cmd.dispatch @ex_0::@variant_a::@dispatch1[%c512] { + rw %resource_a[%c0 for %c64] : !stream.resource{%c128} + } + stream.cmd.dispatch @ex_0::@variant_a::@dispatch1[%c128] { + rw %resource_a[%c0 for %c64] : !stream.resource{%c128} + } + } => !stream.timepoint + %tp_b = stream.cmd.execute on(#hal.device.affinity<@device_b>) with(%resource_b_arg as %resource_b: !stream.resource{%c128}) { + stream.cmd.dispatch @ex_0::@variant_a::@dispatch0[%c512] { + rw %resource_b[%c0 for %c32] : !stream.resource{%c128} + } + stream.cmd.dispatch @ex_0::@variant_a::@dispatch1[%c512] { + rw %resource_b[%c0 for %c64] : !stream.resource{%c128} + } + stream.cmd.dispatch @ex_0::@variant_b::@dispatch0[%c128] { + rw %resource_b[%c0 for %c64] : !stream.resource{%c128} + } + } => !stream.timepoint + util.return %tp_a, %tp_b : !stream.timepoint, !stream.timepoint +} From d895934256f6a9b3b03cb413a9c4925c9b39ecc1 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 28 Mar 2024 15:51:13 -0700 Subject: [PATCH 11/25] Adding VerifyAffinitiesPass to ensure all ops have some affinity. --- .../Dialect/Stream/Transforms/BUILD.bazel | 1 + .../Dialect/Stream/Transforms/CMakeLists.txt | 1 + .../Dialect/Stream/Transforms/Passes.cpp | 5 ++ .../Dialect/Stream/Transforms/Passes.td | 5 ++ .../Stream/Transforms/VerifyAffinities.cpp | 67 +++++++++++++++++++ .../Stream/Transforms/test/BUILD.bazel | 1 + .../Stream/Transforms/test/CMakeLists.txt | 1 + .../Transforms/test/verify_affinities.mlir | 33 +++++++++ 8 files changed, 114 insertions(+) create mode 100644 compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp create mode 100644 compiler/src/iree/compiler/Dialect/Stream/Transforms/test/verify_affinities.mlir diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel index 6471943fdf2e..1a1d2e2dd5b1 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel @@ -37,6 +37,7 @@ iree_compiler_cc_library( "ScheduleConcurrency.cpp", "ScheduleExecution.cpp", "SpecializeDispatches.cpp", + "VerifyAffinities.cpp", "VerifyAsyncAccessRanges.cpp", "VerifyLowerings.cpp", ], diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt index 4f1a114a2024..9d78c8ed9ef9 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt @@ -38,6 +38,7 @@ iree_cc_library( "ScheduleConcurrency.cpp" "ScheduleExecution.cpp" "SpecializeDispatches.cpp" + "VerifyAffinities.cpp" "VerifyAsyncAccessRanges.cpp" "VerifyLowerings.cpp" DEPS diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp index a0861e7c7fd6..b99b79279239 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp @@ -95,6 +95,11 @@ void buildStreamTensorPassPipeline(OpPassManager &passManager, // TODO(benvanik): compute affinities for executables. // TODO(benvanik): annotate all dispatches with preferred executable affinity. // TODO(benvanik): DFA to specify all value affinities and pin dispatches. + + // Verify that all ops that may require affinities have them assigned or + // available (on a parent scope, etc). This allows subsequent passes to trust + // that an affinity lookup will always return a valid affinity. + passManager.addPass(IREE::Stream::createVerifyAffinitiesPass()); } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td index 83f1d0fb4459..ca2ec3a5b61d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td @@ -486,6 +486,11 @@ def VerifyInputPass : let summary = "Verifies that input dialects are supported by the streams dialect."; } +def VerifyAffinitiesPass : + Pass<"iree-stream-verify-affinities", "mlir::ModuleOp"> { + let summary = "Verifies that all operations have affinities assigned (directly or indirectly)."; +} + def VerifyLoweringToTensorsPass : Pass<"iree-stream-verify-lowering-to-tensors", "mlir::ModuleOp"> { let summary = "Verifies that input dialects are converted to stream.tensor.* ops."; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp new file mode 100644 index 000000000000..042bbb860263 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp @@ -0,0 +1,67 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::Stream { + +#define GEN_PASS_DEF_VERIFYAFFINITIESPASS +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc" + +namespace { + +// Verifies that |op| has an affinity assigned on itself or a parent. +static LogicalResult +verifyAffinityAssigned(IREE::Stream::AffinityOpInterface op) { + if (!op.requiresAffinity()) { + return success(); // does not require an affinity + } else if (IREE::Stream::AffinityAttr::lookup(op)) { + return success(); // has an affinity + } + return op->emitOpError() + << "does not have an affinity assigned; ensure that the op or some " + "ancestor of it has a valid execution affinity assigned"; +} + +//===----------------------------------------------------------------------===// +// --iree-stream-verify-affinities +//===----------------------------------------------------------------------===// + +struct VerifyAffinitiesPass + : public IREE::Stream::impl::VerifyAffinitiesPassBase< + VerifyAffinitiesPass> { + void runOnOperation() override { + auto moduleOp = getOperation(); + if (moduleOp + .walk([&](Operation *op) { + if (isa(op)) { + return WalkResult::advance(); + } + if (auto affinityOp = + dyn_cast(op)) { + if (failed(verifyAffinityAssigned(affinityOp))) { + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }) + .wasInterrupted()) + return signalPassFailure(); + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::IREE::Stream diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel index 1f2104a550a1..524a1ce109a6 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel @@ -43,6 +43,7 @@ iree_lit_test_suite( "schedule_concurrency.mlir", "schedule_execution.mlir", "specialize_dispatches.mlir", + "verify_affinities.mlir", "verify_async_access_ranges.mlir", ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt index 2e2294a00054..5ea981160d91 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt @@ -41,6 +41,7 @@ iree_lit_test_suite( "schedule_concurrency.mlir" "schedule_execution.mlir" "specialize_dispatches.mlir" + "verify_affinities.mlir" "verify_async_access_ranges.mlir" TOOLS FileCheck diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/verify_affinities.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/verify_affinities.mlir new file mode 100644 index 000000000000..ee298109fc85 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/verify_affinities.mlir @@ -0,0 +1,33 @@ +// RUN: iree-opt --iree-stream-verify-affinities --split-input-file %s --verify-diagnostics | FileCheck %s + +// Tests that affinities on ops are checked. + +// CHECK-LABEL: @affinityOnOp +util.func public @affinityOnOp(%size: index) { + // CHECK: stream.async.alloca + %0 = stream.async.alloca on(#hal.device.promise<@device>) : !stream.resource{%size} + util.return +} + +// ----- + +// Tests that affinities on ancestor ops are allowed. + +// CHECK-LABEL: @affinityOnAncestorOp +util.func public @affinityOnAncestorOp(%size: index) attributes { + stream.affinity = #hal.device.promise<@device> +} { + // CHECK: stream.async.alloca + %0 = stream.async.alloca : !stream.resource{%size} + util.return +} + +// ----- + +// Tests that ops with no affinities fail. + +util.func public @missingAffinity(%size: index) { + // expected-error @+1 {{does not have an affinity assigned}} + %0 = stream.async.alloca : !stream.resource{%size} + util.return +} From 324e4ccc286aac534d6d13ebcbf240d96c12e604 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 8 May 2024 13:59:34 -0700 Subject: [PATCH 12/25] Adding `ordinal` support on `#hal.device.target`. This allows for distinguishing multiple devices matching the same requirements such as multiple GPUs on the same node. --- .../iree/compiler/Dialect/HAL/IR/HALAttrs.cpp | 73 ++++++++++++++++--- .../iree/compiler/Dialect/HAL/IR/HALAttrs.td | 6 +- .../Transforms/test/initialize_devices.mlir | 35 ++++++--- 3 files changed, 89 insertions(+), 25 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp index f477925c1636..7bab983b6e07 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp @@ -191,23 +191,37 @@ void IREE::HAL::DeviceTargetAttr::printStatusDescription( // is effectively: // ``` // %device_count = hal.devices.count : index -// %result:2 = scf.while(%i = 0, %device = null) { +// %result:3 = scf.while(%i = 0, %match_ordinal = 0, %device = null) { // %is_null = util.cmp.eq %device, null : !hal.device // %in_bounds = arith.cmpi slt %i, %device_count : index // %continue_while = arith.andi %is_null, %in_bounds : i1 -// scf.condition(%continue_while) %i, %device : index, !hal.device +// scf.condition(%continue_while) %i, %match_ordinal %device +// : index, index, !hal.device // } do { // %device_i = hal.devices.get %i : !hal.device -// %is_match = <>(%device_i) +// %device_match = <>(%device_i) +// %ordinal_match = arith.cmpi eq %match_ordinal, %device_ordinal : index +// %is_match = arith.andi %device_match, %ordinal_match : i1 // %try_device = arith.select %is_match, %device_i, null : !hal.device // %next_i = arith.addi %i, %c1 : index -// scf.yield %next_i, %try_device : index, !hal.device +// %match_adv = arith.select %device_match, %c1, %c0 : index +// %next_match_ordinal = arith.addi %match_ordinal, %match_adv : index +// scf.yield %next_i, %next_match_ordinal, %try_device +// : index, index !hal.device // } // ``` // Upon completion %result#1 contains the device (or null). +// If the target had an ordinal specified we skip matches until a match with the +// specified ordinal is reached. Value IREE::HAL::DeviceTargetAttr::buildDeviceEnumeration( Location loc, const IREE::HAL::TargetRegistry &targetRegistry, OpBuilder &builder) const { + // Device configuration can control selection beyond just the match + // expression. + auto configAttr = getConfiguration(); + IntegerAttr deviceOrdinalAttr = + configAttr ? configAttr.getAs("ordinal") : IntegerAttr{}; + // Defers to the target backend to build the device match or does a simple // fallback for unregistered backends (usually for testing, but may be used // as a way to bypass validation for out-of-tree experiments). @@ -231,28 +245,63 @@ Value IREE::HAL::DeviceTargetAttr::buildDeviceEnumeration( Value c0 = builder.create(loc, 0); Value c1 = builder.create(loc, 1); Value nullDevice = builder.create(loc, deviceType); + Value deviceOrdinal = deviceOrdinalAttr + ? builder.create( + loc, deviceOrdinalAttr.getInt()) + : c0; Value deviceCount = builder.create(loc, indexType); auto whileOp = builder.create( - loc, TypeRange{indexType, deviceType}, ValueRange{c0, nullDevice}, + loc, + TypeRange{ + /*i=*/indexType, + /*match_ordinal=*/indexType, + /*device=*/deviceType, + }, + ValueRange{ + /*i=*/c0, + /*match_ordinal=*/c0, + /*device=*/nullDevice, + }, [&](OpBuilder &beforeBuilder, Location loc, ValueRange operands) { Value isNull = beforeBuilder.create( - loc, operands[1], nullDevice); + loc, operands[/*device=*/2], nullDevice); Value inBounds = beforeBuilder.create( - loc, arith::CmpIPredicate::slt, operands[0], deviceCount); + loc, arith::CmpIPredicate::slt, operands[/*i=*/0], deviceCount); Value continueWhile = beforeBuilder.create(loc, isNull, inBounds); beforeBuilder.create(loc, continueWhile, operands); }, [&](OpBuilder &afterBuilder, Location loc, ValueRange operands) { + // Check whether the device is a match. Value device = afterBuilder.create( - loc, deviceType, operands[0]); - Value isMatch = buildDeviceMatch(loc, device, afterBuilder); + loc, deviceType, operands[/*i=*/0]); + Value isDeviceMatch = buildDeviceMatch(loc, device, afterBuilder); + + // Check whether whether this matching device ordinal is the requested + // ordinal out of all matching devices. + Value isOrdinalMatch = afterBuilder.create( + loc, arith::CmpIPredicate::eq, operands[/*match_ordinal=*/1], + deviceOrdinal); + Value nextMatchOrdinal = afterBuilder.create( + loc, operands[/*match_ordinal=*/1], + afterBuilder.create(loc, isDeviceMatch, c1, c0)); + + // Break if the device and ordinal match, otherwise continue with null. + Value isMatch = afterBuilder.create(loc, isDeviceMatch, + isOrdinalMatch); Value tryDevice = afterBuilder.create( loc, isMatch, device, nullDevice); - Value nextI = afterBuilder.create(loc, operands[0], c1); - afterBuilder.create(loc, ValueRange{nextI, tryDevice}); + + Value nextI = + afterBuilder.create(loc, operands[/*i=*/0], c1); + afterBuilder.create( + loc, ValueRange{ + /*i=*/nextI, + /*match_ordinal=*/nextMatchOrdinal, + /*device=*/tryDevice, + }); }); - return whileOp.getResult(1); + return whileOp.getResult(/*device=*/2); } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index 92ef8128c841..cc7c523c03c4 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td @@ -492,8 +492,10 @@ def HAL_DeviceTargetAttr : AttrDef : !hal // ----- -// Tests that #hal.device.target<*> enumerates all devices. +// Tests that #hal.device.target<*> enumerates all devices and tries to match +// a particular target with the given ordinal. The ordinal allows for multiple +// devices of the same type to be differentiated. // CHECK: util.global private @device_a : !hal.device -util.global private @device_a = #hal.device.target<"a", [ +util.global private @device_a = #hal.device.target<"a", { + ordinal = 2 : index +}, [ #hal.executable.target<"backend0", "format0">, #hal.executable.target<"backend1", "format1"> ]> : !hal.device @@ -42,19 +46,19 @@ util.global private @device_a = #hal.device.target<"a", [ // CHECK-NEXT: util.initializer // CHECK-DAG: %[[NULL_DEVICE:.+]] = util.null : !hal.device // CHECK-DAG: %[[DEVICE_COUNT:.+]] = hal.devices.count -// CHECK: %[[WHILE:.+]]:2 = scf.while (%arg0 = %c0, %arg1 = %[[NULL_DEVICE]]) -// CHECK-DAG: %[[IS_DEVICE_NULL:.+]] = util.cmp.eq %arg1, %[[NULL_DEVICE]] +// CHECK: %[[WHILE:.+]]:3 = scf.while (%arg0 = %c0, %arg1 = %c0, %arg2 = %[[NULL_DEVICE]]) +// CHECK-DAG: %[[IS_DEVICE_NULL:.+]] = util.cmp.eq %arg2, %[[NULL_DEVICE]] // CHECK-DAG: %[[IS_END:.+]] = arith.cmpi slt, %arg0, %[[DEVICE_COUNT]] // CHECK-DAG: %[[CONTINUE:.+]] = arith.andi %[[IS_DEVICE_NULL]], %[[IS_END]] -// CHECK-NEXT: scf.condition(%[[CONTINUE]]) %arg0, %arg1 +// CHECK-NEXT: scf.condition(%[[CONTINUE]]) %arg0, %arg1, %arg2 // CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%arg0: index, %arg1: !hal.device) +// CHECK-NEXT: ^bb0(%arg0: index, %arg1: index, %arg2: !hal.device) // CHECK-DAG: %[[DEVICE_N:.+]] = hal.devices.get %arg0 : !hal.device // NOTE: this is the fallback path for device matching unregistered targets. // Real targets can have much more complex logic if they so choose. // CHECK-DAG: %{{.+}}, %[[ID_MATCH:.+]] = hal.device.query<%[[DEVICE_N]] : !hal.device> key("hal.device.id" :: "a") -// CHECK-NEXT: %[[ANY_FORMAT_MATCH:.+]] = scf.if %[[ID_MATCH]] -> (i1) { +// CHECK-NEXT: %[[IS_DEVICE_MATCH:.+]] = scf.if %[[ID_MATCH]] -> (i1) { // CHECK-DAG: %{{.+}}, %[[FORMAT0_MATCH:.+]] = hal.device.query<%[[DEVICE_N]] : !hal.device> key("hal.executable.format" :: "format0") // CHECK-DAG: %{{.+}}, %[[FORMAT1_MATCH:.+]] = hal.device.query<%[[DEVICE_N]] : !hal.device> key("hal.executable.format" :: "format1") // CHECK-DAG: %[[FORMAT_MATCH_OR:.+]] = arith.ori %[[FORMAT0_MATCH]], %[[FORMAT1_MATCH]] @@ -62,13 +66,22 @@ util.global private @device_a = #hal.device.target<"a", [ // CHECK-NEXT: } else { // CHECK-DAG: scf.yield %false -// CHECK-DAG: %[[YIELD_DEVICE:.+]] = arith.select %[[ANY_FORMAT_MATCH]], %[[DEVICE_N]], %[[NULL_DEVICE]] +// Check that if the device matches this is the ordinal selected. If not the +// correct ordinal we'll skip it and continue to look for the next. +// CHECK-DAG: %[[IS_ORDINAL_MATCH:.+]] = arith.cmpi eq, %arg1, %c2 +// CHECK-DAG: %[[NEXT_MATCH_ADVANCE:.+]] = arith.select %[[IS_DEVICE_MATCH]], %c1, %c0 +// CHECK-DAG: %[[NEXT_MATCH_ORDINAL:.+]] = arith.addi %arg1, %[[NEXT_MATCH_ADVANCE]] + +// CHECK-DAG: %[[IS_MATCH:.+]] = arith.andi %[[IS_DEVICE_MATCH]], %[[IS_ORDINAL_MATCH]] +// CHECK-DAG: %[[YIELD_DEVICE:.+]] = arith.select %[[IS_MATCH]], %[[DEVICE_N]], %[[NULL_DEVICE]] // CHECK-DAG: %[[NEXT_I:.+]] = arith.addi %arg0, %c1 -// CHECK-NEXT: scf.yield %[[NEXT_I]], %[[YIELD_DEVICE]] -// CHECK-DAG: %[[IS_NULL:.+]] = util.cmp.eq %[[WHILE]]#1, %[[NULL_DEVICE]] +// CHECK-NEXT: scf.yield %[[NEXT_I]], %[[NEXT_MATCH_ORDINAL]], %[[YIELD_DEVICE]] + +// Error out if no device was found because at least one match is required. +// CHECK-DAG: %[[IS_NULL:.+]] = util.cmp.eq %[[WHILE]]#2, %[[NULL_DEVICE]] // CHECK-NEXT: scf.if %[[IS_NULL]] { // CHECK: util.status.check_ok %c5_i32, "HAL device `device_a` not found or unavailable: #hal.device.target<{{.+}}>" -// CHECK: util.global.store %[[WHILE]]#1, @device_a +// CHECK: util.global.store %[[WHILE]]#2, @device_a // ----- From 9c857e4413d1464c03f2317bba36522046e21a40 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 8 May 2024 20:33:45 -0700 Subject: [PATCH 13/25] Adding `#hal.device.alias` attribute for resolving device configs. This allows for less verbose "I don't care, pick something for me" attributes that are expanded into the full target devices and their executable configurations. Resolution happens early in the process so that any flags that may be influencing the resolved configurations are captured and no longer required by the pipeline. Tests and tooling could use these attributes in place of `#hal.device.target` but would need to run the pass as part of their pipeline in order to perform the expansion. Resolving in a pass vs doing so inline also allows for signaling errors and passing in scoped device target registries instead of relying on the globals that are not available in API usage. --- .../iree/compiler/Dialect/HAL/IR/HALAttrs.cpp | 236 +----------------- .../iree/compiler/Dialect/HAL/IR/HALAttrs.td | 62 ----- .../Dialect/HAL/IR/test/attributes.mlir | 14 ++ .../Dialect/HAL/Transforms/BUILD.bazel | 1 + .../Dialect/HAL/Transforms/CMakeLists.txt | 1 + .../Dialect/HAL/Transforms/Passes.cpp | 2 + .../compiler/Dialect/HAL/Transforms/Passes.td | 19 ++ .../HAL/Transforms/ResolveDeviceAliases.cpp | 134 ++++++++++ .../Dialect/HAL/Transforms/VerifyDevices.cpp | 4 +- .../Dialect/HAL/Transforms/test/BUILD.bazel | 1 + .../HAL/Transforms/test/CMakeLists.txt | 1 + .../test/resolve_device_aliases.mlir | 41 +++ 12 files changed, 229 insertions(+), 287 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/ResolveDeviceAliases.cpp create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_device_aliases.mlir diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp index 7bab983b6e07..6d3c4a22dcc1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp @@ -85,225 +85,6 @@ uint32_t CollectiveAttr::getEncodedValue() const { return value.packed; } -//===----------------------------------------------------------------------===// -// #hal.device.target<*> -//===----------------------------------------------------------------------===// - -// static -DeviceTargetAttr DeviceTargetAttr::get(MLIRContext *context, - StringRef deviceID) { - // TODO(benvanik): query default configuration from the target backend. - return get(context, StringAttr::get(context, deviceID), - DictionaryAttr::get(context), {}); -} - -// static -Attribute DeviceTargetAttr::parse(AsmParser &p, Type type) { - StringAttr deviceIDAttr; - DictionaryAttr configAttr; - SmallVector executableTargetAttrs; - // `<"device-id"` - if (failed(p.parseLess()) || failed(p.parseAttribute(deviceIDAttr))) { - return {}; - } - // `, ` - if (succeeded(p.parseOptionalComma())) { - if (succeeded(p.parseOptionalLSquare())) { - // `[targets, ...]` (optional) - do { - IREE::HAL::ExecutableTargetAttr executableTargetAttr; - if (failed(p.parseAttribute(executableTargetAttr))) - return {}; - executableTargetAttrs.push_back(executableTargetAttr); - } while (succeeded(p.parseOptionalComma())); - if (failed(p.parseRSquare())) - return {}; - } else { - // `{config dict}` (optional) - if (failed(p.parseAttribute(configAttr))) - return {}; - // `, [targets, ...]` (optional) - if (succeeded(p.parseOptionalComma())) { - if (failed(p.parseLSquare())) - return {}; - do { - IREE::HAL::ExecutableTargetAttr executableTargetAttr; - if (failed(p.parseAttribute(executableTargetAttr))) - return {}; - executableTargetAttrs.push_back(executableTargetAttr); - } while (succeeded(p.parseOptionalComma())); - if (failed(p.parseRSquare())) - return {}; - } - } - } - // `>` - if (failed(p.parseGreater())) { - return {}; - } - return get(p.getContext(), deviceIDAttr, configAttr, executableTargetAttrs); -} - -void DeviceTargetAttr::print(AsmPrinter &p) const { - auto &os = p.getStream(); - os << "<"; - p.printAttribute(getDeviceID()); - auto configAttr = getConfiguration(); - if (configAttr && !configAttr.empty()) { - os << ", "; - p.printAttribute(configAttr); - } - auto executableTargetAttrs = getExecutableTargets(); - if (!executableTargetAttrs.empty()) { - os << ", ["; - llvm::interleaveComma(executableTargetAttrs, os, - [&](auto executableTargetAttr) { - p.printAttribute(executableTargetAttr); - }); - os << "]"; - } - os << ">"; -} - -std::string DeviceTargetAttr::getSymbolNameFragment() { - return sanitizeSymbolName(getDeviceID().getValue().lower()); -} - -bool DeviceTargetAttr::hasConfigurationAttr(StringRef name) { - auto configAttr = getConfiguration(); - return configAttr && configAttr.get(name); -} - -void DeviceTargetAttr::getExecutableTargets( - SetVector &resultAttrs) { - for (auto attr : getExecutableTargets()) { - resultAttrs.insert(attr); - } -} - -void IREE::HAL::DeviceTargetAttr::printStatusDescription( - llvm::raw_ostream &os) const { - cast().print(os, /*elideType=*/true); -} - -// Produces a while-loop that enumerates each device available and tries to -// match it against the target information. SCF is... not very wieldy, but this -// is effectively: -// ``` -// %device_count = hal.devices.count : index -// %result:3 = scf.while(%i = 0, %match_ordinal = 0, %device = null) { -// %is_null = util.cmp.eq %device, null : !hal.device -// %in_bounds = arith.cmpi slt %i, %device_count : index -// %continue_while = arith.andi %is_null, %in_bounds : i1 -// scf.condition(%continue_while) %i, %match_ordinal %device -// : index, index, !hal.device -// } do { -// %device_i = hal.devices.get %i : !hal.device -// %device_match = <>(%device_i) -// %ordinal_match = arith.cmpi eq %match_ordinal, %device_ordinal : index -// %is_match = arith.andi %device_match, %ordinal_match : i1 -// %try_device = arith.select %is_match, %device_i, null : !hal.device -// %next_i = arith.addi %i, %c1 : index -// %match_adv = arith.select %device_match, %c1, %c0 : index -// %next_match_ordinal = arith.addi %match_ordinal, %match_adv : index -// scf.yield %next_i, %next_match_ordinal, %try_device -// : index, index !hal.device -// } -// ``` -// Upon completion %result#1 contains the device (or null). -// If the target had an ordinal specified we skip matches until a match with the -// specified ordinal is reached. -Value IREE::HAL::DeviceTargetAttr::buildDeviceEnumeration( - Location loc, const IREE::HAL::TargetRegistry &targetRegistry, - OpBuilder &builder) const { - // Device configuration can control selection beyond just the match - // expression. - auto configAttr = getConfiguration(); - IntegerAttr deviceOrdinalAttr = - configAttr ? configAttr.getAs("ordinal") : IntegerAttr{}; - - // Defers to the target backend to build the device match or does a simple - // fallback for unregistered backends (usually for testing, but may be used - // as a way to bypass validation for out-of-tree experiments). - auto buildDeviceMatch = [&](Location loc, Value device, - OpBuilder &builder) -> Value { - // Ask the target backend to build the match expression. It may opt to - // let the default handling take care of things. - Value match; - auto targetDevice = targetRegistry.getTargetDevice(getDeviceID()); - if (targetDevice) - match = targetDevice->buildDeviceTargetMatch(loc, device, *this, builder); - if (match) - return match; - return buildDeviceIDAndExecutableFormatsMatch( - loc, device, getDeviceID(), getExecutableTargets(), builder); - }; - - // Enumerate all devices and match the first one found (if any). - Type indexType = builder.getIndexType(); - Type deviceType = builder.getType(); - Value c0 = builder.create(loc, 0); - Value c1 = builder.create(loc, 1); - Value nullDevice = builder.create(loc, deviceType); - Value deviceOrdinal = deviceOrdinalAttr - ? builder.create( - loc, deviceOrdinalAttr.getInt()) - : c0; - Value deviceCount = builder.create(loc, indexType); - auto whileOp = builder.create( - loc, - TypeRange{ - /*i=*/indexType, - /*match_ordinal=*/indexType, - /*device=*/deviceType, - }, - ValueRange{ - /*i=*/c0, - /*match_ordinal=*/c0, - /*device=*/nullDevice, - }, - [&](OpBuilder &beforeBuilder, Location loc, ValueRange operands) { - Value isNull = beforeBuilder.create( - loc, operands[/*device=*/2], nullDevice); - Value inBounds = beforeBuilder.create( - loc, arith::CmpIPredicate::slt, operands[/*i=*/0], deviceCount); - Value continueWhile = - beforeBuilder.create(loc, isNull, inBounds); - beforeBuilder.create(loc, continueWhile, operands); - }, - [&](OpBuilder &afterBuilder, Location loc, ValueRange operands) { - // Check whether the device is a match. - Value device = afterBuilder.create( - loc, deviceType, operands[/*i=*/0]); - Value isDeviceMatch = buildDeviceMatch(loc, device, afterBuilder); - - // Check whether whether this matching device ordinal is the requested - // ordinal out of all matching devices. - Value isOrdinalMatch = afterBuilder.create( - loc, arith::CmpIPredicate::eq, operands[/*match_ordinal=*/1], - deviceOrdinal); - Value nextMatchOrdinal = afterBuilder.create( - loc, operands[/*match_ordinal=*/1], - afterBuilder.create(loc, isDeviceMatch, c1, c0)); - - // Break if the device and ordinal match, otherwise continue with null. - Value isMatch = afterBuilder.create(loc, isDeviceMatch, - isOrdinalMatch); - Value tryDevice = afterBuilder.create( - loc, isMatch, device, nullDevice); - - Value nextI = - afterBuilder.create(loc, operands[/*i=*/0], c1); - afterBuilder.create( - loc, ValueRange{ - /*i=*/nextI, - /*match_ordinal=*/nextMatchOrdinal, - /*device=*/tryDevice, - }); - }); - return whileOp.getResult(/*device=*/2); -} - //===----------------------------------------------------------------------===// // #hal.executable.target<*> //===----------------------------------------------------------------------===// @@ -945,6 +726,13 @@ Value IREE::HAL::DeviceFallbackAttr::buildDeviceEnumeration( // #hal.device.select<*> //===----------------------------------------------------------------------===// +// static +DeviceSelectAttr DeviceSelectAttr::get(MLIRContext *context, + ArrayRef values) { + return DeviceSelectAttr::get(context, IREE::HAL::DeviceType::get(context), + ArrayAttr::get(context, values)); +} + // static LogicalResult DeviceSelectAttr::verify(function_ref emitError, @@ -952,10 +740,12 @@ DeviceSelectAttr::verify(function_ref emitError, if (devicesAttr.empty()) return emitError() << "must have at least one device to select"; for (auto deviceAttr : devicesAttr) { - if (!deviceAttr.isa()) { - return emitError() << "can only select between #hal.device.target, " - "#hal.device.ordinal, #hal.device.fallback, or " - "other device initialization attributes"; + if (!mlir::isa(deviceAttr) && + !mlir::isa(deviceAttr)) { + return emitError() << "can only select between #hal.device.alias, " + "#hal.device.target, #hal.device.ordinal, " + "#hal.device.fallback, or other device " + "initialization attributes"; } } // TODO(benvanik): when !hal.device is parameterized we should check that the diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index cc7c523c03c4..2d10dc32aac0 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td @@ -476,65 +476,6 @@ def HAL_InterfaceBindingArrayAttr : TypedArrayAttrBase; -//===----------------------------------------------------------------------===// -// #hal.device.target<*> -//===----------------------------------------------------------------------===// - -def HAL_DeviceTargetAttr : AttrDef, -]> { - let mnemonic = "device.target"; - let summary = [{generic device target specification}]; - let description = [{ - Specifies the properties of a target runtime device. - Target devices are specified with a canonical identifier matching those used - by the runtime (such as `cpu`, `vulkan`, etc). Target devices may support - several target executable formats specified with `#hal.executable.target`. - An optional configuration dictionary allows for overriding backend defaults. - - If used to initialize a device global returns the first device matching the - target requirements or null if no devices match. An optional `ordinal` - index may be provided that selects the N-th matching device and is used to - select between multiple homogeneous devices. - - Example: - ```mlir - #hal.device.target<"llvm-cpu", { - device_configuration = ... - }, [ - #hal.executable.target<"llvm-cpu", "embedded-elf-arm_32">, - #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64">, - ]> - ``` - }]; - let parameters = (ins - AttrParameter<"StringAttr", "">:$deviceID, - AttrParameter<"DictionaryAttr", "">:$configuration, - ArrayRefParameter<"ExecutableTargetAttr", "">:$executable_targets - ); - let builders = [ - AttrBuilder<(ins "StringRef":$deviceID)>, - ]; - - let extraClassDeclaration = [{ - Type getType() { return IREE::HAL::DeviceType::get(getContext()); } - - // Returns a symbol-compatible name that pseudo-uniquely identifies this - // target. Callers must perform deduplication when required. - std::string getSymbolNameFragment(); - - // Returns true if there's an attribute with the given name in the - // configuration dictionary. - bool hasConfigurationAttr(StringRef name); - - // Returns zero or more executable targets that this device supports. - void getExecutableTargets( - SetVector &resultAttrs); - }]; - - let hasCustomAssemblyFormat = 1; -} - //===----------------------------------------------------------------------===// // #hal.executable.target<*> //===----------------------------------------------------------------------===// @@ -930,9 +871,6 @@ def HAL_DeviceSelectAttr : AttrDef, AttrBuilder<(ins "ArrayRef":$values )>, diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir index 3a05d9cf8ecf..ec0cbdd1d60b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir @@ -61,6 +61,20 @@ // ----- +// CHECK-LABEL: "device.aliases" +"device.aliases"() { + // CHECK-SAME: alias_0 = #hal.device.alias<"a"> : !hal.device + alias_0 = #hal.device.alias<"a"> : !hal.device, + // CHECK-SAME: alias_1 = #hal.device.alias<"b", {}> : !hal.device + alias_1 = #hal.device.alias<"b", {}> : !hal.device, + // CHECK-SAME: alias_2 = #hal.device.alias<"c"[4]> : !hal.device + alias_2 = #hal.device.alias<"c"[4]> : !hal.device, + // CHECK-SAME: alias_3 = #hal.device.alias<"d", {config = 123 : index}> + alias_3 = #hal.device.alias<"d", {config = 123 : index}> : !hal.device +} : () -> () + +// ----- + // CHECK-LABEL: "device.targets" "device.targets"() { // CHECK-SAME: target_0 = #hal.device.target<"a"> : !hal.device diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel index e5be69d00346..bdad4275b98b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel @@ -35,6 +35,7 @@ iree_compiler_cc_library( "PreprocessExecutables.cpp", "PruneExecutables.cpp", "RepeatDispatches.cpp", + "ResolveDeviceAliases.cpp", "ResolveDevicePromises.cpp", "ResolveExportOrdinals.cpp", "SerializeExecutables.cpp", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt index c7b1207f5bf7..72b0b74dcbfe 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt @@ -36,6 +36,7 @@ iree_cc_library( "PreprocessExecutables.cpp" "PruneExecutables.cpp" "RepeatDispatches.cpp" + "ResolveDeviceAliases.cpp" "ResolveDevicePromises.cpp" "ResolveExportOrdinals.cpp" "SerializeExecutables.cpp" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index a58a51fa5ca6..dc6ab9d67036 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -200,6 +200,8 @@ void buildHALDeviceAssignmentPassPipeline(OpPassManager &passManager, } passManager.addPass(IREE::HAL::createMaterializeTargetDevicesPass()); passManager.addPass(IREE::HAL::createResolveDevicePromisesPass()); + passManager.addPass( + IREE::HAL::createResolveDeviceAliasesPass({&targetRegistry})); passManager.addPass(IREE::HAL::createVerifyDevicesPass({&targetRegistry})); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td index 8ff1cf7e5ee4..aa896f247402 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td @@ -91,6 +91,25 @@ def ResolveDevicePromisesPass : ]; } +def ResolveDeviceAliasesPass : + Pass<"iree-hal-resolve-device-aliases", "mlir::ModuleOp"> { + let summary = "Resolves `#hal.device.alias` attributes to their expanded configurations."; + let description = [{ + Resolves device aliases to the concrete targets using defaults, flags, and + registered device configurations. + }]; + let options = [ + Option< + "targetRegistry", "target-registry", + "llvm::cl::TargetRegistryRef", "", + "Target registry containing the list of available devices and backends." + >, + ]; + let dependentDialects = [ + "IREE::HAL::HALDialect", + ]; +} + def VerifyDevicesPass : Pass<"iree-hal-verify-devices", "mlir::ModuleOp"> { let summary = "Verifies that all devices can be targeted with the available compiler plugins."; diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ResolveDeviceAliases.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ResolveDeviceAliases.cpp new file mode 100644 index 000000000000..0108aeea3ba1 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ResolveDeviceAliases.cpp @@ -0,0 +1,134 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::HAL { + +#define GEN_PASS_DEF_RESOLVEDEVICEALIASESPASS +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// --iree-hal-resolve-device-aliases +//===----------------------------------------------------------------------===// + +static FailureOr +resolveAliasAttr(Operation *forOp, IREE::HAL::DeviceAliasAttr aliasAttr, + const TargetRegistry &targetRegistry) { + // Lookup device in the registry. + auto targetDevice = + targetRegistry.getTargetDevice(aliasAttr.getDeviceID().getValue()); + if (!targetDevice) { + auto diagnostic = forOp->emitError(); + diagnostic << "unregistered device alias " << aliasAttr.getDeviceID() + << "; ensure it is linked into the compiler (available = [ "; + for (const auto &targetName : targetRegistry.getRegisteredTargetDevices()) { + diagnostic << "'" << targetName << "' "; + } + diagnostic << "])"; + return diagnostic; + } + + // Query the default device target. + auto defaultAttr = + targetDevice->getDefaultDeviceTarget(forOp->getContext(), targetRegistry); + assert(defaultAttr && "expected a default device target attr"); + + // Merge in any additional configuration from the alias attr. + if (aliasAttr.getOrdinal().has_value() || + (aliasAttr.getConfiguration() && !aliasAttr.getConfiguration().empty())) { + NamedAttrList configAttrs; + if (auto defaultConfigAttr = defaultAttr.getConfiguration()) { + for (auto existingAttr : defaultConfigAttr) { + configAttrs.push_back(existingAttr); + } + } + if (auto overrideConfigAttr = aliasAttr.getConfiguration()) { + for (auto overrideAttr : overrideConfigAttr) { + configAttrs.set(overrideAttr.getName(), overrideAttr.getValue()); + } + } + if (aliasAttr.getOrdinal().has_value()) { + configAttrs.set("ordinal", + IntegerAttr::get(IndexType::get(forOp->getContext()), + aliasAttr.getOrdinal().value())); + } + defaultAttr = IREE::HAL::DeviceTargetAttr::get( + forOp->getContext(), defaultAttr.getDeviceID(), + DictionaryAttr::get(forOp->getContext(), configAttrs), + defaultAttr.getExecutableTargets()); + } + + return defaultAttr; +} + +static FailureOr +resolveNestedAliasAttrs(Operation *forOp, Attribute attr, + const TargetRegistry &targetRegistry) { + if (auto aliasAttr = dyn_cast(attr)) { + return resolveAliasAttr(forOp, aliasAttr, targetRegistry); + } else if (auto selectAttr = dyn_cast(attr)) { + SmallVector resolvedAttrs; + bool didChange = false; + for (auto deviceAttr : selectAttr.getDevices()) { + auto resolvedAttr = + resolveNestedAliasAttrs(forOp, deviceAttr, targetRegistry); + if (failed(resolvedAttr)) { + return failure(); + } + didChange = didChange || *resolvedAttr != deviceAttr; + resolvedAttrs.push_back(*resolvedAttr); + } + return didChange ? IREE::HAL::DeviceSelectAttr::get(attr.getContext(), + resolvedAttrs) + : attr; + } else { + return attr; // pass-through + } +} + +struct ResolveDeviceAliasesPass + : public IREE::HAL::impl::ResolveDeviceAliasesPassBase< + ResolveDeviceAliasesPass> { + using IREE::HAL::impl::ResolveDeviceAliasesPassBase< + ResolveDeviceAliasesPass>::ResolveDeviceAliasesPassBase; + void runOnOperation() override { + // Walks all device globals and resolve any aliases found. + auto moduleOp = getOperation(); + for (auto globalOp : moduleOp.getOps()) { + if (!isa(globalOp.getGlobalType())) { + continue; + } + auto initialValue = globalOp.getGlobalInitialValue(); + if (!initialValue) { + continue; + } + auto resolvedValue = resolveNestedAliasAttrs(globalOp, initialValue, + *targetRegistry.value); + if (failed(resolvedValue)) { + return signalPassFailure(); + } + globalOp.setGlobalInitialValue(*resolvedValue); + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::IREE::HAL diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp index 21f37543f36a..e1ca6247996d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp @@ -50,7 +50,7 @@ verifyDeviceTargetAttr(Operation *deviceOp, auto diagnostic = deviceOp->emitError(); diagnostic << "unregistered target device " << deviceTargetAttr.getDeviceID() - << "; ensure it is linked in to the compiler (available = [ "; + << "; ensure it is linked into the compiler (available = [ "; for (const auto &targetName : targetRegistry.getRegisteredTargetDevices()) { diagnostic << "'" << targetName << "' "; } @@ -65,7 +65,7 @@ verifyDeviceTargetAttr(Operation *deviceOp, auto diagnostic = deviceOp->emitError(); diagnostic << "unregistered target backend " << executableTargetAttr.getBackend() - << "; ensure it is linked in to the compiler (available = [ "; + << "; ensure it is linked into the compiler (available = [ "; for (const auto &targetName : targetRegistry.getRegisteredTargetBackends()) { diagnostic << "'" << targetName << "' "; diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel index 3d1d096fc06d..812e2f9c589a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel @@ -32,6 +32,7 @@ iree_lit_test_suite( "preprocess_executables.mlir", "prune_executables.mlir", "repeat_dispatches.mlir", + "resolve_device_aliases.mlir", "resolve_device_promises.mlir", "resolve_export_ordinals.mlir", "strip_executable_contents.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt index 972947ed406e..0fa3fa2e6fbd 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt @@ -30,6 +30,7 @@ iree_lit_test_suite( "preprocess_executables.mlir" "prune_executables.mlir" "repeat_dispatches.mlir" + "resolve_device_aliases.mlir" "resolve_device_promises.mlir" "resolve_export_ordinals.mlir" "strip_executable_contents.mlir" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_device_aliases.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_device_aliases.mlir new file mode 100644 index 000000000000..82a45cc85d76 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_device_aliases.mlir @@ -0,0 +1,41 @@ +// RUN: iree-opt --split-input-file --iree-hal-resolve-device-aliases %s --mlir-print-local-scope --verify-diagnostics | FileCheck %s + +// CHECK: util.global private @device +// CHECK-SAME: #hal.device.target<"local" +// CHECK-SAME: extra_config = 4 : index +// CHECK-SAME: #hal.executable.target<"vmvx" +util.global private @device = #hal.device.alias<"vmvx", { + extra_config = 4 : index +}> : !hal.device + +// ----- + +// CHECK: util.global private @device_ordinal +// CHECK-SAME: #hal.device.target<"local" +// CHECK-SAME: ordinal = 123 : index +// CHECK-SAME: #hal.executable.target<"vmvx" +util.global private @device_ordinal = #hal.device.alias<"vmvx"[123]> : !hal.device + +// ----- + +// CHECK: util.global private @device_select +// CHECK-SAME: #hal.device.select<[ +// CHECK-SAME: #hal.device.target<"local", {ordinal = 0 : index} +// CHECK-SAME: #hal.device.target<"local", {ordinal = 1 : index} +util.global private @device_select = #hal.device.select<[ + #hal.device.alias<"vmvx"[0]> : !hal.device, + #hal.device.alias<"vmvx"[1]> : !hal.device +]> : !hal.device + +// ----- + +// expected-error@+1 {{unregistered device alias "__unregistered__"}} +util.global private @device_unregistered = #hal.device.alias<"__unregistered__"> : !hal.device + +// ----- + +// expected-error@+1 {{unregistered device alias "__unregistered__"}} +util.global private @device_select_unregistered = #hal.device.select<[ + #hal.device.alias<"vmvx"> : !hal.device, + #hal.device.alias<"__unregistered__"> : !hal.device +]> : !hal.device From 31074e728b180234272204a90a7adec4abd0696d Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 13 May 2024 14:22:42 -0700 Subject: [PATCH 14/25] Adding iree.abi.affinity arg/result attrs on the native ABI. These map to an opaque affinity on the tensor import/export ops and act as a seed to placement when lowering into stream. --- .../Torch/InputConversion/FuncConversion.cpp | 15 ++- .../Native/Transforms/WrapEntryPoints.cpp | 108 ++++++++++++------ .../TFLite/Transforms/WrapEntryPoints.cpp | 8 +- .../Dialect/Flow/IR/FlowOpFolders.cpp | 4 - .../Dialect/Flow/IR/test/tensor_folding.mlir | 8 +- .../Dialect/Flow/IR/test/tensor_ops.mlir | 26 +++++ .../Flow/Transforms/ExportBenchmarkFuncs.cpp | 3 +- .../Transforms/InsertDispatchDebugTargets.cpp | 3 +- .../compiler/Dialect/HAL/IR/HALOpFolders.cpp | 4 +- .../iree/compiler/Dialect/HAL/IR/HALOps.cpp | 20 ++-- .../iree/compiler/Dialect/HAL/IR/HALOps.td | 28 +++-- .../Conversion/FlowToStream/Patterns.cpp | 2 +- .../Conversion/HALToStream/Patterns.cpp | 28 +++-- .../Dialect/Stream/IR/StreamOpFolders.cpp | 4 +- .../Stream/Transforms/ConvertToStream.cpp | 2 +- .../Transforms/SimplifyGlobalAccesses.cpp | 10 +- .../Common/IREEImportPublic.cpp | 12 +- 17 files changed, 193 insertions(+), 92 deletions(-) diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp index 4bdd5c468729..b8a630d454b3 100644 --- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp +++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp @@ -292,13 +292,15 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() { aliasedResults.push_back( postambleBuilder.create( barrierInput.getLoc(), barrierInput.getType(), barrierInput, - barrierInputDims, exportStorage, waitFence)); + barrierInputDims, exportStorage, waitFence, + /*affinity=*/nullptr)); } else { aliasedResults.push_back(barrierInput); } } auto barrierOp = postambleBuilder.create( - funcOp.getLoc(), aliasedResults, coarseSignalFence); + funcOp.getLoc(), aliasedResults, coarseSignalFence, + /*affinity=*/nullptr); for (auto [barrierResult, meta] : llvm::zip_equal(barrierOp.getResults(), barrierResultMeta)) { Value exportStorage; @@ -308,7 +310,8 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() { Value exportedValue = postambleBuilder.create( funcOp.getLoc(), postambleBuilder.getType(), barrierResult, - TypeAttr::get(barrierResult.getType()), StringAttr()); + TypeAttr::get(barrierResult.getType()), /*name=*/nullptr, + /*affinity=*/nullptr); if (returnIndex >= 0) { newReturnOperands[returnIndex] = exportedValue; } @@ -380,7 +383,8 @@ LogicalResult ConvertedAsyncFunctionInfo::convertImmutableTensorArg( Value importedTensor = builder.create( loc, builtinTensorType, argValue, TypeAttr::get(builtinTensorType), waitFence, - /*name=*/StringAttr()); + /*name=*/nullptr, + /*affinity=*/nullptr); if (builtinTensorType != torchType) { importedTensor = builder.create( loc, torchType, importedTensor); @@ -415,7 +419,8 @@ LogicalResult ConvertedAsyncFunctionInfo::convertMutableTensorArg( loc, builtinTensorType, argValue, /*target_encoding=*/TypeAttr::get(builtinTensorType), /*wait_fence*/ fences->first, - /*name=*/StringAttr()); + /*name=*/nullptr, + /*affinity=*/nullptr); rewriter.replaceOpWithNewOp( userOp, copyToVtOp.getResult().getType(), imported); } else if (auto overwriteOp = diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp index bd2702f5f87e..e91bc54b1f0c 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp @@ -26,9 +26,9 @@ namespace mlir::iree_compiler::IREE::ABI { static IREE::ABI::InvocationModel getInvocationModel(Operation *op, IREE::ABI::InvocationModel defaultModel) { auto modelAttr = op->getAttrOfType("iree.abi.model"); - if (!modelAttr) + if (!modelAttr) { return defaultModel; - if (modelAttr == "coarse-fences") { + } else if (modelAttr == "coarse-fences") { return IREE::ABI::InvocationModel::CoarseFences; } else { return IREE::ABI::InvocationModel::Sync; @@ -51,7 +51,8 @@ static void stripABIAttrs(SmallVectorImpl &allAttrs) { for (auto attr : attrDict) { // TODO(benvanik): faster lookup. if (attr.getName() != "iree.abi.output" && - attr.getName() != "iree.abi.encoding") { + attr.getName() != "iree.abi.encoding" && + attr.getName() != "iree.abi.affinity") { attrs.push_back(attr); } } @@ -59,6 +60,17 @@ static void stripABIAttrs(SmallVectorImpl &allAttrs) { } } +static void stripABIAttrs(FunctionOpInterface op) { + SmallVector argAttrs; + op.getAllArgAttrs(argAttrs); + stripABIAttrs(argAttrs); + op.setAllArgAttrs(argAttrs); + SmallVector resultAttrs; + op.getAllResultAttrs(resultAttrs); + stripABIAttrs(resultAttrs); + op.setAllResultAttrs(resultAttrs); +} + // Creates the corresponding wrapper function for the given import function. static IREE::Util::FuncOp createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel, @@ -150,7 +162,7 @@ createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel, importOp.getLoc(), entryBuilder.getType(), device, IREE::HAL::FenceFlagBitfield::None); auto barrierOp = entryBuilder.create( - importOp.getLoc(), tensorArgs, waitFence); + importOp.getLoc(), tensorArgs, waitFence, /*affinity=*/nullptr); for (auto [argIndex, readyArg] : llvm::zip_equal(tensorArgIndices, barrierOp.getResults())) { entryArgs[argIndex] = readyArg; @@ -187,20 +199,24 @@ createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel, // NOTE: we insert a barrier on this above if needed so that the wait // fence will be signaled when the tensor is ready for consumption by the // import. - auto encoding = + auto encodingAttr = importOp.getArgAttrOfType(argIndex, "iree.abi.encoding"); - auto exportOp = entryBuilder.create( + auto tensorExportOp = entryBuilder.create( arg.getLoc(), newType, arg, - encoding ? encoding : TypeAttr::get(oldType), /*name=*/nullptr); - arguments.push_back(exportOp.getTarget()); + encodingAttr ? encodingAttr : TypeAttr::get(oldType), + /*name=*/nullptr, + /*affinity=*/nullptr); + arguments.push_back(tensorExportOp.getTarget()); } else { arguments.push_back(arg); } } - if (waitFence) + if (waitFence) { arguments.push_back(waitFence); - if (signalFence) + } + if (signalFence) { arguments.push_back(signalFence); + } // Make the call with the updated types. auto callOp = entryBuilder.create(importOp.getLoc(), @@ -225,12 +241,14 @@ createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel, // NOTE: we set the import pending on the signal fence from the import // indicating when the returned tensor is ready for consumption by the // program. - auto encoding = importOp.getResultAttrOfType( + auto encodingAttr = importOp.getResultAttrOfType( resultIndex, "iree.abi.encoding"); - results.push_back(entryBuilder.create( + auto tensorImportOp = entryBuilder.create( importOp.getLoc(), oldType, result, - encoding ? encoding : TypeAttr::get(oldType), signalFence, - /*name=*/nullptr)); + encodingAttr ? encodingAttr : TypeAttr::get(oldType), signalFence, + /*name=*/nullptr, + /*affinity=*/nullptr); + results.push_back(tensorImportOp); } else { results.push_back(result); } @@ -285,8 +303,9 @@ static LogicalResult wrapImportFunc(IREE::ABI::InvocationModel invocationModel, auto wrapperOp = createImportWrapperFunc( invocationModel, importOp, cast(importOp.getFunctionType()), newImportType, privateName); - if (!wrapperOp) + if (!wrapperOp) { return failure(); + } moduleOp.insert(++Block::iterator(importOp), wrapperOp); // Update the import to the new type and mark it as being converted so we @@ -302,15 +321,17 @@ static StringAttr getNameFromDictAttr(DictionaryAttr attr) { static StringAttr inferArgumentName(MLIRContext *context, int index, DictionaryAttr attrs) { - if (auto attrName = getNameFromDictAttr(attrs)) + if (auto attrName = getNameFromDictAttr(attrs)) { return attrName; + } return StringAttr::get(context, "input" + std::to_string(index)); } static StringAttr inferResultName(MLIRContext *context, int index, DictionaryAttr attrs) { - if (auto attrName = getNameFromDictAttr(attrs)) + if (auto attrName = getNameFromDictAttr(attrs)) { return attrName; + } return StringAttr::get(context, "output" + std::to_string(index)); } @@ -326,8 +347,9 @@ static void formatIOAttr(DictionaryAttr attrs, llvm::raw_ostream &os) { auto shouldIncludeAttr = [](const NamedAttribute &attr) { return attr.getName().getValue() != "iree.abi.name"; }; - if (!llvm::any_of(attrs, shouldIncludeAttr)) + if (!llvm::any_of(attrs, shouldIncludeAttr)) { return; + } os << " {"; llvm::interleaveComma(llvm::make_filter_range(attrs, shouldIncludeAttr), os, [&](auto argAttr) { @@ -363,8 +385,9 @@ formatSourceDeclaration(IREE::ABI::InvocationModel invocationModel, os << "func @" << publicName; os << "("; for (auto arg : exportOp.getArguments()) { - if (arg.getArgNumber() > 0) + if (arg.getArgNumber() > 0) { os << ", "; + } os << "%"; os << inferArgumentName(exportOp.getContext(), arg.getArgNumber(), getIOAttr(allArgAttrs, arg.getArgNumber())) @@ -377,8 +400,9 @@ formatSourceDeclaration(IREE::ABI::InvocationModel invocationModel, os << ") -> ("; for (auto [resultNumber, resultType] : llvm::enumerate(exportOp.getResultTypes())) { - if (resultNumber > 0) + if (resultNumber > 0) { os << ", "; + } os << "%"; os << inferResultName(exportOp.getContext(), resultNumber, getIOAttr(allResultAttrs, resultNumber)) @@ -494,8 +518,9 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, // Populate the reflection attrs based on the original types. populateReflectionAttrs(invocationModel, exportOp, wrapperOp); exportOp->removeAttr("iree.reflection"); - if (auto affinityAttr = exportOp->getAttr("stream.affinity")) + if (auto affinityAttr = exportOp->getAttr("stream.affinity")) { wrapperOp->setAttr("stream.affinity", affinityAttr); + } auto *entryBlock = wrapperOp.addEntryBlock(); auto entryBuilder = OpBuilder::atBlockBegin(entryBlock); @@ -506,8 +531,9 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, for (unsigned i = 0; i < exportOp.getNumArguments(); ++i) { auto outputAttr = exportOp.getArgAttrOfType(i, "iree.abi.output"); - if (!outputAttr) + if (!outputAttr) { continue; + } // Today all outputs need to be a !hal.buffer - we could change this // in the future to be something more generalized. auto storageArg = entryBlock->getArgument(i); @@ -544,14 +570,15 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, entryBlock->getArguments().slice(0, oldExportType.getNumInputs()))) { auto oldType = oldExportType.getInput(argIndex); if (llvm::isa(oldType)) { - auto encoding = + auto encodingAttr = exportOp.getArgAttrOfType(argIndex, "iree.abi.encoding"); - auto importOp = entryBuilder.create( + auto tensorImportOp = entryBuilder.create( arg.getLoc(), oldType, arg, - encoding ? encoding : TypeAttr::get(oldType), waitFence, + encodingAttr ? encodingAttr : TypeAttr::get(oldType), waitFence, inferArgumentName(entryBuilder.getContext(), argIndex, - exportOp.getArgAttrDict(argIndex))); - arguments.push_back(importOp.getTarget()); + exportOp.getArgAttrDict(argIndex)), + exportOp.getArgAttr(argIndex, "iree.abi.affinity")); + arguments.push_back(tensorImportOp.getTarget()); } else { arguments.push_back(arg); } @@ -565,14 +592,16 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, // Alias results to storage buffers if provided. for (unsigned resultIndex = 0; resultIndex < asyncResults.size(); ++resultIndex) { - if (!resultStorages[resultIndex]) + if (!resultStorages[resultIndex]) { continue; + } auto source = asyncResults[resultIndex]; auto sourceDims = IREE::Util::buildDynamicDimsForValue( exportOp.getLoc(), source, entryBuilder); auto aliasOp = entryBuilder.create( exportOp.getLoc(), source.getType(), source, sourceDims, - resultStorages[resultIndex], waitFence); + resultStorages[resultIndex], waitFence, + exportOp.getResultAttr(resultIndex, "iree.abi.affinity")); asyncResults[resultIndex] = cast(aliasOp.getResult()); } @@ -582,8 +611,9 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, if (signalFence) { SmallVector asyncTensors; for (auto result : asyncResults) { - if (llvm::isa(result.getType())) + if (llvm::isa(result.getType())) { asyncTensors.push_back(result); + } } if (asyncTensors.empty()) { // TODO(benvanik): maybe use a global timeline? global stores may not @@ -592,7 +622,7 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, signalFence); } else { auto barrierOp = entryBuilder.create( - exportOp.getLoc(), asyncTensors, signalFence); + exportOp.getLoc(), asyncTensors, signalFence, /*affinity=*/nullptr); asyncResults = llvm::to_vector(barrierOp.getResults()); } } @@ -603,20 +633,25 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, auto oldType = oldExportType.getResult(resultIndex); auto newType = newExportType.getResult(resultIndex); if (llvm::isa(oldType)) { - auto encoding = exportOp.getResultAttrOfType( + auto encodingAttr = exportOp.getResultAttrOfType( resultIndex, "iree.abi.encoding"); auto dynamicDims = IREE::Util::buildDynamicDimsForValue( result.getLoc(), result, entryBuilder); - results.push_back(entryBuilder.create( + auto tensorExportOp = entryBuilder.create( result.getLoc(), newType, result, - encoding ? encoding : TypeAttr::get(result.getType()), dynamicDims, + encodingAttr ? encodingAttr : TypeAttr::get(result.getType()), + dynamicDims, inferResultName(entryBuilder.getContext(), resultIndex, - exportOp.getResultAttrDict(resultIndex)))); + exportOp.getResultAttrDict(resultIndex)), + exportOp.getResultAttr(resultIndex, "iree.abi.affinity")); + results.push_back(tensorExportOp); } else { results.push_back(result); } } + stripABIAttrs(exportOp); + entryBuilder.create(exportOp.getLoc(), results); return wrapperOp; } @@ -643,8 +678,9 @@ static LogicalResult wrapExportFunc(IREE::ABI::InvocationModel invocationModel, // marshals arguments/results to the original function. auto wrapperOp = createExportWrapperFunc(invocationModel, exportOp, publicName); - if (!wrapperOp) + if (!wrapperOp) { return failure(); + } symbolTable.insert(wrapperOp, Block::iterator(exportOp)); return success(); diff --git a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp index 3e8c0fdb8d71..d21bfe0e4ce5 100644 --- a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp @@ -227,7 +227,8 @@ class WrapEntryPointsPass auto dynamicDims = inputDynamicDims.loadDynamicDims(recalculateBuilder); auto castOp = recalculateBuilder.create( loc, inputValue.getType(), inputPlaceholder, inputValue.getType(), - dynamicDims, /*wait_fence=*/Value{}, /*name=*/nullptr); + dynamicDims, /*wait_fence=*/Value{}, /*name=*/nullptr, + /*affinity=*/nullptr); inputValue.replaceAllUsesWith(castOp.getTarget()); } while (entryBlock.getNumArguments() > 0) { @@ -525,7 +526,8 @@ class WrapEntryPointsPass callOperands.push_back(entryBuilder.create( arg.getLoc(), inputDynamicDims.tensorType, arg, TypeAttr::get(inputDynamicDims.tensorType), dynamicDims, - /*wait_fence=*/Value{}, /*name=*/nullptr)); + /*wait_fence=*/Value{}, /*name=*/nullptr, + /*affinity=*/nullptr)); } auto callOp = entryBuilder.create( entryFuncOp.getLoc(), entryFuncOp, callOperands); @@ -541,7 +543,7 @@ class WrapEntryPointsPass } callResults.push_back(entryBuilder.create( result.getLoc(), bufferType, result, outputDynamicDims.tensorType, - dynamicDims, /*name=*/nullptr)); + dynamicDims, /*name=*/nullptr, /*affinity=*/nullptr)); for (auto [dynamicDim, globalOp] : llvm::zip_equal(dynamicDims, outputDynamicDims.globalOps)) { globalOp.createStoreOp(result.getLoc(), dynamicDim, entryBuilder); diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp index a2079c321f44..d9a6a0fdb3f2 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp @@ -288,7 +288,6 @@ struct ElideRedundantWorkloadValues struct ElideRedundantOperandsOfWorkgroupCountFromSliceOp : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(DispatchWorkgroupsOp op, PatternRewriter &rewriter) const override { Region &count = op.getWorkgroupCount(); @@ -369,7 +368,6 @@ void DispatchWorkgroupsOp::getCanonicalizationPatterns( // Bubble up the ordinal ops so that all uses go through this operation. struct BubbleUpOrdinalOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(DispatchWorkloadOrdinalOp ordinalOp, PatternRewriter &rewriter) const override { auto blockArg = llvm::dyn_cast(ordinalOp.getOperand()); @@ -894,7 +892,6 @@ namespace { template struct FlattenTensorCastLikeChain : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(CastOpTy reshapeOp, PatternRewriter &rewriter) const override { // We want the same result value/shape but to source from the ancestor. We @@ -1294,7 +1291,6 @@ namespace { // to be updated to use the source of the cast as the target tensor. struct FoldTensorUpdateOpWithCasts : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TensorUpdateOp updateOp, PatternRewriter &rewriter) const override { auto targetCastOp = updateOp.getTarget().getDefiningOp(); diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir index bcb1dbb3cfca..cf1e34189f98 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir @@ -351,8 +351,8 @@ util.func public @splatDynamicZeroElements(%value: f32, %dim: index) -> tensor<0 // ----- -// CHECK-LABEL: @cloneConst -util.func public @cloneConst() -> tensor<4xi32> { +// CHECK-LABEL: @cloneConstant +util.func public @cloneConstant() -> tensor<4xi32> { // CHECK-NEXT: %[[C:.+]] = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi32> %0 = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi32> %1 = flow.tensor.clone %0 : tensor<4xi32> @@ -362,8 +362,8 @@ util.func public @cloneConst() -> tensor<4xi32> { // ----- -// CHECK-LABEL: @cloneConstZeroElements -util.func public @cloneConstZeroElements() -> tensor<0x2xi32> { +// CHECK-LABEL: @cloneConstantZeroElements +util.func public @cloneConstantZeroElements() -> tensor<0x2xi32> { // CHECK-NEXT: %[[C:.+]] = arith.constant dense<> : tensor<0x2xi32> %0 = arith.constant dense<> : tensor<0x2xi32> // CHECK-NOT: flow.tensor.clone diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir index b0a19add6b49..8b01c00a1c53 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir @@ -7,6 +7,8 @@ util.func public @tensorReshape(%arg0 : tensor<4x4xf32>) -> tensor<16xf32> { util.return %0 : tensor<16xf32> } +// ----- + // CHECK-LABEL: @tensorReshapeScalar util.func public @tensorReshapeScalar(%arg0 : tensor) -> tensor { // CHECK-NEXT: %0 = flow.tensor.reshape %arg0 : tensor -> tensor @@ -14,6 +16,8 @@ util.func public @tensorReshapeScalar(%arg0 : tensor) -> tensor { util.return %0 : tensor } +// ----- + // CHECK-LABEL: @tensorReshapeDynamic util.func public @tensorReshapeDynamic(%arg0 : tensor) -> tensor { %c4 = arith.constant 4 : index @@ -23,6 +27,8 @@ util.func public @tensorReshapeDynamic(%arg0 : tensor) -> tensor } +// ----- + // CHECK-LABEL: @tensorReshapeComplex util.func public @tensorReshapeComplex(%arg0 : tensor<4x4xcomplex>) -> tensor<16xcomplex> { // CHECK-NEXT: flow.tensor.reshape %arg0 : tensor<4x4xcomplex> -> tensor<16xcomplex> @@ -48,6 +54,8 @@ util.func public @tensorLoad(%arg0 : tensor<4x4xf32>, %arg1 : index, %arg2 : ind util.return %0 : f32 } +// ----- + // CHECK-LABEL: @tensorLoadScalar util.func public @tensorLoadScalar(%arg0 : tensor) -> f32 { // CHECK-NEXT: %0 = flow.tensor.load %arg0 : tensor @@ -55,6 +63,8 @@ util.func public @tensorLoadScalar(%arg0 : tensor) -> f32 { util.return %0 : f32 } +// ----- + // CHECK-LABEL: @tensorLoadDynamic util.func public @tensorLoadDynamic(%arg0 : tensor, %arg1 : index, %arg2 : index) -> f32 { %c4 = arith.constant 4 : index @@ -72,6 +82,8 @@ util.func public @tensorStore(%arg0 : tensor<4x4xf32>, %arg1 : index, %arg2 : in util.return %0 : tensor<4x4xf32> } +// ----- + // CHECK-LABEL: @tensorStoreScalar util.func public @tensorStoreScalar(%arg0 : f32, %arg1 : tensor) -> tensor { // CHECK-NEXT: %0 = flow.tensor.store %arg0, %arg1 : tensor @@ -79,6 +91,8 @@ util.func public @tensorStoreScalar(%arg0 : f32, %arg1 : tensor) -> tensor< util.return %0 : tensor } +// ----- + // CHECK-LABEL: @tensorStoreDynamic util.func public @tensorStoreDynamic(%arg0 : tensor, %arg1 : index, %arg2 : index, %arg3 : f32) -> tensor { %c4 = arith.constant 4 : index @@ -114,6 +128,8 @@ util.func public @tensorSplat(%arg0 : f32) -> tensor<4x4xf32> { util.return %0 : tensor<4x4xf32> } +// ----- + // CHECK-LABEL: @tensorSplatScalar util.func public @tensorSplatScalar(%arg0 : f32) -> tensor { // CHECK-NEXT: %0 = flow.tensor.splat %arg0 : tensor @@ -121,6 +137,8 @@ util.func public @tensorSplatScalar(%arg0 : f32) -> tensor { util.return %0 : tensor } +// ----- + // CHECK-LABEL: @tensorSplatDynamic util.func public @tensorSplatDynamic(%arg0 : f32) -> tensor { %c4 = arith.constant 4 : index @@ -138,6 +156,8 @@ util.func public @tensorClone(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { util.return %0 : tensor<4x4xf32> } +// ----- + // CHECK-LABEL: @tensorCloneScalar util.func public @tensorCloneScalar(%arg0 : tensor) -> tensor { // CHECK-NEXT: %0 = flow.tensor.clone %arg0 : tensor @@ -145,6 +165,8 @@ util.func public @tensorCloneScalar(%arg0 : tensor) -> tensor { util.return %0 : tensor } +// ----- + // CHECK-LABEL: @tensorCloneDynamic util.func public @tensorCloneDynamic(%arg0 : tensor) -> tensor { %c4 = arith.constant 4 : index @@ -162,6 +184,8 @@ util.func public @tensorSlice(%arg0 : tensor<4x4xf32>, %arg1 : index, %arg2 : in util.return %0 : tensor<2x2xf32> } +// ----- + // CHECK-LABEL: @tensorSliceDynamic util.func public @tensorSliceDynamic(%arg0 : tensor, %arg1 : index, %arg2 : index) -> tensor { %c2 = arith.constant 2 : index @@ -180,6 +204,8 @@ util.func public @tensorUpdate(%arg0 : tensor<2x2xf32>, %arg1 : tensor<4x4xf32>, util.return %0 : tensor<4x4xf32> } +// ----- + // CHECK-LABEL: @tensorUpdateDynamic util.func public @tensorUpdateDynamic(%arg0 : tensor, %arg1 : tensor, %arg2 : index, %arg3 : index) -> tensor { %c1 = arith.constant 1 : index diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp index ecfb86a2ae36..3967d5097d0e 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp @@ -81,7 +81,8 @@ createBufferLikeGlobalOp(std::string name, Location loc, Type globalType, // hal.tensor.export auto bufferExportOp = initializerBuilder.create( loc, globalOp.getType(), splatOp.getResult(), - TypeAttr::get(splatOp.getType()), /*name=*/nullptr); + TypeAttr::get(splatOp.getType()), /*name=*/nullptr, + /*affinity=*/nullptr); // util.optimization_barrier (try to prevent optimizations across the export) auto barrierOp = initializerBuilder.create( loc, bufferExportOp.getTarget()); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp index beb995d92a79..49212b54b83e 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp @@ -103,7 +103,8 @@ static LogicalResult replaceReturnWithOpResults(mlir::ModuleOp moduleOp, if (llvm::isa(retVal.getType())) { auto type = IREE::HAL::BufferViewType::get(context); auto exportOp = builder.create( - loc, type, retVal, TypeAttr::get(retVal.getType()), /*name=*/nullptr); + loc, type, retVal, TypeAttr::get(retVal.getType()), /*name=*/nullptr, + /*affinity=*/nullptr); exports.push_back(exportOp.getResult()); newTypes.push_back(type); } else { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp index d082847334ef..23dab245d380 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp @@ -96,8 +96,8 @@ struct DeduplicateTensorBarrierSources } if (orderedSources.size() == op.getSources().size()) return failure(); - auto newOp = rewriter.create(op.getLoc(), orderedSources, - op.getSignalFence()); + auto newOp = rewriter.create( + op.getLoc(), orderedSources, op.getSignalFence(), op.getAffinityAttr()); SmallVector newResults; newResults.reserve(newOp.getNumResults()); for (unsigned newIndex : resultMapping) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 07ff982c1fbe..e28025e3ec49 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -431,15 +431,16 @@ LogicalResult ReturnOp::verify() { void TensorImportOp::build(OpBuilder &builder, OperationState &result, Type resultType, Value source, - TypeAttr targetEncoding, StringAttr name) { + TypeAttr targetEncoding, StringAttr name, + Attribute affinity) { build(builder, result, resultType, source, targetEncoding, - /*waitFence=*/Value{}, name); + /*waitFence=*/Value{}, name, affinity); } void TensorImportOp::build(OpBuilder &builder, OperationState &result, Type resultType, Value source, TypeAttr targetEncoding, Value waitFence, - StringAttr name) { + StringAttr name, Attribute affinity) { auto shapedType = llvm::cast(resultType); assert((isa(source.getType()) || shapedType.hasStaticShape()) && @@ -454,7 +455,7 @@ void TensorImportOp::build(OpBuilder &builder, OperationState &result, builder.getIndexAttr(i))); } build(builder, result, resultType, source, targetEncoding, dynamicDims, - waitFence, name); + waitFence, name, affinity); } Value TensorImportOp::getTiedResult(unsigned resultIndex) { @@ -530,10 +531,12 @@ LogicalResult TensorImportOp::verify() { void TensorExportOp::build(OpBuilder &builder, OperationState &result, Type resultType, Value source, - TypeAttr sourceEncoding, StringAttr name) { + TypeAttr sourceEncoding, StringAttr name, + Attribute affinity) { auto dynamicDims = IREE::Util::buildDynamicDimsForValue(result.location, source, builder); - build(builder, result, resultType, source, sourceEncoding, dynamicDims, name); + build(builder, result, resultType, source, sourceEncoding, dynamicDims, name, + affinity); } Value TensorExportOp::getTiedResult(unsigned resultIndex) { @@ -592,10 +595,11 @@ LogicalResult TensorAliasOp::verify() { //===----------------------------------------------------------------------===// void TensorBarrierOp::build(OpBuilder &builder, OperationState &result, - ValueRange sources, Value signalFence) { + ValueRange sources, Value signalFence, + Attribute affinity) { auto resultTypes = llvm::map_to_vector( sources, [](Value source) { return source.getType(); }); - build(builder, result, resultTypes, sources, signalFence); + build(builder, result, resultTypes, sources, signalFence, affinity); } Value TensorBarrierOp::getTiedResult(unsigned resultIndex) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 599c1ffecdc5..d35a0cb55459 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -125,13 +125,15 @@ def HAL_TensorImportOp : HAL_PureOp<"tensor.import", [ TypeAttr:$target_encoding, HAL_ShapeDynamicDims:$target_dims, Optional:$wait_fence, - OptionalAttr:$name + OptionalAttr:$name, + OptionalAttr:$affinity ); let results = (outs AnyTensor:$target ); let assemblyFormat = [{ + (`on` `(` $affinity^ `)`)? (`wait` `(` $wait_fence^ `)` `=` `` `>`)? $source ($name^)? @@ -145,14 +147,16 @@ def HAL_TensorImportOp : HAL_PureOp<"tensor.import", [ "Type":$resultType, "Value":$source, "TypeAttr":$targetEncoding, - "StringAttr":$name + "StringAttr":$name, + "Attribute":$affinity )>, OpBuilder<(ins "Type":$resultType, "Value":$source, "TypeAttr":$targetEncoding, "Value":$waitFence, - "StringAttr":$name + "StringAttr":$name, + "Attribute":$affinity )>, ]; @@ -190,13 +194,15 @@ def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [ AnyTensor:$source, TypeAttr:$source_encoding, HAL_ShapeDynamicDims:$source_dims, - OptionalAttr:$name + OptionalAttr:$name, + OptionalAttr:$affinity ); let results = (outs AnyTypeOf<[HAL_Buffer, HAL_BufferView]>:$target ); let assemblyFormat = [{ + (`on` `(` $affinity^ `)`)? $source ($name^)? `:` @@ -211,7 +217,8 @@ def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [ "Type":$resultType, "Value":$source, "TypeAttr":$sourceEncoding, - "StringAttr":$name + "StringAttr":$name, + "Attribute":$affinity )>, ]; @@ -273,13 +280,15 @@ def HAL_TensorAliasOp : HAL_PureOp<"tensor.alias", [ AnyTensor:$source, HAL_ShapeDynamicDims:$source_dims, AnyTypeOf<[HAL_Buffer, HAL_BufferView]>:$storage, - Optional:$wait_fence + Optional:$wait_fence, + OptionalAttr:$affinity ); let results = (outs AnyTensor:$result ); let assemblyFormat = [{ + (`on` `(` $affinity^ `)`)? (`wait` `(` $wait_fence^ `)` `=` `` `>`)? $source `:` type($source) (`{` $source_dims^ `}`)? `to` @@ -311,13 +320,15 @@ def HAL_TensorBarrierOp : HAL_Op<"tensor.barrier", [ let arguments = (ins Variadic:$sources, - HAL_Fence:$signal_fence + HAL_Fence:$signal_fence, + OptionalAttr:$affinity ); let results = (outs Variadic:$results ); let assemblyFormat = [{ + (`on` `(` $affinity^ `)`)? `join` `` `(` $sources `:` type($sources) `)` `=` `` `>` $signal_fence `:` type($signal_fence) @@ -327,7 +338,8 @@ def HAL_TensorBarrierOp : HAL_Op<"tensor.barrier", [ let builders = [ OpBuilder<(ins "ValueRange":$sources, - "Value":$signalFence + "Value":$signalFence, + "Attribute":$affinity )>, ]; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp index 060aeb897620..457f09fa6dd9 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp @@ -28,7 +28,7 @@ static Value buildResultSizeOf(Location loc, Value tensorValue, ConversionPatternRewriter &rewriter) { // TODO(benvanik): see if we can stash this on the side to avoid expensive // materialization of a bunch of redundant IR. - return rewriter.createOrFold( + return rewriter.create( loc, rewriter.getIndexType(), TypeAttr::get(tensorValue.getType()), dynamicDims, IREE::Stream::AffinityAttr::lookup(tensorValue.getDefiningOp())); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp index 2acd3ba4f51c..35eb31ff20da 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp @@ -49,11 +49,16 @@ struct ConvertTensorImportOp } } + auto affinityAttr = + dyn_cast_if_present(op.getAffinityAttr()); + if (!affinityAttr) { + affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + } + // Import (buffer view to stream resource). - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); auto resultType = rewriter.getType( IREE::Stream::Lifetime::External); - auto resultSize = rewriter.createOrFold( + Value resultSize = rewriter.create( op.getLoc(), rewriter.getIndexType(), TypeAttr::get(op.getTarget().getType()), adaptor.getTargetDims(), affinityAttr); @@ -77,7 +82,7 @@ struct ConvertTensorImportOp auto unknownType = rewriter.getType(); rewriter.replaceOpWithNewOp( op, unknownType, resource, resultSize, resultSize, affinityAttr, - affinityAttr); + /*target_affinity=*/IREE::Stream::AffinityAttr{}); return success(); } @@ -133,7 +138,11 @@ struct ConvertTensorExportOp return rewriter.notifyMatchFailure(op, "unsupported HAL cast conversion"); } - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto affinityAttr = + dyn_cast_if_present(op.getAffinityAttr()); + if (!affinityAttr) { + affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + } auto source = consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); @@ -145,7 +154,8 @@ struct ConvertTensorExportOp if (source.resource.getType() != externalType) { exportSource = rewriter.create( op.getLoc(), externalType, source.resource, source.resourceSize, - source.resourceSize, affinityAttr, affinityAttr); + source.resourceSize, /*source_affinity=*/IREE::Stream::AffinityAttr{}, + affinityAttr); } // Export (stream resource to buffer view). @@ -179,11 +189,15 @@ struct ConvertTensorAliasOp // All operations (if any) will happen on the device specified by the alias // as that indicates the affinity of the storage. - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto affinityAttr = + dyn_cast_if_present(op.getAffinityAttr()); + if (!affinityAttr) { + affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + } // Query the target storage buffer length; we will only populate up to // what is required for the output. - auto storageSize = rewriter.createOrFold( + Value storageSize = rewriter.create( op.getLoc(), rewriter.getIndexType(), TypeAttr::get(op.getSource().getType()), adaptor.getSourceDims(), affinityAttr); diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp index dbe9abcaa2f6..24d939a7ff73 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp @@ -1183,7 +1183,7 @@ struct TensorConstantToEmpty : public OpRewritePattern { return failure(); // Definitely empty if here. - auto resultSize = rewriter.createOrFold( + Value resultSize = rewriter.create( constantOp.getLoc(), rewriter.getIndexType(), TypeAttr::get(constantOp.getResultEncoding()), constantOp.getResultEncodingDims(), constantOp.getAffinityAttr()); @@ -1219,7 +1219,7 @@ struct TensorConstantToSplat : public OpRewritePattern { } auto resultType = IREE::Stream::ResourceType::get(constantOp.getContext()); - auto resultSize = rewriter.createOrFold( + Value resultSize = rewriter.create( constantOp.getLoc(), rewriter.getIndexType(), TypeAttr::get(constantOp.getResultEncoding()), constantOp.getResultEncodingDims(), constantOp.getAffinityAttr()); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp index 11873a2c73d7..6e26e2440e2e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp @@ -55,7 +55,7 @@ static Value buildTensorImportOp(Location loc, Value sourceTensor, // This may differ from the external encoding of the tensor as imports are // a transfer operation that may need to reformat the tensor. auto encodingAttr = TypeAttr::get(sourceTensor.getType()); - auto resultSize = builder.createOrFold( + Value resultSize = builder.create( loc, builder.getIndexType(), encodingAttr, dynamicDims, /*affinity=*/nullptr); diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp index b4b708a534ff..318160fe0727 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp @@ -228,23 +228,23 @@ rearrangeBlockGlobalAccesses(Block &block, // op order issues. SmallVector>> sequencedBuckets; sequencedBuckets.push_back({}); // Start in a sequence. - block.walk([&](Operation *op) { + for (auto &op : block) { auto &buckets = sequencedBuckets.back(); if (auto loadOp = dyn_cast(op)) { if (!immutableGlobals.contains(loadOp.getGlobalName())) { - buckets[loadOp.getGlobalName()].push_back(op); + buckets[loadOp.getGlobalName()].push_back(&op); } } else if (auto storeOp = dyn_cast(op)) { - buckets[storeOp.getGlobalName()].push_back(op); - } else if (doesOpBlockMotion(op)) { + buckets[storeOp.getGlobalName()].push_back(&op); + } else if (doesOpBlockMotion(&op)) { // Split point - all accesses after this point must not assume anything // about accesses before it. if (!buckets.empty()) { sequencedBuckets.push_back({}); } } - }); + } bool didRemoveAny = false; for (auto &buckets : sequencedBuckets) { didRemoveAny = optimizeBuckets(block, buckets) || didRemoveAny; diff --git a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp index 3200829635bf..6e8ecc0bcbb0 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp @@ -209,13 +209,15 @@ class TensorImportPattern // the work. rewriter.replaceOpWithNewOp( srcOp, resultType, adaptor.getSource(), TypeAttr::get(resultType), - /*name=*/nullptr); + /*name=*/nullptr, + /*affinity=*/nullptr); } else { // Dynamic dims explicitly provided (or wrong, in which case the verifier // will get it). rewriter.replaceOpWithNewOp( srcOp, resultType, adaptor.getSource(), TypeAttr::get(resultType), - adaptor.getTargetDims(), /*wait_fence=*/Value{}, /*name=*/nullptr); + adaptor.getTargetDims(), /*wait_fence=*/Value{}, /*name=*/nullptr, + /*affinity=*/nullptr); } return success(); } @@ -237,14 +239,16 @@ class TensorExportPattern // the work. rewriter.replaceOpWithNewOp( srcOp, resultType, adaptor.getSource(), - TypeAttr::get(adaptor.getSource().getType()), /*name=*/nullptr); + TypeAttr::get(adaptor.getSource().getType()), /*name=*/nullptr, + /*affinity=*/nullptr); } else { // Dynamic dims explicitly provided (or wrong, in which case the verifier // will get it). rewriter.replaceOpWithNewOp( srcOp, resultType, adaptor.getSource(), TypeAttr::get(adaptor.getSource().getType()), adaptor.getSourceDims(), - /*name=*/nullptr); + /*name=*/nullptr, + /*affinity=*/nullptr); } return success(); } From 5392b83b7121dad378298c1ad2dbbcc61de484b4 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 13 May 2024 09:07:44 -0700 Subject: [PATCH 15/25] Adding flow.tensor.transfer op. This allows for frontends to specify a clone of a tensor to a target context. This is lowered into a stream.async.transfer and with analysis will allow for hinting placement. More flow-level optimizations are likely to be required in larger programs but until we start to see those things are kept simple here. --- .../Dialect/Flow/IR/FlowOpFolders.cpp | 34 +++++++++++++ .../iree/compiler/Dialect/Flow/IR/FlowOps.cpp | 14 +++++ .../iree/compiler/Dialect/Flow/IR/FlowOps.td | 51 +++++++++++++++++++ .../Dialect/Flow/IR/test/tensor_folding.mlir | 15 ++++++ .../Dialect/Flow/IR/test/tensor_ops.mlir | 9 ++++ .../Conversion/FlowToStream/Patterns.cpp | 28 ++++++++-- .../FlowToStream/test/tensor_ops.mlir | 13 +++++ 7 files changed, 161 insertions(+), 3 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp index d9a6a0fdb3f2..1a60d1cc284e 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp @@ -1153,6 +1153,40 @@ void TensorCloneOp::getCanonicalizationPatterns(RewritePatternSet &results, results.insert>(context); } +//===----------------------------------------------------------------------===// +// flow.tensor.transfer +//===----------------------------------------------------------------------===// + +namespace { + +// Attempts to identify trivial cases where we locally recognize that a tensor +// is transferred to the same context it's already on. This does not look across +// control flow edges or globals and is mostly for simplifying IR that may come +// in with a transfer on every single tensor. +struct ElideRedundantTransfer : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TensorTransferOp op, + PatternRewriter &rewriter) const override { + auto baseValue = + IREE::Util::TiedOpInterface::findTiedBaseValue(op.getOperand()); + if (auto transferOp = dyn_cast_if_present( + baseValue.getDefiningOp())) { + if (transferOp.getTarget() == op.getTarget()) { + rewriter.replaceOp(op, op.getOperand()); + return success(); + } + } + return failure(); + } +}; + +} // namespace + +void TensorTransferOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // flow.tensor.slice //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp index e0be51c3f9ff..0e47aa55e2b3 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp @@ -1785,6 +1785,20 @@ LogicalResult TensorCloneOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// flow.tensor.transfer +//===----------------------------------------------------------------------===// + +LogicalResult TensorTransferOp::verify() { + if (failed(verifyOpDynamicDims(getOperation(), {getOperand()}, + getArgumentDims())) || + failed(verifyOpDynamicDims(getOperation(), {getResult()}, + getArgumentDims()))) { + return failure(); + } + return success(); +} + //===----------------------------------------------------------------------===// // flow.tensor.slice //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td index 30b27c80b658..938e7adfe62e 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td @@ -1500,6 +1500,57 @@ def FLOW_TensorCloneOp : FLOW_PureOp<"tensor.clone", [ let hasFolder = 1; } +def FLOW_TensorTransferOp : FLOW_PureOp<"tensor.transfer", [ + AllTypesMatch<["operand", "result"]>, + DeclareOpInterfaceMethods, + Util_ShapeAwareOp, +]> { + let summary = [{transfers a tensor to a target by copying if needed}]; + let description = [{ + Transfers the tensor from whichever context it may be in to the specified + target context. If the contexts are compatible and can access each others + memory the operation may be elided and otherwise will become one or more + copies to transfer the tensor in cases where staging through an intermediate + context is required. + }]; + + let arguments = (ins + FLOW_Tensor:$operand, + FLOW_ShapeDynamicDims:$argument_dims, + AnyAttr:$target + ); + let results = (outs + FLOW_Tensor:$result + ); + + let assemblyFormat = [{ + $operand `:` type($result) (`{` $argument_dims^ `}`)? + `to` $target + attr-dict-with-keyword + }]; + + let builders = [ + OpBuilder<(ins "Value":$operand, "Attribute":$target), + [{ + build($_builder, $_state, + operand.getType(), + operand, + IREE::Util::buildDynamicDimsForValue($_state.location, operand, $_builder), + target); + }]>, + ]; + + let extraClassDeclaration = [{ + bool isHoistableLeafOp() { return false; } + + ValueRange getOperandDynamicDims(unsigned idx) { return getArgumentDims(); } + ValueRange getResultDynamicDims(unsigned idx) { return getArgumentDims(); } + }]; + + let hasVerifier = 1; + let hasCanonicalizer = 1; +} + def FLOW_TensorSliceOp : FLOW_PureOp<"tensor.slice", [ AllRanksMatch<["source", "result"]>, AllElementTypesMatch<["source", "result"]>, diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir index cf1e34189f98..959e398533c6 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir @@ -397,6 +397,21 @@ util.func public @cloneDynamicZeroElements(%arg0: tensor<0x?xf32>, %dim: index) // ----- +// CHECK-LABEL: @ElideRedundantTransfer +// CHECK-SAME: (%[[OPERAND:.+]]: tensor<4x?xf32>, %[[DIM:.+]]: index) +util.func public @ElideRedundantTransfer(%arg0: tensor<4x?xf32>, %dim: index) -> tensor<4x?xi32> { + // CHECK: %[[TRANSFER:.+]] = flow.tensor.transfer %arg0 + %transfer = flow.tensor.transfer %arg0 : tensor<4x?xf32>{%dim} to "target" + // CHECK: %[[BITCAST:.+]] = flow.tensor.bitcast %[[TRANSFER]] + %bitcast = flow.tensor.bitcast %transfer : tensor<4x?xf32>{%dim} -> tensor<4x?xi32>{%dim} + // CHECK-NOT: flow.transfer + %redundant = flow.tensor.transfer %bitcast : tensor<4x?xi32>{%dim} to "target" + // CHECK-NEXT: %[[BITCAST]] + util.return %redundant : tensor<4x?xi32> +} + +// ----- + // CHECK-LABEL: @sliceConst0D util.func public @sliceConst0D() -> tensor { %0 = arith.constant dense<0> : tensor diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir index 8b01c00a1c53..62d79a1f3dbb 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir @@ -158,6 +158,15 @@ util.func public @tensorClone(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { // ----- +// CHECK-LABEL: @tensorTransfer +util.func public @tensorTransfer(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { + // CHECK-NEXT: %0 = flow.tensor.transfer %arg0 : tensor<4x4xf32> to "dummy" + %0 = flow.tensor.transfer %arg0 : tensor<4x4xf32> to "dummy" + util.return %0 : tensor<4x4xf32> +} + +// ----- + // CHECK-LABEL: @tensorCloneScalar util.func public @tensorCloneScalar(%arg0 : tensor) -> tensor { // CHECK-NEXT: %0 = flow.tensor.clone %arg0 : tensor diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp index 457f09fa6dd9..316e15ac59c2 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp @@ -200,6 +200,27 @@ struct ConvertTensorCloneOp } }; +struct ConvertTensorTransferOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(IREE::Flow::TensorTransferOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto targetAffinityAttr = + dyn_cast(adaptor.getTarget()); + if (!targetAffinityAttr) + return rewriter.notifyMatchFailure(op, "invalid stream affinity attr"); + auto unknownType = rewriter.getType(); + auto operand = + consumeTensorOperand(op.getLoc(), adaptor.getOperand(), rewriter); + rewriter.replaceOpWithNewOp( + op, unknownType, operand.resource, operand.resourceSize, + operand.resourceSize, + /*source_affinity=*/IREE::Stream::AffinityAttr{}, targetAffinityAttr); + return success(); + } +}; + struct ConvertTensorSliceOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -969,9 +990,10 @@ void populateFlowToStreamConversionPatterns(MLIRContext *context, ConvertTensorCastLikeOp, ConvertTensorCastLikeOp, ConvertTensorAllocaOp, ConvertTensorEmptyOp, ConvertTensorSplatOp, - ConvertTensorCloneOp, ConvertTensorSliceOp, ConvertTensorUpdateOp, - ConvertTensorLoadOp, ConvertTensorStoreOp, ConvertTensorTraceOp>( - typeConverter, context); + ConvertTensorCloneOp, ConvertTensorTransferOp, + ConvertTensorSliceOp, ConvertTensorUpdateOp, ConvertTensorLoadOp, + ConvertTensorStoreOp, ConvertTensorTraceOp>(typeConverter, + context); patterns.insert(typeConverter, context); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir index 7633d8c1b849..abc96e7dd832 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir @@ -136,6 +136,19 @@ util.func public @tensorSplat(%value: i8, %dim0: index) -> tensor { // ----- +util.global private @device : !hal.device + +// CHECK-LABEL: @tensorTransfer +// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM0:.+]]: index) +util.func public @tensorTransfer(%input: tensor, %dim0: index) -> tensor { + // CHECK: %[[TRANSFER:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device>) !stream.resource<*>{%[[INPUT_SIZE]]} + %transfer = flow.tensor.transfer %input : tensor{%dim0} to #hal.device.affinity<@device> + // CHECK: util.return %[[TRANSFER]], %[[INPUT_SIZE]] + util.return %transfer : tensor +} + +// ----- + // CHECK-LABEL: @tensorSlice // CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index) util.func public @tensorSlice(%input : tensor<5x24x48xf32>) -> tensor<3x24x48xf32> { From c05323fce0318a029e36bff9592c5ddd57604baf Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 21 May 2024 09:40:28 -0700 Subject: [PATCH 16/25] New AssignTargetDevices pass to replace the legacy one. The legacy pass has been moved aside so that the old flags still work but will be removed in the future. --- .../plugins/target/CUDA/test/smoketest.mlir | 4 +- .../LLVMCPU/test/smoketest_embedded.mlir | 2 +- .../target/LLVMCPU/test/smoketest_system.mlir | 2 +- .../target/MetalSPIRV/test/smoketest.mlir | 2 +- .../plugins/target/ROCM/test/smoketest.mlir | 8 +- .../ROCM/test/target_device_features.mlir | 9 +- .../plugins/target/VMVX/test/smoketest.mlir | 2 +- .../target/VulkanSPIRV/test/smoketest.mlir | 2 +- .../target/WebGPUSPIRV/test/smoketest.mlir | 2 +- .../compiler/API/Internal/CompilerDriver.cpp | 4 +- .../iree/compiler/ConstEval/JitGlobals.cpp | 3 +- .../Dialect/HAL/Target/TargetOptions.cpp | 13 +- .../Dialect/HAL/Target/TargetOptions.h | 31 ++- .../Transforms/AssignLegacyTargetDevices.cpp | 117 +++++++++ .../HAL/Transforms/AssignTargetDevices.cpp | 233 +++++++++++++++--- .../Dialect/HAL/Transforms/BUILD.bazel | 2 + .../Dialect/HAL/Transforms/CMakeLists.txt | 2 + .../Transforms/MaterializeTargetDevices.cpp | 232 +++++++++++++---- .../Dialect/HAL/Transforms/Passes.cpp | 69 ++++-- .../compiler/Dialect/HAL/Transforms/Passes.h | 30 ++- .../compiler/Dialect/HAL/Transforms/Passes.td | 71 +++++- .../Dialect/HAL/Transforms/test/BUILD.bazel | 1 + .../HAL/Transforms/test/CMakeLists.txt | 1 + .../test/assign_legacy_target_devices.mlir | 42 ++++ .../test/assign_target_devices.mlir | 47 ++-- .../materialize_dispatch_instrumentation.mlir | 2 +- .../test/materialize_target_devices.mlir | 111 +++++++-- .../materialize_homogeneous_encodings.mlir | 10 +- .../Modules/HAL/Inline/Transforms/Passes.cpp | 6 +- .../Modules/HAL/Loader/Transforms/Passes.cpp | 6 +- .../src/iree/compiler/Pipelines/Pipelines.cpp | 35 +-- .../src/iree/compiler/Pipelines/Pipelines.h | 6 +- .../Common/test/pad_to_intrinsics_wmma.mlir | 1 - .../docs/community/blog/posts/microkernels.md | 4 +- runtime/src/iree/vm/bytecode/disassembler.c | 1 + .../cpu/embedded/example_hal.mlir | 2 +- .../cpu/embedded/example_stream.mlir | 2 +- .../cpu/embedded/example_transform.mlir | 2 +- .../custom_dispatch/cpu/mlp_plugin/mlp.mlir | 2 +- .../cpu/mlp_plugin/mlp_linalg.mlir | 2 +- .../cpu/mlp_plugin/mlp_linalg_two_matmul.mlir | 2 +- .../cpu/mlp_plugin/mlp_torch.mlir | 2 +- .../cpu/mlp_plugin/mlp_tosa.mlir | 2 +- .../custom_dispatch/cuda/kernels/example.mlir | 2 +- .../custom_dispatch/hip/kernels/example.mlir | 2 +- .../vulkan/shaders/example.mlir | 2 +- .../vulkan/shaders/example_inline.mlir | 2 +- .../vulkan/shaders/example_transform.mlir | 2 +- samples/multiple_modules/pipeline_async.mlir | 3 +- samples/multiple_modules/pipeline_sync.mlir | 3 +- samples/transform_dialect/example_module.mlir | 2 +- tests/compiler_driver/precompile.mlir | 4 +- .../compiler_driver/preprocessing_flags.mlir | 2 +- tests/e2e/regression/libm_linking.mlir | 4 +- 54 files changed, 932 insertions(+), 225 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignLegacyTargetDevices.cpp create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_legacy_target_devices.mlir diff --git a/compiler/plugins/target/CUDA/test/smoketest.mlir b/compiler/plugins/target/CUDA/test/smoketest.mlir index 54606d797899..fc7d8fcd1b55 100644 --- a/compiler/plugins/target/CUDA/test/smoketest.mlir +++ b/compiler/plugins/target/CUDA/test/smoketest.mlir @@ -5,7 +5,9 @@ module attributes { hal.device.targets = [ - #hal.device.target<"cuda", [#hal.executable.target<"cuda", "cuda-nvptx-fb">]> + #hal.device.target<"cuda", [ + #hal.executable.target<"cuda", "cuda-nvptx-fb"> + ]> : !hal.device ] } { diff --git a/compiler/plugins/target/LLVMCPU/test/smoketest_embedded.mlir b/compiler/plugins/target/LLVMCPU/test/smoketest_embedded.mlir index e772c4da3f86..f9e0a4b9e2ab 100644 --- a/compiler/plugins/target/LLVMCPU/test/smoketest_embedded.mlir +++ b/compiler/plugins/target/LLVMCPU/test/smoketest_embedded.mlir @@ -7,7 +7,7 @@ module attributes { #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", { native_vector_size = 16 : index }> - ]> + ]> : !hal.device ] } { diff --git a/compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir b/compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir index 6e7f8d5327fc..d6c6658cfd3e 100644 --- a/compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir +++ b/compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir @@ -9,7 +9,7 @@ module attributes { #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", { native_vector_size = 16 : index }> - ]> + ]> : !hal.device ] } { diff --git a/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir b/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir index 720e00b2f835..d32ac8ef561f 100644 --- a/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir +++ b/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir @@ -8,7 +8,7 @@ module attributes { compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>> }> - ]> + ]> : !hal.device ] } { diff --git a/compiler/plugins/target/ROCM/test/smoketest.mlir b/compiler/plugins/target/ROCM/test/smoketest.mlir index 91c91ba34964..b446ea2fd93c 100644 --- a/compiler/plugins/target/ROCM/test/smoketest.mlir +++ b/compiler/plugins/target/ROCM/test/smoketest.mlir @@ -2,7 +2,9 @@ module attributes { hal.device.targets = [ - #hal.device.target<"rocm", [#hal.executable.target<"rocm", "rocm-hsaco-fb">]> + #hal.device.target<"rocm", [ + #hal.executable.target<"rocm", "rocm-hsaco-fb"> + ]> : !hal.device ] } { @@ -44,7 +46,9 @@ stream.executable public @add_dispatch_0 { #loc = loc(unknown) module attributes { hal.device.targets = [ - #hal.device.target<"rocm", [#hal.executable.target<"rocm", "rocm-hsaco-fb">]> + #hal.device.target<"rocm", [ + #hal.executable.target<"rocm", "rocm-hsaco-fb"> + ]> : !hal.device ] } { diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir index 5973c05c8488..ae7676e17e52 100644 --- a/compiler/plugins/target/ROCM/test/target_device_features.mlir +++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir @@ -1,7 +1,7 @@ -// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=mi300x %s | FileCheck %s --check-prefix=GFX942 -// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=GFX940 -// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=rx7900xtx %s | FileCheck %s --check-prefix=GFX1100 -// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx941 --iree-rocm-target-features=+sramecc,-xnack %s | FileCheck %s --check-prefix=GFX941 +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=mi300x %s | FileCheck %s --check-prefix=GFX942 +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=GFX940 +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=rx7900xtx %s | FileCheck %s --check-prefix=GFX1100 +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx941 --iree-rocm-target-features=+sramecc,-xnack %s | FileCheck %s --check-prefix=GFX941 // GFX942: target = #iree_gpu.target (index, index, index) { %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 diff --git a/compiler/plugins/target/VMVX/test/smoketest.mlir b/compiler/plugins/target/VMVX/test/smoketest.mlir index b640d12e17fc..44b320890755 100644 --- a/compiler/plugins/target/VMVX/test/smoketest.mlir +++ b/compiler/plugins/target/VMVX/test/smoketest.mlir @@ -4,7 +4,7 @@ module attributes { hal.device.targets = [ #hal.device.target<"local", [ #hal.executable.target<"vmvx", "vmvx-bytecode-fb"> - ]> + ]> : !hal.device ] } { diff --git a/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir b/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir index f8d81592b778..6ef88a8a2025 100644 --- a/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir +++ b/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir @@ -8,7 +8,7 @@ module attributes { compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32, 32], max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>> }> - ]> + ]> : !hal.device ] } { diff --git a/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir b/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir index 31f361b1ab5f..69c5ceba58ba 100644 --- a/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir +++ b/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir @@ -9,7 +9,7 @@ module attributes { compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32], max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>> }> - ]> + ]> : !hal.device ] } { diff --git a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp index e648f097854b..488555af6640 100644 --- a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp +++ b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp @@ -934,10 +934,12 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { if (!getCompilationPhase(compileFrom, compileTo)) { return false; } + + // TODO: move to someplace centralized; erroring here is not great. // InlineStatic (currently) only supports the `vmvx-inline` backend. if (session.schedulingOptions.executionModel == SchedulingOptions::ExecutionModel::InlineStatic) { - for (auto target : session.halTargetOptions.targets) { + for (auto target : session.halTargetOptions.legacyTargetBackends) { if (target != "vmvx-inline") { parsedModule->emitError() << "InlineStatic execution model is not " "compatible with hal target '" diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp index e3b19f0957f1..42ede9d246a0 100644 --- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp +++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp @@ -619,7 +619,8 @@ struct JitGlobalsPass : public JitGlobalsBase { requestedTargetDevice = resolveTargetDevice(*targetRegistry.value); hasRequestedTargetDevice = targetRegistry->getTargetDevice(requestedTargetDevice) != nullptr; - compileOptions->executableOptions.targets.push_back(requestedTargetDevice); + compileOptions->executableOptions.legacyTargetBackends.push_back( + requestedTargetDevice); compileOptions->targetOptions.f32Extension = true; compileOptions->targetOptions.f64Extension = true; compileOptions->targetOptions.truncateUnsupportedFloats = false; diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp index c00cb6386156..26fcae5402a6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp @@ -22,10 +22,21 @@ void TargetOptions::bindOptions(OptionsBinder &binder) { // initialized, so targetBackendsFlags needs to be here to be initialized // first. binder.list( - "iree-hal-target-backends", targets, + "iree-hal-target-backends", legacyTargetBackends, llvm::cl::desc("Target backends for executable compilation."), llvm::cl::ZeroOrMore, llvm::cl::cat(halTargetOptionsCategory)); + binder.list("iree-hal-target-device", targetDevices, + llvm::cl::desc("Target device specifications."), + llvm::cl::ZeroOrMore, + llvm::cl::cat(halTargetOptionsCategory)); + binder.opt( + "iree-hal-default-device", defaultDevice, + llvm::cl::desc("Which device is considered the default when no device " + "affinity is specified. Either the device name when names " + "are specified or the numeric ordinal of the device."), + llvm::cl::cat(halTargetOptionsCategory)); + binder.opt( "iree-hal-executable-debug-level", debugLevel, llvm::cl::desc("Debug level for executable translation (0-3)"), diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.h b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.h index 711e0e186612..08601eacd6d7 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.h @@ -17,8 +17,35 @@ namespace mlir::iree_compiler::IREE::HAL { // TODO(benvanik): remove this and replace with the pass pipeline options. // Controls executable translation targets. struct TargetOptions { - // TODO(benvanik): multiple targets of the same type, etc. - std::vector targets; + // TODO(benvanik): remove the legacy flag once users are switched to devices. + std::vector legacyTargetBackends; + + // Specifies target devices to assign to the program. May be omitted if the + // program already has devices assigned or no devices are required (host + // program not using the HAL). + // + // Two devices, one the local host device and the other a Vulkan device: + // `local`, `vulkan` + // + // One device selecting between Vulkan if available and otherwise use the + // local host device: + // `vulkan,local` + // + // Two CUDA devices selected by runtime ordinal; at runtime two --device= + // flags are required to configure both devices: + // `cuda[0]`, `cuda[1]` + // + // A fully-defined target specification: + // `#hal.device.target<"cuda", {...}, [#hal.executable.target<...>]>` + // + // Named device for defining a reference by #hal.device.promise<@some_name>: + // `some_name=vulkan` + std::vector targetDevices; + + // Which device is considered the default when no device affinity is specified + // on a particular operation. Accepts string names matching those specified + // in the target devices list or numeric ordinals if names were omitted. + std::string defaultDevice; // Coarse debug level for executable translation across all targets. // Each target backend can use this to control its own flags, with values diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignLegacyTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignLegacyTargetDevices.cpp new file mode 100644 index 000000000000..38b0af7669e9 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignLegacyTargetDevices.cpp @@ -0,0 +1,117 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h" +#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::HAL { + +#define GEN_PASS_DEF_ASSIGNLEGACYTARGETDEVICESPASS +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// --iree-hal-assign-legacy-target-devices +//===----------------------------------------------------------------------===// + +struct AssignLegacyTargetDevicesPass + : public IREE::HAL::impl::AssignLegacyTargetDevicesPassBase< + AssignLegacyTargetDevicesPass> { + using IREE::HAL::impl::AssignLegacyTargetDevicesPassBase< + AssignLegacyTargetDevicesPass>::AssignLegacyTargetDevicesPassBase; + + void runOnOperation() override { + auto moduleOp = getOperation(); + + // If no targets are specified we can't do anything - another pass earlier + // in the pipeline will have had to add the targets. + if (targetBackends.empty()) { + return; + } + + // Check to see if targets are already specified and if so then no-op the + // pass so that we don't mess with whatever the user intended. + auto existingTargetsAttr = + moduleOp->getAttrOfType("hal.device.targets"); + if (existingTargetsAttr) { + return; + } + + // If there are any device globals declared then bail as it means the user + // has already materialized the devices they want. + for (auto globalOp : moduleOp.getOps()) { + if (isa(globalOp.getGlobalType())) { + return; + } + } + + llvm::SmallDenseSet targetAttrSet; + SmallVector targetAttrs; + for (const auto &targetBackendName : targetBackends) { + auto targetBackend = targetRegistry->getTargetBackend(targetBackendName); + if (!targetBackend) { + auto diagnostic = emitError(moduleOp.getLoc()) + << "target backend '" << targetBackendName + << "' not registered; registered backends: ["; + llvm::interleaveComma(targetRegistry->getRegisteredTargetBackends(), + diagnostic); + diagnostic << "]"; + return signalPassFailure(); + } + auto targetDeviceName = targetBackend->getLegacyDefaultDeviceID(); + auto targetDevice = targetRegistry->getTargetDevice(targetDeviceName); + if (!targetDevice) { + auto diagnostic = emitError(moduleOp.getLoc()) + << "target device '" << targetDeviceName + << "' not registered; registered devices: ["; + llvm::interleaveComma(targetRegistry->getRegisteredTargetDevices(), + diagnostic); + diagnostic << "]"; + return signalPassFailure(); + } + + // Ask the target backend for its default device specification attribute. + auto targetAttr = targetDevice->getDefaultDeviceTarget( + moduleOp.getContext(), *targetRegistry.value); + if (!targetAttr) { + emitError(moduleOp.getLoc()) << "no default device targets available"; + return signalPassFailure(); + } + if (!targetAttrSet.contains(targetAttr)) { + targetAttrSet.insert(targetAttr); + targetAttrs.push_back(targetAttr); + } + } + + Attribute targetsAttr; + if (targetAttrs.size() == 1) { + targetsAttr = targetAttrs.front(); + } else { + targetsAttr = + IREE::HAL::DeviceSelectAttr::get(moduleOp.getContext(), targetAttrs); + } + moduleOp->setAttr("hal.device.targets", + ArrayAttr::get(moduleOp.getContext(), targetsAttr)); + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::IREE::HAL diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp index ddcc2ef70aa6..4892e4438e35 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AssignTargetDevices.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 The IREE Authors +// Copyright 2024 The IREE Authors // // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -14,6 +14,7 @@ #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/AsmParser/AsmParser.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -31,6 +32,160 @@ namespace { // --iree-hal-assign-target-devices //===----------------------------------------------------------------------===// +// Strips leading and trailing whitespace from |value|. +static StringRef stripWhitespace(StringRef value) { + while (!value.empty() && llvm::isSpace(value.front())) { + value = value.drop_front(1); + } + while (!value.empty() && llvm::isSpace(value.back())) { + value = value.drop_back(1); + } + return value; +} + +// Strips leading and trailing double quotes from |value| if both exist. +static StringRef stripQuotes(StringRef value) { + value = stripWhitespace(value); + StringRef unquoted = value; + if (unquoted.consume_front("\"") && unquoted.consume_back("\"")) { + return stripWhitespace(unquoted); + } + return value; +} + +// Consumes a leading `name=` literal. +// Returns the `name` and leaves remaining characters after `=` in |value|. +// Returns an empty string if no name literal is present. +static StringRef consumeNameLiteral(StringRef &value) { + value = stripWhitespace(value); + const size_t splitIdx = value.find('='); + if (splitIdx == std::string::npos) { + return ""; + } + for (size_t i = 0; i < splitIdx; ++i) { + const char c = value[i]; + if (!llvm::isAlnum(c) && c != '_') { + return value; + } + } + const StringRef name = value.substr(0, splitIdx); + value = stripWhitespace(value.substr(splitIdx + 1)); + return stripWhitespace(name); +} + +// Consumes the first portion of |value| corresponding to a device alias. +// Expects: `abc` or `abc[123]` (and allows `"abc"[123]`). +// Only valid literals will be parsed (a-z0-9_). +// Returns the device ID and optional ordinal. All other unconsumed characters +// will remain in |value| upon return. +static std::pair> +consumeAliasLiteral(StringRef &value) { + value = stripWhitespace(value); + const size_t splitIdx = value.find(','); + StringRef part = + splitIdx == std::string::npos ? value : value.substr(0, splitIdx); + + StringRef deviceID = part; + std::optional ordinal; + + const size_t ordinalIdx = part.find('['); + if (ordinalIdx != std::string::npos) { + deviceID = part.substr(0, ordinalIdx); + StringRef ordinalStr = part.substr(ordinalIdx + 1); + APInt ordinalInt; + if (!ordinalStr.consumeInteger(10, ordinalInt)) { + ordinal = ordinalInt.getSExtValue(); + } + } + + value = stripWhitespace(value.substr(part.size())); + return std::make_pair(stripQuotes(deviceID), ordinal); +} + +struct TargetSpec { + StringAttr name; + TypedAttr attr; +}; + +// Parses the user-provided string into a target spec. +// +// Supports attributes: +// #hal.device.alias<...> +// #hal.device.target<...> +// #hal.device.select<...> +// #hal.device.fallback<...> +// Supports convenience shorthand: +// ...,... -> #hal.device.select<[...,...]> +// target -> #hal.device.alias<"target"> +// target[0] -> #hal.device.alias<"target"[0]> +// "target"[0] -> #hal.device.alias<"target"[0]> +// Supports name= prefixes: +// name=... -> ... +static FailureOr parseTargetSpec(Location loc, + StringRef targetSpecStr) { + auto *context = loc.getContext(); + targetSpecStr = stripQuotes(targetSpecStr); + + // Check for a name prefix and strip it from the spec. + StringRef name = consumeNameLiteral(targetSpecStr); + StringAttr nameAttr = + name.empty() ? StringAttr{} : StringAttr::get(context, name); + + // Parse the spec attributes. + SmallVector attrs; + while (!targetSpecStr.empty()) { + TypedAttr typedAttr; + if (targetSpecStr.starts_with('#')) { + // MLIR attribute. + size_t numRead = 0; + auto parsedAttr = mlir::parseAttribute(targetSpecStr, context, + /*type=*/nullptr, &numRead); + if (!parsedAttr) { + return mlir::emitError(loc) << "failed to parse target spec prefix `" + << targetSpecStr << "`"; + } + typedAttr = dyn_cast(parsedAttr); + if (!typedAttr) { + return mlir::emitError(loc) << "unexpected target attribute type: " + "expected a `!hal.device` but got `" + << parsedAttr << "`"; + } + targetSpecStr = stripWhitespace(targetSpecStr.substr(numRead)); + } else { + // Alias string. + auto [deviceID, ordinal] = consumeAliasLiteral(targetSpecStr); + typedAttr = IREE::HAL::DeviceAliasAttr::get( + context, IREE::HAL::DeviceType::get(context), + StringAttr::get(context, deviceID), ordinal, DictionaryAttr{}); + } + + if (!typedAttr || !isa(typedAttr.getType())) { + return mlir::emitError(loc) << "unexpected target attribute type: " + "expected a `!hal.device` but got `" + << typedAttr.getType() << "`"; + } + attrs.push_back(typedAttr); + + if (targetSpecStr.empty()) { + break; // done + } else if (!targetSpecStr.starts_with(',')) { + return mlir::emitError(loc) + << "unexpected additional characters after parsing an element: `" + << targetSpecStr << "`"; + } + targetSpecStr = targetSpecStr.substr(1); // strip , + } + + if (attrs.empty()) { + return mlir::emitError(loc) << "expected one or more target attributes"; + } else if (attrs.size() == 1) { + return TargetSpec{nameAttr, cast(attrs.front())}; + } else { + return TargetSpec{nameAttr, + IREE::HAL::DeviceSelectAttr::get(context, attrs)}; + } +} + struct AssignTargetDevicesPass : public IREE::HAL::impl::AssignTargetDevicesPassBase< AssignTargetDevicesPass> { @@ -55,50 +210,58 @@ struct AssignTargetDevicesPass // If there are any device globals declared then bail as it means the user // has already materialized the devices they want. for (auto globalOp : moduleOp.getOps()) { - if (isa(globalOp.getGlobalType())) + if (isa(globalOp.getGlobalType())) { return; + } } - llvm::SmallDenseSet targetAttrSet; - SmallVector targetAttrs; - for (const auto &targetBackendName : targetBackends) { - auto targetBackend = targetRegistry->getTargetBackend(targetBackendName); - if (!targetBackend) { - auto diagnostic = emitError(moduleOp.getLoc()) - << "target backend '" << targetBackendName - << "' not registered; registered backends: ["; - llvm::interleaveComma(targetRegistry->getRegisteredTargetBackends(), - diagnostic); - diagnostic << "]"; + // Parse each spec and validate correctness. + bool hasAnyNamed = false; + bool hasAnyUnnamed = false; + SmallVector targetSpecs; + for (auto &targetDevice : targetDevices) { + auto targetSpecOr = parseTargetSpec(moduleOp.getLoc(), targetDevice); + if (failed(targetSpecOr)) { return signalPassFailure(); } - auto targetDeviceName = targetBackend->getLegacyDefaultDeviceID(); - auto targetDevice = targetRegistry->getTargetDevice(targetDeviceName); - if (!targetDevice) { - auto diagnostic = emitError(moduleOp.getLoc()) - << "target device '" << targetDeviceName - << "' not registered; registered devices: ["; - llvm::interleaveComma(targetRegistry->getRegisteredTargetDevices(), - diagnostic); - diagnostic << "]"; - return signalPassFailure(); + if (targetSpecOr->name) { + hasAnyNamed = true; + } else { + hasAnyUnnamed = true; } + targetSpecs.push_back(*targetSpecOr); + } - // Ask the target backend for its default device specification attribute. - auto targetAttr = targetDevice->getDefaultDeviceTarget( - moduleOp.getContext(), *targetRegistry.value); - if (!targetAttr) { - emitError(moduleOp.getLoc()) << "no default device targets available"; - return signalPassFailure(); + // If any spec has a name assigned then all must have names assigned. + if (hasAnyNamed && hasAnyUnnamed) { + emitError(moduleOp.getLoc()) + << "if any target device spec has a name then all must be named"; + return signalPassFailure(); + } + + if (hasAnyNamed) { + // NOTE: we allow duplicate names to override assignment. + llvm::MapVector deviceAttrMap; + for (auto targetSpec : targetSpecs) { + assert(targetSpec.name && "all devices must be named"); + deviceAttrMap[targetSpec.name] = targetSpec.attr; } - if (!targetAttrSet.contains(targetAttr)) { - targetAttrSet.insert(targetAttr); - targetAttrs.push_back(targetAttr); + SmallVector deviceAttrs; + for (auto [name, value] : deviceAttrMap) { + deviceAttrs.push_back(NamedAttribute(name, value)); } + moduleOp->setAttr( + "hal.device.targets", + DictionaryAttr::get(moduleOp.getContext(), deviceAttrs)); + } else { + SmallVector deviceAttrs; + for (auto [name, value] : targetSpecs) { + assert(!name && "no devices may have names"); + deviceAttrs.push_back(value); + } + moduleOp->setAttr("hal.device.targets", + ArrayAttr::get(moduleOp.getContext(), deviceAttrs)); } - - moduleOp->setAttr("hal.device.targets", - ArrayAttr::get(moduleOp.getContext(), targetAttrs)); } }; diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel index bdad4275b98b..e4c7cc53bf91 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_compiler_cc_library( name = "Transforms", srcs = [ + "AssignLegacyTargetDevices.cpp", "AssignTargetDevices.cpp", "CaptureExecutableSources.cpp", "ConfigureExecutables.cpp", @@ -75,6 +76,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:AffineTransforms", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:BufferizationDialect", "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:FuncDialect", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt index 72b0b74dcbfe..9af31340af3d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt @@ -16,6 +16,7 @@ iree_cc_library( HDRS "Passes.h" SRCS + "AssignLegacyTargetDevices.cpp" "AssignTargetDevices.cpp" "CaptureExecutableSources.cpp" "ConfigureExecutables.cpp" @@ -50,6 +51,7 @@ iree_cc_library( MLIRAffineToStandard MLIRAffineTransforms MLIRArithDialect + MLIRAsmParser MLIRBufferizationDialect MLIRControlFlowDialect MLIRFuncDialect diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp index 904e20aa95af..395673c87465 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp @@ -29,6 +29,153 @@ namespace { // --iree-hal-materialize-target-devices //===----------------------------------------------------------------------===// +// Returns the canonical name for a device by ordinal: +// device ordinal `N` -> `@__device_N` +static FlatSymbolRefAttr makeDefaultDeviceOrdinalRef(MLIRContext *context, + int64_t ordinal) { + return FlatSymbolRefAttr::get( + context, (StringRef("__device_") + std::to_string(ordinal)).str()); +} + +// Returns the canonical name for a device by name: +// device name `NAME` -> `@NAME` +static FlatSymbolRefAttr makeDefaultDeviceNameRef(MLIRContext *context, + StringRef name) { + return FlatSymbolRefAttr::get(context, name); +} + +// Returns a symbol ref constructed to reference the specified device. +// Supports: +// integer attrs: device ordinal `N` -> `@__device_N` +// string attrs: device name `NAME` -> `@NAME` +static FailureOr +makeDefaultDeviceAttrRef(Attribute defaultDeviceAttr) { + if (auto stringAttr = dyn_cast(defaultDeviceAttr)) { + return makeDefaultDeviceNameRef(stringAttr.getContext(), stringAttr); + } else if (auto integerAttr = dyn_cast(defaultDeviceAttr)) { + return makeDefaultDeviceOrdinalRef(integerAttr.getContext(), + integerAttr.getInt()); + } + return failure(); +} + +// Creates a named device global with the given attribute. +static FailureOr +createDeviceGlobal(Location loc, StringAttr name, Attribute targetAttr, + OpBuilder &moduleBuilder) { + auto deviceType = moduleBuilder.getType(); + auto globalOp = moduleBuilder.create( + loc, name, /*isMutable=*/false, deviceType); + globalOp.setPrivate(); + + TypedAttr attrValue; + if (auto arrayAttr = dyn_cast(targetAttr)) { + if (arrayAttr.size() == 1) { + auto typedAttr = dyn_cast(arrayAttr.getValue().front()); + if (typedAttr && isa(typedAttr.getType())) { + // Don't care exactly what the attribute is, only that it's a device. + attrValue = typedAttr; + } + } else { + // Expand arrays to selects. + attrValue = moduleBuilder.getAttr(deviceType, + arrayAttr); + } + } else if (auto typedAttr = dyn_cast(targetAttr)) { + if (isa(typedAttr.getType())) { + // Don't care exactly what the attribute is, only that it's a device. + attrValue = typedAttr; + } + } + if (!attrValue) { + return mlir::emitError(loc) + << "module has invalid device targets specified; " + "expected hal.device.targets to be an array of !hal.device " + "initialization attributes or a dictionary with named values"; + } + + globalOp.setInitialValueAttr(attrValue); + return FlatSymbolRefAttr::get(globalOp); +} + +// Creates one or more device globals based on the specified targets and returns +// the "default" device (usually just the first one specified). +static FailureOr createDeviceGlobals(mlir::ModuleOp moduleOp, + Attribute targetsAttr) { + auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); + + FlatSymbolRefAttr firstDeviceRef; + if (auto dictAttr = dyn_cast(targetsAttr)) { + for (auto namedTargetsAttr : dictAttr.getValue()) { + auto deviceRefOr = + createDeviceGlobal(moduleOp.getLoc(), namedTargetsAttr.getName(), + namedTargetsAttr.getValue(), moduleBuilder); + if (failed(deviceRefOr)) { + return failure(); + } else if (!firstDeviceRef) { + firstDeviceRef = *deviceRefOr; + } + } + } else if (auto arrayAttr = dyn_cast(targetsAttr)) { + for (auto [i, ordinalTargetsAttr] : llvm::enumerate(arrayAttr.getValue())) { + auto deviceRefOr = + createDeviceGlobal(moduleOp.getLoc(), + moduleBuilder.getStringAttr( + StringRef("__device_") + std::to_string(i)), + ordinalTargetsAttr, moduleBuilder); + if (failed(deviceRefOr)) { + return failure(); + } else if (!firstDeviceRef) { + firstDeviceRef = *deviceRefOr; + } + } + } else { + return moduleOp.emitError() + << "unexpected `hal.device.targets` attribute; must be a dictionary " + "of named devices or an array of devices to use by ordinal"; + } + + return firstDeviceRef; +} + +// Assigns the default device affinity to all top level ops that don't already +// have one set. +static void assignDefaultDeviceAffinity(mlir::ModuleOp moduleOp, + FlatSymbolRefAttr defaultDeviceRef) { + Builder builder(moduleOp); + auto affinityName = builder.getStringAttr("stream.affinity"); + auto affinityAttr = builder.getAttr( + defaultDeviceRef, /*queue_mask=*/-1ll); + + // TODO(benvanik): make this an interface that can be registered on types. + auto isAnnotatableType = [](Type type) { + return isa(type) || isa(type); + }; + for (auto &op : moduleOp.getOps()) { + bool shouldAnnotate = true; + if (auto globalOp = dyn_cast(op)) { + if (!isAnnotatableType(globalOp.getGlobalType())) { + shouldAnnotate = false; + } + } else if (op.hasTrait()) { + // Symbol table ops can't reference parent symbols properly. + shouldAnnotate = false; + } + if (!shouldAnnotate) { + continue; // skip op + } + + if (auto affinityOp = dyn_cast(op)) { + if (!affinityOp.getAffinity()) + affinityOp.setAffinity(affinityAttr); + } else { + if (!op.hasAttr(affinityName)) { + op.setAttr(affinityName, affinityAttr); + } + } + } +} + struct MaterializeTargetDevicesPass : public IREE::HAL::impl::MaterializeTargetDevicesPassBase< MaterializeTargetDevicesPass> { @@ -38,63 +185,50 @@ struct MaterializeTargetDevicesPass void runOnOperation() override { auto moduleOp = getOperation(); - // Only run if there's a module-level attribute specified. - auto deviceTargetAttrs = - moduleOp->getAttrOfType("hal.device.targets"); - if (!deviceTargetAttrs || deviceTargetAttrs.empty()) - return; - moduleOp->removeAttr("hal.device.targets"); - - // Create the default device global. - auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); - auto deviceType = moduleBuilder.getType(); - auto globalOp = moduleBuilder.create( - moduleOp.getLoc(), "__device.0", /*isMutable=*/false, deviceType); - globalOp.setPrivate(); - if (deviceTargetAttrs.size() == 1) { - auto typedAttr = - dyn_cast(deviceTargetAttrs.getValue().front()); - if (typedAttr && isa(typedAttr.getType())) { - globalOp.setInitialValueAttr(typedAttr); - } else { - moduleOp.emitOpError() - << "has invalid device targets specified; " - "expect hal.device.targets to be an " - "ArrayAttr of !hal.device initialization attributes"; + // Only materialize devices if there's a module-level attribute specified. + FlatSymbolRefAttr defaultDeviceRef; + auto deviceTargetAttrs = moduleOp->getAttr("hal.device.targets"); + if (deviceTargetAttrs) { + moduleOp->removeAttr("hal.device.targets"); + + // Create the globals and get the default device. + auto firstDeviceOr = createDeviceGlobals(moduleOp, deviceTargetAttrs); + if (failed(firstDeviceOr)) { + // Fails if invalid attributes. return signalPassFailure(); } - } else { - globalOp.setInitialValueAttr( - moduleBuilder.getAttr( - deviceType, deviceTargetAttrs)); + defaultDeviceRef = *firstDeviceOr; } - // Assign affinities to all top level ops that don't already have one set. - auto affinityName = StringAttr::get(&getContext(), "stream.affinity"); - auto affinityAttr = moduleBuilder.getAttr( - FlatSymbolRefAttr::get(globalOp), /*queue_mask=*/-1ll); - auto isAnnotatableType = [](Type type) { - return isa(type) || isa(type); - }; - for (auto &op : moduleOp.getOps()) { - bool shouldAnnotate = true; - if (auto globalOp = dyn_cast(op)) { - if (!isAnnotatableType(globalOp.getGlobalType())) - shouldAnnotate = false; - } else if (op.hasTrait()) { - // Symbol table ops can't reference parent symbols properly. - shouldAnnotate = false; + // Select the default device from what the user specified or from the first + // created. + auto defaultDeviceAttr = moduleOp->getAttr("hal.device.default"); + if (defaultDeviceAttr) { + // Always prefer the explicitly specified default device. + moduleOp->removeAttr("hal.device.default"); + auto defaultDeviceRefOr = makeDefaultDeviceAttrRef(defaultDeviceAttr); + if (failed(defaultDeviceRefOr)) { + moduleOp.emitError() << "invalid `hal.device.default` value, must be " + "an ordinal or a name"; + return signalPassFailure(); } - if (!shouldAnnotate) - continue; - if (auto affinityOp = dyn_cast(op)) { - if (!affinityOp.getAffinity()) - affinityOp.setAffinity(affinityAttr); + defaultDeviceRef = *defaultDeviceRefOr; + } else if (!defaultDevice.empty()) { + // Fallback to the option specified, if any provided. + long long defaultDeviceOrdinal = 0; + if (!llvm::getAsSignedInteger(defaultDevice, 10, defaultDeviceOrdinal)) { + defaultDeviceRef = + makeDefaultDeviceOrdinalRef(&getContext(), defaultDeviceOrdinal); } else { - if (!op.hasAttr(affinityName)) - op.setAttr(affinityName, affinityAttr); + defaultDeviceRef = + makeDefaultDeviceNameRef(&getContext(), defaultDevice); } } + + // Assign affinities to all top level ops that don't already have one set. + if (defaultDeviceRef) { + assignDefaultDeviceAffinity(moduleOp, defaultDeviceRef); + } } }; diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index dc6ab9d67036..773ed925097f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -27,10 +27,6 @@ namespace mlir::iree_compiler::IREE::HAL { namespace { struct TransformOptions : public PassPipelineOptions { - // TODO(benvanik): replace the global iree-hal-target-backends flag with this. - // ListOption targets{ - // *this, "targets", llvm::cl::desc("One or more HAL devices to target."), - // llvm::cl::ZeroOrMore}; Option serializeExecutables{ *this, "serialize-executables", @@ -184,21 +180,26 @@ static void addExecutableSubstitutionPasses(OpPassManager &passManager, // --iree-hal-device-assignment-pipeline //===----------------------------------------------------------------------===// -void buildHALDeviceAssignmentPassPipeline(OpPassManager &passManager, - const TargetRegistry &targetRegistry, - const TargetOptions &targetOptions) { +void buildHALDeviceAssignmentPassPipeline( + OpPassManager &passManager, const TargetRegistry &targetRegistry, + const AssignmentOptions &assignmentOptions) { // The HAL must know its targets early on in the process. This pass discovers/ // derives/specifies the target devices and annotates the module with that // information. This allows subsequent passes to lookup which devices they are // targeting. - if (!targetOptions.targets.empty()) { + if (!assignmentOptions.legacyTargetBackends.empty()) { // Today we just assign devices from parameters but we should instead be // performing analysis at the flow level and then doing magic device // database lookups here. + passManager.addPass(IREE::HAL::createAssignLegacyTargetDevicesPass( + {&targetRegistry, assignmentOptions.legacyTargetBackends})); + } + if (!assignmentOptions.targetDevices.empty()) { passManager.addPass(IREE::HAL::createAssignTargetDevicesPass( - {&targetRegistry, targetOptions.targets})); + {assignmentOptions.targetDevices})); } - passManager.addPass(IREE::HAL::createMaterializeTargetDevicesPass()); + passManager.addPass(IREE::HAL::createMaterializeTargetDevicesPass( + {assignmentOptions.defaultDevice})); passManager.addPass(IREE::HAL::createResolveDevicePromisesPass()); passManager.addPass( IREE::HAL::createResolveDeviceAliasesPass({&targetRegistry})); @@ -282,12 +283,17 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, // Device assignment and interface materialization //---------------------------------------------------------------------------- - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(PipelinePhase::ExecutableSources, passManager); + } if (compileFrom < PipelinePhase::ExecutableSources) { + AssignmentOptions assignmentOptions; + assignmentOptions.legacyTargetBackends = targetOptions.legacyTargetBackends; + assignmentOptions.targetDevices = targetOptions.targetDevices; + assignmentOptions.defaultDevice = targetOptions.defaultDevice; buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry, - targetOptions); + assignmentOptions); buildHALConfigurationPassPipeline(passManager, targetRegistry, targetOptions, hooks); @@ -305,17 +311,20 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, } } - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(PipelinePhase::ExecutableSources, passManager); - if (compileTo == PipelinePhase::ExecutableSources) + } + if (compileTo == PipelinePhase::ExecutableSources) { return; + } //---------------------------------------------------------------------------- // Executable translation //---------------------------------------------------------------------------- - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(PipelinePhase::ExecutableConfigurations, passManager); + } if (compileFrom < PipelinePhase::ExecutableConfigurations) { // Select a translation strategy for each hal.executable.variant and @@ -358,10 +367,12 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, } } - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(PipelinePhase::ExecutableConfigurations, passManager); - if (compileTo == PipelinePhase::ExecutableConfigurations) + } + if (compileTo == PipelinePhase::ExecutableConfigurations) { return; + } // TODO(benvanik): move translation after conversion; today translation // inserts the workgroup count logic we need to convert but we could instead @@ -374,8 +385,9 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, // After this point the executables are opaque blobs and we cannot change // their interfaces. - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(PipelinePhase::ExecutableTargets, passManager); + } if (compileFrom < PipelinePhase::ExecutableTargets) { passManager.addNestedPass( @@ -391,10 +403,12 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, IREE::HAL::createCaptureExecutableSourcesPass({"2.translated"})); } - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(PipelinePhase::ExecutableTargets, passManager); - if (compileTo == PipelinePhase::ExecutableTargets) + } + if (compileTo == PipelinePhase::ExecutableTargets) { return; + } // Substitute hal.executables we've translated with those specified on the // command line. This developer feature allows for splicing in hand-authored @@ -577,13 +591,14 @@ void registerHALPasses() { registerPasses(); // Pipelines. - PassPipelineRegistration<>("iree-hal-device-assignment-pipeline", - "Runs HAL target device assignment pipeline.", - [](OpPassManager &passManager) { - buildHALDeviceAssignmentPassPipeline( - passManager, TargetRegistry::getGlobal(), - TargetOptions::FromFlags::get()); - }); + PassPipelineRegistration( + "iree-hal-device-assignment-pipeline", + "Runs HAL target device assignment pipeline.", + [](OpPassManager &passManager, + const AssignmentOptions &assignmentOptions) { + buildHALDeviceAssignmentPassPipeline( + passManager, TargetRegistry::getGlobal(), assignmentOptions); + }); PassPipelineRegistration<>("iree-hal-configuration-pipeline", "Runs HAL target configuration pipeline.", [](OpPassManager &passManager) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h index d09080bf5a2c..e3231c3fd05a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h @@ -46,12 +46,36 @@ struct PipelineHooks { std::function afterPhase; }; +struct AssignmentOptions : public PassPipelineOptions { + // TODO(benvanik): remove the legacy flag once users are switched to devices. + ListOption legacyTargetBackends{ + *this, + "legacy-target-backends", + llvm::cl::desc("DEPRECATED: Target backend names."), + llvm::cl::ZeroOrMore, + }; + ListOption targetDevices{ + *this, + "target-devices", + llvm::cl::desc("Target device specifications."), + llvm::cl::ZeroOrMore, + }; + Option defaultDevice{ + *this, + "default-device", + llvm::cl::desc("Which device is considered the default when no device " + "affinity is specified. Either the device name when names " + "are specified or the numeric ordinal of the device."), + llvm::cl::init(""), + }; +}; + // Assigns devices from flags and coarse module-level specification. // Frontends are encouraged to create and assign devices themselves in order to // support more complex configurations (multiple devices, fallbacks, etc). -void buildHALDeviceAssignmentPassPipeline(OpPassManager &passManager, - const TargetRegistry &targetRegistry, - const TargetOptions &targetOptions); +void buildHALDeviceAssignmentPassPipeline( + OpPassManager &passManager, const TargetRegistry &targetRegistry, + const AssignmentOptions &assignmentOptions); // Adds a set of passes to the given pass manager that run the head of the HAL // pipeline to materialize interfaces, import externally specified executables, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td index aa896f247402..f5e345fe597a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td @@ -42,8 +42,8 @@ def ConvertToHALPass : // Device management //===----------------------------------------------------------------------===// -def AssignTargetDevicesPass : - Pass<"iree-hal-assign-target-devices", "mlir::ModuleOp"> { +def AssignLegacyTargetDevicesPass : + Pass<"iree-hal-assign-legacy-target-devices", "mlir::ModuleOp"> { let summary = "Assigns the HAL devices the module will target to the given list of targets."; let description = [{ Assigns target HAL devices to the module based on the given list. @@ -65,14 +65,75 @@ def AssignTargetDevicesPass : ]; } +def AssignTargetDevicesPass : + Pass<"iree-hal-assign-target-devices", "mlir::ModuleOp"> { + let summary = "Assigns the HAL devices the module will target to the given list of target specifications."; + let description = [{ + Assigns target HAL devices to the module based on the given list of target + specifications. + + Targets can be specified in several ways depending on whether there are + multiple devices, named devices, or devices imported from external files. + Human-friendly device aliases can be used as shorthand for + `IREE::HAL::TargetDevice` implementations providing their own configuration. + The aliases are identical to those used by `#hal.device.alias<>`. + + If multiple targets are specified they will be available as multiple + distinct devices. A single device may select from one or more targets such + that the first enumerated that matches at runtime will be selected. For + example a `gpu` device may select between CUDA, HIP, or Vulkan at runtime + based on what kind of device the user has and what HAL implementations were + compiled into the runtime. + + Examples using the canonical flag: + ```mlir + // Two devices, one the local host device and the other a Vulkan device: + --iree-hal-target-device=local + --iree-hal-target-device=vulkan + + // One device selecting between Vulkan if available and otherwise use the + // local host device: + --iree-hal-target-device=vulkan,local + + // Two CUDA devices selected by runtime ordinal; at runtime two --device= + // flags are required to configure both devices: + --iree-hal-target-device=cuda[0] + --iree-hal-target-device=cuda[1] + + // A fully-defined target specification: + --iree-hal-target-device=#hal.device.target<"cuda", {...}, [#hal.executable.target<...>]> + + // Named device for defining a reference by #hal.device.promise<@some_name>: + --iree-hal-target-device=some_name=vulkan + ``` + }]; + let options = [ + ListOption< + "targetDevices", "targetDevices", + "std::string", + "List of target device specifications." + >, + ]; + let dependentDialects = [ + "IREE::HAL::HALDialect", + ]; +} + def MaterializeTargetDevicesPass : Pass<"iree-hal-materialize-target-devices", "mlir::ModuleOp"> { let summary = "Materializes global device handles based on a `hal.device.targets` spec."; let description = [{ - Materializes a global `!hal.device` for the devices specified by the - `hal.device.targets` attribute on the module. It's preferred that frontends - provide IR with the globals assigned as this only supports a single device. + Materializes global `!hal.device` ops for the devices specified by the + `hal.device.targets` attribute on the module. An optional default device can + be specified to assign to ops that do not have a default device specified. }]; + let options = [ + Option< + "defaultDevice", "defaultDevice", + "std::string", "", + "Which device is considered the default when no device affinity is specified." + >, + ]; let dependentDialects = [ "IREE::HAL::HALDialect", "IREE::Util::UtilDialect", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel index 812e2f9c589a..a2a704dab53f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel @@ -16,6 +16,7 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "assign_legacy_target_devices.mlir", "assign_target_devices.mlir", "capture_executable_sources.mlir", "convert_to_hal.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt index 0fa3fa2e6fbd..28c81ad6e171 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "assign_legacy_target_devices.mlir" "assign_target_devices.mlir" "capture_executable_sources.mlir" "convert_to_hal.mlir" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_legacy_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_legacy_target_devices.mlir new file mode 100644 index 000000000000..1bdb7afb04b9 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_legacy_target_devices.mlir @@ -0,0 +1,42 @@ +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-legacy-target-devices)' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-0 +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-legacy-target-devices{targetBackends=vmvx})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-1 +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-legacy-target-devices{targetBackends=vmvx,vmvx-inline})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-2 +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-legacy-target-devices{targetBackends=vmvx,vmvx})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-EQ + +// TARGET-1: #device_target_local = #hal.device.target<"local" + +// TARGET-2: #device_target_local = #hal.device.target<"local" +// TARGET-2: #device_target_vmvx_inline = #hal.device.target<"vmvx-inline" + +// TARGET-EQ: #device_target_local = #hal.device.target<"local" + +// CHECK: module +// TARGET-0: @module { +// TARGET-1: @module attributes { +// TARGET-1-SAME: hal.device.targets = [#device_target_local] +// TARGET-2: @module attributes { +// TARGET-2-SAME: hal.device.targets = [#hal.device.select<[#device_target_local, #device_target_vmvx_inline]> : !hal.device] +// TARGET-EQ: @module attributes { +// TARGET-EQ-SAME: hal.device.targets = [#device_target_local]} +module @module {} + +// ----- + +// The pass is a no-op when targets are already specified. + +// CHECK: #device_target_foo = #hal.device.target<"foo" +// CHECK: module @module attributes {hal.device.targets = [#device_target_foo]} +module @module attributes { + hal.device.targets = [#hal.device.target<"foo">] +} {} + +// ----- + +// The pass does nothing when one or more devices has already been defined. + +// CHECK: module @module +// CHECK-NOT: hal.device.targets +module @module { + // CHECK: @existing_device + util.global private @existing_device : !hal.device +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir index e10b6494c00d..61926d34a1d7 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir @@ -1,33 +1,36 @@ -// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices)' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-0 -// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vmvx})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-1 -// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vmvx,vmvx-inline})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-2 -// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vmvx,vmvx})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-EQ - -// TARGET-1: #device_target_local = #hal.device.target<"local" - -// TARGET-2: #device_target_local = #hal.device.target<"local" -// TARGET-2: #device_target_vmvx_inline = #hal.device.target<"vmvx-inline" - -// TARGET-EQ: #device_target_local = #hal.device.target<"local" +// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices)' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-0 +// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=device})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-1 +// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=device_a,device_b})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-2 +// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=device_a[0],device_a[1]})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ORDINALS +// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=#hal.device.target<"local">})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ATTR +// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=#hal.device.alias<"device_a">})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ALIAS +// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices="device_a,#hal.device.alias<"device_b">"})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-SELECT +// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=device_a=#hal.device.alias<"device_a">,"device_bc=device_b,#hal.device.alias<"device_c">"})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-SELECT-MULTI // CHECK: module -// TARGET-0: @module { -// TARGET-1: @module attributes { -// TARGET-1-SAME: hal.device.targets = [#device_target_local] -// TARGET-2: @module attributes { -// TARGET-2-SAME: hal.device.targets = [#device_target_local, #device_target_vmvx_inline]} -// TARGET-EQ: @module attributes { -// TARGET-EQ-SAME: hal.device.targets = [#device_target_local]} -module @module {} +// TARGET-0-NOT: hal.device.targets +// TARGET-1: hal.device.targets = [#hal.device.alias<"device"> : !hal.device] +// TARGET-2: hal.device.targets = [#hal.device.alias<"device_a"> : !hal.device, #hal.device.alias<"device_b"> : !hal.device]} +// TARGET-ORDINALS: hal.device.targets = [#hal.device.alias<"device_a"[0]> : !hal.device, #hal.device.alias<"device_a"[1]> : !hal.device]} +// TARGET-ATTR: hal.device.targets = [#hal.device.target<"local"> : !hal.device] +// TARGET-ALIAS: hal.device.targets = [#hal.device.alias<"device_a"> : !hal.device] +// TARGET-SELECT: hal.device.targets = [#hal.device.select<[#hal.device.alias<"device_a"> : !hal.device, #hal.device.alias<"device_b"> : !hal.device]> : !hal.device] +// TARGET-SELECT-MULTI: hal.device.targets = { +// TARGET-SELECT-MULTI-SAME: device_a = #hal.device.alias<"device_a"> : !hal.device, +// TARGET-SELECT-MULTI-SAME: device_bc = #hal.device.select<[#hal.device.alias<"device_b"> : !hal.device, #hal.device.alias<"device_c"> : !hal.device]> : !hal.device +// TARGET-SELECT-MULTI-SAME: } +module @module { + util.global private @tensor_global : tensor<4xf32> +} // ----- // The pass is a no-op when targets are already specified. -// CHECK: #device_target_foo = #hal.device.target<"foo" -// CHECK: module @module attributes {hal.device.targets = [#device_target_foo]} +// CHECK: module @module attributes { +// CHECK-SAME: hal.device.targets = [#hal.device.target<"foo"> : !hal.device] module @module attributes { - hal.device.targets = [#hal.device.target<"foo">] + hal.device.targets = [#hal.device.target<"foo"> : !hal.device] } {} // ----- diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_dispatch_instrumentation.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_dispatch_instrumentation.mlir index 607c876db6f9..e86f1414e473 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_dispatch_instrumentation.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_dispatch_instrumentation.mlir @@ -4,7 +4,7 @@ module attributes {hal.device.targets = [ #hal.device.target<"llvm-cpu", [ #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64">, #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> - ]> + ]> : !hal.device ]} { // Instrumentation storage buffer allocated at startup (defaults to 64MB + footer): diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir index d5d011c649b9..11b39183bde5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir @@ -11,24 +11,36 @@ module @module attributes { // ----- -// Valid input with proper attributes. +// Modules without anything that needs an environment are OK as-is. -// CHECK: #device_target_llvm_cpu = #hal.device.target<"llvm-cpu"> -#device_target_llvm_cpu = #hal.device.target<"llvm-cpu"> -// CHECK: #device_target_vmvx = #hal.device.target<"vmvx"> -#device_target_vmvx = #hal.device.target<"vmvx"> +// CHECK: module @module +module @module { + // CHECK-NEXT: hal.executable private @exe + hal.executable private @exe { + // CHECK-NEXT: hal.executable.variant public @embedded_elf_arm_64 + hal.executable.variant public @embedded_elf_arm_64 target(#hal.executable.target<"backend", "format", {}>) {} + } +} + +// ----- + +// Valid input with proper attributes for a single device. + +// CHECK: #[[DEVICE_A:.+]] = #hal.device.target<"device_a" +#device_a = #hal.device.target<"device_a", [#hal.executable.target<"backend_a", "format_a">]> +// CHECK: #[[DEVICE_B:.+]] = #hal.device.target<"device_b" +#device_b = #hal.device.target<"device_b", [#hal.executable.target<"backend_b", "format_b">]> // CHECK: module @module // CHECK-NOT: hal.device.targets module @module attributes { hal.device.targets = [ - #device_target_llvm_cpu, - #device_target_vmvx + #hal.device.select<[#device_a, #device_b]> : !hal.device ] } { - // CHECK: util.global private @__device.0 = #hal.device.select<[ - // CHECK-SAME: #device_target_llvm_cpu, - // CHECK-SAME: #device_target_vmvx + // CHECK: util.global private @__device_0 = #hal.device.select<[ + // CHECK-SAME: #[[DEVICE_A]], + // CHECK-SAME: #[[DEVICE_B]] // CHECK-SAME: ]> : !hal.device // CHECK: util.global private @tensor_global @@ -46,13 +58,80 @@ module @module attributes { // ----- -// Modules without anything that needs an environment are OK. +// Multiple devices using device names. + +// CHECK: #[[DEVICE_A:.+]] = #hal.device.target<"device_a" +#device_a = #hal.device.target<"device_a", [#hal.executable.target<"backend_a", "format_a">]> +// CHECK: #[[DEVICE_B:.+]] = #hal.device.target<"device_b" +#device_b = #hal.device.target<"device_b", [#hal.executable.target<"backend_b", "format_b">]> +// CHECK: #[[DEVICE_C:.+]] = #hal.device.target<"device_c" +#device_c = #hal.device.target<"device_c", [#hal.executable.target<"backend_c", "format_c">]> // CHECK: module @module -module @module { - // CHECK-NEXT: hal.executable private @exe - hal.executable private @exe { - // CHECK-NEXT: hal.executable.variant public @embedded_elf_arm_64 - hal.executable.variant public @embedded_elf_arm_64 target(#hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {}>) {} +// CHECK-NOT: hal.device.targets +module @module attributes { + hal.device.targets = { + device_a = #device_a, + device_bc = [#device_b, #device_c] } +} { + // CHECK: util.global private @device_a = #[[DEVICE_A]] + // CHECK: util.global private @device_bc = #hal.device.select<[#[[DEVICE_B]], #[[DEVICE_C]]]> + + // CHECK: util.global private @tensor_global + // CHECK-SAME: stream.affinity = #hal.device.affinity<@device_a> + util.global private @tensor_global : tensor<4xf32> +} + +// ----- + +// Default device selection by name. + +// CHECK: #[[DEVICE_A:.+]] = #hal.device.target<"device_a" +#device_a = #hal.device.target<"device_a", [#hal.executable.target<"backend_a", "format_a">]> +// CHECK: #[[DEVICE_B:.+]] = #hal.device.target<"device_b" +#device_b = #hal.device.target<"device_b", [#hal.executable.target<"backend_b", "format_b">]> + +// CHECK: module @module +// CHECK-NOT: hal.device.targets +module @module attributes { + hal.device.targets = { + device_a = #device_a, + device_b = #device_b + }, + hal.device.default = "device_b" +} { + // CHECK: util.global private @device_a + // CHECK: util.global private @device_b + + // CHECK: util.global private @tensor_global + // CHECK-SAME: stream.affinity = #hal.device.affinity<@device_b> + util.global private @tensor_global : tensor<4xf32> } + +// ----- + +// Default device selection by ordinal. + +// CHECK: #[[DEVICE_A:.+]] = #hal.device.target<"device_a" +#device_a = #hal.device.target<"device_a", [#hal.executable.target<"backend_a", "format_a">]> +// CHECK: #[[DEVICE_B:.+]] = #hal.device.target<"device_b" +#device_b = #hal.device.target<"device_b", [#hal.executable.target<"backend_b", "format_b">]> + +// CHECK: module @module +// CHECK-NOT: hal.device.targets +module @module attributes { + hal.device.targets = [ + #device_a, + #device_b + ], + hal.device.default = 1 : index +} { + // CHECK: util.global private @__device_0 + // CHECK: util.global private @__device_1 + + // CHECK: util.global private @tensor_global + // CHECK-SAME: stream.affinity = #hal.device.affinity<@__device_1> + util.global private @tensor_global : tensor<4xf32> +} + diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir index c0d15970287d..a3eae92ad9b5 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir @@ -5,7 +5,7 @@ #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> #map2 = affine_map<(d0, d1, d2) -> (d2, d1)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1)> -#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> +#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> : !hal.device module attributes {hal.device.targets = [#device_target_llvm_cpu]} { util.func public @lhs_encoding(%arg0: tensor) -> tensor { %cst = arith.constant 0.000000e+00 : f32 @@ -36,7 +36,7 @@ module attributes {hal.device.targets = [#device_target_llvm_cpu]} { #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> #map2 = affine_map<(d0, d1, d2) -> (d2, d1)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1)> -#device_target_vulkan = #hal.device.target<"vulkan", [#executable_target_vulkan_spirv_fb]> +#device_target_vulkan = #hal.device.target<"vulkan", [#executable_target_vulkan_spirv_fb]> : !hal.device module attributes {hal.device.targets = [#device_target_vulkan]} { util.func public @lhs_encoding(%arg0: tensor) -> tensor { %cst = arith.constant 0.000000e+00 : f32 @@ -69,10 +69,10 @@ module attributes {hal.device.targets = [#device_target_vulkan]} { #map2 = affine_map<(d0, d1, d2) -> (d2, d1)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1)> #executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}> -#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> +#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> : !hal.device #executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> -#device_target_vulkan = #hal.device.target<"vulkan", [#executable_target_vulkan_spirv_fb]> -module attributes {hal.device.targets = [#device_target_vulkan, #device_target_llvm_cpu]} { +#device_target_vulkan = #hal.device.target<"vulkan", [#executable_target_vulkan_spirv_fb]> : !hal.device +module attributes {hal.device.targets = [#hal.device.select<[#device_target_vulkan, #device_target_llvm_cpu]> : !hal.device]} { util.func public @lhs_encoding(%arg0: tensor) -> tensor { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp index f7a8276e9da9..0d68029cb45f 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp @@ -53,8 +53,12 @@ void buildHALInlineStaticTransformPassPipeline( // Device assignment and interface materialization //---------------------------------------------------------------------------- + IREE::HAL::AssignmentOptions assignmentOptions; + assignmentOptions.legacyTargetBackends = targetOptions.legacyTargetBackends; + assignmentOptions.targetDevices = targetOptions.targetDevices; + assignmentOptions.defaultDevice = targetOptions.defaultDevice; IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry, - targetOptions); + assignmentOptions); IREE::HAL::buildHALConfigurationPassPipeline(passManager, targetRegistry, targetOptions); diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp index 370d0abedc16..96c7eb8fbef7 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp @@ -53,8 +53,12 @@ void buildHALInlineDynamicTransformPassPipeline( // Device assignment and interface materialization //---------------------------------------------------------------------------- + IREE::HAL::AssignmentOptions assignmentOptions; + assignmentOptions.legacyTargetBackends = targetOptions.legacyTargetBackends; + assignmentOptions.targetDevices = targetOptions.targetDevices; + assignmentOptions.defaultDevice = targetOptions.defaultDevice; IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry, - targetOptions); + assignmentOptions); IREE::HAL::buildHALConfigurationPassPipeline(passManager, targetRegistry, targetOptions); diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp index 0052ed868bce..0d730cd82a24 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp @@ -77,16 +77,9 @@ void buildIREEPrecompileTransformPassPipeline( PreprocessingOptions preprocessingOptions, GlobalOptimizationOptions globalOptimizationOptions, SchedulingOptions schedulingOptions, - IREE::HAL::TargetOptions executableOptions, IREEVMPipelineHooks &hooks, + IREE::HAL::TargetOptions halTargetOptions, IREEVMPipelineHooks &hooks, OpPassManager &passManager, IREEVMPipelinePhase compileFrom, IREEVMPipelinePhase compileTo) { - // If the user specified a set of target devices we attach them to the module - // IR so that they are available for all passes that may want to use this - // information. If trying to compile in a generic mode the user should omit - // specifying targets. - IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry, - executableOptions); - // Input pipelines can result in changes to the exported functions and types // and must run before generating bindings. // After input processing, there should only be IREE legal types in @@ -143,6 +136,18 @@ void buildIREEPrecompileTransformPassPipeline( if (compileTo == IREEVMPipelinePhase::Input) return; // early-exit + // If the user specified a set of target devices we attach them to the module + // IR so that they are available for all passes that may want to use this + // information. If trying to compile in a generic mode the user should omit + // specifying targets. + IREE::HAL::AssignmentOptions halAssignmentOptions; + halAssignmentOptions.legacyTargetBackends = + halTargetOptions.legacyTargetBackends; + halAssignmentOptions.targetDevices = halTargetOptions.targetDevices; + halAssignmentOptions.defaultDevice = halTargetOptions.defaultDevice; + IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry, + halAssignmentOptions); + // Now that inputs are legalized, generate wrapper for entry functions. if (compileFrom < IREEVMPipelinePhase::ABI) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "ABI"); @@ -242,13 +247,13 @@ void buildIREEVMTransformPassPipeline( PreprocessingOptions preprocessingOptions, GlobalOptimizationOptions globalOptimizationOptions, SchedulingOptions schedulingOptions, - IREE::HAL::TargetOptions executableOptions, - IREE::VM::TargetOptions targetOptions, IREEVMPipelineHooks &hooks, + IREE::HAL::TargetOptions halTargetOptions, + IREE::VM::TargetOptions vmTargetOptions, IREEVMPipelineHooks &hooks, OpPassManager &passManager, IREEVMPipelinePhase compileFrom, IREEVMPipelinePhase compileTo) { buildIREEPrecompileTransformPassPipeline( targetRegistry, bindingOptions, inputOptions, preprocessingOptions, - globalOptimizationOptions, schedulingOptions, executableOptions, hooks, + globalOptimizationOptions, schedulingOptions, halTargetOptions, hooks, passManager, compileFrom, compileTo); if (compileTo <= IREEVMPipelinePhase::GlobalOptimization) @@ -311,16 +316,16 @@ void buildIREEVMTransformPassPipeline( case SchedulingOptions::ExecutionModel::AsyncInternal: case SchedulingOptions::ExecutionModel::AsyncExternal: IREE::HAL::buildHALTransformPassPipeline(passManager, targetRegistry, - executableOptions, hooks, + halTargetOptions, hooks, halCompileFrom, halCompileTo); break; case SchedulingOptions::ExecutionModel::InlineStatic: IREE::HAL::Inline::buildHALInlineStaticTransformPassPipeline( - passManager, targetRegistry, executableOptions); + passManager, targetRegistry, halTargetOptions); break; case SchedulingOptions::ExecutionModel::InlineDynamic: IREE::HAL::Loader::buildHALInlineDynamicTransformPassPipeline( - passManager, targetRegistry, executableOptions); + passManager, targetRegistry, halTargetOptions); break; } if (hooks.afterPhase) @@ -336,7 +341,7 @@ void buildIREEVMTransformPassPipeline( IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "VM"); if (hooks.beforePhase) hooks.beforePhase(IREEVMPipelinePhase::VM, passManager); - IREE::VM::buildVMTransformPassPipeline(passManager, targetOptions); + IREE::VM::buildVMTransformPassPipeline(passManager, vmTargetOptions); passManager.addPass(IREE::Util::createDropCompilerHintsPass()); if (hooks.afterPhase) hooks.afterPhase(IREEVMPipelinePhase::VM, passManager); diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.h b/compiler/src/iree/compiler/Pipelines/Pipelines.h index 1fba7471d76b..cdc754f06f5f 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.h +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.h @@ -102,7 +102,7 @@ void buildIREEPrecompileTransformPassPipeline( PreprocessingOptions preprocessingOptions, GlobalOptimizationOptions highLevelOptimizationOptions, SchedulingOptions schedulingOptions, - IREE::HAL::TargetOptions executableOptions, IREEVMPipelineHooks &hooks, + IREE::HAL::TargetOptions halTargetOptions, IREEVMPipelineHooks &hooks, OpPassManager &passManager, IREEVMPipelinePhase compileFrom = IREEVMPipelinePhase::Start, IREEVMPipelinePhase compileTo = IREEVMPipelinePhase::GlobalOptimization); @@ -118,8 +118,8 @@ void buildIREEVMTransformPassPipeline( PreprocessingOptions preprocessingOptions, GlobalOptimizationOptions highLevelOptimizationOptions, SchedulingOptions schedulingOptions, - IREE::HAL::TargetOptions executableOptions, - IREE::VM::TargetOptions targetOptions, IREEVMPipelineHooks &hooks, + IREE::HAL::TargetOptions halTargetOptions, + IREE::VM::TargetOptions vmTargetOptions, IREEVMPipelineHooks &hooks, OpPassManager &passManager, IREEVMPipelinePhase compileFrom = IREEVMPipelinePhase::Start, IREEVMPipelinePhase compileTo = IREEVMPipelinePhase::End); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir index aba35c8bf4b5..f7f832859c0f 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir @@ -2,7 +2,6 @@ // RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv},canonicalize))" | FileCheck %s -check-prefix=CONVOLUTION // RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=contraction},canonicalize))" | FileCheck %s -check-prefix=CONTRACT - // CHECK: func.func @matmul_static( // CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xf16>, // CHECK-SAME: %[[ARG1:.+]]: tensor<20x30xf16>, diff --git a/docs/website/docs/community/blog/posts/microkernels.md b/docs/website/docs/community/blog/posts/microkernels.md index 2cb59280b811..62c8e6a6a26f 100644 --- a/docs/website/docs/community/blog/posts/microkernels.md +++ b/docs/website/docs/community/blog/posts/microkernels.md @@ -338,7 +338,7 @@ This then goes to the LLVM x86 backend, which produces x86 assembly. [...] // -----// IR Dump After Inliner (inline) //----- // #executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver4", cpu_features = "+mmx,+popcnt,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+avx,+avx2,+sse4a,+fma,+avx512f,+bmi,+bmi2,+aes,+pclmul,+avx512vl,+avx512bw,+avx512dq,+avx512cd,+avx512vbmi,+avx512ifma,+avx512vpopcntdq,+avx512vbmi2,+gfni,+vpclmulqdq,+avx512vnni,+avx512bitalg,+avx512bf16,+adx,+clflushopt,+clwb,+clzero,+cx16,+cx8,+crc32,+f16c,+fsgsbase,+fxsr,+invpcid,+lzcnt,+movbe,+mwaitx,+pku,+prfchw,+rdpid,+rdpru,+rdrnd,+rdseed,+sahf,+sha,+shstk,+vaes,+wbnoinvd,+x87,+xsave,+xsavec,+xsaveopt,+xsaves,+evex512", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 64 : index, target_triple = "x86_64-unknown-unknown-eabi-elf", ukernels = "all"}> -#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}> +#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}> : !hal.device module attributes {hal.device.targets = [#device_target_llvm_cpu]} { func.func @matmul_dynamic(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_dynamic(%input0: tensor, %input1: tensor, %input2: tensor) -> (%output0: tensor)"}} { %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index @@ -367,7 +367,7 @@ module attributes {hal.device.targets = [#device_target_llvm_cpu]} { // -----// IR Dump After CSE (cse) //----- // #executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver4", cpu_features = "+mmx,+popcnt,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+avx,+avx2,+sse4a,+fma,+avx512f,+bmi,+bmi2,+aes,+pclmul,+avx512vl,+avx512bw,+avx512dq,+avx512cd,+avx512vbmi,+avx512ifma,+avx512vpopcntdq,+avx512vbmi2,+gfni,+vpclmulqdq,+avx512vnni,+avx512bitalg,+avx512bf16,+adx,+clflushopt,+clwb,+clzero,+cx16,+cx8,+crc32,+f16c,+fsgsbase,+fxsr,+invpcid,+lzcnt,+movbe,+mwaitx,+pku,+prfchw,+rdpid,+rdpru,+rdrnd,+rdseed,+sahf,+sha,+shstk,+vaes,+wbnoinvd,+x87,+xsave,+xsavec,+xsaveopt,+xsaves,+evex512", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 64 : index, target_triple = "x86_64-unknown-unknown-eabi-elf", ukernels = "all"}> #map = affine_map<()[s0] -> (s0 ceildiv 16)> -#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}> +#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}> : !hal.device module attributes {hal.device.targets = [#device_target_llvm_cpu]} { func.func @matmul_dynamic(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_dynamic(%input0: tensor, %input1: tensor, %input2: tensor) -> (%output0: tensor)"}} { %cst = arith.constant 0.000000e+00 : f32 diff --git a/runtime/src/iree/vm/bytecode/disassembler.c b/runtime/src/iree/vm/bytecode/disassembler.c index 853c36fa40bc..ed6684c50410 100644 --- a/runtime/src/iree/vm/bytecode/disassembler.c +++ b/runtime/src/iree/vm/bytecode/disassembler.c @@ -1326,6 +1326,7 @@ iree_status_t iree_vm_bytecode_disassemble_op( IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(b, " : ")); EMIT_REF_REG_NAME(false_value_reg); EMIT_OPTIONAL_VALUE_REF(®s->ref[false_value_reg]); + IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(b, " -> !")); EMIT_TYPE_NAME(type_def); break; } diff --git a/samples/custom_dispatch/cpu/embedded/example_hal.mlir b/samples/custom_dispatch/cpu/embedded/example_hal.mlir index e9edfd57c0d5..91a87ad67085 100644 --- a/samples/custom_dispatch/cpu/embedded/example_hal.mlir +++ b/samples/custom_dispatch/cpu/embedded/example_hal.mlir @@ -43,7 +43,7 @@ // compiled binary (CPU + Vulkan, etc). #cpu_target = #hal.device.target<"llvm-cpu", [ #x86_64_target -]> +]> : !hal.device module @example attributes {hal.device.targets = [#cpu_target]} { diff --git a/samples/custom_dispatch/cpu/embedded/example_stream.mlir b/samples/custom_dispatch/cpu/embedded/example_stream.mlir index a8b6861fa9f2..910a0070982f 100644 --- a/samples/custom_dispatch/cpu/embedded/example_stream.mlir +++ b/samples/custom_dispatch/cpu/embedded/example_stream.mlir @@ -48,7 +48,7 @@ #cpu_target = #hal.device.target<"llvm-cpu", [ #arm_64_target, #x86_64_target -]> +]> : !hal.device module @example attributes {hal.device.targets = [#cpu_target]} { diff --git a/samples/custom_dispatch/cpu/embedded/example_transform.mlir b/samples/custom_dispatch/cpu/embedded/example_transform.mlir index 709a01671654..858052cc3cd5 100644 --- a/samples/custom_dispatch/cpu/embedded/example_transform.mlir +++ b/samples/custom_dispatch/cpu/embedded/example_transform.mlir @@ -28,7 +28,7 @@ // hence we only support llvm-cpu here. #cpu_target = #hal.device.target<"llvm-cpu", [ #x86_64_target -]> +]> : !hal.device #map = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0)> diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir index 599ed8afe79f..2aa594390032 100644 --- a/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir +++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir @@ -21,7 +21,7 @@ // hence we only support llvm-cpu here. #cpu_target = #hal.device.target<"llvm-cpu", [ #x86_64_target -]> +]> : !hal.device #map = affine_map<(d0, d1) -> (d0, d1)> module @example attributes {hal.device.targets = [#cpu_target]} { diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg.mlir index 3bc9f122f60b..c725daffdb39 100644 --- a/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg.mlir +++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg.mlir @@ -72,7 +72,7 @@ // hence we only support llvm-cpu here. #cpu_target = #hal.device.target<"llvm-cpu", [ #x86_64_target -]> +]> : !hal.device #map = affine_map<(d0, d1) -> (d0, d1)> module @example attributes {hal.device.targets = [#cpu_target]} { diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg_two_matmul.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg_two_matmul.mlir index 8636c9419d95..0dcc3580785c 100644 --- a/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg_two_matmul.mlir +++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg_two_matmul.mlir @@ -29,7 +29,7 @@ // hence we only support llvm-cpu here. #cpu_target = #hal.device.target<"llvm-cpu", [ #x86_64_target -]> +]> : !hal.device #map = affine_map<(d0, d1) -> (d0, d1)> module @example attributes {hal.device.targets = [#cpu_target]} { diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_torch.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_torch.mlir index 787608555cab..6b6fbf13f968 100644 --- a/samples/custom_dispatch/cpu/mlp_plugin/mlp_torch.mlir +++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_torch.mlir @@ -41,7 +41,7 @@ // hence we only support llvm-cpu here. #cpu_target = #hal.device.target<"llvm-cpu", [ #x86_64_target -]> +]> : !hal.device #map = affine_map<(d0, d1) -> (d0, d1)> module @example attributes {hal.device.targets = [#cpu_target]} { diff --git a/samples/custom_dispatch/cpu/mlp_plugin/mlp_tosa.mlir b/samples/custom_dispatch/cpu/mlp_plugin/mlp_tosa.mlir index 0b27ae22beac..4bb05935748f 100644 --- a/samples/custom_dispatch/cpu/mlp_plugin/mlp_tosa.mlir +++ b/samples/custom_dispatch/cpu/mlp_plugin/mlp_tosa.mlir @@ -41,7 +41,7 @@ // hence we only support llvm-cpu here. #cpu_target = #hal.device.target<"llvm-cpu", [ #x86_64_target -]> +]> : !hal.device module @example attributes {hal.device.targets = [#cpu_target]} { func.func @mlp_invocation(%lhs: tensor<2x4xf32>, %rhs : tensor<4x8xf32>) -> tensor<2x8xf32> { diff --git a/samples/custom_dispatch/cuda/kernels/example.mlir b/samples/custom_dispatch/cuda/kernels/example.mlir index 15a3bb43bc31..62e49c6a94e6 100644 --- a/samples/custom_dispatch/cuda/kernels/example.mlir +++ b/samples/custom_dispatch/cuda/kernels/example.mlir @@ -27,7 +27,7 @@ #cuda_target = #hal.device.target<"cuda", [ #nvptx_sm_52_target, #nvptx_sm_80_target -]> +]> : !hal.device module @example attributes {hal.device.targets = [#cuda_target]} { diff --git a/samples/custom_dispatch/hip/kernels/example.mlir b/samples/custom_dispatch/hip/kernels/example.mlir index 8819d867cc22..3ca1bad2430f 100644 --- a/samples/custom_dispatch/hip/kernels/example.mlir +++ b/samples/custom_dispatch/hip/kernels/example.mlir @@ -23,7 +23,7 @@ // compiled binary. #rocm_target = #hal.device.target<"rocm", [ #rocm_gfx1100_target -]> +]> : !hal.device module @example attributes {hal.device.targets = [#rocm_target]} { diff --git a/samples/custom_dispatch/vulkan/shaders/example.mlir b/samples/custom_dispatch/vulkan/shaders/example.mlir index ef10fb7b7dbd..69843aa1691e 100644 --- a/samples/custom_dispatch/vulkan/shaders/example.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example.mlir @@ -29,7 +29,7 @@ // compiled binary. #vulkan_target = #hal.device.target<"vulkan", [ #spirv_target -]> +]> : !hal.device module @example attributes {hal.device.targets = [#vulkan_target]} { diff --git a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir index 36912bb35df9..41576518089f 100644 --- a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir @@ -27,7 +27,7 @@ // These can come from compiler flags and multiple targets can be supported // It's possible, for example, to support targeting multiple devices in the same // compiled binary. -#vulkan_target = #hal.device.target<"vulkan", [#spirv_target]> +#vulkan_target = #hal.device.target<"vulkan", [#spirv_target]> : !hal.device module @example attributes {hal.device.targets = [#vulkan_target]} { diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir index b4885a03081d..4bea02d2210e 100644 --- a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir @@ -33,7 +33,7 @@ // hence we only support vulkan here. It is possible to hand author a custom // kernel that supports multiple targets by specifying an object per-target, but // that requires authoring the kernel for multiple targets. -#vulkan_target = #hal.device.target<"vulkan", [#spirv_target]> +#vulkan_target = #hal.device.target<"vulkan", [#spirv_target]> : !hal.device #map = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0)> diff --git a/samples/multiple_modules/pipeline_async.mlir b/samples/multiple_modules/pipeline_async.mlir index 46ad83cb3164..367635963a56 100644 --- a/samples/multiple_modules/pipeline_async.mlir +++ b/samples/multiple_modules/pipeline_async.mlir @@ -1,7 +1,8 @@ // RUN: (iree-compile --iree-execution-model=async-external --iree-hal-target-backends=vmvx %p/module_a.mlir -o=%t.module_a.vmfb && \ // RUN: iree-compile --iree-execution-model=async-external --iree-hal-target-backends=vmvx %p/module_b.mlir -o=%t.module_b.vmfb && \ // RUN: iree-compile --iree-execution-model=async-external --iree-hal-target-backends=vmvx %s | \ -// RUN: iree-run-module --device=local-task \ +// RUN: iree-run-module \ +// RUN: --device=local-task \ // RUN: --module=%t.module_a.vmfb \ // RUN: --module=%t.module_b.vmfb \ // RUN: --module=- --function=run \ diff --git a/samples/multiple_modules/pipeline_sync.mlir b/samples/multiple_modules/pipeline_sync.mlir index 3f9a6e0335ef..b9f8d15b633c 100644 --- a/samples/multiple_modules/pipeline_sync.mlir +++ b/samples/multiple_modules/pipeline_sync.mlir @@ -1,7 +1,8 @@ // RUN: (iree-compile --iree-hal-target-backends=vmvx %p/module_a.mlir -o=%t.module_a.vmfb && \ // RUN: iree-compile --iree-hal-target-backends=vmvx %p/module_b.mlir -o=%t.module_b.vmfb && \ // RUN: iree-compile --iree-hal-target-backends=vmvx %s | \ -// RUN: iree-run-module --device=local-sync \ +// RUN: iree-run-module \ +// RUN: --device=local-sync \ // RUN: --module=%t.module_a.vmfb \ // RUN: --module=%t.module_b.vmfb \ // RUN: --module=- --function=run \ diff --git a/samples/transform_dialect/example_module.mlir b/samples/transform_dialect/example_module.mlir index 585bb2591534..017f39381820 100644 --- a/samples/transform_dialect/example_module.mlir +++ b/samples/transform_dialect/example_module.mlir @@ -35,7 +35,7 @@ module attributes { #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", { iree.gpu.target = #target }> - ]> + ]> : !hal.device ] } { hal.executable private @example_module_dispatch_0 { diff --git a/tests/compiler_driver/precompile.mlir b/tests/compiler_driver/precompile.mlir index 5cdd11784fc2..b25cb34b5479 100644 --- a/tests/compiler_driver/precompile.mlir +++ b/tests/compiler_driver/precompile.mlir @@ -7,4 +7,6 @@ func.func @test(%arg0 : tensor<10x20xf32>, %arg1 : tensor<20x30xf32>, %arg2 : te } // Just check that we have the right target and executable targets. -// CHECK: module attributes {hal.device.targets = [#hal.device.target<"local", [#hal.executable.target<"vmvx" +// CHECK: module +// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@[[DEVICE:.+]]> +// CHECK: util.global private @[[DEVICE]] = #hal.device.target<"local", [#hal.executable.target<"vmvx" diff --git a/tests/compiler_driver/preprocessing_flags.mlir b/tests/compiler_driver/preprocessing_flags.mlir index 331398854177..f6e8adc49107 100644 --- a/tests/compiler_driver/preprocessing_flags.mlir +++ b/tests/compiler_driver/preprocessing_flags.mlir @@ -13,7 +13,7 @@ func.func @test(%arg0 : tensor<10x20xf32>, %arg1 : tensor<20x30xf32>, %arg2 : te // CHECK: ConvertConv2DToImg2ColPass (iree-preprocessing-convert-conv2d-to-img2col) // CHECK: PadLinalgOpsPass (iree-preprocessing-pad-linalg-ops) // CHECK-LABEL: module -// CHECK-NEXT: util.func public @test( +// CHECK: util.func public @test( // CHECK-DAG: %[[ARG0:.+]] = hal.tensor.import %{{[a-zA-Z0-9]+}} "input0" : !hal.buffer_view -> tensor<10x20xf32> // CHECK-DAG: %[[ARG1:.+]] = hal.tensor.import %{{[a-zA-Z0-9]+}} "input1" : !hal.buffer_view -> tensor<20x30xf32> // CHECK-DAG: %[[ARG2:.+]] = hal.tensor.import %{{[a-zA-Z0-9]+}} "input2" : !hal.buffer_view -> tensor<10x30xf32> diff --git a/tests/e2e/regression/libm_linking.mlir b/tests/e2e/regression/libm_linking.mlir index e63e593436ae..5cbeff009848 100644 --- a/tests/e2e/regression/libm_linking.mlir +++ b/tests/e2e/regression/libm_linking.mlir @@ -1,5 +1,5 @@ -// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=llvm-cpu},iree-transformation-pipeline)' %s | FileCheck %s -// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=llvm-cpu},iree-transformation-pipeline)' --iree-llvmcpu-link-embedded=false %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=llvm-cpu},iree-transformation-pipeline)' %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=llvm-cpu},iree-transformation-pipeline)' --iree-llvmcpu-link-embedded=false %s | FileCheck %s // When lowering to CPU code through LLVM, certain LLVM intrinsics require // linking against libm (the standard C library of math functions, `-lm`). From 0dbfc134a864303a1f828736da3638a9284a4532 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 28 May 2024 08:36:33 -0700 Subject: [PATCH 17/25] Stripping affinity attrs earlier in the pipeline. --- .../Dialect/Stream/Transforms/ConvertToStream.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp index 6e26e2440e2e..aa5cb25e5499 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp @@ -181,12 +181,10 @@ struct GenericResourcePattern : public ConversionPattern { } }; -namespace { struct OptimizationBarrierOpConversion : public OpConversionPattern { using OpConversionPattern< IREE::Util::OptimizationBarrierOp>::OpConversionPattern; - LogicalResult matchAndRewrite(IREE::Util::OptimizationBarrierOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -204,7 +202,14 @@ struct OptimizationBarrierOpConversion return success(); } }; -} // namespace + +static void stripAffinityAttrs(ModuleOp moduleOp) { + moduleOp->removeAttr("stream.affinity.default"); + auto affinityName = StringAttr::get(moduleOp.getContext(), "stream.affinity"); + for (auto &op : moduleOp.getOps()) { + op.removeDiscardableAttr(affinityName); + } +} //===----------------------------------------------------------------------===// // --iree-stream-conversion @@ -290,6 +295,9 @@ struct ConvertToStreamPass final std::move(patterns)))) { return signalPassFailure(); } + + // Strip affinity ops as they are no longer required. + stripAffinityAttrs(getOperation()); } }; From fd1277d347614df775c2a1e91bbb234c81dd8db1 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 28 May 2024 12:11:06 -0700 Subject: [PATCH 18/25] Changing flow.tensor.load/store lowerings to avoid transfers. This should make it more efficient to load/store partial values at the cost of possibly transfering multiple slices when loading/storing many values. Those should be changed to use larger staging buffer transfers anyway, though. --- .../Conversion/FlowToStream/Patterns.cpp | 121 ++++++++++++++---- .../FlowToStream/test/tensor_ops.mlir | 61 +++++---- compiler/src/iree/compiler/Utils/IntegerSet.h | 18 +++ 3 files changed, 152 insertions(+), 48 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp index 316e15ac59c2..8e4d854208d8 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp @@ -261,6 +261,17 @@ struct ConvertTensorUpdateOp } }; +static bool isScalarTensor(RankedTensorType type) { + if (type.getRank() == 0) + return true; // tensor + if (!type.hasStaticShape()) + return false; // tensor<...?...xi32> + int64_t elementCount = 1; + for (int64_t dim : type.getShape()) + elementCount *= dim; + return elementCount == 1; // tensor<1xi32> or tensor<1x1x1xi32> +} + struct ConvertTensorLoadOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -271,20 +282,71 @@ struct ConvertTensorLoadOp auto source = consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + // If the source is not a staging resource then we need to transfer it to + // a staging resource. We slice out just what is being loaded so that we + // don't transfer the entire tensor. If loading multiple values from the + // same tensor we'll either want to have batched that before this point + // by loading an entire buffer or after by coalescing the slices. + // + // If already a staging resource then we can fast-path load the value. auto stagingType = rewriter.getType( IREE::Stream::Lifetime::Staging); - auto loadSource = source.resource; - if (source.resource.getType() != stagingType) { - loadSource = rewriter.createOrFold( + if (source.resource.getType() == stagingType) { + rewriter.replaceOpWithNewOp( + op, resultType, source.resource, op.getSource().getType(), + adaptor.getSourceDims(), source.resourceSize, adaptor.getIndices()); + return success(); + } + + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + + // Scalar tensors get transferred without slicing. + auto sourceEncoding = op.getSource().getType(); + if (isScalarTensor(sourceEncoding)) { + auto transferOp = rewriter.create( op.getLoc(), stagingType, source.resource, source.resourceSize, source.resourceSize, /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op), - /*result_affinity=*/nullptr); + /*result_affinity=*/IREE::Stream::AffinityAttr::lookup(op)); + rewriter.replaceOpWithNewOp( + op, resultType, transferOp.getResult(), sourceEncoding, + adaptor.getSourceDims(), transferOp.getResultSize(), + adaptor.getIndices()); + return success(); } + // Slice out the individual element value. + IndexSet indexSet(op.getLoc(), rewriter); + indexSet.populate(adaptor.getIndices()); + SmallVector sliceIndices; + SmallVector sliceLengths; + SmallVector loadIndices; + SmallVector resultDims; + for (auto index : adaptor.getIndices()) { + // TODO(benvanik): support larger buffer slices. + sliceIndices.push_back(index); + sliceLengths.push_back(indexSet.get(1)); + loadIndices.push_back(indexSet.get(0)); + resultDims.push_back(1); + } + auto resultEncoding = + RankedTensorType::get(resultDims, sourceEncoding.getElementType(), + sourceEncoding.getEncoding()); + Value resultSize = rewriter.create( + op.getLoc(), resultEncoding, ValueRange{}, affinityAttr); + auto sliceOp = rewriter.create( + op.getLoc(), source.resource.getType(), source.resource, sourceEncoding, + adaptor.getSourceDims(), source.resourceSize, sliceIndices, + sliceLengths, resultEncoding, ValueRange{}, resultSize, affinityAttr); + auto transferOp = rewriter.create( + op.getLoc(), stagingType, sliceOp.getResult(), sliceOp.getResultSize(), + sliceOp.getResultSize(), + /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op), + /*result_affinity=*/IREE::Stream::AffinityAttr::lookup(op)); rewriter.replaceOpWithNewOp( - op, resultType, loadSource, op.getSource().getType(), - op.getSourceDims(), source.resourceSize, adaptor.getIndices()); + op, resultType, transferOp.getResult(), sliceOp.getResultEncoding(), + sliceOp.getResultEncodingDims(), transferOp.getResultSize(), + loadIndices); return success(); } }; @@ -298,32 +360,39 @@ struct ConvertTensorStoreOp auto target = consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + // If the target is a staging resource then we can directly store into it + // with a fast-path. Otherwise we need to stage an upload. auto stagingType = rewriter.getType( IREE::Stream::Lifetime::Staging); - auto storeTarget = target.resource; - if (target.resource.getType() != stagingType) { - storeTarget = rewriter.createOrFold( - op.getLoc(), stagingType, storeTarget, target.resourceSize, - target.resourceSize, - /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op), - /*result_affinity=*/nullptr); + if (target.resource.getType() == stagingType) { + rewriter.replaceOpWithNewOp( + op, target.resource.getType(), target.resource, + op.getTarget().getType(), adaptor.getTargetDims(), + target.resourceSize, adaptor.getIndices(), adaptor.getValue()); + return success(); } - auto newOp = rewriter.create( - op.getLoc(), storeTarget.getType(), storeTarget, - op.getTarget().getType(), adaptor.getTargetDims(), target.resourceSize, - adaptor.getIndices(), adaptor.getValue()); - - Value newResult = newOp.getResult(); - if (target.resource.getType() != stagingType) { - newResult = rewriter.createOrFold( - op.getLoc(), target.resource.getType(), newResult, - target.resourceSize, target.resourceSize, - /*source_affinity=*/nullptr, - /*result_affinity=*/IREE::Stream::AffinityAttr::lookup(op)); + // Scalar tensors disconnect from the original target. + auto targetEncoding = op.getTarget().getType(); + if (isScalarTensor(targetEncoding)) { + rewriter.replaceOpWithNewOp( + op, target.resource.getType(), adaptor.getValue(), targetEncoding, + adaptor.getTargetDims(), target.resourceSize, + IREE::Stream::AffinityAttr::lookup(op)); + return success(); } - rewriter.replaceOp(op, {newResult}); + // Use fill to store the value. + // TODO(benvanik): support larger buffer slices (stage + update). + IndexSet indexSet(op.getLoc(), rewriter); + indexSet.populate(adaptor.getIndices()); + SmallVector lengths; + for (auto index : adaptor.getIndices()) + lengths.push_back(indexSet.get(1)); + rewriter.replaceOpWithNewOp( + op, target.resource, targetEncoding, adaptor.getTargetDims(), + target.resourceSize, adaptor.getIndices(), lengths, adaptor.getValue(), + IREE::Stream::AffinityAttr::lookup(op)); return success(); } }; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir index abc96e7dd832..a755d44ad27b 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir @@ -179,49 +179,66 @@ util.func public @tensorUpdate(%update : tensor<1x1x10xf32>, %target : tensor<5x // ----- -util.global private @device : !hal.device - // CHECK-LABEL: @tensorLoad // CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource<*>, %[[SOURCE_SIZE:.+]]: index) util.func public @tensorLoad(%source : tensor<2x3xi32>) -> i32 { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - // CHECK: %[[T0:.+]] = stream.async.transfer - // CHECK-SAME: %[[SOURCE]] : !stream.resource<*>{%[[SOURCE_SIZE]]} - // CHECK-SAME: from(#hal.device.affinity<@device>) -> !stream.resource{%[[SOURCE_SIZE]]} - // CHECK: %[[T1:.+]] = stream.tensor.load %[[T0]][%c0, %c1] : tensor<2x3xi32> in !stream.resource{%[[SOURCE_SIZE]]} -> i32 - %0 = flow.tensor.load %source[%c0, %c1] : tensor<2x3xi32> attributes { - stream.affinity = #hal.device.affinity<@device> - } - // CHECK: util.return %[[T1]] + // CHECK: %[[SLICE_SIZE:.+]] = stream.tensor.sizeof tensor<1x1xi32> + // CHECK: %[[SLICE:.+]] = stream.tensor.slice %[[SOURCE]][%c0, %c1 for %c1, %c1] : tensor<2x3xi32> in !stream.resource<*>{%[[SOURCE_SIZE]]} -> tensor<1x1xi32> in !stream.resource<*>{%[[SLICE_SIZE]]} + // CHECK: %[[STAGING:.+]] = stream.async.transfer + // CHECK-SAME: %[[SLICE]] : !stream.resource<*>{%[[SLICE_SIZE]]} + // CHECK-SAME: !stream.resource{%[[SLICE_SIZE]]} + // CHECK: %[[VALUE:.+]] = stream.tensor.load %[[STAGING]][%c0, %c0] : tensor<1x1xi32> in !stream.resource{%[[SLICE_SIZE]]} -> i32 + %0 = flow.tensor.load %source[%c0, %c1] : tensor<2x3xi32> + // CHECK: util.return %[[VALUE]] util.return %0 : i32 } // ----- -util.global private @device : !hal.device +// CHECK-LABEL: @tensorLoadScalar +// CHECK-SAME: (%[[SOURCE:.+]]: !stream.resource<*>, %[[SOURCE_SIZE:.+]]: index) +util.func public @tensorLoadScalar(%source : tensor) -> i32 { + // CHECK: %[[STAGING:.+]] = stream.async.transfer + // CHECK-SAME: %[[SOURCE]] : !stream.resource<*>{%[[SOURCE_SIZE]]} + // CHECK-SAME: !stream.resource{%[[SOURCE_SIZE]]} + // CHECK: %[[VALUE:.+]] = stream.tensor.load %[[STAGING]] : tensor in !stream.resource{%[[SOURCE_SIZE]]} -> i32 + %0 = flow.tensor.load %source : tensor + // CHECK: util.return %[[VALUE]] + util.return %0 : i32 +} + +// ----- // CHECK-LABEL: @tensorStore // CHECK-SAME: (%[[TARGET:.+]]: !stream.resource<*>, %[[TARGET_SIZE:.+]]: index) util.func public @tensorStore(%target : tensor<2x3xi32>) -> tensor<2x3xi32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c9 = arith.constant 9 : i32 - // CHECK: %[[T0:.+]] = stream.async.transfer %[[TARGET]] : !stream.resource<*>{%[[TARGET_SIZE]]} - // CHECK-SAME: from(#hal.device.affinity<@device>) -> !stream.resource{%[[TARGET_SIZE]]} - // CHECK: %[[T1:.+]] = stream.tensor.store %c9_i32, %[[T0]][%c0, %c1] : - // CHECK-SAME: i32 -> tensor<2x3xi32> in %[[T0]] as !stream.resource{%[[TARGET_SIZE]]} - // CHECK: %[[T2:.+]] = stream.async.transfer %[[T1]] : !stream.resource{%[[TARGET_SIZE]]} -> - // CHECK-SAME: to(#hal.device.affinity<@device>) !stream.resource<*>{%[[TARGET_SIZE]]} - %0 = flow.tensor.store %c9, %target[%c0, %c1] : tensor<2x3xi32> attributes { - stream.affinity = #hal.device.affinity<@device> - } - // CHECK: util.return %[[T2]] + // CHECK: %[[VALUE:.+]] = arith.constant 9 + %value = arith.constant 9 : i32 + // CHECK: %[[FILL:.+]] = stream.tensor.fill %[[VALUE]], %[[TARGET]][%c0, %c1 for %c1, %c1] : i32 -> tensor<2x3xi32> in %[[TARGET]] as !stream.resource<*>{%[[TARGET_SIZE]]} + %0 = flow.tensor.store %value, %target[%c0, %c1] : tensor<2x3xi32> + // CHECK: util.return %[[FILL]] util.return %0 : tensor<2x3xi32> } // ----- +// CHECK-LABEL: @tensorStoreScalar +// CHECK-SAME: (%[[TARGET:.+]]: !stream.resource<*>, %[[TARGET_SIZE:.+]]: index) +util.func public @tensorStoreScalar(%target : tensor) -> tensor { + // CHECK: %[[VALUE:.+]] = arith.constant 9 + %value = arith.constant 9 : i32 + // CHECK: %[[SPLAT:.+]] = stream.tensor.splat %[[VALUE]] : i32 -> tensor in !stream.resource<*>{%[[TARGET_SIZE]]} + %0 = flow.tensor.store %value, %target : tensor + // CHECK: util.return %[[SPLAT]] + util.return %0 : tensor +} + +// ----- + // CHECK-LABEL: @tensorTrace // CHECK-SAME: (%[[TENSOR0:.+]]: !stream.resource<*>, %[[TENSOR0_SIZE:.+]]: index, %[[TENSOR1:.+]]: !stream.resource<*>, %[[TENSOR1_SIZE:.+]]: index, %[[TENSOR1_DIM0:.+]]: index, %[[TENSOR1_DIM2:.+]]: index) util.func public @tensorTrace(%tensor0: tensor<5xf32>, %tensor1: tensor, %tensor1_dim0: index, %tensor1_dim2: index) { diff --git a/compiler/src/iree/compiler/Utils/IntegerSet.h b/compiler/src/iree/compiler/Utils/IntegerSet.h index 594eecd92f1b..aed0376fe4f7 100644 --- a/compiler/src/iree/compiler/Utils/IntegerSet.h +++ b/compiler/src/iree/compiler/Utils/IntegerSet.h @@ -33,6 +33,15 @@ class IntegerSet { return memoizedValue; } + Value add(StorageT lhs, StorageT rhs) { return get(lhs + rhs); } + Value add(Value lhs, StorageT rhs) { + APInt lhsValue; + if (matchPattern(lhs, m_ConstantInt(&lhsValue))) { + return add(lhsValue.getSExtValue(), rhs); + } + return builder.create(loc, lhs, get(rhs)); + } + void populate(ValueRange values) { for (auto value : values) { APInt intValue; @@ -66,6 +75,15 @@ class IndexSet { } Value get(APInt value) { return get(value.getSExtValue()); } + Value add(int64_t lhs, int64_t rhs) { return get(lhs + rhs); } + Value add(Value lhs, int64_t rhs) { + APInt lhsValue; + if (matchPattern(lhs, m_ConstantInt(&lhsValue))) { + return add(lhsValue.getSExtValue(), rhs); + } + return builder.create(loc, lhs, get(rhs)); + } + void populate(ValueRange values) { for (auto value : values) { APInt intValue; From 31e6e1aa18bb0fbe65a832d6a5b6c8b34bbbe5c6 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 28 May 2024 13:44:55 -0700 Subject: [PATCH 19/25] Removing affinity from stream.timepoint.await. --- .../Dialect/Stream/IR/StreamOpFolders.cpp | 16 +--------------- .../iree/compiler/Dialect/Stream/IR/StreamOps.td | 5 +---- .../Stream/Transforms/ScheduleExecution.cpp | 3 --- .../Transforms/test/schedule_execution.mlir | 2 -- 4 files changed, 2 insertions(+), 24 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp index 24d939a7ff73..84cf1eb0e11f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp @@ -3263,16 +3263,8 @@ struct GroupAwaitsByTimepoint : public OpRewritePattern { if (dominanceInfo.dominates(use.getOwner(), op)) continue; auto awaitOp = dyn_cast(use.getOwner()); - if (!awaitOp || - !AffinityAttr::areCompatible( - llvm::dyn_cast_if_present(op.getAffinityAttr()), - llvm::dyn_cast_if_present( - awaitOp.getAffinityAttr()))) { - // Can't combine if the affinities differ as the wait semantics are - // load-bearing. Probably. They really shouldn't be. - // TODO(benvanik): remove affinity from stream.timepoint.await. + if (!awaitOp) continue; - } // Ensure all dependencies of the await op are available. if (!areAllOperandsDefinedBy(awaitOp, op, dominanceInfo)) { // One or more operands is defined after op so we can't merge. @@ -3299,9 +3291,6 @@ struct GroupAwaitsByTimepoint : public OpRewritePattern { } auto newOp = rewriter.create( op.getLoc(), newOperands, newOperandSizes, op.getAwaitTimepoint()); - if (op.getAffinity().has_value()) { - newOp.setAffinityAttr(op.getAffinityAttr()); - } // Replace covered ops with the new results. unsigned resultIdx = 0; @@ -3349,9 +3338,6 @@ struct FoldDuplicateAwaitResources : public OpRewritePattern { // Create replacement op with deduped operands/results. auto newOp = rewriter.create( op.getLoc(), newOperands, newOperandSizes, op.getAwaitTimepoint()); - if (op.getAffinity().has_value()) { - newOp.setAffinityAttr(op.getAffinityAttr()); - } // Replace all duplicate results with the base results. for (auto &replacement : replacements) { diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td index cb362716081b..99e793a00f63 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td @@ -3753,7 +3753,6 @@ def Stream_TimepointBarrierOp : Stream_PureOp<"timepoint.barrier", [ def Stream_TimepointAwaitOp : Stream_PureOp<"timepoint.await", [ AttrSizedOperandSegments, - Stream_AffinityOp, Stream_TimelineOp, Util_SizeAwareOp, DeclareOpInterfaceMethods>:$resource_operands, Variadic:$resource_operand_sizes, - Stream_Timepoint:$await_timepoint, - OptionalAttr:$affinity + Stream_Timepoint:$await_timepoint ); let results = (outs Variadic` $resource_operands `:` custom(type($resource_operands), diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp index f8278ffef19d..9c5e3d4570d3 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp @@ -275,9 +275,6 @@ LogicalResult processRegion(Location loc, MLIRContext *context, Region ®ion, auto awaitOp = builder.create( executeOp.getLoc(), newResult, newResultSize, executeOp.getResultTimepoint()); - if (executeOp.getAffinity().has_value()) { - awaitOp.setAffinityAttr(executeOp.getAffinityAttr()); - } // Explicitly copy the Value since it is marked as const. Value toBeDeleted = oldResult; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir index 0f33b51cbc0f..dcdc586184be 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir @@ -75,7 +75,6 @@ util.func public @partitioningWithAffinities(%arg0: !stream.resource) // CHECK-NEXT: } => !stream.timepoint // CHECK-NEXT: %[[READY:.+]] = stream.timepoint.await - // CHECK-SAME: on(#hal.device.affinity<@device_b>) // CHECK-SAME: %[[TIMEPOINT1]] => %[[RESULT]] : !stream.resource{%c20} // CHECK-NEXT: util.return %[[READY]] util.return %dispatch2 : !stream.resource @@ -133,7 +132,6 @@ util.func public @partitioningWithConcurrentAffinities(%arg0: !stream.resource !stream.timepoint // CHECK-NEXT: %[[READY:.+]] = stream.timepoint.await - // CHECK-SAME: on(#hal.device.affinity<@device_c>) // CHECK-SAME: %[[TIMEPOINT2]] => %[[RESULT]] : !stream.resource{%c20} // CHECK-NEXT: util.return %[[READY]] util.return %dispatch2 : !stream.resource From 710287135c92ed641529aca9c31d15d99f9ed043 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 28 May 2024 15:25:04 -0700 Subject: [PATCH 20/25] Handling AffinityOpInterface on stream.async.transfer. --- .../Transforms/MaterializeTargetDevices.cpp | 5 +- .../Dialect/Stream/Analysis/Partitioning.cpp | 4 +- .../Partitioning/ReferencePartitioning.cpp | 6 +- .../Conversion/FlowToStream/Patterns.cpp | 60 ++++++++----- .../Conversion/HALToStream/Patterns.cpp | 12 +-- .../Dialect/Stream/IR/StreamInterfaces.td | 10 ++- .../compiler/Dialect/Stream/IR/StreamOps.cpp | 11 +++ .../compiler/Dialect/Stream/IR/StreamOps.td | 60 ++++--------- .../Dialect/Stream/IR/StreamTypes.cpp | 4 +- .../Stream/Transforms/ConvertToStream.cpp | 25 +++--- .../Transforms/MaterializeCopyOnWrite.cpp | 2 +- .../Dialect/Stream/Transforms/RefineUsage.cpp | 2 +- .../Stream/Transforms/ScheduleExecution.cpp | 4 +- .../compiler/ExternalInterfaces/BUILD.bazel | 1 + .../ExternalInterfaces/CMakeLists.txt | 1 + .../StreamExternalModels.cpp | 89 ++++++++++++++++--- 16 files changed, 178 insertions(+), 118 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp index 395673c87465..67bf38359f7f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp @@ -166,8 +166,9 @@ static void assignDefaultDeviceAffinity(mlir::ModuleOp moduleOp, } if (auto affinityOp = dyn_cast(op)) { - if (!affinityOp.getAffinity()) - affinityOp.setAffinity(affinityAttr); + if (!affinityOp.getAffinityAttr()) { + affinityOp.setAffinityAttr(affinityAttr); + } } else { if (!op.hasAttr(affinityName)) { op.setAttr(affinityName, affinityAttr); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp index 5ed2ff89357f..93fcd37c7506 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp @@ -58,9 +58,9 @@ LogicalResult Partition::verify(Location loc) { for (auto *op : ops) { if (auto affinityOp = dyn_cast(op)) { if (!IREE::Stream::AffinityAttr::areCompatible( - affinity, affinityOp.getAffinity())) { + affinity, affinityOp.getAffinityAttr())) { return op->emitError("op affinity ") - << affinityOp.getAffinity() + << affinityOp.getAffinityAttr() << " is not compatible with the partition affinity " << affinity; } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp index b86ff6102a9e..a4fff96c3016 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp @@ -54,8 +54,8 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config, DenseSet clonedOps; void insert(Operation *op) { if (auto affinityOp = dyn_cast(op)) { - affinity = affinity ? affinity.joinAND(affinityOp.getAffinity()) - : affinityOp.getAffinity(); + affinity = affinity ? affinity.joinAND(affinityOp.getAffinityAttr()) + : affinityOp.getAffinityAttr(); } ops.insert(op); } @@ -109,7 +109,7 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config, IREE::Stream::AffinityAttr affinityAttr; if (auto affinityOp = dyn_cast(op)) { - affinityAttr = affinityOp.getAffinity(); + affinityAttr = affinityOp.getAffinityAttr(); } LLVM_DEBUG({ diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp index 8e4d854208d8..bdc5aafc6197 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp @@ -25,13 +25,13 @@ namespace { // size of operands must be queried from the input resource. static Value buildResultSizeOf(Location loc, Value tensorValue, ValueRange dynamicDims, + IREE::Stream::AffinityAttr affinityAttr, ConversionPatternRewriter &rewriter) { // TODO(benvanik): see if we can stash this on the side to avoid expensive // materialization of a bunch of redundant IR. return rewriter.create( loc, rewriter.getIndexType(), TypeAttr::get(tensorValue.getType()), - dynamicDims, - IREE::Stream::AffinityAttr::lookup(tensorValue.getDefiningOp())); + dynamicDims, affinityAttr); } struct ConvertTensorConstantOp @@ -123,13 +123,14 @@ struct ConvertTensorCastLikeOp : public OpConversionPattern { auto unknownType = rewriter.getType(); auto source = consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); - auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), - op.getResultDims(), rewriter); + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto resultSize = + buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), + affinityAttr, rewriter); rewriter.replaceOpWithNewOp( op, unknownType, source.resource, op.getSource().getType(), op.getSourceDims(), source.resourceSize, op.getResult().getType(), - adaptor.getResultDims(), resultSize, - IREE::Stream::AffinityAttr::lookup(op)); + adaptor.getResultDims(), resultSize, affinityAttr); return success(); } }; @@ -141,10 +142,12 @@ struct ConvertTensorAllocaOp matchAndRewrite(IREE::Flow::TensorAllocaOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type unknownType = IREE::Stream::ResourceType::get(getContext()); - auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), - op.getResultDims(), rewriter); + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto resultSize = + buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), + affinityAttr, rewriter); rewriter.replaceOpWithNewOp( - op, unknownType, resultSize, IREE::Stream::AffinityAttr::lookup(op)); + op, unknownType, resultSize, affinityAttr); return success(); } }; @@ -156,11 +159,13 @@ struct ConvertTensorEmptyOp matchAndRewrite(IREE::Flow::TensorEmptyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type unknownType = IREE::Stream::ResourceType::get(getContext()); - auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), - op.getResultDims(), rewriter); + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto resultSize = + buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), + affinityAttr, rewriter); rewriter.replaceOpWithNewOp( op, unknownType, op.getResult().getType(), adaptor.getResultDims(), - resultSize, IREE::Stream::AffinityAttr::lookup(op)); + resultSize, affinityAttr); return success(); } }; @@ -172,12 +177,13 @@ struct ConvertTensorSplatOp matchAndRewrite(IREE::Flow::TensorSplatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto unknownType = rewriter.getType(); - auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), - op.getResultDims(), rewriter); + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto resultSize = + buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), + affinityAttr, rewriter); rewriter.replaceOpWithNewOp( op, unknownType, adaptor.getValue(), op.getResult().getType(), - adaptor.getResultDims(), resultSize, - IREE::Stream::AffinityAttr::lookup(op)); + adaptor.getResultDims(), resultSize, affinityAttr); return success(); } }; @@ -230,13 +236,15 @@ struct ConvertTensorSliceOp auto unknownType = rewriter.getType(); auto source = consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); - auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), - op.getResultDims(), rewriter); + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto resultSize = + buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), + affinityAttr, rewriter); rewriter.replaceOpWithNewOp( op, unknownType, source.resource, op.getSource().getType(), op.getSourceDims(), source.resourceSize, adaptor.getStartIndices(), adaptor.getLengths(), op.getResult().getType(), adaptor.getResultDims(), - resultSize, IREE::Stream::AffinityAttr::lookup(op)); + resultSize, affinityAttr); return success(); } }; @@ -676,6 +684,8 @@ struct ConvertDispatchOp : public OpConversionPattern { LogicalResult matchAndRewrite(IREE::Flow::DispatchOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + // Zero is going to be used for each operand to start. auto zeroOffset = rewriter.create(op.getLoc(), 0); @@ -723,7 +733,8 @@ struct ConvertDispatchOp : public OpConversionPattern { auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue( op.getLoc(), result.value(), rewriter); resultSizes.push_back(buildResultSizeOf(op.getLoc(), result.value(), - resultDynamicDims, rewriter)); + resultDynamicDims, affinityAttr, + rewriter)); resultTypes.push_back(unknownType); } } @@ -732,7 +743,7 @@ struct ConvertDispatchOp : public OpConversionPattern { op, resultTypes, adaptor.getWorkload(), adaptor.getEntryPointsAttr(), dispatchOperands, dispatchOperandSizes, dispatchOperandOffsets, dispatchOperandEnds, dispatchOperandLengths, resultSizes, - adaptor.getTiedOperandsAttr(), IREE::Stream::AffinityAttr::lookup(op)); + adaptor.getTiedOperandsAttr(), affinityAttr); newOp->setDialectAttrs(op->getDialectAttrs()); return success(); } @@ -778,6 +789,8 @@ struct ConvertCallOp : public OpConversionPattern { LogicalResult matchAndRewrite(IREE::Flow::CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + // Zero is going to be used for each operand to start. auto zeroOffset = rewriter.create(op.getLoc(), 0); @@ -825,7 +838,8 @@ struct ConvertCallOp : public OpConversionPattern { auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue( op.getLoc(), result.value(), rewriter); resultSizes.push_back(buildResultSizeOf(op.getLoc(), result.value(), - resultDynamicDims, rewriter)); + resultDynamicDims, affinityAttr, + rewriter)); resultTypes.push_back(unknownType); } } @@ -834,7 +848,7 @@ struct ConvertCallOp : public OpConversionPattern { op, resultTypes, adaptor.getCalleeAttr(), callOperands, callOperandSizes, callOperandOffsets, callOperandEnds, callOperandLengths, resultSizes, adaptor.getTiedOperandsAttr(), - IREE::Stream::AffinityAttr::lookup(op)); + affinityAttr); newOp->setDialectAttrs(op->getDialectAttrs()); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp index 35eb31ff20da..4323473fbb02 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp @@ -49,11 +49,7 @@ struct ConvertTensorImportOp } } - auto affinityAttr = - dyn_cast_if_present(op.getAffinityAttr()); - if (!affinityAttr) { - affinityAttr = IREE::Stream::AffinityAttr::lookup(op); - } + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); // Import (buffer view to stream resource). auto resultType = rewriter.getType( @@ -138,11 +134,7 @@ struct ConvertTensorExportOp return rewriter.notifyMatchFailure(op, "unsupported HAL cast conversion"); } - auto affinityAttr = - dyn_cast_if_present(op.getAffinityAttr()); - if (!affinityAttr) { - affinityAttr = IREE::Stream::AffinityAttr::lookup(op); - } + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); auto source = consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td index f34003daec32..f17b7538a7b4 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td @@ -145,9 +145,10 @@ def Stream_AffinityOp : OpInterface<"AffinityOpInterface"> { Returns the stream affinity for the op, indicating where it should run. }], /*retTy=*/"IREE::Stream::AffinityAttr", - /*methodName=*/"getAffinity", + /*methodName=*/"getAffinityAttr", /*args=*/(ins), - /*methodBody=*/[{ + /*methodBody=*/"", + /*defaultImplementation=*/[{ return dyn_cast_or_null($_self->getAttr("affinity")); }] >, @@ -156,9 +157,10 @@ def Stream_AffinityOp : OpInterface<"AffinityOpInterface"> { Sets the stream affinity for the op, indicating where it should run. }], /*retTy=*/"void", - /*methodName=*/"setAffinity", + /*methodName=*/"setAffinityAttr", /*args=*/(ins "IREE::Stream::AffinityAttr":$value), - /*methodBody=*/[{ + /*methodBody=*/"", + /*defaultImplementation=*/[{ if (value) $_self->setAttr("affinity", value); else $_self->removeAttr("affinity"); }] diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp index d9da2822588c..698c7b967e35 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp @@ -2019,6 +2019,17 @@ LogicalResult AsyncTransferOp::verify() { return success(); } +IREE::Stream::AffinityAttr AsyncTransferOp::getAffinityAttr() { + return getResultAffinityAttr(); +} + +void AsyncTransferOp::setAffinityAttr(IREE::Stream::AffinityAttr value) { + if (value) + setResultAffinityAttr(value); + else + removeResultAffinityAttr(); +} + void AsyncTransferOp::getAsyncAccessRanges( SmallVectorImpl &ranges) { ranges.push_back({ResourceAccessBitfield::Read, getSource(), Value{}, diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td index 99e793a00f63..dbe5207734a1 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td @@ -86,10 +86,7 @@ def OpGroupResourceOps : OpDocGroup { let opDocGroup = OpGroupResourceOps in { def Stream_ResourceAllocOp : Stream_Op<"resource.alloc", [ - DeclareOpInterfaceMethods, + Stream_AffinityOp, Util_SizeAwareOp, AlwaysSpeculatable, MemoryEffects<[MemAlloc]>, @@ -148,10 +145,7 @@ def Stream_ResourceAllocOp : Stream_Op<"resource.alloc", [ } def Stream_ResourceAllocaOp : Stream_Op<"resource.alloca", [ - DeclareOpInterfaceMethods, + Stream_AffinityOp, Stream_TimelineOp, Util_SizeAwareOp, AlwaysSpeculatable, @@ -209,10 +203,7 @@ def Stream_ResourceAllocaOp : Stream_Op<"resource.alloca", [ } def Stream_ResourceDeallocaOp : Stream_Op<"resource.dealloca", [ - DeclareOpInterfaceMethods, + Stream_AffinityOp, Stream_TimelineOp, Util_SizeAwareOp, MemoryEffects<[MemFree]>, @@ -645,10 +636,7 @@ let opDocGroup = OpGroupParameterOps in { def Stream_ParameterLoadOp : Stream_PureOp<"parameter.load", [ AttrSizedOperandSegments, AllTypesMatch<["results"]>, - DeclareOpInterfaceMethods, + Stream_AffinityOp, Stream_CmdPhaseOp, Stream_TimelineOp, Util_SizeAwareOp, @@ -702,10 +690,7 @@ def Stream_ParameterLoadOp : Stream_PureOp<"parameter.load", [ } def Stream_ParameterReadOp : Stream_Op<"parameter.read", [ - DeclareOpInterfaceMethods, + Stream_AffinityOp, Stream_CmdPhaseOp, Stream_TimelineOp, Util_SizeAwareOp, @@ -757,10 +742,7 @@ def Stream_ParameterReadOp : Stream_Op<"parameter.read", [ } def Stream_ParameterWriteOp : Stream_Op<"parameter.write", [ - DeclareOpInterfaceMethods, + Stream_AffinityOp, Stream_CmdPhaseOp, Stream_TimelineOp, Util_SizeAwareOp, @@ -813,10 +795,7 @@ def Stream_ParameterWriteOp : Stream_Op<"parameter.write", [ def Stream_ParameterGatherOp : Stream_Op<"parameter.gather", [ AttrSizedOperandSegments, - DeclareOpInterfaceMethods, + Stream_AffinityOp, Stream_CmdPhaseOp, Stream_TimelineOp, Util_SizeAwareOp, @@ -872,10 +851,7 @@ def Stream_ParameterGatherOp : Stream_Op<"parameter.gather", [ def Stream_ParameterScatterOp : Stream_Op<"parameter.scatter", [ AttrSizedOperandSegments, - DeclareOpInterfaceMethods, + Stream_AffinityOp, Stream_CmdPhaseOp, Stream_TimelineOp, Util_SizeAwareOp, @@ -982,10 +958,7 @@ def Stream_FileConstantOp : Stream_PureOp<"file.constant", [ } def Stream_FileReadOp : Stream_Op<"file.read", [ - DeclareOpInterfaceMethods, + Stream_AffinityOp, Stream_CmdPhaseOp, Stream_TimelineOp, Util_SizeAwareOp, @@ -1040,10 +1013,7 @@ def Stream_FileReadOp : Stream_Op<"file.read", [ } def Stream_FileWriteOp : Stream_Op<"file.write", [ - DeclareOpInterfaceMethods, + Stream_AffinityOp, Stream_CmdPhaseOp, Stream_TimelineOp, Util_SizeAwareOp, @@ -1783,10 +1753,7 @@ def OpGroupAsyncOps : OpDocGroup { let opDocGroup = OpGroupAsyncOps in { def Stream_AsyncAllocaOp : Stream_Op<"async.alloca", [ - DeclareOpInterfaceMethods, + Stream_AffinityOp, Stream_AsyncPhaseOp, DeclareOpInterfaceMethods, Stream_AsyncPhaseOp, Stream_StreamableOp, DeclareOpInterfaceMethods(op)) { - auto affinityAttr = affinityOp.getAffinity(); + auto affinityAttr = affinityOp.getAffinityAttr(); if (affinityAttr) { auto attr = affinityAttr.getResourceConfigAttr(); if (attr) @@ -339,7 +339,7 @@ AffinityAttr AffinityAttr::lookup(Operation *op) { auto attrId = StringAttr::get(op->getContext(), "stream.affinity"); while (op) { if (auto affinityOp = llvm::dyn_cast(op)) { - auto affinity = affinityOp.getAffinity(); + auto affinity = affinityOp.getAffinityAttr(); if (affinity) return affinity; } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp index aa5cb25e5499..b0b66ac08667 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp @@ -46,6 +46,7 @@ namespace { static Value buildTensorImportOp(Location loc, Value sourceTensor, Type targetType, SmallPtrSetImpl &consumingOps, + IREE::Stream::AffinityAttr affinityAttr, OpBuilder &builder) { // Gather dynamic dimensions from the input value. auto dynamicDims = @@ -56,8 +57,7 @@ static Value buildTensorImportOp(Location loc, Value sourceTensor, // a transfer operation that may need to reformat the tensor. auto encodingAttr = TypeAttr::get(sourceTensor.getType()); Value resultSize = builder.create( - loc, builder.getIndexType(), encodingAttr, dynamicDims, - /*affinity=*/nullptr); + loc, builder.getIndexType(), encodingAttr, dynamicDims, affinityAttr); // Associate the external SSA value, encoding, and shape information with the // stream resource. When lowering we'll then have all the metadata required @@ -66,7 +66,7 @@ static Value buildTensorImportOp(Location loc, Value sourceTensor, IREE::Stream::Lifetime::External); auto importOp = builder.create( loc, externalType, sourceTensor, encodingAttr, dynamicDims, resultSize, - /*affinity=*/nullptr); + affinityAttr); consumingOps.insert(importOp); // If needed insert a transfer to the target lifetime. @@ -75,8 +75,8 @@ static Value buildTensorImportOp(Location loc, Value sourceTensor, result = builder .create( loc, targetType, result, resultSize, resultSize, - /*source_affinity=*/nullptr, - /*result_affinity=*/nullptr) + /*source_affinity=*/affinityAttr, + /*result_affinity=*/affinityAttr) .getResult(); } @@ -90,6 +90,7 @@ static Value buildTensorImportOp(Location loc, Value sourceTensor, // external tensor value. static Value buildTensorExportOp(Location loc, Value sourceValue, TensorType targetType, ValueRange dynamicDims, + IREE::Stream::AffinityAttr affinityAttr, OpBuilder &builder) { auto source = consumeTensorOperand(loc, sourceValue, builder); @@ -101,14 +102,13 @@ static Value buildTensorExportOp(Location loc, Value sourceValue, loc, externalType, source.resource, source.resourceSize, source.resourceSize, /*source_affinity=*/nullptr, - /*result_affinity=*/nullptr); + /*result_affinity=*/affinityAttr); } // Associate the stream resource and external encoding and shape information. auto newOp = builder.create( loc, targetType, source.resource, TypeAttr::get(targetType), dynamicDims, - source.resourceSize, - /*affinity=*/nullptr); + source.resourceSize, affinityAttr); return newOp.getResult(); } @@ -141,6 +141,8 @@ struct GenericResourcePattern : public ConversionPattern { if (!doesOperationNeedWrapping(op)) return failure(); + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + // Export resources into tensor operands for the op to consume. SmallVector newOperands; newOperands.reserve(op->getNumOperands()); @@ -156,8 +158,9 @@ struct GenericResourcePattern : public ConversionPattern { auto dynamicDims = IREE::Util::buildDynamicDimsForValue( op->getLoc(), oldOperand, rewriter); - newOperands.push_back(buildTensorExportOp( - op->getLoc(), newOperand, tensorType, dynamicDims, rewriter)); + newOperands.push_back(buildTensorExportOp(op->getLoc(), newOperand, + tensorType, dynamicDims, + affinityAttr, rewriter)); } rewriter.modifyOpInPlace(op, [&]() { op->setOperands(newOperands); }); @@ -173,7 +176,7 @@ struct GenericResourcePattern : public ConversionPattern { SmallPtrSet consumingOps; auto importedValue = buildTensorImportOp( op->getLoc(), result, rewriter.getType(), - consumingOps, rewriter); + consumingOps, affinityAttr, rewriter); result.replaceAllUsesExcept(importedValue, consumingOps); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp index 34c0ef8addfb..5b7d3a9a5e48 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp @@ -103,7 +103,7 @@ static bool materializeTiedOpCOW(IREE::Util::TiedOpInterface tiedOp) { IREE::Stream::AffinityAttr affinity; if (auto affinityOp = dyn_cast(tiedOp.getOperation())) { - affinity = affinityOp.getAffinity(); + affinity = affinityOp.getAffinityAttr(); } // Clones each operand that is tied to a result and it may be required. diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp index bc73616047ae..02c2bb05adc9 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp @@ -65,7 +65,7 @@ static Lifetime convertUsageToLifetime(ResourceUsageBitfield usage) { // Returns either the affinity of |op| or nullptr. static IREE::Stream::AffinityAttr getOpAffinity(Operation *op) { if (auto affinityOp = dyn_cast(op)) { - return affinityOp.getAffinity(); + return affinityOp.getAffinityAttr(); } return {}; } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp index 9c5e3d4570d3..c850c3b3276a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp @@ -152,8 +152,8 @@ struct ExecutePartitionBuilder { // want to preserve those as long as possible. if (auto affinityOp = dyn_cast(clonedOp)) { - if (affinityOp.getAffinity() == partition->affinity) { - affinityOp.setAffinity(nullptr); + if (affinityOp.getAffinityAttr() == partition->affinity) { + affinityOp.setAffinityAttr(nullptr); } } diff --git a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel index 7bbd7f504d55..66643e6afba2 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel +++ b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel @@ -29,6 +29,7 @@ iree_compiler_cc_library( deps = [ "//compiler/src/iree/compiler/Dialect/Encoding/IR", "//compiler/src/iree/compiler/Dialect/Flow/IR", + "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "//compiler/src/iree/compiler/Dialect/Stream/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", diff --git a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt index 4e2f29a78097..a63fca33ec6b 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt +++ b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt @@ -34,6 +34,7 @@ iree_cc_library( MLIRValueBoundsOpInterface iree::compiler::Dialect::Encoding::IR iree::compiler::Dialect::Flow::IR + iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::Util::IR diff --git a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp index e3ba25755ed1..b82d59968290 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp @@ -6,6 +6,10 @@ #include "iree/compiler/ExternalInterfaces/StreamExternalModels.h" +#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" @@ -14,27 +18,47 @@ namespace mlir::iree_compiler { namespace { +struct FlowTransferTargetAffinityAttrExternalModel + : public IREE::Stream::AffinityOpInterface::ExternalModel< + FlowTransferTargetAffinityAttrExternalModel, + IREE::Flow::TensorTransferOp> { + static void add(MLIRContext *context) { + IREE::Flow::TensorTransferOp::attachInterface< + FlowTransferTargetAffinityAttrExternalModel>(*context); + } + + bool requiresAffinity(Operation *op) const { return true; } + + IREE::Stream::AffinityAttr getAffinity(Operation *op) const { + return op->getAttrOfType("target"); + } + + void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const { + op->setAttr("target", value); + } +}; + template -struct AffinityOpAttrExternalModel +struct HALTensorAffinityAttrExternalModel : public IREE::Stream::AffinityOpInterface::ExternalModel< - AffinityOpAttrExternalModel, OpT> { + HALTensorAffinityAttrExternalModel, OpT> { static void add(MLIRContext *context) { - OpT::template attachInterface>(*context); + OpT::template attachInterface>( + *context); } - // Most structural ops don't require affinities and after placement we don't - // use the affinities even if the ops still exist. bool requiresAffinity(Operation *op) const { return false; } IREE::Stream::AffinityAttr getAffinity(Operation *op) const { - return op->getAttrOfType("stream.affinity"); + return op->getAttrOfType("affinity"); } void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const { if (value) - op->setAttr("stream.affinity", value); - else - op->removeAttr("stream.affinity"); + op->setAttr("affinity", value); + } else { + op->removeAttr("affinity"); + } } }; @@ -61,17 +85,58 @@ struct GlobalOpAffinityAttrExternalModel void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const { if (value) op->setAttr("stream.affinity", value); - else + } else { op->removeAttr("stream.affinity"); + } + } +}; + +template +struct AffinityOpAttrExternalModel + : public IREE::Stream::AffinityOpInterface::ExternalModel< + AffinityOpAttrExternalModel, OpT> { + static void add(MLIRContext *context) { + OpT::template attachInterface< + AffinityOpAttrExternalModel>(*context); + } + + // Most structural ops don't require affinities and after placement we don't + // use the affinities even if the ops still exist. + bool requiresAffinity(Operation *op) const { return false; } + + IREE::Stream::AffinityAttr getAffinity(Operation *op) const { + return op->getAttrOfType("stream.affinity"); + } + + void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const { + if (value) + op->setAttr("stream.affinity", value); + } else { + op->removeAttr("stream.affinity"); + } } }; } // namespace void registerStreamExternalModels(DialectRegistry ®istry) { - // Must ensure that any dependent dialects are registered. - registry.insert(); + registry.insert(); + registry.addExtension( + +[](MLIRContext *context, IREE::Flow::FlowDialect *dialect) { + FlowTransferTargetAffinityAttrExternalModel::add(context); + }); + registry.insert(); + registry.addExtension(+[](MLIRContext *context, + IREE::HAL::HALDialect *dialect) { + HALTensorAffinityAttrExternalModel::add(context); + HALTensorAffinityAttrExternalModel::add(context); + HALTensorAffinityAttrExternalModel::add(context); + HALTensorAffinityAttrExternalModel::add( + context); + }); + + registry.insert(); registry.addExtension( +[](MLIRContext *context, IREE::Util::UtilDialect *dialect) { GlobalOpAffinityAttrExternalModel::add(context); From 601ebbfa720571aead41d5acc135f3c1f2b8f630 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 15 Jul 2024 18:22:45 -0700 Subject: [PATCH 21/25] Adding affinity analysis. This performs whole-program analysis to enable the querying of the ideal affinity for globals, execution ops, and resources. It can run at most phases of compilation (including on linalg/flow IR) though it's primarily used by the stream dialect passes such as conversion. The `AnnotateAffinitiesPass` has been added to aid debugging and the compiler `iree-stream-annotate-input-affinities` flag can be used to turn it on - it has no impact on the program generated but can be useful if affinity analysis fails during conversion. --- .../Dialect/Flow/IR/FlowInterfaces.td | 4 - .../Dialect/Stream/Analysis/Affinity.cpp | 1050 +++++++++++ .../Dialect/Stream/Analysis/Affinity.h | 102 ++ .../Dialect/Stream/Analysis/BUILD.bazel | 2 + .../Dialect/Stream/Analysis/CMakeLists.txt | 2 + .../Dialect/Stream/Analysis/ResourceUsage.cpp | 34 +- .../compiler/Dialect/Stream/IR/StreamBase.td | 1 + .../Dialect/Stream/IR/StreamInterfaces.td | 49 +- .../Dialect/Stream/IR/StreamOpFolders.cpp | 6 +- .../compiler/Dialect/Stream/IR/StreamOps.cpp | 53 +- .../compiler/Dialect/Stream/IR/StreamOps.td | 2 +- .../Dialect/Stream/IR/StreamTypes.cpp | 40 +- .../compiler/Dialect/Stream/IR/StreamTypes.h | 6 +- .../Stream/Transforms/AnnotateAffinities.cpp | 127 ++ .../Dialect/Stream/Transforms/BUILD.bazel | 1 + .../Dialect/Stream/Transforms/CMakeLists.txt | 1 + .../Dialect/Stream/Transforms/Passes.cpp | 23 +- .../Dialect/Stream/Transforms/Passes.td | 5 + .../Stream/Transforms/ScheduleAllocation.cpp | 3 +- .../Stream/Transforms/VerifyAffinities.cpp | 7 +- .../Stream/Transforms/test/BUILD.bazel | 1 + .../Stream/Transforms/test/CMakeLists.txt | 1 + .../Transforms/test/annotate_affinities.mlir | 1549 +++++++++++++++++ .../Dialect/Util/Analysis/Explorer.cpp | 64 +- .../compiler/Dialect/Util/Analysis/Explorer.h | 37 +- .../StreamExternalModels.cpp | 118 +- .../compiler/Preprocessing/Common/BUILD.bazel | 1 + .../Preprocessing/Common/CMakeLists.txt | 1 + .../Preprocessing/Common/PadToIntrinsics.cpp | 71 +- .../compiler/Preprocessing/Common/Passes.td | 4 +- .../Common/test/pad_to_intrinsics_mfma.mlir | 6 +- .../Common/test/pad_to_intrinsics_wmma.mlir | 6 +- 32 files changed, 3240 insertions(+), 137 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp create mode 100644 compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.h create mode 100644 compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp create mode 100644 compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_affinities.mlir diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td index 5a1227eea519..dcb0b0fd5286 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td @@ -9,8 +9,4 @@ include "iree/compiler/Dialect/Util/IR/UtilBase.td" -//===----------------------------------------------------------------------===// -// IREE::Flow::StreamableOpInterface -//===----------------------------------------------------------------------===// - #endif // IREE_DIALECT_FLOW_INTERFACES diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp new file mode 100644 index 000000000000..ac3c1660e2d3 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp @@ -0,0 +1,1050 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h" + +#include + +#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "iree/compiler/Dialect/Util/Analysis/DFX/Element.h" +#include "iree/compiler/Dialect/Util/Analysis/DFX/State.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" + +#define DEBUG_TYPE "iree-util-dfx" + +namespace mlir::iree_compiler::IREE::Stream { + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +static const std::string getAffinitySetAsStr( + const DFX::PotentialValuesState &state, + AsmState &asmState) { + std::string str; + llvm::raw_string_ostream sstream(str); + sstream << "pvs: "; + if (state.isValidState()) { + sstream << "["; + if (state.isUndefContained()) { + sstream << "undef, "; + } + llvm::interleaveComma(state.getAssumedSet(), sstream, + [&](IREE::Stream::AffinityAttr value) { + cast(value).print(sstream); + }); + sstream << "]"; + } else { + sstream << "(invalid)"; + } + sstream.flush(); + return str; +} + +//===----------------------------------------------------------------------===// +// Analysis elements +//===----------------------------------------------------------------------===// + +class ValueProducerAffinityPVS + : public DFX::StateWrapper< + DFX::PotentialValuesState, + DFX::ValueElement> { +public: + using BaseType = + DFX::StateWrapper, + DFX::ValueElement>; + using BaseType::BaseType; + + static ValueProducerAffinityPVS &createForPosition(const Position &pos, + DFX::Solver &solver) { + return *(new (solver.getAllocator()) ValueProducerAffinityPVS(pos)); + } + + // Identity definitions. + const std::string getName() const override { + return "ValueProducerAffinityPVS"; + } + const void *getID() const override { return &ID; } + static bool classof(const DFX::AbstractElement *element) { + return (element->getID() == &ID); + } + static const char ID; + + const std::string getAsStr(AsmState &asmState) const override { + return getAffinitySetAsStr(getState(), asmState); + } + +private: + void initializeValue(Value value, DFX::Solver &solver) override; + ChangeStatus updateValue(Value value, DFX::Solver &solver) override; + void updateFromUse(Value value, OpOperand &operand, StateType &newState, + DFX::Solver &solver); + + // Operations that the value is pinned to. + SetVector pinnedOps; +}; +const char ValueProducerAffinityPVS::ID = 0; + +class GlobalAffinityPVS + : public DFX::StateWrapper< + DFX::PotentialValuesState, + DFX::TypedOperationElement> { +public: + using BaseType = DFX::StateWrapper< + DFX::PotentialValuesState, + DFX::TypedOperationElement>; + using BaseType::BaseType; + + static GlobalAffinityPVS &createForPosition(const Position &pos, + DFX::Solver &solver) { + return *(new (solver.getAllocator()) GlobalAffinityPVS(pos)); + } + + // Identity definitions. + const std::string getName() const override { return "GlobalAffinityPVS"; } + const void *getID() const override { return &ID; } + static bool classof(const DFX::AbstractElement *element) { + return (element->getID() == &ID); + } + static const char ID; + + const std::string getAsStr(AsmState &asmState) const override { + return getAffinitySetAsStr(getState(), asmState); + } + +private: + void initializeOperation(IREE::Util::GlobalOpInterface globalOp, + DFX::Solver &solver) override; + ChangeStatus updateOperation(IREE::Util::GlobalOpInterface globalOp, + DFX::Solver &solver) override; +}; +const char GlobalAffinityPVS::ID = 0; + +class OpAffinityPVS : public DFX::StateWrapper< + DFX::PotentialValuesState, + DFX::OperationElement> { +public: + using BaseType = + DFX::StateWrapper, + DFX::OperationElement>; + using BaseType::BaseType; + + static OpAffinityPVS &createForPosition(const Position &pos, + DFX::Solver &solver) { + return *(new (solver.getAllocator()) OpAffinityPVS(pos)); + } + + // Identity definitions. + const std::string getName() const override { return "OpAffinityPVS"; } + const void *getID() const override { return &ID; } + static bool classof(const DFX::AbstractElement *element) { + return (element->getID() == &ID); + } + static const char ID; + + const std::string getAsStr(AsmState &asmState) const override { + return getAffinitySetAsStr(getState(), asmState); + } + +private: + void initializeOperation(Operation *op, DFX::Solver &solver) override; + ChangeStatus updateOperation(Operation *op, DFX::Solver &solver) override; +}; +const char OpAffinityPVS::ID = 0; + +//===----------------------------------------------------------------------===// +// ValueConsumerAffinityPVS +//===----------------------------------------------------------------------===// + +class ValueConsumerAffinityPVS + : public DFX::StateWrapper< + DFX::PotentialValuesState, + DFX::ValueElement> { +public: + using BaseType = + DFX::StateWrapper, + DFX::ValueElement>; + using BaseType::BaseType; + + static ValueConsumerAffinityPVS &createForPosition(const Position &pos, + DFX::Solver &solver) { + return *(new (solver.getAllocator()) ValueConsumerAffinityPVS(pos)); + } + + // Identity definitions. + const std::string getName() const override { + return "ValueConsumerAffinityPVS"; + } + const void *getID() const override { return &ID; } + static bool classof(const DFX::AbstractElement *element) { + return (element->getID() == &ID); + } + static const char ID; + + const std::string getAsStr(AsmState &asmState) const override { + return getAffinitySetAsStr(getState(), asmState); + } + +private: + void initializeValue(Value value, DFX::Solver &solver) override; + ChangeStatus updateValue(Value value, DFX::Solver &solver) override; + TraversalResult updateFromUse(Value value, OpOperand &operand, + StateType &newState, DFX::Solver &solver); +}; +const char ValueConsumerAffinityPVS::ID = 0; + +void ValueConsumerAffinityPVS::initializeValue(Value value, + DFX::Solver &solver) {} + +ChangeStatus ValueConsumerAffinityPVS::updateValue(Value value, + DFX::Solver &solver) { + StateType newState; + auto traversalResult = TraversalResult::COMPLETE; + + // Walk into all consumers of the SSA value. + // Note that we may end up at multiple global stores of different globals + // by walking down through calls/branches/etc. + traversalResult |= solver.getExplorer().walkTransitiveUses( + value, + [&](OpOperand &operand) { + traversalResult |= updateFromUse(value, operand, newState, solver); + return WalkResult::advance(); + }, + (TraversalBehavior::DEFAULT | TraversalBehavior::DONT_WALK_TIED_VALUES)); + + if (traversalResult == TraversalResult::INCOMPLETE) { + // Incomplete traversal because of external call graph edges or pointers. + newState.unionAssumedWithUndef(); + newState.indicatePessimisticFixpoint(); + } + return DFX::clampStateAndIndicateChange(getState(), newState); +} + +TraversalResult ValueConsumerAffinityPVS::updateFromUse(Value value, + OpOperand &operand, + StateType &newState, + DFX::Solver &solver) { + // If the value is consumed by an affinity-aware op then we can directly use + // the affinity specified on the op. A majority of the values we care about at + // the stream level are consumed by affinity-aware ops and earlier in the + // pipeline dialects may have transfer ops that define affinities we can + // anchor on. + if (auto affinityOp = + dyn_cast(operand.getOwner())) { + auto opPVS = solver.getElementFor( + *this, Position::forOperation(operand.getOwner()), + DFX::Resolution::REQUIRED); + LLVM_DEBUG({ + llvm::dbgs() << "[ValueConsumerAffinityPVS] value "; + value.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " affinity using consumer affinity from "; + operand.get().printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " as "; + opPVS.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState ^= opPVS; + } + + // If the consumer op has the operand tied to one or more results then we walk + // through to track the transitive consumers. When this analysis runs we are + // usually still prior to baking out copy-on-write behavior so it's possible + // that the results of the tied operation end up in different places. + if (auto tiedOp = dyn_cast(operand.getOwner())) { + auto tiedResults = tiedOp.getOperandTiedResults(operand.getOperandNumber()); + for (auto tiedResult : tiedResults) { + auto resultPVS = solver.getElementFor( + *this, Position::forValue(tiedResult), DFX::Resolution::REQUIRED); + LLVM_DEBUG({ + llvm::dbgs() << "[ValueConsumerAffinityPVS] value "; + value.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " affinity referencing tied operand "; + operand.get().printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " result "; + tiedResult.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " as "; + resultPVS.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState ^= resultPVS; + } + } + + // Handle consumers that are not affinity aware - this should have any control + // flow ops so that we can track values that flow through the program. + return TypeSwitch(operand.getOwner()) + .Case([&](mlir::arith::SelectOp op) { + auto &resultPVS = solver.getElementFor( + *this, Position::forValue(op.getResult()), + DFX::Resolution::REQUIRED); + newState ^= resultPVS.getState(); + return TraversalResult::COMPLETE; + }) + .Case([&](mlir::BranchOpInterface op) { + return solver.getExplorer().walkOutgoingBranchOperandArguments( + op, operand.getOperandNumber(), + [&](Block *targetBlock, BlockArgument arg) { + auto &argUsage = solver.getElementFor( + *this, Position::forValue(arg), DFX::Resolution::OPTIONAL); + newState ^= argUsage; + return WalkResult::advance(); + }); + }) + .Case([&](mlir::scf::ForOp op) { + if (operand.getOperandNumber() >= op.getNumControlOperands()) { + int64_t blockIdx = + operand.getOperandNumber() - op.getNumControlOperands(); + auto &beforeUsage = solver.getElementFor( + *this, Position::forValue(op.getRegionIterArg(blockIdx)), + DFX::Resolution::REQUIRED); + newState ^= beforeUsage.getState(); + } + return TraversalResult::COMPLETE; + }) + .Case([&](mlir::scf::WhileOp op) { + auto &beforeUsage = solver.getElementFor( + *this, + Position::forValue( + op.getBeforeBody()->getArgument(operand.getOperandNumber())), + DFX::Resolution::REQUIRED); + newState ^= beforeUsage.getState(); + return TraversalResult::COMPLETE; + }) + .Case([&](mlir::scf::ConditionOp op) { + auto &parentUsage = solver.getElementFor( + *this, + Position::forValue( + op->getParentOp()->getResult(operand.getOperandNumber() - 1)), + DFX::Resolution::REQUIRED); + newState ^= parentUsage.getState(); + if (auto whileOp = + dyn_cast_or_null(op->getParentOp())) { + auto value = Position::forValue( + whileOp.getAfter().getArgument(operand.getOperandNumber() - 1)); + auto &valueUsage = solver.getElementFor( + *this, value, DFX::Resolution::REQUIRED); + newState ^= valueUsage.getState(); + } + return TraversalResult::COMPLETE; + }) + .Case([&](mlir::scf::YieldOp op) { + if (isa(op->getParentOp())) { + auto &operandUsage = solver.getElementFor( + *this, + Position::forValue(op->getOperand(operand.getOperandNumber())), + DFX::Resolution::REQUIRED); + newState ^= operandUsage.getState(); + auto &parentUsage = solver.getElementFor( + *this, + Position::forValue( + op->getParentOp()->getResult(operand.getOperandNumber())), + DFX::Resolution::REQUIRED); + newState ^= parentUsage.getState(); + return TraversalResult::COMPLETE; + } else if (auto whileOp = + dyn_cast(op->getParentOp())) { + auto value = Position::forValue( + whileOp.getBefore().getArgument(operand.getOperandNumber())); + auto &valueUsage = solver.getElementFor( + *this, value, DFX::Resolution::REQUIRED); + newState ^= valueUsage.getState(); + auto &parentUsage = solver.getElementFor( + *this, + Position::forValue( + whileOp->getResult(operand.getOperandNumber())), + DFX::Resolution::REQUIRED); + newState ^= parentUsage.getState(); + return TraversalResult::COMPLETE; + } else if (auto forOp = dyn_cast(op->getParentOp())) { + auto value = Position::forValue( + forOp.getRegionIterArg(operand.getOperandNumber())); + auto &valueUsage = solver.getElementFor( + *this, value, DFX::Resolution::REQUIRED); + newState ^= valueUsage.getState(); + auto &parentUsage = solver.getElementFor( + *this, + Position::forValue(forOp->getResult(operand.getOperandNumber())), + DFX::Resolution::REQUIRED); + newState ^= parentUsage.getState(); + return TraversalResult::COMPLETE; + } else { + assert(false && "unhandled scf yield parent"); + return TraversalResult::INCOMPLETE; + } + }) + .Case([&](IREE::Util::ReturnOp op) { + return solver.getExplorer().walkIncomingCalls( + op->getParentOfType(), + [&](mlir::CallOpInterface callOp) { + auto &argUsage = solver.getElementFor( + *this, + Position::forValue( + callOp->getResult(operand.getOperandNumber())), + DFX::Resolution::OPTIONAL); + getState() ^= argUsage; + return WalkResult::advance(); + }); + }) + .Case([&](IREE::Util::OptimizationBarrierOp op) { + auto &resultPVS = solver.getElementFor( + *this, Position::forValue(op.getResult(operand.getOperandNumber())), + DFX::Resolution::REQUIRED); + newState ^= resultPVS.getState(); + return TraversalResult::COMPLETE; + }) + .Case([&](IREE::Util::GlobalStoreOpInterface op) { + auto *globalInfo = + solver.getExplorer().queryGlobalInfoFrom(op.getGlobalName(), op); + auto &globalPVS = solver.getElementFor( + *this, Position::forOperation(globalInfo->op), + DFX::Resolution::REQUIRED); + newState ^= globalPVS.getState(); + return TraversalResult::COMPLETE; + }) + .Default([&](Operation *op) { return TraversalResult::COMPLETE; }); +} + +//===----------------------------------------------------------------------===// +// ValueProducerAffinityPVS +//===----------------------------------------------------------------------===// + +void ValueProducerAffinityPVS::initializeValue(Value value, + DFX::Solver &solver) { + solver.getExplorer().walkDefiningOps(value, [&](OpResult result) { + if (!isa(result.getType())) { + return WalkResult::skip(); + } + if (auto affinityOp = + dyn_cast_if_present( + result.getOwner())) { + if (affinityOp.pinsValueAffinity()) { + pinnedOps.insert(result.getOwner()); + } + } + return WalkResult::advance(); + }); + solver.getExplorer().walkTransitiveUses(value, [&](OpOperand &operand) { + if (!isa(operand.get().getType())) { + return WalkResult::skip(); + } + if (auto affinityOp = + dyn_cast_if_present( + operand.getOwner())) { + if (affinityOp.pinsValueAffinity()) { + pinnedOps.insert(operand.getOwner()); + } + } + return WalkResult::advance(); + }); +} + +ChangeStatus ValueProducerAffinityPVS::updateValue(Value value, + DFX::Solver &solver) { + StateType newState; + + // If there are any ops that produce the value and pin to a specific affinity + // then we take those directly and ignore all others. + if (!pinnedOps.empty()) { + for (auto pinnedOp : pinnedOps) { + auto &opPVS = solver.getElementFor( + *this, Position::forOperation(pinnedOp), DFX::Resolution::REQUIRED); + newState ^= opPVS; + } + return DFX::clampStateAndIndicateChange(getState(), newState); + } + + // We special case some ops that act as barriers in the program. This prevents + // us from walking past boundaries that are not profitable to do so with; for + // example, globals are usually stored in independent contexts from where they + // are consumed. + if (auto barrierOp = dyn_cast_if_present( + value.getDefiningOp())) { + auto operand = + barrierOp.getOperand(cast(value).getResultNumber()); + auto operandPVS = solver.getElementFor( + *this, Position::forValue(operand), DFX::Resolution::REQUIRED); + LLVM_DEBUG({ + llvm::dbgs() << "[ValueProducerAffinityPVS] value "; + value.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " affinity using barrier op operand as "; + operandPVS.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState ^= operandPVS; + return DFX::clampStateAndIndicateChange(getState(), newState); + } else if (auto loadOp = + dyn_cast_if_present( + value.getDefiningOp())) { + auto *globalInfo = solver.getExplorer().queryGlobalInfoFrom( + loadOp.getGlobalName(), loadOp); + auto &globalPVS = solver.getElementFor( + *this, Position::forOperation(globalInfo->op), + DFX::Resolution::REQUIRED); + LLVM_DEBUG({ + llvm::dbgs() << "[ValueProducerAffinityPVS] value "; + value.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " affinity using global op affinity from " + << loadOp.getGlobalName() << " as "; + globalPVS.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState ^= globalPVS.getState(); + return DFX::clampStateAndIndicateChange(getState(), newState); + } + + // Walk the program up into any possible producers of the value. + auto traversalResult = TraversalResult::COMPLETE; + traversalResult |= solver.getExplorer().walkDefiningOps( + value, + [&](OpResult result) { + if (isa(result.getOwner())) { + return WalkResult::advance(); + } + + // If coming from an affinity-aware op that pins the value storage to a + // particular affinity that overrides all other logic. + if (auto affinityOp = + dyn_cast_if_present( + result.getDefiningOp())) { + if (affinityOp.pinsValueAffinity()) { + auto &opPVS = solver.getElementFor( + *this, Position::forOperation(affinityOp), + DFX::Resolution::REQUIRED); + LLVM_DEBUG({ + llvm::dbgs() << "[ValueProducerAffinityPVS] value "; + value.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " affinity using assuming pinned affinity from "; + result.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " as "; + opPVS.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState ^= opPVS; + newState.indicateOptimisticFixpoint(); + return WalkResult::advance(); + } + } + + // If the result value is tied to an operand of the defining op then + // inherit the operand affinity. + if (auto tiedOp = dyn_cast_if_present( + result.getDefiningOp())) { + auto operand = tiedOp.getTiedResultOperand(result); + if (operand) { + auto &valuePVS = solver.getElementFor( + *this, Position::forValue(operand), DFX::Resolution::OPTIONAL); + LLVM_DEBUG({ + llvm::dbgs() << "[ValueProducerAffinityPVS] value "; + value.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " affinity referencing tied operand "; + operand.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " as "; + valuePVS.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState ^= valuePVS; + return WalkResult::advance(); + } + } + + // If the value is produced by the defining op then assume that the + // execution affinity dictates the result affinity. + if (auto affinityOp = + dyn_cast_if_present( + result.getDefiningOp())) { + auto &opPVS = solver.getElementFor( + *this, Position::forOperation(result.getOwner()), + DFX::Resolution::OPTIONAL); + LLVM_DEBUG({ + llvm::dbgs() << "[ValueProducerAffinityPVS] value "; + value.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " affinity using op affinity from result "; + result.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " as "; + opPVS.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState ^= opPVS; + return WalkResult::advance(); + } + + // Special handling for specific ops. + TypeSwitch(result.getOwner()) + .Case([&](auto loadOp) { + auto *globalInfo = solver.getExplorer().queryGlobalInfoFrom( + loadOp.getGlobalName(), loadOp); + auto &globalPVS = solver.getElementFor( + *this, Position::forOperation(globalInfo->op), + DFX::Resolution::REQUIRED); + LLVM_DEBUG({ + llvm::dbgs() << "[ValueProducerAffinityPVS] value "; + value.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() + << " affinity using global op affinity from result "; + result.printAsOperand(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << " as "; + globalPVS.print(llvm::dbgs(), solver.getAsmState()); + llvm::dbgs() << "\n"; + }); + newState ^= globalPVS.getState(); + }) + .Case([&](auto op) { + auto &truePVS = solver.getElementFor( + *this, Position::forValue(op.getTrueValue()), + DFX::Resolution::REQUIRED); + newState ^= truePVS.getState(); + auto &falsePVS = solver.getElementFor( + *this, Position::forValue(op.getFalseValue()), + DFX::Resolution::REQUIRED); + newState ^= falsePVS.getState(); + }) + .Default([&](auto op) { + auto valuePVS = solver.getElementFor( + *this, Position::forValue(result), DFX::Resolution::OPTIONAL); + newState ^= valuePVS; + }); + return WalkResult::advance(); + }, + (TraversalBehavior::DEFAULT | TraversalBehavior::DONT_WALK_TIED_VALUES)); + + if (traversalResult == TraversalResult::INCOMPLETE) { + // Incomplete traversal because of external call graph edges or pointers. + newState.unionAssumedWithUndef(); + newState.indicatePessimisticFixpoint(); + } + return DFX::clampStateAndIndicateChange(getState(), newState); +} + +//===----------------------------------------------------------------------===// +// GlobalAffinityPVS +//===----------------------------------------------------------------------===// + +void GlobalAffinityPVS::initializeOperation( + IREE::Util::GlobalOpInterface globalOp, DFX::Solver &solver) { + // If an affinity is explicitly specified we take that over all analysis. + if (auto affinityAttr = IREE::Stream::AffinityAttr::lookup(globalOp)) { + LLVM_DEBUG({ + llvm::dbgs() << "[GlobalAffinityPVS] global @" + << globalOp.getGlobalName().getValue() + << " affinity explicitly specified as "; + affinityAttr.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + unionAssumed(affinityAttr); + indicateOptimisticFixpoint(); + return; + } +} + +ChangeStatus +GlobalAffinityPVS::updateOperation(IREE::Util::GlobalOpInterface globalOp, + DFX::Solver &solver) { + StateType newState; + auto traversalResult = TraversalResult::COMPLETE; + + const auto *globalInfo = solver.getExplorer().getGlobalInfo(globalOp); + if (globalInfo->isIndirect) { + traversalResult = TraversalResult::INCOMPLETE; + } + + // Traverse all transitive uses of the global. + // We try to place globals where they are used as the common case is weights + // or parameters that are read more frequently than they are written. + // The reasoning is that if there are more writes than reads there's unneeded + // work being done and otherwise there's always at least one read per write + // or more reads than writes. + bool anyLoads = false; + for (auto loadOp : globalInfo->getLoads()) { + anyLoads = true; + auto &valuePVS = solver.getElementFor( + *this, Position::forValue(loadOp.getLoadedGlobalValue()), + DFX::Resolution::OPTIONAL); + if (valuePVS.isValidState()) { + newState ^= valuePVS; + } + } + + // If there were no loads then take the affinity from stores. + // This is not common but can arise in tests or where the globals may be used + // to model side-effecting behavior. + if (!anyLoads) { + for (auto storeOp : globalInfo->getStores()) { + auto &valuePVS = solver.getElementFor( + *this, Position::forValue(storeOp.getStoredGlobalValue()), + DFX::Resolution::OPTIONAL); + if (valuePVS.isValidState()) { + newState ^= valuePVS; + } + } + } + + if (traversalResult == TraversalResult::INCOMPLETE) { + // Incomplete traversal because of external call graph edges or pointers. + newState.unionAssumedWithUndef(); + newState.indicatePessimisticFixpoint(); + } + return DFX::clampStateAndIndicateChange(getState(), newState); +} + +//===----------------------------------------------------------------------===// +// OpAffinityPVS +//===----------------------------------------------------------------------===// + +void OpAffinityPVS::initializeOperation(Operation *op, DFX::Solver &solver) { + // If an affinity is explicitly specified we take that over all analysis. + if (auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op)) { + LLVM_DEBUG({ + llvm::dbgs() << "[OpAffinityPVS] op "; + op->getName().print(llvm::dbgs()); + llvm::dbgs() << " affinity explicitly specified as "; + affinityAttr.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + unionAssumed(affinityAttr); + indicateOptimisticFixpoint(); + return; + } +} + +ChangeStatus OpAffinityPVS::updateOperation(Operation *op, + DFX::Solver &solver) { + StateType newState; + + const bool consumesAny = llvm::any_of( + op->getOperandTypes(), +[](Type type) { + return isa(type); + }); + if (consumesAny) { + for (auto operand : op->getOperands()) { + if (isa(operand.getType())) { + auto valuePVS = solver.getElementFor( + *this, Position::forValue(operand), DFX::Resolution::REQUIRED); + newState ^= valuePVS; + } + } + } else { + for (auto result : op->getResults()) { + if (isa(result.getType())) { + auto valuePVS = solver.getElementFor( + *this, Position::forValue(result), DFX::Resolution::REQUIRED); + newState ^= valuePVS; + } + } + } + + return DFX::clampStateAndIndicateChange(getState(), newState); +} + +//===----------------------------------------------------------------------===// +// AffinityAnalysis +//===----------------------------------------------------------------------===// + +// Tries to find a default affinity specified on an ancestor of |fromOp| and +// adds it to |affinities|. Returns true if an affinity was found. +static bool tryLookupDefaultAffinity( + Operation *fromOp, + SmallVectorImpl &affinities) { + while (fromOp) { + auto affinityAttr = fromOp->getAttrOfType( + "stream.affinity.default"); + if (affinityAttr) { + affinities.push_back(affinityAttr); + return true; + } + fromOp = fromOp->getParentOp(); + } + return false; +} + +// Returns the first affinity if all affinities are compatible and otherwise +// returns nullptr. +static IREE::Stream::AffinityAttr +trySelectLeadAffinity(ArrayRef affinities) { + if (affinities.empty()) { + return {}; + } + auto leadAffinityAttr = affinities.front(); + for (size_t i = 1; i < affinities.size(); ++i) { + if (!IREE::Stream::AffinityAttr::areCompatible(affinities[i], + leadAffinityAttr)) { + return {}; + } + } + return leadAffinityAttr; +} + +// Sorts |affinities| in the natural affinity sort order. +// We unfortunately have to do this as the PVS elements we source from are +// unsorted. +static void +sortAffinities(SmallVectorImpl &affinities) { + // HACK: this should probably do a type id ordering followed by a + // type-specific ordering (interface compare method?). We just need this to be + // stable as the affinities come from multiple DenseSets that have run-to-run + // ordering variance. This is very inefficient but is only used when there are + // multiple possible affinities and we try to avoid that anyway. + if (affinities.size() <= 1) { + return; + } + llvm::stable_sort(affinities, [](IREE::Stream::AffinityAttr lhs, + IREE::Stream::AffinityAttr rhs) { + std::string lhsStr; + llvm::raw_string_ostream lhsStream(lhsStr); + lhs.print(lhsStream); + std::string rhsStr; + llvm::raw_string_ostream rhsStream(rhsStr); + rhs.print(rhsStream); + return lhsStr < rhsStr; + }); +} + +AffinityAnalysis::AffinityAnalysis(Operation *rootOp) + : explorer(rootOp, TraversalAction::RECURSE), solver(explorer, allocator) { + explorer.setOpInterfaceAction( + TraversalAction::RECURSE); + + explorer.setDialectAction(TraversalAction::RECURSE); + + explorer.setDialectAction( + TraversalAction::RECURSE); + explorer.setOpAction(TraversalAction::IGNORE); + + explorer.initialize(); +} + +AffinityAnalysis::~AffinityAnalysis() = default; + +IREE::Stream::AffinityAttr +AffinityAnalysis::lookupGlobalAffinity(Operation *op) { + SmallVector affinities; + if (!tryLookupGlobalAffinity(op, affinities) || affinities.empty()) { + return {}; + } + if (affinities.size() == 1) { + return affinities.front(); + } + return trySelectLeadAffinity(affinities); +} + +bool AffinityAnalysis::tryLookupGlobalAffinity( + Operation *op, SmallVectorImpl &affinities) { + auto globalPVS = + solver.lookupElementFor(Position::forOperation(op)); + if (!globalPVS || !globalPVS->isValidState() || + globalPVS->isUndefContained()) { + // Analysis failed. + return false; + } + if (globalPVS->getAssumedSet().empty()) { + // Analysis completed but no affinity was specified; try to find a default. + return tryLookupDefaultAffinity(op, affinities); + } + for (auto affinityAttr : globalPVS->getAssumedSet()) { + affinities.push_back(affinityAttr); + } + sortAffinities(affinities); + return true; +} + +IREE::Stream::AffinityAttr +AffinityAnalysis::lookupExecutionAffinity(Operation *op) { + SmallVector affinities; + if (!tryLookupExecutionAffinity(op, affinities) || affinities.empty()) { + return {}; + } + if (affinities.size() == 1) { + return affinities.front(); + } + return trySelectLeadAffinity(affinities); +} + +bool AffinityAnalysis::tryLookupExecutionAffinity( + Operation *op, SmallVectorImpl &affinities) { + auto opPVS = + solver.lookupElementFor(Position::forOperation(op)); + if (!opPVS || !opPVS->isValidState() || opPVS->isUndefContained()) { + // Analysis failed. + return false; + } + if (opPVS->getAssumedSet().empty()) { + // Analysis completed but no affinity was specified; try to find a default. + return tryLookupDefaultAffinity(op, affinities); + } + for (auto affinityAttr : opPVS->getAssumedSet()) { + affinities.push_back(affinityAttr); + } + sortAffinities(affinities); + return true; +} + +IREE::Stream::AffinityAttr +AffinityAnalysis::inferExecutionAffinity(Operation *op) { + SmallVector affinities; + if (!tryInferExecutionAffinity(op, affinities) || affinities.empty()) { + return {}; + } + if (affinities.size() == 1) { + return affinities.front(); + } + return trySelectLeadAffinity(affinities); +} + +bool AffinityAnalysis::tryInferExecutionAffinity( + Operation *op, SmallVectorImpl &affinities) { + if (auto affinityOp = dyn_cast(op)) { + return tryLookupExecutionAffinity(op, affinities); + } + DFX::PotentialValuesState opPVS; + const bool consumesAny = llvm::any_of( + op->getOperandTypes(), +[](Type type) { + return isa(type); + }); + if (consumesAny) { + for (auto operand : op->getOperands()) { + if (isa(operand.getType())) { + auto valuePVS = solver.lookupElementFor( + Position::forValue(operand), nullptr, DFX::Resolution::REQUIRED); + if (valuePVS && valuePVS->isValidState()) { + opPVS.unionAssumed(valuePVS->getState()); + } else { + return false; + } + } + } + } else { + for (auto result : op->getResults()) { + if (isa(result.getType())) { + auto valuePVS = solver.lookupElementFor( + Position::forValue(result), nullptr, DFX::Resolution::REQUIRED); + if (valuePVS && valuePVS->isValidState()) { + opPVS.unionAssumed(valuePVS->getState()); + } else { + return false; + } + } + } + } + if (!opPVS.isValidState() || opPVS.isUndefContained()) { + // Analysis failed. + return false; + } + if (opPVS.getAssumedSet().empty()) { + // Analysis completed but no affinity was specified; try to find a default. + return tryLookupDefaultAffinity(op, affinities); + } + for (auto affinityAttr : opPVS.getAssumedSet()) { + affinities.push_back(affinityAttr); + } + sortAffinities(affinities); + return true; +} + +IREE::Stream::AffinityAttr +AffinityAnalysis::lookupResourceAffinity(Value value) { + SmallVector affinities; + if (!tryLookupResourceAffinity(value, affinities) || affinities.empty()) { + return {}; + } + if (affinities.size() == 1) { + return affinities.front(); + } + return trySelectLeadAffinity(affinities); +} + +bool AffinityAnalysis::tryLookupResourceAffinity( + Value value, SmallVectorImpl &affinities) { + auto valuePVS = solver.lookupElementFor( + Position::forValue(value)); + if (!valuePVS || !valuePVS->isValidState() || valuePVS->isUndefContained()) { + // Analysis failed. + return false; + } + if (valuePVS->getAssumedSet().empty()) { + // Analysis completed but no affinity was specified; try to find a default. + return tryLookupDefaultAffinity(value.getParentBlock()->getParentOp(), + affinities); + } + for (auto affinityAttr : valuePVS->getAssumedSet()) { + affinities.push_back(affinityAttr); + } + sortAffinities(affinities); + return true; +} + +LogicalResult AffinityAnalysis::run() { + // Initialize globals so that we can assign them affinity. + explorer.forEachGlobal([&](const auto *globalInfo) { + if (isa( + globalInfo->op.getGlobalType())) { + solver.getOrCreateElementFor( + Position::forOperation(globalInfo->op)); + } + }); + + // Initialize op execution affinities for any ops that use tracked types. + // + // TODO(benvanik): avoid doing this initialization for the entire module and + // instead rely on DFX to automatically populate the required abstract values. + // There's some missing logic in the element initialization, though, and by + // initializing all values we side-step that and work with test programs that + // may not have I/O edges that we could easily latch on to here. + explorer.forEachFunctionLikeOp([&](FunctionOpInterface funcOp) { + for (auto &block : funcOp.getBlocks()) { + for (auto arg : block.getArguments()) { + if (isa(arg.getType())) { + solver.getOrCreateElementFor( + Position::forValue(arg)); + } + } + } + funcOp.walk([&](Operation *op) { + if (auto regionOp = dyn_cast(op)) { + for (auto ®ion : regionOp->getRegions()) { + for (auto arg : region.getArguments()) { + if (isa(arg.getType())) { + solver.getOrCreateElementFor( + Position::forValue(arg)); + } + } + } + } + if (auto affinityOp = dyn_cast(op)) { + solver.getOrCreateElementFor(Position::forOperation(op)); + } + for (auto result : op->getResults()) { + if (isa(result.getType())) { + solver.getOrCreateElementFor( + Position::forValue(result)); + } + } + }); + }); + + if (failed(solver.run())) { + return failure(); // did not converge + } + + LLVM_DEBUG({ + llvm::dbgs() + << "\n\n[Analysis] affinity analysis results for the whole module:\n"; + solver.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + return success(); +} + +} // namespace mlir::iree_compiler::IREE::Stream diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.h b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.h new file mode 100644 index 000000000000..3642a5351a7d --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.h @@ -0,0 +1,102 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_STREAM_ANALYSIS_AFFINITY_H_ +#define IREE_COMPILER_DIALECT_STREAM_ANALYSIS_AFFINITY_H_ + +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Util/Analysis/DFX/Solver.h" +#include "iree/compiler/Dialect/Util/Analysis/Explorer.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Diagnostics.h" + +namespace mlir::iree_compiler::IREE::Stream { + +//===----------------------------------------------------------------------===// +// Affinity analysis +//===----------------------------------------------------------------------===// + +// Performs whole-program analysis of resource and tensor value affinity. +// All `!stream.resource` and `tensor` SSA values will be analyzed and their +// affinities where used will be available for querying via the lookup +// functions. +class AffinityAnalysis { +public: + explicit AffinityAnalysis(Operation *rootOp); + ~AffinityAnalysis(); + + // Runs analysis and populates the resource usage map. + // May fail if analysis cannot be completed due to unsupported or unknown IR. + LogicalResult run(); + + // Returns the affinity of the global |op| based on its loads. + // The global storage should be allocated with this affinity and available for + // fast access from any compatible affinity. + // + // If an explicit affinity is provided via a stream.affinity attribute then + // that will be used in place of analysis. If there are more than one consumer + // (such as multiple loads) with differing affinities or analysis fails then + // no affinity is returned. If all affinities are compatible one will be + // chosen in an unspecified way. + IREE::Stream::AffinityAttr lookupGlobalAffinity(Operation *op); + + // Populates all potential affinities of the global |op| in |affinities|. + // Returns false if analysis failed and the set of affinities is unknown. + bool tryLookupGlobalAffinity( + Operation *op, SmallVectorImpl &affinities); + + // Returns the affinity of the executable |op| based on the op-specific rules + // as to whether its operands or results control placement. The operation + // should be scheduled to execute with this affinity and efficiently consume + // or produce resources that share a compatible affinity. + // + // If an explicit affinity is provided via stream.affinity attrs or the + // affinity op interface then that will be used in place of analysis. If there + // are multiple possible affinities or analysis fails no affinity is returned. + // If all affinities are compatible one will be chosen in an unspecified way. + IREE::Stream::AffinityAttr lookupExecutionAffinity(Operation *op); + + // Populates all potential execution affinities of |op| in |affinities|. + // Returns false if analysis failed and the set of affinities is unknown. + bool tryLookupExecutionAffinity( + Operation *op, SmallVectorImpl &affinities); + + // Returns the affinity of |op| as if it were executable even if it is not. + // This relies on analysis of operands and results having resolved and + // otherwise returns nullptr indicating the op has no assumed affinity. + IREE::Stream::AffinityAttr inferExecutionAffinity(Operation *op); + + // Populates all inferred potential execution affinities of |op| in + // |affinities|. This relies on analysis of operands and results having + // resolved and otherwise returns nullptr indicating the op has no assumed + // affinity. + // Returns false if analysis failed and the set of affinities is unknown. + bool tryInferExecutionAffinity( + Operation *op, SmallVectorImpl &affinities); + + // Returns the affinity of |value| based on its producers. + // The resource should be allocated with this affinity and be usable by any + // compatible affinity. + // + // If there are more than one producer of the value (such as multiple callers) + // with differing affinities or analysis fails then no affinity is returned. + // If all affinities are compatible one will be chosen in an unspecified way. + IREE::Stream::AffinityAttr lookupResourceAffinity(Value value); + + // Populates all potential affinities of |value| in |affinities|. + // Returns false if analysis failed and the set of affinities is unknown. + bool tryLookupResourceAffinity( + Value value, SmallVectorImpl &affinities); + +private: + Explorer explorer; + llvm::BumpPtrAllocator allocator; + DFX::Solver solver; +}; + +} // namespace mlir::iree_compiler::IREE::Stream + +#endif // IREE_COMPILER_DIALECT_STREAM_ANALYSIS_AFFINITY_H_ diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Analysis/BUILD.bazel index 4e1421bf6e20..3cbb5b5492bc 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/BUILD.bazel @@ -15,12 +15,14 @@ package( iree_compiler_cc_library( name = "Analysis", srcs = [ + "Affinity.cpp", "Partitioning.cpp", "Partitioning/ReferencePartitioning.cpp", "ResourceHazards.cpp", "ResourceUsage.cpp", ], hdrs = [ + "Affinity.h", "Partitioning.h", "ResourceHazards.h", "ResourceUsage.h", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt index f1b0fc8d56bb..c2dd74c24e10 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt @@ -14,10 +14,12 @@ iree_cc_library( NAME Analysis HDRS + "Affinity.h" "Partitioning.h" "ResourceHazards.h" "ResourceUsage.h" SRCS + "Affinity.cpp" "Partitioning.cpp" "Partitioning/ReferencePartitioning.cpp" "ResourceHazards.cpp" diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp index 1708782f21be..4ff656c18282 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp @@ -17,7 +17,6 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -416,10 +415,14 @@ class ValueResourceUsage : public AbstractResourceUsage { // TODO(benvanik): remove kFavorTransients. bool isSourceExternal = !sourceUsage.isAssumed(NOT_EXTERNAL); bool isTargetInternal = isAssumed(NOT_EXTERNAL); - if (kFavorTransients && isSourceExternal && isTargetInternal) { + bool deviceChange = + op.getSourceAffinityAttr() != op.getResultAffinityAttr(); + if ((kFavorTransients || deviceChange) && isSourceExternal && + isTargetInternal) { LLVM_DEBUG({ - llvm::dbgs() << "[ValueResourceUsage] skipping forward prop of " - "external into internal: "; + llvm::dbgs() + << "[ValueResourceUsage] skipping forward prop of external " + "into internal due to kFavorTransients/device-change: "; op.print(llvm::dbgs(), solver.getAsmState()); llvm::dbgs() << "\n"; }); @@ -529,7 +532,6 @@ class ValueResourceUsage : public AbstractResourceUsage { *this, Position::forValue(op.getBeforeBody()->getArgument(operandIdx)), DFX::Resolution::REQUIRED); - getState() ^= beforeUsage.getState(); }) .Case([&](mlir::scf::ConditionOp op) { @@ -562,29 +564,30 @@ class ValueResourceUsage : public AbstractResourceUsage { Position::forValue(op->getParentOp()->getResult(operandIdx)), DFX::Resolution::REQUIRED); getState() ^= parentUsage.getState(); - } else if (auto whileOp = - dyn_cast_or_null(op->getParentOp())) { + } else if (auto whileOp = dyn_cast(op->getParentOp())) { auto value = Position::forValue(whileOp.getBefore().getArgument(operandIdx)); auto &valueUsage = solver.getElementFor( *this, value, DFX::Resolution::REQUIRED); getState() ^= valueUsage.getState(); - } else if (auto forOp = - dyn_cast_or_null(op->getParentOp())) { + auto &parentUsage = solver.getElementFor( + *this, Position::forValue(whileOp->getResult(operandIdx)), + DFX::Resolution::REQUIRED); + getState() ^= parentUsage.getState(); + } else if (auto forOp = dyn_cast(op->getParentOp())) { auto value = Position::forValue(forOp.getRegionIterArg(operandIdx)); auto &valueUsage = solver.getElementFor( *this, value, DFX::Resolution::REQUIRED); getState() ^= valueUsage.getState(); - auto &parentUsage = solver.getElementFor( *this, Position::forValue(forOp->getResult(operandIdx)), DFX::Resolution::REQUIRED); getState() ^= parentUsage.getState(); } else { - assert(false && "Unsupported test case"); + assert(false && "unhandled scf yield parent"); } }) - .Case([&](mlir::func::ReturnOp op) { + .Case([&](IREE::Util::ReturnOp op) { auto &operandUsage = solver.getElementFor( *this, Position::forValue(op.getOperand(operandIdx)), DFX::Resolution::REQUIRED); @@ -734,11 +737,14 @@ class ValueResourceUsage : public AbstractResourceUsage { // TODO(benvanik): remove kFavorTransients. bool isSourceInternal = isAssumed(NOT_EXTERNAL); bool isTargetExternal = !resultUsage.isAssumed(NOT_EXTERNAL); - if (kFavorTransients && isSourceInternal && isTargetExternal) { + bool deviceChange = + op.getSourceAffinityAttr() != op.getResultAffinityAttr(); + if ((kFavorTransients || deviceChange) && isSourceInternal && + isTargetExternal) { LLVM_DEBUG({ llvm::dbgs() << "[ValueResourceUsage] skipping back prop of external into " - "internal due to kFavorTransients: "; + "internal due to kFavorTransients/device-change: "; op.print(llvm::dbgs(), solver.getAsmState()); llvm::dbgs() << "\n"; }); diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td index bfcca44e1dcc..4c8fb8d65aa5 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td @@ -504,6 +504,7 @@ def Stream_Channel : TypeDef { // Returns an affinity active for the given operation. // This will recursively walk parent operations until one with the // `stream.affinity` attribute is found. - static AffinityAttr lookup(Operation *op); + static AffinityAttr lookup(Operation *fromOp); + + // Returns an affinity active for the given operation or the fallback + // default if none is specified. + // This will recursively walk parent operations until one with the + // `stream.affinity` attribute is found. + static AffinityAttr lookupOrDefault(Operation *fromOp); // TODO(benvanik): replace with more fine-grained compatibility checks. // "Compatible" can mean a lot of things: are they cache-coherent, are they @@ -115,11 +121,25 @@ def Stream_AffinityAttr : AttrInterface<"AffinityAttr"> { }]; } +//===----------------------------------------------------------------------===// +// IREE::Stream::AffinityTypeInterface +//===----------------------------------------------------------------------===// + +def Stream_AffinityType : TypeInterface<"AffinityTypeInterface"> { + let cppNamespace = "::mlir::iree_compiler::IREE::Stream"; + + let description = [{ + Indicates a type represents a resource that has its affinity tracked. + }]; +} + //===----------------------------------------------------------------------===// // IREE::Stream::AffinityOpInterface //===----------------------------------------------------------------------===// def Stream_AffinityOp : OpInterface<"AffinityOpInterface"> { + let cppNamespace = "::mlir::iree_compiler::IREE::Stream"; + let description = [{ TBD. Used to denote a stream affinity for ops and specify the kind of environment the ops are expected run in. @@ -140,6 +160,19 @@ def Stream_AffinityOp : OpInterface<"AffinityOpInterface"> { return true; }] >, + InterfaceMethod< + /*desc=*/[{ + Returns true if the operands and results should be pinned to the + affinity of the op. This overrides all automatic placement logic. + }], + /*retTy=*/"bool", + /*methodName=*/"pinsValueAffinity", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] + >, InterfaceMethod< /*desc=*/[{ Returns the stream affinity for the op, indicating where it should run. @@ -149,7 +182,7 @@ def Stream_AffinityOp : OpInterface<"AffinityOpInterface"> { /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return dyn_cast_or_null($_self->getAttr("affinity")); + return dyn_cast_or_null($_op->getAttr("affinity")); }] >, InterfaceMethod< @@ -161,8 +194,8 @@ def Stream_AffinityOp : OpInterface<"AffinityOpInterface"> { /*args=*/(ins "IREE::Stream::AffinityAttr":$value), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (value) $_self->setAttr("affinity", value); - else $_self->removeAttr("affinity"); + if (value) $_op->setAttr("affinity", value); + else $_op->removeAttr("affinity"); }] >, ]; @@ -173,6 +206,8 @@ def Stream_AffinityOp : OpInterface<"AffinityOpInterface"> { //===----------------------------------------------------------------------===// def Stream_StreamableOp : OpInterface<"StreamableOpInterface"> { + let cppNamespace = "::mlir::iree_compiler::IREE::Stream"; + let description = [{ Interface for ops that can be asynchronous executed in a streaming context. }]; @@ -212,6 +247,8 @@ def Stream_StreamableOp : OpInterface<"StreamableOpInterface"> { //===----------------------------------------------------------------------===// def Stream_AsyncAccessOp : OpInterface<"AsyncAccessOpInterface"> { + let cppNamespace = "::mlir::iree_compiler::IREE::Stream"; + let description = [{ Interface for stream.async.* ops that access subviews of resources. This allows for some basic analysis and is only valid prior to allocation. @@ -240,6 +277,8 @@ def Stream_AsyncAccessOp : OpInterface<"AsyncAccessOpInterface"> { //===----------------------------------------------------------------------===// def Stream_SubviewEffectOp : OpInterface<"SubviewEffectOpInterface"> { + let cppNamespace = "::mlir::iree_compiler::IREE::Stream"; + let description = [{ Interface for ops that operate on subviews of resources used to query the memory effects for subviews on operands. @@ -258,6 +297,8 @@ def Stream_SubviewEffectOp : OpInterface<"SubviewEffectOpInterface"> { //===----------------------------------------------------------------------===// def Stream_TimelineOp : OpInterface<"TimelineOpInterface"> { + let cppNamespace = "::mlir::iree_compiler::IREE::Stream"; + let description = [{ Interface for ops that operate in an ordered sequence defined by timepoints. }]; diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp index 84cf1eb0e11f..9360eea725c9 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp @@ -1231,7 +1231,7 @@ struct TensorConstantToSplat : public OpRewritePattern { constantOp, constantOp.getResult().getType(), splatOp.getResult(), resultSize, resultSize, /*source_affinity=*/constantOp.getAffinityAttr(), - /*result_affinity=*/nullptr); + /*result_affinity=*/constantOp.getAffinityAttr()); return success(); } }; @@ -1452,9 +1452,9 @@ struct ConvertSplatConstantsIntoSplats LogicalResult matchAndRewrite(AsyncConstantOp constantOp, PatternRewriter &rewriter) const override { auto value = dyn_cast(constantOp.getValue()); - if (!value || !value.isSplat()) + if (!value || !value.isSplat()) { return failure(); - + } auto splatElementAttr = llvm::dyn_cast(value).getSplatValue(); auto splatValue = rewriter.create( diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp index 698c7b967e35..358d1fd98ef9 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp @@ -2020,14 +2020,57 @@ LogicalResult AsyncTransferOp::verify() { } IREE::Stream::AffinityAttr AsyncTransferOp::getAffinityAttr() { - return getResultAffinityAttr(); + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging && + resultType.getLifetime() == IREE::Stream::Lifetime::Staging) { + // TODO(multi-device): figure out how to model staging->staging transfers. + return getSourceAffinityAttr(); + } else if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging) { + // If source is staging then the op should execute on the consumer. + return getResultAffinityAttr(); + } else if (resultType.getLifetime() == IREE::Stream::Lifetime::Staging) { + // If result is staging then the op should execute on the producer. + return getSourceAffinityAttr(); + } else { + // Default to result affinity. + return getResultAffinityAttr(); + } } void AsyncTransferOp::setAffinityAttr(IREE::Stream::AffinityAttr value) { - if (value) - setResultAffinityAttr(value); - else - removeResultAffinityAttr(); + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging && + resultType.getLifetime() == IREE::Stream::Lifetime::Staging) { + // TODO(multi-device): figure out how to model staging->staging transfers. + if (value) { + setSourceAffinityAttr(value); + } else { + removeSourceAffinityAttr(); + } + } else if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging) { + // If source is staging then the op should execute on the consumer. + if (value) { + setResultAffinityAttr(value); + } else { + removeResultAffinityAttr(); + } + } else if (resultType.getLifetime() == IREE::Stream::Lifetime::Staging) { + // If result is staging then the op should execute on the producer. + if (value) { + setSourceAffinityAttr(value); + } else { + removeSourceAffinityAttr(); + } + } else { + // Default to result affinity. + if (value) { + setResultAffinityAttr(value); + } else { + removeResultAffinityAttr(); + } + } } void AsyncTransferOp::getAsyncAccessRanges( diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td index dbe5207734a1..871e3bb5254d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td @@ -1520,7 +1520,7 @@ def Stream_TensorFillOp : Stream_Op<"tensor.fill", [ let assemblyFormat = [{ (`on` `(` $affinity^ `)`)? - $value `,` $target `[` $start_indices `for` $lengths `]` `:` + $value `,` $target (`[` $start_indices `for` $lengths^ `]`)? `:` type($value) `->` $target_encoding (`` `{` $target_encoding_dims^ `}`)? diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp index 19c24101004f..82b8609e90e5 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp @@ -335,20 +335,40 @@ void TimepointAttr::print(AsmPrinter &p) const { // #stream.affinity //===----------------------------------------------------------------------===// -AffinityAttr AffinityAttr::lookup(Operation *op) { - auto attrId = StringAttr::get(op->getContext(), "stream.affinity"); - while (op) { - if (auto affinityOp = llvm::dyn_cast(op)) { - auto affinity = affinityOp.getAffinityAttr(); - if (affinity) +// static +AffinityAttr AffinityAttr::lookup(Operation *fromOp) { + auto attrId = StringAttr::get(fromOp->getContext(), "stream.affinity"); + while (fromOp) { + if (auto affinityOp = llvm::dyn_cast(fromOp)) { + if (auto affinity = affinityOp.getAffinityAttr()) { return affinity; + } } - auto attr = op->getAttrOfType(attrId); - if (attr) + if (auto attr = fromOp->getAttrOfType(attrId)) { return attr; - op = op->getParentOp(); + } + fromOp = fromOp->getParentOp(); + } + // No affinity found; let caller decide what to do. + return {}; +} + +// static +AffinityAttr AffinityAttr::lookupOrDefault(Operation *fromOp) { + if (auto affinityAttr = AffinityAttr::lookup(fromOp)) { + return affinityAttr; // found a specified affinity + } + auto attrId = + StringAttr::get(fromOp->getContext(), "stream.affinity.default"); + while (fromOp) { + if (auto affinityAttr = + fromOp->getAttrOfType(attrId)) { + return affinityAttr; + } + fromOp = fromOp->getParentOp(); } - return {}; // No affinity found; let caller decide what to do. + // No affinity or default found; let caller decide what to do. + return {}; } // static diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h index 42b8424e4ab5..d69e226fb868 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h @@ -69,9 +69,7 @@ class AffinityAttr; #include "iree/compiler/Dialect/Stream/IR/StreamAttrInterfaces.h.inc" // IWYU pragma: export -namespace mlir::iree_compiler::IREE::Stream { #include "iree/compiler/Dialect/Stream/IR/StreamTypeInterfaces.h.inc" // IWYU pragma: export -} // namespace mlir::iree_compiler::IREE::Stream // clang-format off: must be included after all LLVM/MLIR headers. #define GET_TYPEDEF_CLASSES @@ -99,8 +97,12 @@ struct AsyncAccessRange { const AsyncAccessRange &rhs); }; +} // namespace mlir::iree_compiler::IREE::Stream + #include "iree/compiler/Dialect/Stream/IR/StreamOpInterfaces.h.inc" // IWYU pragma: export +namespace mlir::iree_compiler::IREE::Stream { + //===----------------------------------------------------------------------===// // custom($scope, $key) //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp new file mode 100644 index 000000000000..62b9db24b69f --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp @@ -0,0 +1,127 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h" +#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::Stream { + +#define GEN_PASS_DEF_ANNOTATEAFFINITIESPASS +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// --iree-stream-annotate-affinities +//===----------------------------------------------------------------------===// + +static void annotateOp(Operation *op, + ArrayRef affinities) { + auto affinityOp = dyn_cast(op); + if (!affinityOp || !affinityOp.requiresAffinity()) { + return; + } + if (!affinities.empty()) { + op->setAttr("stream.affinities", + ArrayAttr::get(op->getContext(), + llvm::to_vector_of(affinities))); + } +} + +static void annotateGlobalOp(IREE::Util::GlobalOpInterface globalOp, + AffinityAnalysis &affinityAnalysis) { + if (!isa(globalOp.getGlobalType())) { + return; + } + SmallVector affinities; + if (affinityAnalysis.tryLookupGlobalAffinity(globalOp, affinities)) { + annotateOp(globalOp, affinities); + } +} + +static void annotateOperandsAndResults(Operation *op, + AffinityAnalysis &affinityAnalysis) { + auto emptyArray = ArrayAttr::get(op->getContext(), {}); + SmallVector operandAttrs; + for (auto operand : op->getOperands()) { + if (isa(operand.getType())) { + SmallVector affinities; + if (affinityAnalysis.tryLookupResourceAffinity(operand, affinities)) { + operandAttrs.push_back(ArrayAttr::get( + op->getContext(), llvm::to_vector_of(affinities))); + } else { + operandAttrs.push_back(emptyArray); + } + } + } + SmallVector resultAttrs; + for (auto result : op->getResults()) { + if (isa(result.getType())) { + SmallVector affinities; + if (affinityAnalysis.tryLookupResourceAffinity(result, affinities)) { + resultAttrs.push_back(ArrayAttr::get( + op->getContext(), llvm::to_vector_of(affinities))); + } else { + resultAttrs.push_back(emptyArray); + } + } + } + if (!operandAttrs.empty()) { + op->setAttr("stream.affinities.operands", + ArrayAttr::get(op->getContext(), operandAttrs)); + } + if (!resultAttrs.empty()) { + op->setAttr("stream.affinities.results", + ArrayAttr::get(op->getContext(), resultAttrs)); + } +} + +static void annotateFuncOp(FunctionOpInterface funcOp, + AffinityAnalysis &affinityAnalysis) { + funcOp.walk([&](Operation *op) { + SmallVector affinities; + if (affinityAnalysis.tryLookupExecutionAffinity(op, affinities)) { + annotateOp(op, affinities); + } + annotateOperandsAndResults(op, affinityAnalysis); + }); +} + +struct AnnotateAffinitiesPass + : public IREE::Stream::impl::AnnotateAffinitiesPassBase< + AnnotateAffinitiesPass> { + void runOnOperation() override { + // Run affinity analysis on the whole module. + AffinityAnalysis affinityAnalysis(getOperation()); + if (failed(affinityAnalysis.run())) { + return signalPassFailure(); + } + + // Annotate all ops with derived affinities. + for (auto &op : getOperation().getOps()) { + if (op.hasTrait()) + continue; + if (auto globalOp = dyn_cast(op)) { + annotateGlobalOp(globalOp, affinityAnalysis); + } else if (auto funcOp = dyn_cast(op)) { + annotateFuncOp(funcOp, affinityAnalysis); + } + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::IREE::Stream diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel index 1a1d2e2dd5b1..d2f326c1716a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_compiler_cc_library( name = "Transforms", srcs = [ + "AnnotateAffinities.cpp", "AnnotateDispatchArguments.cpp", "ConvertToStream.cpp", "DumpStatistics.cpp", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt index 9d78c8ed9ef9..5eb3d27b4dd0 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt @@ -16,6 +16,7 @@ iree_cc_library( HDRS "Passes.h" SRCS + "AnnotateAffinities.cpp" "AnnotateDispatchArguments.cpp" "ConvertToStream.cpp" "DumpStatistics.cpp" diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp index b99b79279239..31a5bb6a2243 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp @@ -16,6 +16,12 @@ #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/Passes.h" +static llvm::cl::opt clAnnotateInputAffinities( + "iree-stream-annotate-input-affinities", + llvm::cl::desc("Annotates all tensor/resource affinities on the input to " + "the pipeline for debugging."), + llvm::cl::init(false)); + namespace mlir::iree_compiler::IREE::Stream { using FunctionLikeNest = @@ -68,6 +74,13 @@ void buildStreamTensorPassPipeline(OpPassManager &passManager, // Conversion //---------------------------------------------------------------------------- + // Annotate all ops/resources with the analyzed affinities. + // This should have no behavioral changes during conversion but allows for + // debugging of analysis errors in end-user tooling. + if (clAnnotateInputAffinities) { + passManager.addPass(IREE::Stream::createAnnotateAffinitiesPass()); + } + // Converts from all input dialects into various levels of the stream dialect. // Tensor-like things go to stream.tensor.* ops while lower level buffer-like // things will go to stream.async.* ops. @@ -81,6 +94,9 @@ void buildStreamTensorPassPipeline(OpPassManager &passManager, // Constant/variable optimization //---------------------------------------------------------------------------- + // Run inlining after having baked out affinities. + passManager.addPass(mlir::createInlinerPass()); + // Cleanup globals that were created during conversion. addCleanupPatterns(passManager); @@ -96,10 +112,15 @@ void buildStreamTensorPassPipeline(OpPassManager &passManager, // TODO(benvanik): annotate all dispatches with preferred executable affinity. // TODO(benvanik): DFA to specify all value affinities and pin dispatches. + // TODO(multi-device): it's really nice to be able to verify here but it + // prevents compiling to stream without devices specified or continuation at + // various phases. It'd be nice to find a way to enable this when the user + // expects it to work and otherwise not. + // // Verify that all ops that may require affinities have them assigned or // available (on a parent scope, etc). This allows subsequent passes to trust // that an affinity lookup will always return a valid affinity. - passManager.addPass(IREE::Stream::createVerifyAffinitiesPass()); + // passManager.addPass(IREE::Stream::createVerifyAffinitiesPass()); } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td index ca2ec3a5b61d..f5ee39fa16d2 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td @@ -457,6 +457,11 @@ def PackDispatchOperandsPass : // Diagnostics //===----------------------------------------------------------------------===// +def AnnotateAffinitiesPass : + Pass<"iree-stream-annotate-affinities", "mlir::ModuleOp"> { + let summary = "Annotates affinities on all ops for debugging."; +} + def DumpStatisticsPass : Pass<"iree-stream-dump-statistics", "mlir::ModuleOp"> { let summary = "Dumps stream dialect usage information to a file."; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp index 1bec564dcda2..c8510a63d71b 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp @@ -668,7 +668,8 @@ static LogicalResult applyAsyncTransferOp(IREE::Stream::AsyncTransferOp asyncOp, return llvm::cast(value.getType()) .getLifetime() == IREE::Stream::Lifetime::Staging; }; - auto currentAffinityAttr = IREE::Stream::AffinityAttr::lookup(asyncOp); + auto currentAffinityAttr = + IREE::Stream::AffinityAttr::lookupOrDefault(asyncOp); bool transferIn = asyncOp.getSourceAffinityAttr() != currentAffinityAttr || isStaging(asyncOp.getSource()); bool transferOut = asyncOp.getResultAffinityAttr() != currentAffinityAttr || diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp index 042bbb860263..75792444551f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp @@ -27,7 +27,7 @@ static LogicalResult verifyAffinityAssigned(IREE::Stream::AffinityOpInterface op) { if (!op.requiresAffinity()) { return success(); // does not require an affinity - } else if (IREE::Stream::AffinityAttr::lookup(op)) { + } else if (IREE::Stream::AffinityAttr::lookupOrDefault(op)) { return success(); // has an affinity } return op->emitOpError() @@ -55,7 +55,10 @@ struct VerifyAffinitiesPass return WalkResult::interrupt(); } } - return WalkResult::advance(); + return (op->hasTrait() || + op->hasTrait()) + ? WalkResult::skip() + : WalkResult::advance(); }) .wasInterrupted()) return signalPassFailure(); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel index 524a1ce109a6..362d6728a676 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel @@ -16,6 +16,7 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "annotate_affinities.mlir", "annotate_dispatch_arguments.mlir", "convert_to_stream.mlir", "dump_statistics.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt index 5ea981160d91..fe83ee67863d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "annotate_affinities.mlir" "annotate_dispatch_arguments.mlir" "convert_to_stream.mlir" "dump_statistics.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_affinities.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_affinities.mlir new file mode 100644 index 000000000000..c3e1f1ed38be --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_affinities.mlir @@ -0,0 +1,1549 @@ +// RUN: iree-opt --split-input-file --iree-stream-annotate-affinities %s | FileCheck %s + +// Tests that we can track affinity through optimization barriers. They're meant +// to block optimization but we really can't do much if we don't track affinity. +// We could change this in the future but tests would be harder to write and +// there's not a lot that can be done with an unassigned resource. + +// CHECK-LABEL: @optimization_barrier_consumer +util.func private @optimization_barrier_consumer() -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: util.optimization_barrier + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_dno = util.optimization_barrier %cst : tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.transfer %cst_dno : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %cst_a : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @optimization_barrier_producer +util.func private @optimization_barrier_producer() -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: util.optimization_barrier + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a_dno = util.optimization_barrier %cst_a : tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %cst_a_dno : tensor<1xi32> +} + +// ----- + +// Tests that constant-like ops get placed with their consumer(s). +// We want to replicate constants where they are consumed instead of performing +// transfers at runtime to move them around and by placing with consumers we +// can know when we need to do that early on. + +// CHECK-LABEL: @constant_op +util.func private @constant_op() -> (tensor<1xi32>, tensor<1xi32>) { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %cst_b = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]] + util.return %cst_a, %cst_b : tensor<1xi32>, tensor<1xi32> +} + +// ----- + +// Tests that splats (not constant-like but no consumed values) are placed with +// their consumer(s). These are always best to rematerialize where they are +// consumed to avoid allocating/transfering a bunch of repeated values. + +// CHECK-LABEL: @splat_op +util.func private @splat_op() -> tensor<1xi32> { + %splat_value = arith.constant 123 : i32 + // CHECK: flow.tensor.splat + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %splat = flow.tensor.splat %splat_value : tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %splat_a = flow.tensor.transfer %splat : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %splat_a : tensor<1xi32> +} + +// ----- + +// Tests that imported tensor placement is inherited. +// Frontends can use this to declare where they expect their arguments to +// be living at the time the functions are invoked. Imports do not perform +// transfers so we must use whatever is declared. + +// CHECK-LABEL: @imported_tensor +util.func public @imported_tensor(%buffer_view: !hal.buffer_view, %fence: !hal.fence) -> tensor<1xi32> { + // CHECK: hal.tensor.import + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %tensor = hal.tensor.import on(#hal.device.promise<@dev_a>) wait(%fence) => %buffer_view "input" : !hal.buffer_view -> tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %tensor : tensor<1xi32> +} + +// ----- + +// Tests that consumer-placed ops exported to buffers are properly placed. +// Frontends can use this to explicitly define where exported tensors must live. +// With consumer-placed ops like constants or splats we place them directly on +// the export target. + +// CHECK-LABEL: @exported_constant +util.func public @exported_constant(%fence: !hal.fence) -> !hal.buffer_view { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: hal.tensor.barrier + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_ready = hal.tensor.barrier join(%cst : tensor<1xi32>) => %fence : !hal.fence + // CHECK: hal.tensor.export + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + %buffer_view = hal.tensor.export on(#hal.device.promise<@dev_a>) %cst_ready "output" : tensor<1xi32> -> !hal.buffer_view + util.return %buffer_view : !hal.buffer_view +} + +// ----- + +// Tests that producer-placed ops exported to buffers get the appropriate +// affinity on both devices. Frontends can use this to explicitly define where +// exported tensors must live. Transfers may need to be inserted in order to +// respect the required affinities. Note here that the operand to the export +// is on @dev_a instead of the requested @dev_b. + +// CHECK-LABEL: @exported_producer +util.func public @exported_producer(%fence: !hal.fence) -> !hal.buffer_view { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: flow.tensor.clone + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %clone_a = flow.tensor.clone %cst_a : tensor<1xi32> + // CHECK: hal.tensor.barrier + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %clone_ready_a = hal.tensor.barrier join(%clone_a : tensor<1xi32>) => %fence : !hal.fence + // CHECK: hal.tensor.export + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + %buffer_view = hal.tensor.export on(#hal.device.promise<@dev_b>) %clone_ready_a "output" : tensor<1xi32> -> !hal.buffer_view + // CHECK: util.return + util.return %buffer_view : !hal.buffer_view +} + +// ----- + +// Test in-place aliased storage for results. +// Frontends require that the storage be placed as indicated even if that means +// introducing transfers such that the operation is not in-place. + +// CHECK-LABEL: @aliased_storage +util.func public @aliased_storage(%view: !hal.buffer_view, %storage: !hal.buffer, %fence: !hal.fence) { + // CHECK: hal.tensor.import + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %arg_a = hal.tensor.import on(#hal.device.promise<@dev_a>) %view : !hal.buffer_view -> tensor<4xi32> + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %ret_b = flow.dispatch @dispatch(%arg_a) : (tensor<4xi32>) -> tensor<4xi32> + // CHECK: hal.tensor.alias + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %alias_b = hal.tensor.alias on(#hal.device.promise<@dev_b>) %ret_b : tensor<4xi32> to %storage : !hal.buffer + // CHECK: hal.tensor.barrier + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + hal.tensor.barrier join(%alias_b : tensor<4xi32>) => %fence : !hal.fence + util.return +} + +// ----- + +// Tests aliased storage through tied dispatches. + +// CHECK-LABEL: @tied_aliased_storage +util.func public @tied_aliased_storage(%view: !hal.buffer_view, %storage: !hal.buffer, %fence: !hal.fence) { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst = flow.tensor.constant dense<123> : tensor<4xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.transfer %cst : tensor<4xi32> to #hal.device.promise<@dev_a> + // CHECK: flow.dispatch @dispatch0 + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %t0 = flow.dispatch @dispatch0(%cst) : (tensor<4xi32>) -> tensor<4xi32> + // CHECK: flow.dispatch @dispatch1 + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %t1 = flow.dispatch @dispatch1(%t0) : (tensor<4xi32>) -> %t0 + // CHECK: hal.tensor.alias + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %alias = hal.tensor.alias on(#hal.device.promise<@dev_b>) %t1 : tensor<4xi32> to %storage : !hal.buffer + // CHECK: hal.tensor.barrier + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + hal.tensor.barrier join(%alias : tensor<4xi32>) => %fence : !hal.fence + util.return +} + +// ----- + +// Tests that consumer-placed ops that pass through tied ops get attributed to +// a single consumer. + +// CHECK-LABEL: @tied_constant +util.func private @tied_constant() -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: flow.dispatch @a + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %tied = flow.dispatch @a(%cst) : (tensor<1xi32>) -> %cst + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %tied_a = flow.tensor.transfer %tied : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %tied_a : tensor<1xi32> +} + +// ----- + +// Tests that consumer-placed ops that pass through tied ops get attributed to +// transitive consumers. This is not ideal but allows the application of +// replication policies. + +// CHECK-LABEL: @tied_constant_multi_consumer +util.func private @tied_constant_multi_consumer() -> (tensor<1xi32>, tensor<1xi32>) { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: flow.dispatch @a + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %tied_0 = flow.dispatch @a(%cst) : (tensor<1xi32>) -> %cst + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %tied_0_a = flow.tensor.transfer %tied_0 : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: flow.dispatch @b + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %tied_1 = flow.dispatch @b(%cst) : (tensor<1xi32>) -> %cst + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %tied_1_b = flow.tensor.transfer %tied_1 : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]] + util.return %tied_0_a, %tied_1_b : tensor<1xi32>, tensor<1xi32> +} + +// ----- + +// Tests the proper transfer of consumer-placed values prior to multiple tied +// uses don't pollute the execution affinity of ops after transfers. Note that +// the constant will still have multiple affinities to allow for policies that +// replicate the constant. + +// CHECK-LABEL: @tied_transfer_constant_multi_consumer +util.func private @tied_transfer_constant_multi_consumer() -> (tensor<1xi32>, tensor<1xi32>) { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: flow.dispatch @a + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %tied_0 = flow.dispatch @a(%cst_a) : (tensor<1xi32>) -> %cst_a + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %tied_0_a = flow.tensor.transfer %tied_0 : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %cst_b = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: flow.dispatch @b + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %tied_1 = flow.dispatch @b(%cst_b) : (tensor<1xi32>) -> %cst_b + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %tied_1_b = flow.tensor.transfer %tied_1 : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]] + util.return %tied_0_a, %tied_1_b : tensor<1xi32>, tensor<1xi32> +} + +// ----- + +// Tests that implicitly placed consumers use their transfer execution affinity. + +// CHECK-LABEL: @transfer_execution_affinity +util.func private @transfer_execution_affinity() -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %cst_b = flow.tensor.transfer %cst_a : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %dispatch_b = flow.dispatch @dispatch(%cst_b) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + util.return %dispatch_b : tensor<1xi32> +} + +// ----- + +// Tests that explicitly placed consumers use their explicit execution affinity. + +// CHECK-LABEL: @explicit_execution_affinity +util.func private @explicit_execution_affinity() -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %dispatch_b = flow.dispatch @dispatch(%cst_a) {stream.affinity = #hal.device.promise<@dev_b>} : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + util.return %dispatch_b : tensor<1xi32> +} + +// ----- + +// Tests that consumers of operands with multiple affinities inherit those +// affinities for execution. This allows policies to determine where they want +// to execute out of the resources they may be consuming. + +// CHECK-LABEL: @consume_multi_affinities +util.func private @consume_multi_affinities() -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %cst_b = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_b>} dense<456> : tensor<1xi32> + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %dispatch_ab = flow.dispatch @dispatch(%cst_a, %cst_b) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + util.return %dispatch_ab : tensor<1xi32> +} + +// ----- + +// Tests that globals are placed where they are loaded. + +// CHECK: util.global private @consumed_global_a +// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] +util.global private @consumed_global_a : tensor<1xi32> +util.func private @consumer_fn() -> tensor<1xi32> { + // CHECK: util.global.load @consumed_global_a + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %load = util.global.load @consumed_global_a : tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %load_a = flow.tensor.transfer %load : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %load_a : tensor<1xi32> +} + +// ----- + +// Tests that a global loaded from two locations is attributed to both +// affinities. This allows policies to decide whether to replicate the global. + +// CHECK: util.global private @consumed_global_ab +// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>] +util.global private @consumed_global_ab : tensor<1xi32> +util.func private @consumer_fn_a() -> tensor<1xi32> { + // CHECK: util.global.load @consumed_global_ab + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %load = util.global.load @consumed_global_ab : tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %load_a = flow.tensor.transfer %load : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %load_a : tensor<1xi32> +} +util.func private @consumer_fn_b() -> tensor<1xi32> { + // CHECK: util.global.load @consumed_global_ab + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %load = util.global.load @consumed_global_ab : tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %load_b = flow.tensor.transfer %load : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + util.return %load_b : tensor<1xi32> +} + +// ----- + +// Tests that consumer-placed ops track through global loads. + +// CHECK: util.global private mutable @global_b +// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] +util.global private mutable @global_b : tensor<1xi32> +util.func private @producer_fn() { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: util.global.store + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.global.store %cst_a, @global_b : tensor<1xi32> + util.return +} +util.func private @consumer_fn() -> tensor<1xi32> { + // CHECK: util.global.load + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %load = util.global.load @global_b : tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %load_b = flow.tensor.transfer %load : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + util.return %load_b : tensor<1xi32> +} + +// ----- + +// Tests that globals that are only stored take the fallback placement of +// their producer. This is silly but can arise prior to global optimization +// passes that may elide them. + +// CHECK: util.global private mutable @global_a +// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] +util.global private mutable @global_a : tensor<1xi32> +util.func private @producer_fn() { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: util.global.store + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.global.store %cst_a, @global_a : tensor<1xi32> + util.return +} + +// ----- + +// Tests that global consumers that take on consumed affinity track the global. + +// CHECK: util.global private @global_a +// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] +util.global private @global_a {stream.affinity = #hal.device.promise<@dev_a>} : tensor<1xi32> +// CHECK: util.global private @global_b +// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] +util.global private @global_b {stream.affinity = #hal.device.promise<@dev_b>} : tensor<1xi32> +util.func private @consumer_fn() -> tensor<1xi32> { + // CHECK: util.global.load @global_a + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %load_a = util.global.load @global_a : tensor<1xi32> + // CHECK: util.global.load @global_b + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %load_b = util.global.load @global_b : tensor<1xi32> + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %result_ab = flow.dispatch @dispatch(%load_a, %load_b) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + util.return %result_ab : tensor<1xi32> +} + +// ----- + +// Tests a global update tick that operates on the global from multiple +// affinities. + +// CHECK: util.global private mutable @global_a +// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] +util.global private mutable @global_a {stream.affinity = #hal.device.promise<@dev_a>} = dense<123> : tensor<1xi32> +util.func private @step(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK: util.global.load @global_a + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %load_a = util.global.load @global_a : tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %arg0_b = flow.tensor.transfer %arg0 : tensor<2xi32> to #hal.device.promise<@dev_b> + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>], [#hal.device.promise<@dev_b>]] + %result_b:2 = flow.dispatch @dispatch(%load_a, %arg0_b) {stream.affinity = #hal.device.promise<@dev_b>} : (tensor<1xi32>, tensor<2xi32>) -> (tensor<1xi32>, tensor<2xi32>) + // CHECK: util.global.store + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + util.global.store %result_b#0, @global_a : tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + util.return %result_b#1 : tensor<2xi32> +} + +// ----- + +// Tests that constants passed through selects are placed on the consumer. + +// CHECK-LABEL: @select_constants_consumed +util.func private @select_constants_consumed(%cond: i1) -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_123 = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_456 = flow.tensor.constant dense<456> : tensor<1xi32> + // CHECK: arith.select + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst = arith.select %cond, %cst_123, %cst_456 : tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %cst_a : tensor<1xi32> +} + +// ----- + +// Tests that placed operands passed through selects are tracked on consumers. + +// CHECK-LABEL: @select_constants_placed +util.func private @select_constants_placed(%cond: i1) -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %cst_b = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_b>} dense<456> : tensor<1xi32> + // CHECK: arith.select + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %cst_ab = arith.select %cond, %cst_a, %cst_b : tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + util.return %cst_ab : tensor<1xi32> +} + +// ----- + +// Tests that a callee that does not touch an argument still tracks the +// affinity through it. + +// CHECK-LABEL: @passthrough_caller +util.func private @passthrough_caller() -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: util.call @passthrough_callee + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %result_a = util.call @passthrough_callee(%cst_a) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %result_a : tensor<1xi32> +} +// CHECK: util.func private @passthrough_callee +util.func private @passthrough_callee(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %arg0 : tensor<1xi32> +} + +// ----- + +// Tests that callees that consumer-placed arguments that are passed to callees +// get placed based on callee usage. + +// CHECK-LABEL: @consumer_placement_caller +util.func private @consumer_placement_caller() -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: util.call @consumer_placement_callee + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %result_a = util.call @consumer_placement_callee(%cst) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %result_a : tensor<1xi32> +} +// CHECK: util.func private @consumer_placement_callee +util.func private @consumer_placement_callee(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %arg0_a : tensor<1xi32> +} + +// ----- + +// Tests that multiple potential affinities are propagated across call edges. + +// CHECK-LABEL: @select_caller +util.func private @select_caller(%arg0: tensor<1xi32>, %cond: i1) -> tensor<1xi32> { + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: util.call @select_callee + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %result_ab = util.call @select_callee(%arg0_a, %cond) : (tensor<1xi32>, i1) -> tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + util.return %result_ab : tensor<1xi32> +} +// CHECK: util.func private @select_callee +util.func private @select_callee(%arg0_a: tensor<1xi32>, %cond: i1) -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %cst_b = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_b>} dense<123> : tensor<1xi32> + // CHECK: arith.select + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %select_ab = arith.select %cond, %arg0_a, %cst_b : tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + util.return %select_ab : tensor<1xi32> +} + +// ----- + +// Tests that consumer-placed ops are propagated across call edges. + +// CHECK-LABEL: @consumer_multi_placement_caller +util.func private @consumer_multi_placement_caller() -> (tensor<1xi32>, tensor<1xi32>) { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_c>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: util.call @consumer_multi_placement_callee + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]] + %result_0_c = util.call @consumer_multi_placement_callee(%cst) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %result_0_a = flow.tensor.transfer %result_0_c : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: util.call @consumer_multi_placement_callee + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]] + %result_1_c = util.call @consumer_multi_placement_callee(%cst) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %result_1_b = flow.tensor.transfer %result_1_c : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]] + util.return %result_0_a, %result_1_b : tensor<1xi32>, tensor<1xi32> +} +// CHECK: util.func private @consumer_multi_placement_callee +util.func private @consumer_multi_placement_callee(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]] + %arg0_c = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_c> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]] + util.return %arg0_c : tensor<1xi32> +} + +// ----- + +// Tests that operand/result affinities are tracked across call edges. + +// CHECK-LABEL: @dispatch_fn_a +util.func private @dispatch_fn_a() -> tensor<4xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %0 = flow.tensor.constant dense<123> : tensor<4xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %1 = flow.tensor.transfer %0 : tensor<4xi32> to #hal.device.promise<@dev_a> + // CHECK: flow.dispatch @dispatch_a_0 + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %2 = flow.dispatch @dispatch_a_0(%1) : (tensor<4xi32>) -> tensor<4xi32> + // CHECK: util.call @dispatch_fn_b + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %3 = util.call @dispatch_fn_b(%2) : (tensor<4xi32>) -> tensor<4xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %4 = flow.tensor.transfer %3 : tensor<4xi32> to #hal.device.promise<@dev_a> + // CHECK: flow.dispatch @dispatch_a_1 + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %5 = flow.dispatch @dispatch_a_1(%4) : (tensor<4xi32>) -> tensor<4xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %5 : tensor<4xi32> +} +// CHECK: util.func private @dispatch_fn_b +util.func private @dispatch_fn_b(%arg0: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %0 = flow.tensor.transfer %arg0 : tensor<4xi32> to #hal.device.promise<@dev_b> + // CHECK: flow.dispatch @dispatch_b + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %1 = flow.dispatch @dispatch_b(%0) : (tensor<4xi32>) -> tensor<4xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + util.return %1 : tensor<4xi32> +} + +// ----- + +// Tests a realistic call graph with explicit transfers. + +// CHECK-LABEL: @dispatch_fn_a +util.func private @dispatch_fn_a() -> tensor<4xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %0 = flow.tensor.constant dense<123> : tensor<4xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %1 = flow.tensor.transfer %0 : tensor<4xi32> to #hal.device.promise<@dev_a> + // CHECK: util.call @dispatch_fn_b + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %2 = util.call @dispatch_fn_b(%1) : (tensor<4xi32>) -> tensor<4xi32> + // CHECK: util.call @dispatch_fn_c + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]] + %3 = util.call @dispatch_fn_c(%1) : (tensor<4xi32>) -> tensor<4xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %4 = flow.tensor.transfer %2 : tensor<4xi32> to #hal.device.promise<@dev_a> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %5 = flow.tensor.transfer %3 : tensor<4xi32> to #hal.device.promise<@dev_a> + // CHECK: flow.dispatch @dispatch_a + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %6 = flow.dispatch @dispatch_a(%4, %5) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %5 : tensor<4xi32> +} +// CHECK: util.func private @dispatch_fn_b +util.func private @dispatch_fn_b(%arg0: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %0 = flow.tensor.transfer %arg0 : tensor<4xi32> to #hal.device.promise<@dev_b> + // CHECK: flow.dispatch @dispatch_b + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %1 = flow.dispatch @dispatch_b(%0) : (tensor<4xi32>) -> tensor<4xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + util.return %1 : tensor<4xi32> +} +// CHECK: util.func private @dispatch_fn_c +util.func private @dispatch_fn_c(%arg0: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]] + %0 = flow.tensor.transfer %arg0 : tensor<4xi32> to #hal.device.promise<@dev_c> + // CHECK: flow.dispatch @dispatch_c + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_c>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]] + %1 = flow.dispatch @dispatch_c(%0) : (tensor<4xi32>) -> tensor<4xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]] + util.return %1 : tensor<4xi32> +} + +// ----- + +// Tests that consumer-placed ops are tracked across branch edges. + +// CHECK-LABEL: @cfg_branch_constant_consumed +util.func private @cfg_branch_constant_consumed() -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: cf.br ^bb1 + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + cf.br ^bb1(%cst : tensor<1xi32>) +^bb1(%bb1_arg0: tensor<1xi32>): + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.transfer %bb1_arg0 : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %cst_a : tensor<1xi32> +} + +// ----- + +// Tests that producer-placed ops are tracked across branch edges. + +// CHECK-LABEL: @cfg_branch_dispatch_produced +util.func private @cfg_branch_dispatch_produced() -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: cf.br ^bb1 + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + cf.br ^bb1(%cst_a : tensor<1xi32>) +^bb1(%bb1_arg0: tensor<1xi32>): + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %dispatch_a = flow.dispatch @dispatch(%bb1_arg0) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %dispatch_a : tensor<1xi32> +} + +// ----- + +// Tests that back edges on loops track affinity changes. + +// CHECK-LABEL: @cfg_loop_back_edge +util.func private @cfg_loop_back_edge() -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: cf.br ^bb1 + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + cf.br ^bb1(%cst_a : tensor<1xi32>) +^bb1(%bb1_arg0: tensor<1xi32>): + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %bb1_arg0_b = flow.tensor.transfer %bb1_arg0 : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: util.call @step + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + %cond = util.call @step(%bb1_arg0_b) : (tensor<1xi32>) -> i1 + // CHECK: cf.cond_br + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]] + cf.cond_br %cond, ^bb1(%bb1_arg0 : tensor<1xi32>), ^bb2(%bb1_arg0_b : tensor<1xi32>) +^bb2(%bb2_arg0: tensor<1xi32>): + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]] + %bb2_arg0_c = flow.tensor.transfer %bb2_arg0 : tensor<1xi32> to #hal.device.promise<@dev_c> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]] + util.return %bb2_arg0_c : tensor<1xi32> +} +util.func private @step(tensor<1xi32>) -> i1 + +// ----- + +// Tests that conditional branches acting as selects propagate both affinities. + +// CHECK-LABEL: @cfg_cond_branch_select +util.func private @cfg_cond_branch_select(%cond: i1) -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %cst_b = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_b>} dense<456> : tensor<1xi32> + // CHECK: cf.cond_br + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]] + cf.cond_br %cond, ^bb1(%cst_a : tensor<1xi32>), ^bb1(%cst_b : tensor<1xi32>) +^bb1(%bb1_arg0: tensor<1xi32>): + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + util.return %bb1_arg0 : tensor<1xi32> +} + +// ----- + +// Tests that consumer-placed ops through conditional branches acting as selects +// get placed on all targets. + +// CHECK-LABEL: @cfg_cond_branch_select_consumer +util.func private @cfg_cond_branch_select_consumer(%cond: i1) -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: cf.cond_br + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>], [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + cf.cond_br %cond, ^bb1(%cst : tensor<1xi32>), ^bb2(%cst : tensor<1xi32>) +^bb1(%bb1_arg0: tensor<1xi32>): + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.transfer %bb1_arg0 : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %cst_a : tensor<1xi32> +^bb2(%bb2_arg0: tensor<1xi32>): + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %cst_b = flow.tensor.transfer %bb2_arg0 : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + util.return %cst_b : tensor<1xi32> +} + +// ----- + +// Tests scf.if capturing consumer-placed ops tracks the affinity into nested +// regions. + +// CHECK-LABEL: @scf_if_capture_consumer +util.func private @scf_if_capture_consumer(%cond: i1) -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: scf.if + %cst_ab = scf.if %cond -> tensor<1xi32> { + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + scf.yield %cst_a : tensor<1xi32> + // CHECK: else + } else { + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %cst_b = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + scf.yield %cst_b : tensor<1xi32> + // CHECK{LITERAL}: } {stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + } + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + util.return %cst_ab : tensor<1xi32> +} + +// ----- + +// Tests scf.if capturing explicitly placed ops tracks the affinity of their +// produced results into consumers. + +// CHECK-LABEL: @scf_if_capture_producer +util.func private @scf_if_capture_producer(%cond: i1) -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: scf.if + %cst_bc = scf.if %cond -> tensor<1xi32> { + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %cst_b = flow.tensor.transfer %cst_a : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + scf.yield %cst_b : tensor<1xi32> + // CHECK: else + } else { + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]] + %cst_c = flow.tensor.transfer %cst_a : tensor<1xi32> to #hal.device.promise<@dev_c> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]] + scf.yield %cst_c : tensor<1xi32> + // CHECK{LITERAL}: } {stream.affinities.results = [[#hal.device.promise<@dev_b>, #hal.device.promise<@dev_c>]] + } + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>, #hal.device.promise<@dev_c>]] + util.return %cst_bc : tensor<1xi32> +} + +// ----- + +// Tests scf.if returning unassigned consumer-placed operations has the affinity +// tracked across scf.yields and assigned based on the consumer. + +// CHECK-LABEL: @scf_if_consumer_yield +util.func private @scf_if_consumer_yield(%cond: i1) -> tensor<1xi32> { + // CHECK: scf.if + %cst = scf.if %cond -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_0 = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + scf.yield %cst_0 : tensor<1xi32> + // CHECK: else + } else { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_1 = flow.tensor.constant dense<456> : tensor<1xi32> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + scf.yield %cst_1 : tensor<1xi32> + // CHECK{LITERAL}: } {stream.affinities.results = [[#hal.device.promise<@dev_a>]] + } + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %cst_a : tensor<1xi32> +} + +// ----- + +// Tests that consumer-placed ops get placed based on their use in the body. + +// CHECK-LABEL: @scf_for_consumer_body_transfer +util.func private @scf_for_consumer_body_transfer() -> tensor<1xi32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: scf.for + %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg0 = %cst) -> tensor<1xi32> { + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %t = flow.dispatch @dispatch(%arg0_a) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + scf.yield %t : tensor<1xi32> + // CHECK{LITERAL}: } {stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + } + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %for : tensor<1xi32> +} + +// ----- + +// Tests that scf.for ops with transfers/explicit affinities on the edges get +// the + +// CHECK-LABEL: @scf_for_boundary_transfer +util.func private @scf_for_boundary_transfer() -> (tensor<1xi32>, tensor<1xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: scf.for + %for:2 = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg0 = %cst, %arg1 = %cst) -> (tensor<1xi32>, tensor<1xi32>) { + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %t = flow.dispatch @dispatch(%arg0_a) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + scf.yield %t, %arg1 : tensor<1xi32>, tensor<1xi32> + // CHECK{LITERAL}: } {stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>], [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + } + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %for_0_b = flow.tensor.transfer %for#0 : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %for_1_b = flow.tensor.transfer %for#1 : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>], [#hal.device.promise<@dev_b>]] + util.return %for_0_b, %for_1_b : tensor<1xi32>, tensor<1xi32> +} + +// ----- + +// Tests that transfers track through iter_args. + +// CHECK-LABEL: @scf_for_body_transfer +util.func private @scf_for_body_transfer() -> tensor<1xi32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: scf.for + %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg0 = %cst_a) -> tensor<1xi32> { + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %arg0_b = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %t = flow.dispatch @dispatch(%arg0_b) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + scf.yield %t : tensor<1xi32> + // CHECK{LITERAL}: } {stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + } + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]] + %for_c = flow.tensor.transfer %for : tensor<1xi32> to #hal.device.promise<@dev_c> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]] + util.return %for_c : tensor<1xi32> +} + +// ----- + +// Tests that placed values track through iter_args to consumers in scf.for +// bodies. + +// CHECK-LABEL: @scf_for_capture_producer +util.func private @scf_for_capture_producer() -> tensor<1xi32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: scf.for + %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg0 = %cst_a) -> tensor<1xi32> { + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %t = flow.dispatch @dispatch(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + scf.yield %t : tensor<1xi32> + // CHECK{LITERAL}: } {stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + } + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %for : tensor<1xi32> +} + +// ----- + +// Tests that consumer-placed ops get placed based on their use in the body. + +// CHECK-LABEL: @scf_while_consumer_body_transfer +util.func private @scf_while_consumer_body_transfer() -> tensor<1xi32> { + %c0 = arith.constant 0 : index + %c2_i32 = arith.constant 2 : i32 + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: scf.while + %while = scf.while(%arg0 = %cst) : (tensor<1xi32>) -> tensor<1xi32> { + // CHECK: flow.tensor.load + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + %cond_i32 = flow.tensor.load %arg0[%c0] : tensor<1xi32> + %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32 + // CHECK: scf.condition + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + scf.condition(%cond) %arg0 : tensor<1xi32> + } do { + ^bb0(%arg0: tensor<1xi32>): + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %t = flow.dispatch @dispatch(%arg0_a) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + scf.yield %t : tensor<1xi32> + // CHECK: } attributes { + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + } + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %while : tensor<1xi32> +} + +// ----- + +// Tests that consumer-placed ops get placed based on their use as the result +// of an scf.while body. + +// CHECK-LABEL: @scf_while_consumer_result_transfer +util.func private @scf_while_consumer_result_transfer() -> tensor<1xi32> { + %c0 = arith.constant 0 : index + %c2_i32 = arith.constant 2 : i32 + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst = flow.tensor.constant dense<123> : tensor<1xi32> + // CHECK: scf.while + %while = scf.while(%arg0 = %cst) : (tensor<1xi32>) -> tensor<1xi32> { + // CHECK: flow.tensor.load + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + %cond_i32 = flow.tensor.load %arg0[%c0] : tensor<1xi32> + %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32 + // CHECK: scf.condition + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + scf.condition(%cond) %arg0 : tensor<1xi32> + } do { + ^bb0(%arg0: tensor<1xi32>): + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %t = flow.dispatch @dispatch(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + scf.yield %t : tensor<1xi32> + // CHECK: } attributes { + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + } + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %while_a = flow.tensor.transfer %while : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %while_a : tensor<1xi32> +} + +// ----- + +// Tests that transfers track through scf.while bodies. + +// CHECK-LABEL: @scf_while_body_transfer +util.func private @scf_while_body_transfer() -> tensor<1xi32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2_i32 = arith.constant 2 : i32 + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: scf.while + %while = scf.while(%arg0 = %cst_a) : (tensor<1xi32>) -> tensor<1xi32> { + // CHECK: flow.tensor.load + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + %cond_i32 = flow.tensor.load %arg0[%c0] : tensor<1xi32> + %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32 + // CHECK: scf.condition + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + scf.condition(%cond) %arg0 : tensor<1xi32> + } do { + ^bb0(%arg0: tensor<1xi32>): + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %arg0_b = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %t = flow.dispatch @dispatch(%arg0_b) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + scf.yield %t : tensor<1xi32> + // CHECK: } attributes { + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + } + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]] + %while_c = flow.tensor.transfer %while : tensor<1xi32> to #hal.device.promise<@dev_c> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]] + util.return %while_c : tensor<1xi32> +} + +// ----- + +// Tests that placed values track through to consumers in scf.while conditions. + +// CHECK-LABEL: @scf_while_capture_producer_condition +util.func private @scf_while_capture_producer_condition() -> tensor<1xi32> { + %c0 = arith.constant 0 : index + %c2_i32 = arith.constant 2 : i32 + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: scf.while + %while = scf.while(%arg0 = %cst_a) : (tensor<1xi32>) -> tensor<1xi32> { + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: flow.tensor.load + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + %cond_i32 = flow.tensor.load %arg0_a[%c0] : tensor<1xi32> + %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32 + // CHECK: scf.condition + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + scf.condition(%cond) %arg0 : tensor<1xi32> + } do { + ^bb0(%arg0: tensor<1xi32>): + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %t = flow.dispatch @dispatch(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + scf.yield %t : tensor<1xi32> + // CHECK: } attributes { + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + } + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %while : tensor<1xi32> +} + +// ----- + +// Tests that placed values track through to consumers in scf.while bodies. + +// CHECK-LABEL: @scf_while_capture_producer_body +util.func private @scf_while_capture_producer_body() -> tensor<1xi32> { + %c0 = arith.constant 0 : index + %c2_i32 = arith.constant 2 : i32 + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32> + // CHECK: scf.while + %while = scf.while(%arg0 = %cst_a) : (tensor<1xi32>) -> tensor<1xi32> { + // CHECK: flow.tensor.load + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + %cond_i32 = flow.tensor.load %arg0[%c0] : tensor<1xi32> + %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32 + // CHECK: scf.condition + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + scf.condition(%cond) %arg0 : tensor<1xi32> + } do { + ^bb0(%arg0: tensor<1xi32>): + // CHECK: flow.dispatch @dispatch + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %t = flow.dispatch @dispatch(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: scf.yield + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + scf.yield %t : tensor<1xi32> + // CHECK: } attributes { + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + } + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %while : tensor<1xi32> +} + +// ----- + +// Tests a realistic program with ABI ops. + +// CHECK-LABEL: @simple_program +util.func public @simple_program(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view { + // CHECK: hal.tensor.import + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %0 = hal.tensor.import on(#hal.device.promise<@dev_a>) wait(%arg1) => %arg0 "input0" : !hal.buffer_view -> tensor<1xi32> + // CHECK: util.call @_simple_program + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %1 = util.call @_simple_program(%0) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %2 = flow.tensor.transfer %1 : tensor<1xi32> to #hal.device.promise<@dev_a> + // CHECK: hal.tensor.barrier + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %3 = hal.tensor.barrier join(%2 : tensor<1xi32>) => %arg2 : !hal.fence + // CHECK: hal.tensor.export + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + %4 = hal.tensor.export on(#hal.device.promise<@dev_a>) %3 "output0" : tensor<1xi32> -> !hal.buffer_view + util.return %4 : !hal.buffer_view +} +// CHECK: util.func private @_simple_program +util.func private @_simple_program(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: util.call @dispatch_a + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %0 = util.call @dispatch_a(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: flow.tensor.transfer + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %1 = flow.tensor.transfer %0 : tensor<1xi32> to #hal.device.promise<@dev_b> + // CHECK: util.call @dispatch_b + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %2 = util.call @dispatch_b(%1) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + util.return %2 : tensor<1xi32> +} +// CHECK: util.func private @dispatch_a +util.func private @dispatch_a(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %cst = flow.tensor.constant dense<[1]> : tensor<1xi32> + // CHECK: flow.dispatch @dispatch_a + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]] + %0 = flow.dispatch @dispatch_a(%arg0, %cst) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]] + util.return %0 : tensor<1xi32> +} +// CHECK: util.func private @dispatch_b +util.func private @dispatch_b(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: flow.tensor.constant + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %cst = flow.tensor.constant dense<[2]> : tensor<1xi32> + // CHECK: flow.dispatch @dispatch_b + // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>] + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>], [#hal.device.promise<@dev_b>]] + // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]] + %0 = flow.dispatch @dispatch_b(%arg0, %cst) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: util.return + // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]] + util.return %0 : tensor<1xi32> +} diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp index e014588b055f..bb46c833eb17 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp @@ -679,7 +679,8 @@ TraversalResult Explorer::walkOutgoingBranchOperandArguments( // traversal algorithm separated from the policy here. This would let us // reuse the traversal for other kinds of walks that are more specific (like // only getting the ops or values instead of both, etc). -TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn) { +TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn, + TraversalBehavior options) { // Fast-path short-circuit for constants, which are like 25% of all IR. if (value.getDefiningOp() && value.getDefiningOp()->hasTrait()) { @@ -856,15 +857,17 @@ TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn) { // If the op is tied we may need to walk up to the operand the result is // tied to. - if (auto tiedOp = dyn_cast(definingOp)) { - auto tiedOperand = tiedOp.getTiedResultOperand(resultValue); - if (tiedOperand) { - LLVM_DEBUG({ - llvm::dbgs() << " + queuing tied operand "; - tiedOperand.printAsOperand(llvm::dbgs(), asmState); - llvm::dbgs() << "\n"; - }); - worklist.insert(tiedOperand); + if (!bitEnumContains(options, TraversalBehavior::DONT_WALK_TIED_VALUES)) { + if (auto tiedOp = dyn_cast(definingOp)) { + auto tiedOperand = tiedOp.getTiedResultOperand(resultValue); + if (tiedOperand) { + LLVM_DEBUG({ + llvm::dbgs() << " + queuing tied operand "; + tiedOperand.printAsOperand(llvm::dbgs(), asmState); + llvm::dbgs() << "\n"; + }); + worklist.insert(tiedOperand); + } } } @@ -891,7 +894,8 @@ TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn) { return result; } -TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn) { +TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn, + TraversalBehavior options) { LLVM_DEBUG(llvm::dbgs() << "[[ Explorer::walkTransitiveUses ]]\n"); TraversalResult result = TraversalResult::COMPLETE; @@ -1090,15 +1094,17 @@ TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn) { // If the op is tied we may need to walk down to the results the operand // is tied to (multiple results can tie the same operand). - if (auto tiedOp = dyn_cast(ownerOp)) { - for (auto tiedResult : - tiedOp.getOperandTiedResults(use.getOperandNumber())) { - LLVM_DEBUG({ - llvm::dbgs() << " + queuing tied result "; - tiedResult.printAsOperand(llvm::dbgs(), asmState); - llvm::dbgs() << "\n"; - }); - worklist.insert(tiedResult); + if (!bitEnumContains(options, TraversalBehavior::DONT_WALK_TIED_VALUES)) { + if (auto tiedOp = dyn_cast(ownerOp)) { + for (auto tiedResult : + tiedOp.getOperandTiedResults(use.getOperandNumber())) { + LLVM_DEBUG({ + llvm::dbgs() << " + queuing tied result "; + tiedResult.printAsOperand(llvm::dbgs(), asmState); + llvm::dbgs() << "\n"; + }); + worklist.insert(tiedResult); + } } } @@ -1149,14 +1155,18 @@ TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn) { return result; } -TraversalResult Explorer::walkTransitiveUsers(Value value, OperationWalkFn fn) { +TraversalResult Explorer::walkTransitiveUsers(Value value, OperationWalkFn fn, + TraversalBehavior options) { DenseSet visitedOwners; - return walkTransitiveUses(value, [&](OpOperand &use) { - if (visitedOwners.insert(use.getOwner()).second) { - return fn(use.getOwner()); - } - return WalkResult::advance(); - }); + return walkTransitiveUses( + value, + [&](OpOperand &use) { + if (visitedOwners.insert(use.getOwner()).second) { + return fn(use.getOwner()); + } + return WalkResult::advance(); + }, + options); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h index 1e975be96937..35ee12aa822b 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h @@ -37,6 +37,31 @@ enum class TraversalAction { IGNORE, }; +enum class TraversalBehavior : uint32_t { + // When traversing defining ops any tied result will move through its tied + // operand. When traversing uses any tied operand will move through its tied + // results (as many as are tied to the operand). + DEFAULT = 0u, + // Don't traverse through tied operands or results. + DONT_WALK_TIED_VALUES = 1 << 0u, +}; +inline TraversalBehavior operator~(TraversalBehavior value) { + return static_cast(~static_cast(value)); +} +inline TraversalBehavior operator|(TraversalBehavior lhs, + TraversalBehavior rhs) { + return static_cast(static_cast(lhs) | + static_cast(rhs)); +} +inline TraversalBehavior operator&(TraversalBehavior lhs, + TraversalBehavior rhs) { + return static_cast(static_cast(lhs) & + static_cast(rhs)); +} +inline bool bitEnumContains(TraversalBehavior bits, TraversalBehavior bit) { + return (static_cast(bits) & static_cast(bit)) != 0; +} + // Boolean operations on TraversalResult behave as though `INCOMPLETE` is // truthy to allow for |='ing results. enum class TraversalResult { @@ -313,7 +338,9 @@ class Explorer { // Walk %2: [%2 of producer.b] // Walk @some_user::%arg0: [%0 of producer.a] // Walk @some_user::ret0: [%2 of producer.b] - TraversalResult walkDefiningOps(Value value, ResultWalkFn fn); + TraversalResult + walkDefiningOps(Value value, ResultWalkFn fn, + TraversalBehavior options = TraversalBehavior::DEFAULT); // Randomly walks uses of |value| and any transitive alias of |value|. // The uses may come from any part of the program. @@ -334,13 +361,17 @@ class Explorer { // Walk %arg0: [%arg0 of producer.a] // Walk %0: [%0 of call @some_user, %arg0 of producer.b] // Walk %2: [%2 of return, %1 of return] - TraversalResult walkTransitiveUses(Value value, UseWalkFn fn); + TraversalResult + walkTransitiveUses(Value value, UseWalkFn fn, + TraversalBehavior options = TraversalBehavior::DEFAULT); // Randomly walks uses of |value| and any transitive alias of |value| and // returns each owner operation once. As a value may be used multiple times // by a single operation this is equivalent to a walkTransitiveUses with // deduplication on the owner of the use. - TraversalResult walkTransitiveUsers(Value value, OperationWalkFn fn); + TraversalResult + walkTransitiveUsers(Value value, OperationWalkFn fn, + TraversalBehavior options = TraversalBehavior::DEFAULT); private: // Maps callee callable region -> call sites. diff --git a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp index b82d59968290..ab1adf05eabc 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp @@ -18,6 +18,35 @@ namespace mlir::iree_compiler { namespace { +template +struct OptionalOpAffinityAttrExternalModel + : public IREE::Stream::AffinityOpInterface::ExternalModel< + OptionalOpAffinityAttrExternalModel, OpT> { + static void add(MLIRContext *context) { + OpT::template attachInterface>( + *context); + } + + // Affinity only required for results that hold resources that + // require placement. + bool requiresAffinity(Operation *op) const { + auto resultType = cast(op).getResult().getType(); + return isa(resultType); + } + + IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const { + return op->getAttrOfType("stream.affinity"); + } + + void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const { + if (value) { + op->setAttr("stream.affinity", value); + } else { + op->removeAttr("stream.affinity"); + } + } +}; + struct FlowTransferTargetAffinityAttrExternalModel : public IREE::Stream::AffinityOpInterface::ExternalModel< FlowTransferTargetAffinityAttrExternalModel, @@ -29,11 +58,11 @@ struct FlowTransferTargetAffinityAttrExternalModel bool requiresAffinity(Operation *op) const { return true; } - IREE::Stream::AffinityAttr getAffinity(Operation *op) const { + IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const { return op->getAttrOfType("target"); } - void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const { + void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const { op->setAttr("target", value); } }; @@ -49,12 +78,14 @@ struct HALTensorAffinityAttrExternalModel bool requiresAffinity(Operation *op) const { return false; } - IREE::Stream::AffinityAttr getAffinity(Operation *op) const { + bool pinsValueAffinity(Operation *op) const { return true; } + + IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const { return op->getAttrOfType("affinity"); } - void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const { - if (value) + void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const { + if (value) { op->setAttr("affinity", value); } else { op->removeAttr("affinity"); @@ -78,12 +109,12 @@ struct GlobalOpAffinityAttrExternalModel return isa(globalType); } - IREE::Stream::AffinityAttr getAffinity(Operation *op) const { + IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const { return op->getAttrOfType("stream.affinity"); } - void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const { - if (value) + void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const { + if (value) { op->setAttr("stream.affinity", value); } else { op->removeAttr("stream.affinity"); @@ -91,7 +122,7 @@ struct GlobalOpAffinityAttrExternalModel } }; -template +template struct AffinityOpAttrExternalModel : public IREE::Stream::AffinityOpInterface::ExternalModel< AffinityOpAttrExternalModel, OpT> { @@ -102,14 +133,14 @@ struct AffinityOpAttrExternalModel // Most structural ops don't require affinities and after placement we don't // use the affinities even if the ops still exist. - bool requiresAffinity(Operation *op) const { return false; } + bool requiresAffinity(Operation *op) const { return kRequiresAffinity; } - IREE::Stream::AffinityAttr getAffinity(Operation *op) const { + IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const { return op->getAttrOfType("stream.affinity"); } - void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const { - if (value) + void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const { + if (value) { op->setAttr("stream.affinity", value); } else { op->removeAttr("stream.affinity"); @@ -117,32 +148,71 @@ struct AffinityOpAttrExternalModel } }; +struct TensorAffinityTypeExternalModel + : public IREE::Stream::AffinityTypeInterface::ExternalModel< + TensorAffinityTypeExternalModel, RankedTensorType> { + static void add(MLIRContext *context) { + RankedTensorType::attachInterface( + *context); + } +}; + } // namespace void registerStreamExternalModels(DialectRegistry ®istry) { - registry.insert(); + registry.addExtension(+[](MLIRContext *context) { + TensorAffinityTypeExternalModel::add(context); + }); + + registry.insert(); registry.addExtension( - +[](MLIRContext *context, IREE::Flow::FlowDialect *dialect) { - FlowTransferTargetAffinityAttrExternalModel::add(context); + +[](MLIRContext *context, arith::ArithDialect *dialect) { + OptionalOpAffinityAttrExternalModel::add(context); }); + registry.insert(); + registry.addExtension(+[](MLIRContext *context, + IREE::Flow::FlowDialect *dialect) { + FlowTransferTargetAffinityAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add( + context); + AffinityOpAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add( + context); + AffinityOpAttrExternalModel::add( + context); + AffinityOpAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add( + context); + AffinityOpAttrExternalModel::add(context); + }); + registry.insert(); registry.addExtension(+[](MLIRContext *context, IREE::HAL::HALDialect *dialect) { HALTensorAffinityAttrExternalModel::add(context); HALTensorAffinityAttrExternalModel::add(context); HALTensorAffinityAttrExternalModel::add(context); - HALTensorAffinityAttrExternalModel::add( - context); }); registry.insert(); - registry.addExtension( - +[](MLIRContext *context, IREE::Util::UtilDialect *dialect) { - GlobalOpAffinityAttrExternalModel::add(context); - AffinityOpAttrExternalModel::add(context); - AffinityOpAttrExternalModel::add(context); - }); + registry.addExtension(+[](MLIRContext *context, + IREE::Util::UtilDialect *dialect) { + GlobalOpAffinityAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add(context); + AffinityOpAttrExternalModel::add(context); + }); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index 81e204e18fdb..13bc6bf4a6e6 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -57,6 +57,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/HAL/Analysis", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", + "//compiler/src/iree/compiler/Dialect/Stream/Analysis", "//compiler/src/iree/compiler/Dialect/Stream/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", "@llvm-project//llvm:Support", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index 723dacc1595b..3764d492fa61 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -68,6 +68,7 @@ iree_cc_library( iree::compiler::Dialect::HAL::Analysis iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::LinalgExt::IR + iree::compiler::Dialect::Stream::Analysis iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::Util::IR PUBLIC diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp index a5e1a86f89c7..ba415b3fb656 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp @@ -12,6 +12,7 @@ #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h" #include "iree/compiler/Preprocessing/Common/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -25,7 +26,7 @@ namespace mlir::iree_compiler::Preprocessing { -#define GEN_PASS_DEF_PADTOINTRINSICS +#define GEN_PASS_DEF_PADTOINTRINSICSPASS #include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export namespace { @@ -533,7 +534,7 @@ static void padContractionLikeOp( } struct PadToIntrinsicsPass - : public impl::PadToIntrinsicsBase { + : public impl::PadToIntrinsicsPassBase { using Base::Base; void runOnOperation() override; }; @@ -544,10 +545,15 @@ void PadToIntrinsicsPass::runOnOperation() { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - auto funcOp = getOperation(); - IREE::HAL::DeviceAnalysis deviceAnalysis(funcOp->getParentOp()); - if (failed(deviceAnalysis.run())) + auto moduleOp = getOperation(); + IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp); + if (failed(affinityAnalysis.run())) { return signalPassFailure(); + } + IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) { + return signalPassFailure(); + } bool padConvOps = padTargetType == PadTargetType::ConvOp || padTargetType == PadTargetType::All; @@ -555,37 +561,46 @@ void PadToIntrinsicsPass::runOnOperation() { padTargetType == PadTargetType::All; SmallVector targetConvOps; SmallVector targetContractOps; - funcOp.walk([&](linalg::LinalgOp linalgOp) { - if (isa(linalgOp.getOperation()) && padConvOps) { - // Add convOps into worklist. - targetConvOps.push_back(linalgOp); - } else if (isa(linalgOp.getOperation()) && - padContractionOps) { - // Add named contractionOps into worklist. - targetContractOps.push_back(linalgOp); - } else if (isa(linalgOp.getOperation()) && - linalg::isaContractionOpInterface(linalgOp) && - padContractionOps) { - // Add named generic contractionOps into worklist. - targetContractOps.push_back(linalgOp); - } - }); + for (auto funcOp : moduleOp.getOps()) { + funcOp.walk([&](linalg::LinalgOp linalgOp) { + if (isa(linalgOp.getOperation()) && + padConvOps) { + targetConvOps.push_back(linalgOp); + } else if (isa(linalgOp.getOperation()) && + padContractionOps) { + targetContractOps.push_back(linalgOp); + } else if (isa(linalgOp.getOperation()) && + linalg::isaContractionOpInterface(linalgOp) && + padContractionOps) { + targetContractOps.push_back(linalgOp); + } + }); + } // Iterate through and pad ops in the worklists. + auto getRequiredExecutableTargetAttrs = [&](Operation *op) { + SetVector executableTargetAttrs; + SmallVector affinityAttrs; + if (affinityAnalysis.tryInferExecutionAffinity(op, affinityAttrs)) { + for (auto affinityAttr : affinityAttrs) { + deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, op, + executableTargetAttrs); + } + } + return executableTargetAttrs; + }; IRRewriter rewriter(context); for (auto convOp : targetConvOps) { rewriter.setInsertionPoint(convOp); - SetVector executableTargets; - deviceAnalysis.gatherRequiredExecutableTargets(convOp, executableTargets); - padConvOp(rewriter, convOp, executableTargets.getArrayRef()); + auto executableTargetAttrs = getRequiredExecutableTargetAttrs(convOp); + padConvOp(rewriter, convOp, executableTargetAttrs.getArrayRef()); } for (auto contractOp : targetContractOps) { rewriter.setInsertionPoint(contractOp); - SetVector executableTargets; - deviceAnalysis.gatherRequiredExecutableTargets(contractOp, - executableTargets); - padContractionLikeOp(rewriter, contractOp, executableTargets.getArrayRef()); + auto executableTargetAttrs = getRequiredExecutableTargetAttrs(contractOp); + padContractionLikeOp(rewriter, contractOp, + executableTargetAttrs.getArrayRef()); } } diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index edc17057b55b..ca29a52d39bb 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -84,8 +84,8 @@ def MakeSingleDispatchForFunctionPass : ]; } -def PadToIntrinsics : - InterfacePass<"iree-preprocessing-pad-to-intrinsics", "mlir::FunctionOpInterface"> { +def PadToIntrinsicsPass : + Pass<"iree-preprocessing-pad-to-intrinsics", "ModuleOp"> { let summary = "Pad linalg ops such that we can use target's intrinsics."; let dependentDialects = [ "mlir::linalg::LinalgDialect", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir index 5761741f3787..7d9da4586c21 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir @@ -1,6 +1,6 @@ -// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics,canonicalize))" | FileCheck %s -// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv},canonicalize))" | FileCheck %s -check-prefix=CONVOLUTION -// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=contraction},canonicalize))" | FileCheck %s -check-prefix=CONTRACT +// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics,func.func(canonicalize))" | FileCheck %s +// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv},func.func(canonicalize))" | FileCheck %s -check-prefix=CONVOLUTION +// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics{pad-target-type=contraction},func.func(canonicalize))" | FileCheck %s -check-prefix=CONTRACT // CHECK-LABEL: func.func @main0( diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir index f7f832859c0f..ece028330df2 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir @@ -1,6 +1,6 @@ -// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics,canonicalize))" | FileCheck %s -// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv},canonicalize))" | FileCheck %s -check-prefix=CONVOLUTION -// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=contraction},canonicalize))" | FileCheck %s -check-prefix=CONTRACT +// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics,func.func(canonicalize))" | FileCheck %s +// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv},func.func(canonicalize))" | FileCheck %s -check-prefix=CONVOLUTION +// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics{pad-target-type=contraction},func.func(canonicalize))" | FileCheck %s -check-prefix=CONTRACT // CHECK: func.func @matmul_static( // CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xf16>, From 866c0c0db093e7f7635e4c352aa2170e33d1705d Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 28 May 2024 22:30:19 -0700 Subject: [PATCH 22/25] Changing stream conversion to use a value/op affinity analysis. This reworks some of the prior stack to support transfer ops and analysis to determine the placement of ops for execution and resource control. --- .../Torch/InputConversion/FuncConversion.cpp | 89 ++- .../InputConversion/test/func_conversion.mlir | 31 + .../Native/Transforms/WrapEntryPoints.cpp | 121 ++-- .../Transforms/test/wrap_entry_points.mlir | 26 +- .../test/wrap_entry_points_coarse_fences.mlir | 34 + .../compiler/Codegen/Common/CPU/BUILD.bazel | 1 + .../Codegen/Common/CPU/CMakeLists.txt | 1 + .../Common/CPU/CPUMaterializeEncodings.cpp | 74 ++- .../Dialect/HAL/Analysis/BindingLayout.cpp | 9 + .../Dialect/HAL/Analysis/BindingLayout.h | 3 + .../HAL/Conversion/HALToHAL/Patterns.cpp | 90 +++ .../HAL/Conversion/HALToHAL/test/BUILD.bazel | 5 +- .../Conversion/HALToHAL/test/CMakeLists.txt | 1 + .../Conversion/HALToHAL/test/device_ops.mlir | 75 +++ .../HALToVM/ConvertExecutableOps.cpp | 40 +- .../Dialect/HAL/Conversion/HALToVM/Patterns.h | 5 - .../HALToVM/test/executable_ops.mlir | 15 +- .../HAL/Conversion/StreamToHAL/Patterns.cpp | 64 +- .../HAL/Conversion/StreamToHAL/Utils.cpp | 8 +- .../StreamToHAL/test/context_ops.mlir | 61 +- .../iree/compiler/Dialect/HAL/IR/HALAttrs.cpp | 34 +- .../iree/compiler/Dialect/HAL/IR/HALAttrs.td | 4 +- .../compiler/Dialect/HAL/IR/HALOpFolders.cpp | 10 +- .../iree/compiler/Dialect/HAL/IR/HALOps.cpp | 48 +- .../iree/compiler/Dialect/HAL/IR/HALOps.td | 52 +- .../Dialect/HAL/Transforms/ConvertToHAL.cpp | 9 +- .../Transforms/DumpExecutableBenchmarks.cpp | 46 +- .../HAL/Transforms/MaterializeInterfaces.cpp | 12 + .../Transforms/MaterializeTargetDevices.cpp | 54 +- .../Dialect/HAL/Transforms/Passes.cpp | 9 + .../Dialect/HAL/Transforms/VerifyDevices.cpp | 31 +- .../test/assign_target_devices.mlir | 2 +- .../test/materialize_target_devices.mlir | 29 +- .../HAL/Transforms/test/verify_devices.mlir | 17 +- .../Dialect/Stream/Conversion/BUILD.bazel | 1 + .../Dialect/Stream/Conversion/CMakeLists.txt | 1 + .../Conversion/FlowToStream/Patterns.cpp | 585 ++++++++++-------- .../Stream/Conversion/FlowToStream/Patterns.h | 20 +- .../FlowToStream/test/dispatch_ops.mlir | 6 +- .../FlowToStream/test/tensor_ops.mlir | 4 +- .../Conversion/HALToStream/Patterns.cpp | 134 ++-- .../Stream/Conversion/HALToStream/Patterns.h | 20 +- .../Stream/Conversion/PatternUtils.cpp | 101 ++- .../Dialect/Stream/Conversion/PatternUtils.h | 118 +++- .../Conversion/StandardToStream/BUILD.bazel | 2 - .../StandardToStream/CMakeLists.txt | 2 - .../StandardToStream/ConvertConstantOps.cpp | 66 -- .../StandardToStream/ConvertStructuralOps.cpp | 406 ------------ .../Conversion/StandardToStream/Patterns.cpp | 443 ++++++++++++- .../Conversion/StandardToStream/Patterns.h | 8 +- .../Conversion/UtilToStream/Patterns.cpp | 93 ++- .../Stream/Conversion/UtilToStream/Patterns.h | 20 +- .../Stream/Transforms/ConvertToStream.cpp | 236 +++---- .../Stream/Transforms/ElideAsyncCopies.cpp | 3 +- .../Stream/Transforms/EmplaceAllocations.cpp | 24 +- .../materialize_homogeneous_encodings.mlir | 4 +- .../StreamToParams/test/parameter_ops.mlir | 42 +- .../src/iree/compiler/Pipelines/Pipelines.cpp | 24 +- tools/test/compile_pipelines.mlir | 4 +- tools/test/compile_to_continuation.mlir | 50 +- tools/test/compile_to_phase.mlir | 24 +- 61 files changed, 2094 insertions(+), 1457 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/device_ops.mlir delete mode 100644 compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp delete mode 100644 compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp index b8a630d454b3..553dbdeecf5d 100644 --- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp +++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp @@ -124,9 +124,17 @@ enum class TypeDisposition { FENCE, }; +struct BarrierResult { + BlockArgument storage; + Type torchType; + int returnIndex = -1; +}; + struct ConvertedAsyncFunctionInfo { IREE::Util::FuncOp funcOp; SmallVector returnOps; + SmallVector torchArgAttrs; + SmallVector torchResultAttrs; SmallVector torchInputTypes; SmallVector torchResultTypes; SmallVector inputDispositions; @@ -136,7 +144,7 @@ struct ConvertedAsyncFunctionInfo { // Values that must be captured in the coarse barrier. SmallVector barrierInputs; // Meta data per barrier input: storage, torchType, returnIndex (or -1) - SmallVector> barrierResultMeta; + SmallVector barrierResultMeta; LogicalResult postProcess(); LogicalResult convertImmutableTensorArg(BlockArgument argValue, @@ -144,10 +152,25 @@ struct ConvertedAsyncFunctionInfo { LogicalResult convertMutableTensorArg(BlockArgument argValue, Type torchType, OpBuilder &builder); - void addBarrierInput(Value inputTensor, Value storage, Type torchType, + void addBarrierInput(Value inputTensor, BlockArgument storage, Type torchType, int returnIndex) { barrierInputs.push_back(inputTensor); - barrierResultMeta.emplace_back(storage, torchType, returnIndex); + barrierResultMeta.emplace_back(BarrierResult{ + storage, + torchType, + returnIndex, + }); + } + + Attribute getTorchArgAttr(BlockArgument argValue, StringRef attrName) { + return torchArgAttrs.empty() + ? Attribute{} + : torchArgAttrs[argValue.getArgNumber()].get(attrName); + } + Attribute getTorchResultAttr(int returnIndex, StringRef attrName) { + return torchResultAttrs.empty() + ? Attribute{} + : torchResultAttrs[returnIndex].get(attrName); } }; @@ -232,7 +255,8 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() { } if (needsBarrier) { Value source = convertToBuiltinTensor(postambleBuilder, returnValue); - addBarrierInput(source, /*storage=*/Value{}, torchType, returnIndex); + addBarrierInput(source, /*storage=*/BlockArgument{}, torchType, + returnIndex); } break; } @@ -276,15 +300,13 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() { SmallVector aliasedResults; for (auto [barrierInput, meta] : llvm::zip_equal(barrierInputs, barrierResultMeta)) { - Value exportStorage; - Type torchType; - int returnIndex; - std::tie(exportStorage, torchType, returnIndex) = meta; - if (exportStorage) { + if (meta.storage) { // Use the wait fence indicating when the storage is available for // mutation. We need to ensure that no writes are made to the storage // until it indicates it's safe to do so. - auto waitSignalFences = getEnclosingWaitSignalFences(exportStorage); + auto storageAffinityAttr = + getTorchArgAttr(meta.storage, "iree.abi.affinity"); + auto waitSignalFences = getEnclosingWaitSignalFences(meta.storage); assert(waitSignalFences && "async function missing fences"); Value waitFence = waitSignalFences->first; auto barrierInputDims = IREE::Util::buildDynamicDimsForValue( @@ -292,28 +314,30 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() { aliasedResults.push_back( postambleBuilder.create( barrierInput.getLoc(), barrierInput.getType(), barrierInput, - barrierInputDims, exportStorage, waitFence, - /*affinity=*/nullptr)); + barrierInputDims, meta.storage, waitFence, + storageAffinityAttr)); } else { aliasedResults.push_back(barrierInput); } } auto barrierOp = postambleBuilder.create( - funcOp.getLoc(), aliasedResults, coarseSignalFence, - /*affinity=*/nullptr); + funcOp.getLoc(), aliasedResults, coarseSignalFence); for (auto [barrierResult, meta] : llvm::zip_equal(barrierOp.getResults(), barrierResultMeta)) { - Value exportStorage; - Type torchType; - int returnIndex; - std::tie(exportStorage, torchType, returnIndex) = meta; + Attribute exportAffinityAttr; + if (meta.storage) { + exportAffinityAttr = getTorchArgAttr(meta.storage, "iree.abi.affinity"); + } else if (meta.returnIndex >= 0) { + exportAffinityAttr = + getTorchResultAttr(meta.returnIndex, "iree.abi.affinity"); + } Value exportedValue = postambleBuilder.create( funcOp.getLoc(), postambleBuilder.getType(), barrierResult, TypeAttr::get(barrierResult.getType()), /*name=*/nullptr, - /*affinity=*/nullptr); - if (returnIndex >= 0) { - newReturnOperands[returnIndex] = exportedValue; + exportAffinityAttr); + if (meta.returnIndex >= 0) { + newReturnOperands[meta.returnIndex] = exportedValue; } } } @@ -377,14 +401,16 @@ LogicalResult ConvertedAsyncFunctionInfo::convertImmutableTensorArg( << torchType; } + // Propagate explicit affinities to the read. + auto affinityAttr = getTorchArgAttr(argValue, "iree.abi.affinity"); + auto waitSignalFences = getEnclosingWaitSignalFences(argValue); assert(waitSignalFences && "async function missing fences"); Value waitFence = waitSignalFences->first; Value importedTensor = builder.create( loc, builtinTensorType, argValue, TypeAttr::get(builtinTensorType), waitFence, - /*name=*/nullptr, - /*affinity=*/nullptr); + /*name=*/nullptr, affinityAttr); if (builtinTensorType != torchType) { importedTensor = builder.create( loc, torchType, importedTensor); @@ -408,6 +434,9 @@ LogicalResult ConvertedAsyncFunctionInfo::convertMutableTensorArg( .toBuiltinTensor(); } + // Propagate explicit affinities to the read and write. + auto affinityAttr = getTorchArgAttr(argValue, "iree.abi.affinity"); + // There are only a small set of possible users of a mutable tensor. // Handle them by operation here. SmallVector users(argValue.getUsers()); @@ -419,8 +448,7 @@ LogicalResult ConvertedAsyncFunctionInfo::convertMutableTensorArg( loc, builtinTensorType, argValue, /*target_encoding=*/TypeAttr::get(builtinTensorType), /*wait_fence*/ fences->first, - /*name=*/nullptr, - /*affinity=*/nullptr); + /*name=*/nullptr, affinityAttr); rewriter.replaceOpWithNewOp( userOp, copyToVtOp.getResult().getType(), imported); } else if (auto overwriteOp = @@ -444,7 +472,6 @@ void retainFunctionAttributes(Operation *srcOp, IREE::Util::FuncOp destOp) { // Allowlist of function attributes to retain when importing funcs. constexpr const char *kRetainedAttributes[] = { "iree.reflection", - "stream.affinity", }; auto retainedAttributes = ArrayRef( kRetainedAttributes, @@ -476,6 +503,9 @@ void createCoarseFencesSyncWrapper(StringRef syncFunctionName, syncFuncOp.setSymVisibilityAttr(asyncFuncOp.getSymVisibilityAttr()); retainFunctionAttributes(asyncFuncOp, syncFuncOp); syncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr()); + if (auto affinityAttr = asyncFuncOp->getAttr("iree.abi.affinity")) { + syncFuncOp->setAttr("iree.abi.affinity", affinityAttr); + } Block *entryBlock = syncFuncOp.addEntryBlock(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToEnd(entryBlock); @@ -584,6 +614,10 @@ struct FuncConversionPass : public FuncConversionBase { asyncFunctionName.append("$async"); } + // Stash arg/result attrs so they can be referenced during conversion. + torchFunc.getAllArgAttrs(convertedFuncInfo.torchArgAttrs); + torchFunc.getAllResultAttrs(convertedFuncInfo.torchResultAttrs); + // Convert function signature. Type fenceType = rewriter.getType(); FunctionType torchFuncType = torchFunc.getFunctionType(); @@ -644,6 +678,9 @@ struct FuncConversionPass : public FuncConversionBase { asyncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr()); asyncFuncOp->setAttr("iree.abi.model", rewriter.getStringAttr("coarse-fences")); + if (auto affinityAttr = torchFunc->getAttr("iree.abi.affinity")) { + asyncFuncOp->setAttr("iree.abi.affinity", affinityAttr); + } rewriter.inlineRegionBefore( torchFunc.getBody(), asyncFuncOp.getFunctionBody(), asyncFuncOp.end()); diff --git a/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir b/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir index 3e167ad7ba56..3bca01bb6dec 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir @@ -110,6 +110,37 @@ func.func @main(%arg0: !torch.tensor<[5,4],f32>) -> (!torch.vtensor<[5,4],f32>) } } +// ----- +// Tests the immutable + mutable argument case with explicit affinities. +// CHECK-LABEL: @mutable_input_overwrite_no_return +// CHECK: util.func public @main$async( +// CHECK-SAME: %arg0: !hal.buffer_view, %arg1: !hal.buffer_view, +// CHECK-SAME: %arg2: !hal.fence, %arg3: !hal.fence) -> !hal.buffer_view +// CHECK-DAG: %[[WAIT_ARG0:.+]] = hal.tensor.import on(#hal.device.promise<@dev_a>) wait(%arg2) => %arg0 +// CHECK-DAG: %[[TORCH_ARG0:.+]] = torch_c.from_builtin_tensor %[[WAIT_ARG0]] +// CHECK-DAG: %[[WAIT_ARG1:.+]] = hal.tensor.import on(#hal.device.promise<@dev_b>) wait(%arg2) => %arg1 +// CHECK-DAG: %[[TORCH_ARG1:.+]] = torch_c.from_builtin_tensor %[[WAIT_ARG1]] +// CHECK-DAG: %[[TORCH_RESULT0:.+]] = torch.operator "other_calc"(%[[TORCH_ARG0]]) +// CHECK-DAG: %[[TORCH_RESULT1:.+]] = torch.operator "mutate_inplace"(%[[TORCH_ARG1]]) +// CHECK-DAG: %[[TENSOR_ARG0:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT0]] +// CHECK-DAG: %[[TENSOR_ARG1:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT1]] +// CHECK: %[[EXPORT_ALIAS1:.+]] = hal.tensor.alias on(#hal.device.promise<@dev_b>) wait(%arg2) => %[[TENSOR_ARG1]] : tensor<5x4xf32> to %arg1 : !hal.buffer_view +// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%[[EXPORT_ALIAS1]], %[[TENSOR_ARG0]] : tensor<5x4xf32>, tensor<4x5xi32>) => %arg3 : !hal.fence +// CHECK-DAG: %[[EXPORT_RESULT0:.+]] = hal.tensor.export on(#hal.device.promise<@dev_b>) %[[BARRIER_RESULTS]]#0 +// CHECK-DAG: %[[EXPORT_RESULT1:.+]] = hal.tensor.export on(#hal.device.promise<@dev_a>) %[[BARRIER_RESULTS]]#1 +// CHECK: util.return %[[EXPORT_RESULT1]] +builtin.module @mutable_input_overwrite_no_return_affinities { +func.func @main(%arg0: !torch.vtensor<[4,5],si32> {iree.abi.affinity = #hal.device.promise<@dev_a>}, + %arg1: !torch.tensor<[5,4],f32> {iree.abi.affinity = #hal.device.promise<@dev_b>}) + -> (!torch.vtensor<[4,5],si32> {iree.abi.affinity = #hal.device.promise<@dev_a>}) { + %0 = torch.copy.to_vtensor %arg1 : !torch.vtensor<[5,4],f32> + %1 = torch.operator "mutate_inplace"(%0) : (!torch.vtensor<[5,4],f32>) -> !torch.vtensor<[5,4],f32> + %2 = torch.operator "other_calc"(%arg0) : (!torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32> + torch.overwrite.tensor.contents %1 overwrites %arg1 : !torch.vtensor<[5,4],f32>, !torch.tensor<[5,4],f32> + return %2 : !torch.vtensor<[4,5],si32> +} +} + // ----- // CHECK-LABEL: @retained_attribute_reflection // CHECK: util.func public @main$async( diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp index e91bc54b1f0c..577cf52d67db 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp @@ -43,16 +43,22 @@ static Type mapToABIType(Type type) { return type; } +// Returns true if the given |attr| is a known ABI attribute that is only used +// by this pass. +static bool isABIAttr(NamedAttribute attr) { + return attr.getName() == "iree.abi.affinity" || + attr.getName() == "iree.abi.encoding" || + attr.getName() == "iree.abi.model" || + attr.getName() == "iree.abi.output"; +} + // Removes all ABI attrs handled by this pass from all dictionaries. static void stripABIAttrs(SmallVectorImpl &allAttrs) { for (auto &attrDict : allAttrs) { SmallVector attrs; attrs.reserve(attrDict.size()); for (auto attr : attrDict) { - // TODO(benvanik): faster lookup. - if (attr.getName() != "iree.abi.output" && - attr.getName() != "iree.abi.encoding" && - attr.getName() != "iree.abi.affinity") { + if (!isABIAttr(attr)) { attrs.push_back(attr); } } @@ -60,7 +66,16 @@ static void stripABIAttrs(SmallVectorImpl &allAttrs) { } } +// Removes all ABI attrs from the |op| and its args/results. static void stripABIAttrs(FunctionOpInterface op) { + NamedAttrList attrs; + for (auto attr : op->getAttrs()) { + if (!isABIAttr(attr)) { + attrs.push_back(attr); + } + } + op->setAttrs(attrs); + SmallVector argAttrs; op.getAllArgAttrs(argAttrs); stripABIAttrs(argAttrs); @@ -71,6 +86,11 @@ static void stripABIAttrs(FunctionOpInterface op) { op.setAllResultAttrs(resultAttrs); } +template +static T fallback(T optionalValue, T defaultValue) { + return optionalValue ? optionalValue : defaultValue; +} + // Creates the corresponding wrapper function for the given import function. static IREE::Util::FuncOp createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel, @@ -101,12 +121,7 @@ createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel, argAttrDict.push_back(nullptr); // signal break; } - - // Update the import type and propagate back the attributes we may have - // modified above. importOp.setType(newImportType); - importOp.setAllArgAttrs(argAttrDict); - importOp.setAllResultAttrs(resultAttrDict); auto *entryBlock = wrapperOp.addEntryBlock(); auto entryBuilder = OpBuilder::atBlockBegin(entryBlock); @@ -129,6 +144,12 @@ createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel, // users mark their functions 'nosideeffects' to avoid the host wait. const bool hasSideEffects = !importOp->hasAttr("nosideeffects"); + // Fetch and normalize any explicitly assigned affinity. + auto defaultAffinityAttr = importOp->getAttr("iree.abi.affinity"); + if (defaultAffinityAttr) { + importOp->setAttr("stream.affinity", defaultAffinityAttr); + } + // When running async we insert a barrier on tensor arguments and attach that // to the fence we pass to the import for waiting. We'll also allocate the // signal fence that the import must signal when the returned tensors are @@ -141,15 +162,24 @@ createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel, // No fences. break; case IREE::ABI::InvocationModel::CoarseFences: { - // HACK: this is relying on the fact that there's only one HAL device. - // We should instead have a way of creating fences on the device that - // is used to produce the tensors we're wrapping. - // - // TODO(multi-device): emit get with derived ordinal or lookup with attr. We - // could always say device 0 for now but could instead look for an - // iree.abi.affinity/iree.abi.device/etc. - Value device = - IREE::HAL::DeviceType::resolveAny(importOp.getLoc(), entryBuilder); + Value device; + // TODO(benvanik): support other affinity types. + if (auto deviceAffinityAttr = + dyn_cast_if_present( + defaultAffinityAttr)) { + device = entryBuilder + .create( + importOp.getLoc(), + entryBuilder.getType(), + deviceAffinityAttr) + .getResult(0); + } else { + // HACK: if no devices are available we get the first one available at + // runtime. This is suboptimal but we expect most usage to have affinities + // assigned prior to ABI conversion. + device = + IREE::HAL::DeviceType::resolveAny(importOp.getLoc(), entryBuilder); + } // When exporting a fence we need to put a barrier between the rest of the // program and the tensors consumed by the import. @@ -162,7 +192,7 @@ createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel, importOp.getLoc(), entryBuilder.getType(), device, IREE::HAL::FenceFlagBitfield::None); auto barrierOp = entryBuilder.create( - importOp.getLoc(), tensorArgs, waitFence, /*affinity=*/nullptr); + importOp.getLoc(), tensorArgs, waitFence); for (auto [argIndex, readyArg] : llvm::zip_equal(tensorArgIndices, barrierOp.getResults())) { entryArgs[argIndex] = readyArg; @@ -203,9 +233,10 @@ createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel, importOp.getArgAttrOfType(argIndex, "iree.abi.encoding"); auto tensorExportOp = entryBuilder.create( arg.getLoc(), newType, arg, - encodingAttr ? encodingAttr : TypeAttr::get(oldType), + fallback(encodingAttr, TypeAttr::get(oldType)), /*name=*/nullptr, - /*affinity=*/nullptr); + fallback(importOp.getArgAttr(argIndex, "iree.abi.affinity"), + defaultAffinityAttr)); arguments.push_back(tensorExportOp.getTarget()); } else { arguments.push_back(arg); @@ -245,9 +276,10 @@ createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel, resultIndex, "iree.abi.encoding"); auto tensorImportOp = entryBuilder.create( importOp.getLoc(), oldType, result, - encodingAttr ? encodingAttr : TypeAttr::get(oldType), signalFence, + fallback(encodingAttr, TypeAttr::get(oldType)), signalFence, /*name=*/nullptr, - /*affinity=*/nullptr); + fallback(importOp.getResultAttr(resultIndex, "iree.abi.affinity"), + defaultAffinityAttr)); results.push_back(tensorImportOp); } else { results.push_back(result); @@ -255,6 +287,9 @@ createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel, } entryBuilder.create(importOp.getLoc(), results); + + stripABIAttrs(importOp); + return wrapperOp; } @@ -518,8 +553,11 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, // Populate the reflection attrs based on the original types. populateReflectionAttrs(invocationModel, exportOp, wrapperOp); exportOp->removeAttr("iree.reflection"); - if (auto affinityAttr = exportOp->getAttr("stream.affinity")) { - wrapperOp->setAttr("stream.affinity", affinityAttr); + + // Fetch and normalize any explicitly assigned affinity. + auto defaultAffinityAttr = exportOp->getAttr("iree.abi.affinity"); + if (defaultAffinityAttr) { + exportOp->setAttr("stream.affinity", defaultAffinityAttr); } auto *entryBlock = wrapperOp.addEntryBlock(); @@ -572,12 +610,13 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, if (llvm::isa(oldType)) { auto encodingAttr = exportOp.getArgAttrOfType(argIndex, "iree.abi.encoding"); + auto argName = inferArgumentName(entryBuilder.getContext(), argIndex, + exportOp.getArgAttrDict(argIndex)); auto tensorImportOp = entryBuilder.create( arg.getLoc(), oldType, arg, - encodingAttr ? encodingAttr : TypeAttr::get(oldType), waitFence, - inferArgumentName(entryBuilder.getContext(), argIndex, - exportOp.getArgAttrDict(argIndex)), - exportOp.getArgAttr(argIndex, "iree.abi.affinity")); + fallback(encodingAttr, TypeAttr::get(oldType)), waitFence, argName, + fallback(exportOp.getArgAttr(argIndex, "iree.abi.affinity"), + defaultAffinityAttr)); arguments.push_back(tensorImportOp.getTarget()); } else { arguments.push_back(arg); @@ -601,7 +640,8 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, auto aliasOp = entryBuilder.create( exportOp.getLoc(), source.getType(), source, sourceDims, resultStorages[resultIndex], waitFence, - exportOp.getResultAttr(resultIndex, "iree.abi.affinity")); + fallback(exportOp.getResultAttr(resultIndex, "iree.abi.affinity"), + defaultAffinityAttr)); asyncResults[resultIndex] = cast(aliasOp.getResult()); } @@ -622,7 +662,7 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, signalFence); } else { auto barrierOp = entryBuilder.create( - exportOp.getLoc(), asyncTensors, signalFence, /*affinity=*/nullptr); + exportOp.getLoc(), asyncTensors, signalFence); asyncResults = llvm::to_vector(barrierOp.getResults()); } } @@ -635,15 +675,17 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, if (llvm::isa(oldType)) { auto encodingAttr = exportOp.getResultAttrOfType( resultIndex, "iree.abi.encoding"); + auto resultName = + inferResultName(entryBuilder.getContext(), resultIndex, + exportOp.getResultAttrDict(resultIndex)); auto dynamicDims = IREE::Util::buildDynamicDimsForValue( result.getLoc(), result, entryBuilder); auto tensorExportOp = entryBuilder.create( result.getLoc(), newType, result, - encodingAttr ? encodingAttr : TypeAttr::get(result.getType()), - dynamicDims, - inferResultName(entryBuilder.getContext(), resultIndex, - exportOp.getResultAttrDict(resultIndex)), - exportOp.getResultAttr(resultIndex, "iree.abi.affinity")); + fallback(encodingAttr, TypeAttr::get(result.getType())), dynamicDims, + resultName, + fallback(exportOp.getResultAttr(resultIndex, "iree.abi.affinity"), + defaultAffinityAttr)); results.push_back(tensorExportOp); } else { results.push_back(result); @@ -731,12 +773,15 @@ class WrapEntryPointsPass exportOps.push_back(funcOp); } } + if (importOps.empty() && exportOps.empty()) { + return; // no-op + } SymbolTable symbolTable(moduleOp); // Create a wrapper function for each imported function. - // This will preserve the internal types (tensors/etc) but change the import - // to taking the ABI types and rewrite calls. + // This will preserve the internal types (tensors/etc) but change the + // import to taking the ABI types and rewrite calls. for (auto importOp : importOps) { if (failed(wrapImportFunc(getInvocationModel(importOp, invocationModel), moduleOp, importOp, symbolTable))) { diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir index 3780ee21a59e..72a04416c89b 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir @@ -181,6 +181,28 @@ util.func private @caller(%arg0: tensor) -> tensor<2x?xi32> { // ----- +// Tests that explicit import affinity specification is carried through to +// the marshaling ops. + +// CHECK-LABEL: util.func private @pinnedImport(%arg0: !hal.buffer_view) -> !hal.buffer_view +util.func private @pinnedImport(tensor<2xi32> {iree.abi.affinity = #hal.device.promise<@dev_a>}) -> (tensor<2xi32> {iree.abi.affinity = #hal.device.promise<@dev_b>}) + +// CHECK: util.func private @_pinnedImport(%[[ARG_TENSOR:.+]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK: %[[ARG_VIEW:.+]] = hal.tensor.export on(#hal.device.promise<@dev_a>) %[[ARG_TENSOR]] : tensor<2xi32> -> !hal.buffer_view +// CHECK: %[[RET_VIEW:.+]] = util.call @pinnedImport(%[[ARG_VIEW]]) : (!hal.buffer_view) -> !hal.buffer_view +// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import on(#hal.device.promise<@dev_b>) %[[RET_VIEW]] : !hal.buffer_view -> tensor<2xi32> +// CHECK: util.return %[[RET_TENSOR]] +// CHECK: } + +// CHECK: util.func private @pinnedCaller(%arg0: tensor +util.func private @pinnedCaller(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK: util.call @_pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32> + %0 = util.call @pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32> + util.return %0 : tensor<2xi32> +} + +// ----- + // Tests that imports with encodings specified are propagated to the HAL ops. // CHECK-LABEL: util.func private @importEncodings(%arg0: !hal.buffer_view) -> !hal.buffer_view @@ -188,10 +210,10 @@ util.func private @importEncodings(tensor {iree.abi.encoding = tensor) -> tensor<2x?xi32> { // CHECK: %[[ARG_DIM:.+]] = tensor.dim %[[ARG_TENSOR]], %c0 -// CHECK: %[[ARG_VIEW:.+]] = hal.tensor.export %[[ARG_TENSOR]] : tensor{%[[ARG_DIM]]} -> !hal.buffer_view +// CHECK: %[[ARG_VIEW:.+]] = hal.tensor.export %[[ARG_TENSOR]] : tensor as tensor{%[[ARG_DIM]]} -> !hal.buffer_view // CHECK: %[[RET_VIEW:.+]] = util.call @importEncodings(%[[ARG_VIEW]]) : (!hal.buffer_view) -> !hal.buffer_view // CHECK: %[[RET_DIM:.+]] = hal.buffer_view.dim<%[[RET_VIEW]] : !hal.buffer_view>[1] -// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import %[[RET_VIEW]] : !hal.buffer_view -> tensor<2x?xi32>{%[[RET_DIM]]} +// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import %[[RET_VIEW]] : !hal.buffer_view -> tensor<2x?xf32> as tensor<2x?xi32>{%[[RET_DIM]]} // CHECK: util.return %[[RET_TENSOR]] // CHECK: } diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir index 4505a54da10e..f9af50581508 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir @@ -170,6 +170,40 @@ util.func private @caller(%arg0: tensor, %arg1: tensor) -> (te // ----- +// Tests that explicit import affinity specification is carried through to +// the marshaling ops. + +util.global private @dev_a : !hal.device +util.global private @dev_b : !hal.device +util.global private @dev_c : !hal.device + +// CHECK-LABEL: util.func private @pinnedImport(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view +util.func private @pinnedImport(tensor<2xi32> {iree.abi.affinity = #hal.device.affinity<@dev_a>}) -> (tensor<2xi32> {iree.abi.affinity = #hal.device.affinity<@dev_b>}) attributes { + iree.abi.affinity = #hal.device.affinity<@dev_c>, + iree.abi.model = "coarse-fences", + nosideeffects +} + +// CHECK: util.func private @_pinnedImport(%[[ARG_TENSOR:.+]]: tensor<2xi32>) -> tensor<2xi32> { +// CHECK-DAG: %[[DEVICE_C:.+]] = hal.device.resolve on(<@dev_c>) : !hal.device +// CHECK-DAG: %[[ARG_FENCE:.+]] = hal.fence.create device(%[[DEVICE_C]] : !hal.device) flags("None") : !hal.fence +// CHECK-DAG: %[[ARG_READY:.+]] = hal.tensor.barrier join(%[[ARG_TENSOR]] : tensor<2xi32>) => %[[ARG_FENCE]] : !hal.fence +// CHECK-DAG: %[[ARG_VIEW:.+]] = hal.tensor.export on(#hal.device.affinity<@dev_a>) %[[ARG_READY]] : tensor<2xi32> -> !hal.buffer_view +// CHECK-DAG: %[[RESULT_FENCE:.+]] = hal.fence.create device(%[[DEVICE_C]] : !hal.device) flags("None") : !hal.fence +// CHECK: %[[RET_VIEW:.+]] = util.call @pinnedImport(%[[ARG_VIEW]], %[[ARG_FENCE]], %[[RESULT_FENCE]]) : (!hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view +// CHECK: %[[RET_TENSOR:.+]] = hal.tensor.import on(#hal.device.affinity<@dev_b>) wait(%[[RESULT_FENCE]]) => %[[RET_VIEW]] : !hal.buffer_view -> tensor<2xi32> +// CHECK: util.return %[[RET_TENSOR]] +// CHECK: } + +// CHECK: util.func private @pinnedCaller(%arg0: tensor +util.func private @pinnedCaller(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK: util.call @_pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32> + %0 = util.call @pinnedImport(%arg0) : (tensor<2xi32>) -> tensor<2xi32> + util.return %0 : tensor<2xi32> +} + +// ----- + // Tests a side-effect-free import that doesn't take/return reference types. // CHECK-LABEL: util.func private @importI32(%arg0: i32, %arg1: !hal.fence, %arg2: !hal.fence) -> i32 diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel index 98da9d99c543..fb3b309c6d8d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel @@ -64,6 +64,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Encoding/IR", "//compiler/src/iree/compiler/Dialect/HAL/Analysis", "//compiler/src/iree/compiler/Dialect/HAL/IR", + "//compiler/src/iree/compiler/Dialect/Stream/Analysis", "//runtime/src/iree/builtins/ukernel:exported_bits", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt index 72361213574c..0b79e7d710c5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt @@ -86,6 +86,7 @@ iree_cc_library( iree::compiler::Dialect::Encoding::IR iree::compiler::Dialect::HAL::Analysis iree::compiler::Dialect::HAL::IR + iree::compiler::Dialect::Stream::Analysis PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp index c107bc7eec0a..4edaffa6159e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" #include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/MathExtras.h" @@ -547,6 +548,35 @@ static LogicalResult materializeFuncOpEncodings( return success(); } +// Returns the executable targets used within |funcOp|. +// +// TODO(multi-device): delete this pass and rely on tensor-based analysis to +// materialize encodings based on where tensors are used. This pass is not able +// to handle that. +static std::optional> +getFuncExecutableTargetAttrs(FunctionOpInterface funcOp, + IREE::Stream::AffinityAnalysis &affinityAnalysis, + IREE::HAL::DeviceAnalysis &deviceAnalysis) { + // Get a set of all unique affinities used by resources within the function. + SetVector uniqueAffinityAttrs; + SmallVector lookupAffinityAttrs; + funcOp.walk([&](Operation *op) { + if (affinityAnalysis.tryLookupExecutionAffinity(op, lookupAffinityAttrs)) { + uniqueAffinityAttrs.insert(lookupAffinityAttrs.begin(), + lookupAffinityAttrs.end()); + } + lookupAffinityAttrs.clear(); + }); + + // Resolve affinities to executable targets. + SetVector executableTargetAttrs; + for (auto affinityAttr : uniqueAffinityAttrs) { + deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, funcOp, + executableTargetAttrs); + } + return executableTargetAttrs; +} + struct CPUMaterializeHostEncodingPass : public CPUMaterializeHostEncodingBase { CPUMaterializeHostEncodingPass() = default; @@ -560,23 +590,36 @@ struct CPUMaterializeHostEncodingPass auto moduleOp = getOperation(); // Run required analysis passes. + IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp); + if (failed(affinityAnalysis.run())) { + return signalPassFailure(); + } IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); - if (failed(deviceAnalysis.run())) + if (failed(deviceAnalysis.run())) { return signalPassFailure(); + } for (auto funcOp : moduleOp.getOps()) { // Gather the required executable targets for the function. Note that it's // possible there are more required for ops nested within the function but // this pass is a hack and can't handle that :shrug:. - SetVector executableTargets; - deviceAnalysis.gatherRequiredExecutableTargets(funcOp, executableTargets); + auto executableTargets = getFuncExecutableTargetAttrs( + funcOp, affinityAnalysis, deviceAnalysis); + if (!executableTargets) { + funcOp.emitOpError() + << "could not determine executable targets for the function"; + return signalPassFailure(); + } else if (executableTargets->empty()) { + // Probably no tensors. + continue; + } // HACK: this pass is run on the host _but shouldn't be_. Because it's // run on the host and IREE is a compiler capable of multi-targeting there // may be multiple executable targets at any point in the host program. // This pass can't handle that and assumes it's been checked earlier by // spooky action at a distance. This needs to be fixed. - if (executableTargets.size() != 1) { + if (executableTargets->size() != 1) { funcOp.emitOpError() << "has multiple executable targets and CPU data " "tiling isn't built to support that"; return signalPassFailure(); @@ -584,7 +627,7 @@ struct CPUMaterializeHostEncodingPass // Materialize encodings within the function. if (failed( - materializeFuncOpEncodings(funcOp, executableTargets.front()))) { + materializeFuncOpEncodings(funcOp, executableTargets->front()))) { return signalPassFailure(); } } @@ -636,22 +679,35 @@ struct CPUMaterializeUpperBoundTileSizePass auto moduleOp = getOperation(); // Run required analysis passes. + IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp); + if (failed(affinityAnalysis.run())) { + return signalPassFailure(); + } IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); - if (failed(deviceAnalysis.run())) + if (failed(deviceAnalysis.run())) { return signalPassFailure(); + } for (auto funcOp : moduleOp.getOps()) { // Gather the required executable targets for the function. Note that it's // possible there are more required for ops nested within the function but // this pass is a hack and can't handle that :shrug:. - SetVector executableTargets; - deviceAnalysis.gatherRequiredExecutableTargets(funcOp, executableTargets); + auto executableTargets = getFuncExecutableTargetAttrs( + funcOp, affinityAnalysis, deviceAnalysis); + if (!executableTargets) { + funcOp.emitOpError() + << "could not determine executable targets for the function"; + return signalPassFailure(); + } else if (executableTargets->empty()) { + // Probably no tensors. + continue; + } // Get patterns specialized for the executable targets used by the // function. RewritePatternSet patterns(&getContext()); MaterializeEncodingFn materializeEncodingFn = - getUpperBoundMaterializeEncodingFn(executableTargets.getArrayRef()); + getUpperBoundMaterializeEncodingFn(executableTargets->getArrayRef()); if (!materializeEncodingFn) return signalPassFailure(); populateMaterializeUpperBoundTileSizePatterns(patterns, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp index f762a1f870de..fc2672931666 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp @@ -211,6 +211,15 @@ BindingLayoutAnalysis::BindingLayoutAnalysis(Operation *rootOp, } } +bool BindingLayoutAnalysis::hasDispatches() const { + for (auto &it : exportInfos) { + if (!it.second->dispatchOps.empty()) { + return true; // found at least one dispatch + } + } + return false; +} + ArrayRef BindingLayoutAnalysis::getExportDispatches(Operation *exportOp) const { auto it = exportInfos.find(exportOp); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h index 050e18e6a801..7d08959f2490 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h @@ -59,6 +59,9 @@ class BindingLayoutAnalysis { public: explicit BindingLayoutAnalysis(Operation *rootOp, SymbolTable &symbolTable); + // Whether there are any dispatches in the program. + bool hasDispatches() const; + // Returns all of the dispatches to the given executable export. ArrayRef getExportDispatches(IREE::Stream::ExecutableExportOp exportOp) const { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp index a17a1008c87f..705292b1b546 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/Patterns.cpp @@ -15,6 +15,92 @@ namespace mlir::iree_compiler { namespace { +struct ConvertDeviceResolveAnyOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(IREE::HAL::DeviceResolveOp resolveOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (adaptor.getAffinity()) { + return rewriter.notifyMatchFailure( + resolveOp, "only resolving unspecified affinities to any device"); + } + + auto deviceType = rewriter.getType(); + Value device; + auto resolveDevice = [&]() { + if (!device) { + device = rewriter.create( + resolveOp.getLoc(), deviceType, + rewriter.create(resolveOp.getLoc(), 0)); + } + return device; + }; + + SmallVector results; + for (auto resultType : resolveOp.getResultTypes()) { + if (isa(resultType)) { + results.push_back(resolveDevice()); + } else if (isa(resultType)) { + results.push_back(rewriter.create( + resolveOp.getLoc(), resolveDevice())); + } else if (isa(resultType)) { + results.push_back(rewriter.create( + resolveOp.getLoc(), -1ll, 64)); + } + } + + rewriter.replaceOp(resolveOp, results); + return success(); + } +}; + +struct ConvertDeviceResolveAffinityOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(IREE::HAL::DeviceResolveOp resolveOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto affinityAttr = adaptor.getAffinityAttr(); + if (!affinityAttr) { + return rewriter.notifyMatchFailure( + resolveOp, "only resolving fully specified affinities"); + } + auto flatDeviceAttr = dyn_cast(affinityAttr.getDevice()); + if (!flatDeviceAttr) { + return rewriter.notifyMatchFailure( + resolveOp, "nested device references not yet supported"); + } + + auto deviceType = rewriter.getType(); + Value device; + auto resolveDevice = [&]() { + if (!device) { + device = rewriter.create( + resolveOp.getLoc(), deviceType, flatDeviceAttr.getValue(), + /*is_immutable=*/true); + } + return device; + }; + + SmallVector results; + for (auto resultType : resolveOp.getResultTypes()) { + if (isa(resultType)) { + results.push_back(resolveDevice()); + } else if (isa(resultType)) { + results.push_back(rewriter.create( + resolveOp.getLoc(), resolveDevice())); + } else if (isa(resultType)) { + results.push_back(rewriter.create( + resolveOp.getLoc(), affinityAttr.getQueueMask(), 64)); + } + } + + rewriter.replaceOp(resolveOp, results); + return success(); + } +}; + struct ConvertExecutableCalculateWorkgroupsOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -43,6 +129,10 @@ void populateHALToHALPatterns(MLIRContext *context, ConversionTarget &conversionTarget, TypeConverter &typeConverter, RewritePatternSet &patterns) { + conversionTarget.addIllegalOp(); + patterns.insert(typeConverter, context); + patterns.insert(typeConverter, context); + conversionTarget.addIllegalOp(); patterns.insert(typeConverter, context); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel index 605665250693..b38bbbea7b32 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel @@ -15,7 +15,10 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( - ["pseudo_ops.mlir"], + [ + "device_ops.mlir", + "pseudo_ops.mlir", + ], include = ["*.mlir"], ), cfg = "//compiler:lit.cfg.py", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt index 2a57a806c84e..675710985b50 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "device_ops.mlir" "pseudo_ops.mlir" TOOLS FileCheck diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/device_ops.mlir new file mode 100644 index 000000000000..44fb128b5a5b --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/device_ops.mlir @@ -0,0 +1,75 @@ +// RUN: iree-opt --split-input-file --allow-unregistered-dialect --iree-hal-conversion %s | FileCheck %s + +// CHECK-LABEL: @deviceResolveAnyDevice +util.func public @deviceResolveAnyDevice() -> !hal.device { + // CHECK-DAG: %[[ANY_ORDINAL:.+]] = arith.constant 0 + // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %[[ANY_ORDINAL]] : !hal.device + %device = hal.device.resolve : !hal.device + // CHECK: util.return %[[DEVICE]] + util.return %device : !hal.device +} + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @deviceResolveDevice +util.func public @deviceResolveDevice() -> !hal.device { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + %device = hal.device.resolve on(#hal.device.affinity<@device>) : !hal.device + // CHECK: util.return %[[DEVICE]] + util.return %device : !hal.device +} + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @deviceResolveDeviceQueueAffinityAny +util.func public @deviceResolveDeviceQueueAffinityAny() -> (!hal.device, i64) { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant -1 : i64 + %device, %queue_affinity_any = hal.device.resolve on(#hal.device.affinity<@device>) : !hal.device, i64 + // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]] + util.return %device, %queue_affinity_any : !hal.device, i64 +} + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @deviceResolveDeviceQueueAffinity45 +util.func public @deviceResolveDeviceQueueAffinity45() -> (!hal.device, i64) { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64 + %device, %queue_affinity_45 = hal.device.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64 + // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]] + util.return %device, %queue_affinity_45 : !hal.device, i64 +} + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @deviceResolveAllocator +util.func public @deviceResolveAllocator() -> !hal.allocator { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator + %allocator = hal.device.resolve on(#hal.device.affinity<@device>) : !hal.allocator + // CHECK: util.return %[[ALLOCATOR]] + util.return %allocator : !hal.allocator +} + +// ----- + +util.global private @device : !hal.device + +// CHECK-LABEL: @deviceResolveAllocatorQueueAffinity45 +util.func public @deviceResolveAllocatorQueueAffinity45() -> (!hal.allocator, i64) { + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator + // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64 + %allocator, %queue_affinity_45 = hal.device.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.allocator, i64 + // CHECK: util.return %[[ALLOCATOR]], %[[QUEUE_AFFINITY]] + util.return %allocator, %queue_affinity_45 : !hal.allocator, i64 +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp index f8efeba05173..a911de8b2830 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp @@ -69,32 +69,6 @@ Value createPackedConstantBuffer(Location loc, ValueRange constantValues, return constantBuffer; } -IREE::VM::RodataOp -createExecutableBinaryRodata(IREE::HAL::ExecutableBinaryOp binaryOp, - OpBuilder &builder) { - auto executableOp = - binaryOp.getOperation()->getParentOfType(); - auto insertPoint = builder.saveInsertionPoint(); - builder.setInsertionPoint(builder.getInsertionBlock()->getParentOp()); - - std::string rodataName = sanitizeSymbolName( - (executableOp.getName() + "_" + binaryOp.getName()).str()); - auto rodataOp = builder.create( - binaryOp.getLoc(), rodataName, binaryOp.getData()); - rodataOp.setPrivate(); - if (binaryOp.getMimeType().has_value()) { - rodataOp.setMimeTypeAttr(binaryOp.getMimeTypeAttr()); - } - - // TODO(benvanik): should these be page aligned? memcpy fastpath is fine for - // now. - rodataOp.setAlignmentAttr(builder.getI64IntegerAttr(16)); - - builder.restoreInsertionPoint(insertPoint); - - return rodataOp; -} - namespace { class RemoveExecutableOpConversion @@ -128,9 +102,15 @@ class ExecutableCreateOpConversion auto executableBinaryOp = SymbolTable::lookupNearestSymbolFrom( createOp, createOp.getExecutableTarget()); - auto rodataOp = createExecutableBinaryRodata(executableBinaryOp, rewriter); - auto executableRodata = rewriter.createOrFold( - createOp.getLoc(), rodataOp); + auto executableOp = executableBinaryOp.getOperation() + ->getParentOfType(); + std::string rodataName = sanitizeSymbolName( + (executableOp.getName() + "_" + executableBinaryOp.getName()).str()); + auto rodataOp = rewriter.create( + executableBinaryOp.getLoc(), + IREE::VM::RefType::get(rewriter.getType()), + rewriter.getStringAttr(rodataName), executableBinaryOp.getData(), + rewriter.getI64IntegerAttr(16), executableBinaryOp.getMimeTypeAttr()); // Get format string as a rodata blob. auto executableFormatStr = rewriter.create( @@ -151,7 +131,7 @@ class ExecutableCreateOpConversion SmallVector callOperands = { adaptor.getDevice(), executableFormatStr, - executableRodata, + rodataOp, constantBuffer, }; callOperands.append(adaptor.getLayouts().begin(), diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h index 071643a9405e..059652473930 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h @@ -23,11 +23,6 @@ void populateHALToVMPatterns(MLIRContext *context, SymbolTable &importSymbols, Value createPackedConstantBuffer(Location loc, ValueRange constantValues, OpBuilder &builder); -// Creates a vm.rodata containing the contents of a hal.executable.binary. -IREE::VM::RodataOp -createExecutableBinaryRodata(IREE::HAL::ExecutableBinaryOp binaryOp, - OpBuilder &builder); - } // namespace mlir::iree_compiler #endif // IREE_COMPILER_DIALECT_HAL_CONVERSION_HALTOVM_PATTERNS_H_ diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir index 9249b56542ea..5dd534142627 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir @@ -1,7 +1,5 @@ // RUN: iree-opt --split-input-file --iree-vm-conversion %s | FileCheck %s -// CHECK: vm.rodata private @exe_binary1 {alignment = 16 : i64} dense<[0, 1, 2, 3]> : vector<4xi8> -// CHECK: vm.rodata private @exe_binary2 {alignment = 16 : i64} dense<[4, 5, 6, 7]> : vector<4xi8> hal.executable @exe { hal.executable.binary @binary1 attributes { data = dense<[0, 1, 2, 3]> : vector<4xi8>, @@ -24,7 +22,7 @@ util.func public @executableCreate( ) -> (!hal.executable, !hal.executable) { // CHECK-DAG: %[[FORMAT1:.+]] = vm.rodata.inline "_utf8_format1_ - // CHECK-DAG: %[[BINARY1:.+]] = vm.const.ref.rodata @exe_binary1 : !vm.buffer + // CHECK-DAG: %[[BINARY1:.+]] = vm.rodata.inline "exe_binary1" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8> // CHECK-DAG: %[[NULL1:.+]] = vm.const.ref.zero : !vm.buffer // CHECK: %[[EXE1:.+]] = vm.call.variadic @hal.executable.create( // CHECK-SAME: %[[DEV]], %[[FORMAT1]], %[[BINARY1]], %[[NULL1]], [%[[LAYOUT0]], %[[LAYOUT1]]] @@ -32,7 +30,7 @@ util.func public @executableCreate( %0 = hal.executable.create device(%device : !hal.device) target(@exe::@binary1) layouts([%layout0, %layout1]) : !hal.executable // CHECK-DAG: %[[FORMAT2:.+]] = vm.rodata.inline "_utf8_format2_ - // CHECK-DAG: %[[BINARY2:.+]] = vm.const.ref.rodata @exe_binary2 : !vm.buffer + // CHECK-DAG: %[[BINARY2:.+]] = vm.rodata.inline "exe_binary2" {alignment = 16 : i64} : !vm.buffer = dense<[4, 5, 6, 7]> : vector<4xi8> // CHECK-DAG: %[[NULL2:.+]] = vm.const.ref.zero : !vm.buffer // CHECK: %[[EXE2:.+]] = vm.call.variadic @hal.executable.create( // CHECK-SAME: %[[DEV]], %[[FORMAT2]], %[[BINARY2]], %[[NULL2]], [%[[LAYOUT1]], %[[LAYOUT0]]] @@ -45,14 +43,12 @@ util.func public @executableCreate( // ----- -// CHECK: vm.rodata private @exe1_binary1 {alignment = 16 : i64} dense<[0, 1, 2, 3]> : vector<4xi8> hal.executable @exe1 { hal.executable.binary @binary1 attributes { data = dense<[0, 1, 2, 3]> : vector<4xi8>, format = "format" } } -// CHECK: vm.rodata private @exe2_binary2 {alignment = 16 : i64} dense<[4, 5, 6, 7]> : vector<4xi8> hal.executable @exe2 { hal.executable.binary @binary2 attributes { data = dense<[4, 5, 6, 7]> : vector<4xi8>, @@ -67,17 +63,16 @@ util.func public @multipleExecutables( %layout1: !hal.pipeline_layout ) -> (!hal.executable, !hal.executable) { // CHECK-DAG: %[[FORMAT1:.+]] = vm.rodata.inline "_utf8_format_ - // CHECK-DAG: %[[BINARY1:.+]] = vm.const.ref.rodata @exe1_binary1 : !vm.buffer + // CHECK-DAG: %[[BINARY1:.+]] = vm.rodata.inline "exe1_binary1" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8> %0 = hal.executable.create device(%device : !hal.device) target(@exe1::@binary1) layouts([%layout0, %layout1]) : !hal.executable // CHECK-DAG: %[[FORMAT2:.+]] = vm.rodata.inline "_utf8_format_ - // CHECK-DAG: %[[BINARY2:.+]] = vm.const.ref.rodata @exe2_binary2 : !vm.buffer + // CHECK-DAG: %[[BINARY2:.+]] = vm.rodata.inline "exe2_binary2" {alignment = 16 : i64} : !vm.buffer = dense<[4, 5, 6, 7]> : vector<4xi8> %1 = hal.executable.create device(%device : !hal.device) target(@exe2::@binary2) layouts([%layout1, %layout0]) : !hal.executable util.return %0, %1 : !hal.executable, !hal.executable } // ----- -// CHECK: vm.rodata private @exe_binary {alignment = 16 : i64} dense<[0, 1, 2, 3]> : vector<4xi8> hal.executable @exe { hal.executable.binary @binary attributes { data = dense<[0, 1, 2, 3]> : vector<4xi8>, @@ -95,7 +90,7 @@ util.func public @executableConstants( %constant0: i32, %constant1: i32 ) -> !hal.executable { // CHECK-DAG: %[[FORMAT:.+]] = vm.rodata.inline "_utf8_format_ - // CHECK-DAG: %[[BINARY:.+]] = vm.const.ref.rodata @exe_binary : !vm.buffer + // CHECK-DAG: %[[BINARY:.+]] = vm.rodata.inline "exe_binary" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8> // CHECK: %[[CONSTANTS:.+]] = vm.buffer.alloc %c12, %c16 : !vm.buffer // CHECK-DAG: %[[INDEX0:.+]] = vm.const.i64 0 diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index de9bbb4b5f1a..a58f32af496d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -34,58 +34,32 @@ struct ContextResolveOpPattern // Get the affinity from the op or an ancestor. Note that there may be no // affinity specified at all. - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(resolveOp); + auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(resolveOp); + + // If no affinity was specified then resolve as 'any'. + if (!affinityAttr) { + rewriter.replaceOpWithNewOp( + resolveOp, resolveOp.getResultTypes(), + IREE::HAL::DeviceAffinityAttr{}); + return success(); + } // We currently only handle HAL device affinities. // We could make this an interface to select the device and allow users to // provide their own affinities to convert to HAL. In the future users may // also want to provide devices as function arguments post-initialization. // For now we just have one way to specify device globals. - auto deviceAffinityAttr = - dyn_cast_if_present(affinityAttr); - if (!deviceAffinityAttr) { - resolveOp.emitOpError() << "failed to resolve affinity: only HAL device " - "affinities are supported"; - return rewriter.notifyMatchFailure( - resolveOp, "only HAL device affinities are supported"); - } - - // Get the device handle and queue. - // - // TODO(multi-device): specialized types; may need analysis we don't have - // or at least a symbol lookup. An alternative would be an optional type - // on the affinity in cases where we've evaluated it early but for now - // we assume all device types are unspecialized. - auto deviceType = rewriter.getType(); - Value device = rewriter.create( - resolveOp.getLoc(), deviceType, - deviceAffinityAttr.getDevice().getValue(), - /*is_immutable=*/true); - int64_t queueMask = deviceAffinityAttr.getQueueMask(); - - SmallVector results; - if (isa(resultTypes[0])) { - results.push_back(device); - } else if (isa(resultTypes[0])) { - results.push_back(rewriter.create( - resolveOp.getLoc(), device)); - } else { - return rewriter.notifyMatchFailure( - resolveOp, "unrecognized context resolve types for a HAL target"); - } - if (resultTypes.size() > 1) { - if (isa(resultTypes[1])) { - results.push_back(rewriter.create( - resolveOp.getLoc(), queueMask, 64)); - } else { - return rewriter.notifyMatchFailure( - resolveOp, - "unrecognized context resolve types for a HAL target (extended)"); - } + if (auto deviceAffinityAttr = + dyn_cast_if_present(affinityAttr)) { + rewriter.replaceOpWithNewOp( + resolveOp, resolveOp.getResultTypes(), deviceAffinityAttr); + return success(); } - rewriter.replaceOp(resolveOp, results); - return success(); + resolveOp.emitOpError() << "failed to resolve affinity: only HAL device " + "affinities are supported"; + return rewriter.notifyMatchFailure( + resolveOp, "only HAL device affinities are supported"); } }; @@ -684,7 +658,7 @@ struct CmdDispatchOpPattern // make this difficult. For now we assume each stream region being lowered // has a singular affinity that may itself reference multiple devices in the // future but currently uniquely identifies a device. - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(dispatchOp); + auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(dispatchOp); // Get the device handle we're executing against in this execution region. // Note that this is a dynamic value: we have to treat the device as unknown diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp index 8b628da69d3d..5c685c028537 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp @@ -22,7 +22,7 @@ static llvm::cl::opt clExternalResourcesMappable( namespace mlir::iree_compiler { Value lookupDeviceFor(Operation *op, OpBuilder &builder) { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op); auto resolveOp = builder.create( op->getLoc(), TypeRange{ @@ -34,7 +34,7 @@ Value lookupDeviceFor(Operation *op, OpBuilder &builder) { std::tuple lookupDeviceAndQueueAffinityFor(Operation *op, OpBuilder &builder) { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op); auto resolveOp = builder.create( op->getLoc(), TypeRange{ @@ -46,7 +46,7 @@ std::tuple lookupDeviceAndQueueAffinityFor(Operation *op, } Value lookupAllocatorFor(Operation *op, OpBuilder &builder) { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op); auto resolveOp = builder.create( op->getLoc(), TypeRange{ @@ -58,7 +58,7 @@ Value lookupAllocatorFor(Operation *op, OpBuilder &builder) { std::tuple lookupAllocatorAndQueueAffinityFor(Operation *op, OpBuilder &builder) { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(op); auto resolveOp = builder.create( op->getLoc(), TypeRange{ diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir index 20a9c59a127f..c60d0daf71af 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/context_ops.mlir @@ -1,5 +1,7 @@ // RUN: iree-opt --split-input-file --allow-unregistered-dialect --iree-hal-conversion %s | FileCheck %s +// NOTE: the hal.device.resolve lowering in HAL-to-HAL does most of the work. + util.global private @device : !hal.device // CHECK-LABEL: @contextResolveDefaultDevice @@ -16,63 +18,12 @@ util.func public @contextResolveDefaultDevice() -> !hal.device attributes { util.global private @device : !hal.device -// CHECK-LABEL: @contextResolveDevice -util.func public @contextResolveDevice() -> !hal.device { - // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device - %device = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device - // CHECK: util.return %[[DEVICE]] - util.return %device : !hal.device -} - -// ----- - -util.global private @device : !hal.device - -// CHECK-LABEL: @contextResolveDeviceQueueAffinityAny -util.func public @contextResolveDeviceQueueAffinityAny() -> (!hal.device, i64) { - // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device - // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant -1 : i64 - %device, %queue_affinity_any = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.device, i64 - // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]] - util.return %device, %queue_affinity_any : !hal.device, i64 -} - -// ----- - -util.global private @device : !hal.device - -// CHECK-LABEL: @contextResolveDeviceQueueAffinity45 -util.func public @contextResolveDeviceQueueAffinity45() -> (!hal.device, i64) { - // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device - // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64 - %device, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, i64 - // CHECK: util.return %[[DEVICE]], %[[QUEUE_AFFINITY]] - util.return %device, %queue_affinity_45 : !hal.device, i64 -} - -// ----- - -util.global private @device : !hal.device - -// CHECK-LABEL: @contextResolveAllocator -util.func public @contextResolveAllocator() -> !hal.allocator { - // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device - // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator - %allocator = stream.context.resolve on(#hal.device.affinity<@device>) : !hal.allocator - // CHECK: util.return %[[ALLOCATOR]] - util.return %allocator : !hal.allocator -} - -// ----- - -util.global private @device : !hal.device - // CHECK-LABEL: @contextResolveAllocatorQueueAffinity45 -util.func public @contextResolveAllocatorQueueAffinity45() -> (!hal.allocator, i64) { +util.func public @contextResolveAllocatorQueueAffinity45() -> (!hal.device, !hal.allocator, i64) { // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator // CHECK-DAG: %[[QUEUE_AFFINITY:.+]] = arith.constant 48 : i64 - %allocator, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.allocator, i64 - // CHECK: util.return %[[ALLOCATOR]], %[[QUEUE_AFFINITY]] - util.return %allocator, %queue_affinity_45 : !hal.allocator, i64 + %device, %allocator, %queue_affinity_45 = stream.context.resolve on(#hal.device.affinity<@device, [4, 5]>) : !hal.device, !hal.allocator, i64 + // CHECK: util.return %[[DEVICE]], %[[ALLOCATOR]], %[[QUEUE_AFFINITY]] + util.return %device, %allocator, %queue_affinity_45 : !hal.device, !hal.allocator, i64 } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp index 6d3c4a22dcc1..fe32e8bc4925 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp @@ -448,28 +448,34 @@ Attribute DeviceTargetAttr::parse(AsmParser &p, Type type) { // `[targets, ...]` (optional) do { IREE::HAL::ExecutableTargetAttr executableTargetAttr; - if (failed(p.parseAttribute(executableTargetAttr))) + if (failed(p.parseAttribute(executableTargetAttr))) { return {}; + } executableTargetAttrs.push_back(executableTargetAttr); } while (succeeded(p.parseOptionalComma())); - if (failed(p.parseRSquare())) + if (failed(p.parseRSquare())) { return {}; + } } else { // `{config dict}` (optional) - if (failed(p.parseAttribute(configAttr))) + if (failed(p.parseAttribute(configAttr))) { return {}; + } // `, [targets, ...]` (optional) if (succeeded(p.parseOptionalComma())) { - if (failed(p.parseLSquare())) + if (failed(p.parseLSquare())) { return {}; + } do { IREE::HAL::ExecutableTargetAttr executableTargetAttr; - if (failed(p.parseAttribute(executableTargetAttr))) + if (failed(p.parseAttribute(executableTargetAttr))) { return {}; + } executableTargetAttrs.push_back(executableTargetAttr); } while (succeeded(p.parseOptionalComma())); - if (failed(p.parseRSquare())) + if (failed(p.parseRSquare())) { return {}; + } } } } @@ -502,7 +508,14 @@ void DeviceTargetAttr::print(AsmPrinter &p) const { } std::string DeviceTargetAttr::getSymbolNameFragment() { - return sanitizeSymbolName(getDeviceID().getValue().lower()); + std::string name = getDeviceID().getValue().lower(); + if (auto ordinalAttr = + dyn_cast_if_present(getConfigurationAttr("ordinal"))) { + name += "_"; + name += std::to_string(ordinalAttr.getInt()); + name += "_"; // can't have trailing numbers + } + return sanitizeSymbolName(name); } bool DeviceTargetAttr::hasConfigurationAttr(StringRef name) { @@ -510,6 +523,13 @@ bool DeviceTargetAttr::hasConfigurationAttr(StringRef name) { return configAttr && configAttr.get(name); } +Attribute DeviceTargetAttr::getConfigurationAttr(StringRef name) { + if (auto configAttr = getConfiguration()) { + return configAttr.get(name); + } + return {}; +} + void DeviceTargetAttr::getExecutableTargets( SetVector &resultAttrs) { for (auto attr : getExecutableTargets()) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index 2d10dc32aac0..9511cf84b7ca 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td @@ -754,6 +754,8 @@ def HAL_DeviceTargetAttr : AttrDef:$device, + AttrParameter<"SymbolRefAttr", "">:$device, AttrParameter<"int64_t", "">:$queue_mask ); diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp index 23dab245d380..bf596ce8dcef 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp @@ -90,14 +90,16 @@ struct DeduplicateTensorBarrierSources for (auto source : op.getSources()) { auto it = uniqueSources.insert(std::make_pair(source, orderedSources.size())); - if (it.second) + if (it.second) { orderedSources.push_back(source); + } resultMapping.push_back(it.first->second); } - if (orderedSources.size() == op.getSources().size()) + if (orderedSources.size() == op.getSources().size()) { return failure(); - auto newOp = rewriter.create( - op.getLoc(), orderedSources, op.getSignalFence(), op.getAffinityAttr()); + } + auto newOp = rewriter.create(op.getLoc(), orderedSources, + op.getSignalFence()); SmallVector newResults; newResults.reserve(newOp.getNumResults()); for (unsigned newIndex : resultMapping) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index e28025e3ec49..538b32d13d99 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -458,19 +458,6 @@ void TensorImportOp::build(OpBuilder &builder, OperationState &result, waitFence, name, affinity); } -Value TensorImportOp::getTiedResult(unsigned resultIndex) { - return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource()); -} - -::std::optional -TensorImportOp::getTiedResultOperandIndex(unsigned resultIndex) { - return {0}; // source -} - -SmallVector TensorImportOp::getTiedResultOperandIndices() { - return {0}; // source -} - static LogicalResult verifyTypeStorageCompatibility(Operation *op, Type encodingType, Type storageType) { @@ -539,19 +526,6 @@ void TensorExportOp::build(OpBuilder &builder, OperationState &result, affinity); } -Value TensorExportOp::getTiedResult(unsigned resultIndex) { - return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource()); -} - -::std::optional -TensorExportOp::getTiedResultOperandIndex(unsigned resultIndex) { - return {0}; // source -} - -SmallVector TensorExportOp::getTiedResultOperandIndices() { - return {0}; // source -} - LogicalResult TensorExportOp::verify() { TensorExportOp op = *this; auto sourceType = llvm::cast(op.getSource().getType()); @@ -595,11 +569,10 @@ LogicalResult TensorAliasOp::verify() { //===----------------------------------------------------------------------===// void TensorBarrierOp::build(OpBuilder &builder, OperationState &result, - ValueRange sources, Value signalFence, - Attribute affinity) { + ValueRange sources, Value signalFence) { auto resultTypes = llvm::map_to_vector( sources, [](Value source) { return source.getType(); }); - build(builder, result, resultTypes, sources, signalFence, affinity); + build(builder, result, resultTypes, sources, signalFence); } Value TensorBarrierOp::getTiedResult(unsigned resultIndex) { @@ -1063,6 +1036,23 @@ void DescriptorSetLayoutCreateOp::getAsmResultNames( setNameFn(getResult(), "descriptor_set_layout"); } +//===----------------------------------------------------------------------===// +// hal.device.resolve +//===----------------------------------------------------------------------===// + +void DeviceResolveOp::getAsmResultNames( + function_ref setNameFn) { + for (auto result : getResults()) { + if (isa(result.getType())) { + setNameFn(result, "device"); + } else if (isa(result.getType())) { + setNameFn(result, "allocator"); + } else if (isa(result.getType())) { + setNameFn(result, "queue_affinity"); + } + } +} + //===----------------------------------------------------------------------===// // hal.device.allocator //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index d35a0cb55459..dff5438039ce 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -98,11 +98,6 @@ let opDocGroup = OpGroupPseudoOps in { def HAL_TensorImportOp : HAL_PureOp<"tensor.import", [ AttrSizedOperandSegments, - DeclareOpInterfaceMethods, Util_ShapeAwareOp, ]> { let summary = [{imports a tensor from a HAL buffer view}]; @@ -171,11 +166,6 @@ def HAL_TensorImportOp : HAL_PureOp<"tensor.import", [ } def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [ - DeclareOpInterfaceMethods, Util_ShapeAwareOp, ]> { let summary = [{exports a tensor to a HAL buffer view}]; @@ -320,15 +310,13 @@ def HAL_TensorBarrierOp : HAL_Op<"tensor.barrier", [ let arguments = (ins Variadic:$sources, - HAL_Fence:$signal_fence, - OptionalAttr:$affinity + HAL_Fence:$signal_fence ); let results = (outs Variadic:$results ); let assemblyFormat = [{ - (`on` `(` $affinity^ `)`)? `join` `` `(` $sources `:` type($sources) `)` `=` `` `>` $signal_fence `:` type($signal_fence) @@ -338,8 +326,7 @@ def HAL_TensorBarrierOp : HAL_Op<"tensor.barrier", [ let builders = [ OpBuilder<(ins "ValueRange":$sources, - "Value":$signalFence, - "Attribute":$affinity + "Value":$signalFence )>, ]; @@ -1616,6 +1603,41 @@ def OpGroupDeviceOps : OpDocGroup { let opDocGroup = OpGroupDeviceOps in { +def HAL_DeviceResolveOp : HAL_PureOp<"device.resolve", [ + DeclareOpInterfaceMethods, +]> { + let summary = [{resolves device handles based on affinity}]; + let description = [{ + Examples: + ``` + // Returns a HAL device. + = hal.device.resolve on(#something) : !hal.device + // Returns a HAL device, allocator, and (optional) queue affinity. + = hal.device.resolve on(#something) : !hal.device, !hal.allocator, i64 + // Returns a HAL allocator and (optional) queue affinity. + = hal.device.resolve on(#something) : !hal.allocator, i64 + // Returns "any" device. Should only be used as a fallback. + = hal.device.resolve : !hal.device + ``` + }]; + + let arguments = (ins + OptionalAttr:$affinity + ); + let results = (outs + Variadic>:$results + ); + + let assemblyFormat = [{ + (`on` `(` $affinity^ `)`)? + attr-dict `:` type($results) + }]; +} + def HAL_DeviceAllocatorOp : HAL_PureOp<"device.allocator", [ DeclareOpInterfaceMethods, ]> { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp index 06d236699997..eedb427a3f15 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp @@ -49,6 +49,7 @@ struct ConvertToHALPass : public IREE::HAL::impl::ConvertToHALPassBase { void runOnOperation() override { auto *context = &getContext(); + auto moduleOp = getOperation(); // Gather all interfaces from registered dialects. // These will perform the tensor->buffer mapping for their ops. @@ -64,8 +65,7 @@ struct ConvertToHALPass HALTypeConverter typeConverter(conversionInterfaces); HALConversionTarget conversionTarget(context, typeConverter); - RewritePatternSet patterns(&getContext()); - + RewritePatternSet patterns(context); populateHALToHALPatterns(context, conversionTarget, typeConverter, patterns); populateUtilToHALPatterns(context, conversionTarget, typeConverter, @@ -84,13 +84,14 @@ struct ConvertToHALPass // NOTE: we allow ops that we don't know about to allow custom dialects // that don't need anything HAL-specific to pass through. - if (failed(applyPartialConversion(getOperation(), conversionTarget, + if (failed(applyPartialConversion(moduleOp, conversionTarget, std::move(patterns)))) { return signalPassFailure(); } // Cleanup conversion attributes used for spooky action at a distance. - for (auto executableOp : getOperation().getOps()) { + moduleOp->removeAttr("stream.affinity.default"); + for (auto executableOp : moduleOp.getOps()) { for (auto variantOp : executableOp.getOps()) { for (auto exportOp : diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp index 6e5f110b68f3..7ad1f4ae959e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp @@ -165,12 +165,30 @@ static DispatchParamsMap gatherDispatchParams(mlir::ModuleOp moduleOp, return map; } +static std::pair +getDeviceAndQueueAffinity(Location loc, IREE::Stream::AffinityAttr affinityAttr, + OpBuilder &builder) { + if (auto deviceAffinityAttr = + dyn_cast_if_present(affinityAttr)) { + auto resolveOp = builder.create( + loc, + TypeRange{ + builder.getType(), + builder.getI64Type(), + }, + deviceAffinityAttr); + return std::make_pair(resolveOp.getResult(0), resolveOp.getResult(1)); + } + auto device = IREE::HAL::DeviceType::resolveAny(loc, builder); + auto queueAffinity = builder.create(loc, -1, 64); + return std::make_pair(device, queueAffinity); +} + // Appends a global hal.buffer initialized to the size required for all // of the bindings in |dispatchParams| (plus alignment). -static IREE::Util::GlobalOp -appendGlobalBuffer(Location loc, StringRef baseName, - const DispatchParams &dispatchParams, - OpBuilder &moduleBuilder) { +static IREE::Util::GlobalOp appendGlobalBuffer( + Location loc, StringRef baseName, const DispatchParams &dispatchParams, + IREE::Stream::AffinityAttr affinityAttr, OpBuilder &moduleBuilder) { // Create a global to hold the HAL buffer. auto globalOp = moduleBuilder.create( loc, (baseName + "_buffer").str(), @@ -191,12 +209,12 @@ appendGlobalBuffer(Location loc, StringRef baseName, auto initBuilder = OpBuilder::atBlockBegin(initOp.addEntryBlock()); IndexSet indexSet(loc, initBuilder); - // TODO(multi-device): support multiple devices in benchmark generation. - Value device = IREE::HAL::DeviceType::resolveAny(loc, initBuilder); + // Resolve allocator for the benchmark device. + auto [device, queueAffinity] = + getDeviceAndQueueAffinity(loc, affinityAttr, initBuilder); auto allocator = initBuilder.create(loc, device).getResult(); - auto queueAffinity = initBuilder.create(loc, -1, 64); auto memoryTypes = IREE::HAL::MemoryTypeBitfield::DeviceLocal; auto bufferUsage = IREE::HAL::BufferUsageBitfield::Transfer | IREE::HAL::BufferUsageBitfield::DispatchStorage; @@ -234,8 +252,8 @@ static void appendDispatchBenchmark(IREE::Stream::AffinityAttr affinityAttr, } // Add a global variable holding an initialized buffer for the dispatch IO. - auto bufferGlobalOp = - appendGlobalBuffer(loc, baseName, dispatchParams, moduleBuilder); + auto bufferGlobalOp = appendGlobalBuffer(loc, baseName, dispatchParams, + affinityAttr, moduleBuilder); // Create an exported benchmark function that runs the dispatches. auto funcType = @@ -261,10 +279,9 @@ static void appendDispatchBenchmark(IREE::Stream::AffinityAttr affinityAttr, auto batchSizeArg = funcBuilder.create( loc, funcBuilder.getIndexType(), entryBlock->getArgument(0)); - // TODO(multi-device): support multiple devices in benchmark generation. - // For now we should just use the affinityAttr to resolve the device. - Value device = IREE::HAL::DeviceType::resolveAny(loc, funcBuilder); - Value queueAffinity = funcBuilder.create(loc, -1, 64); + // Resolve device for this particular benchmark. + auto [device, queueAffinity] = + getDeviceAndQueueAffinity(loc, affinityAttr, funcBuilder); // Create and begin command buffer. // TODO(benvanik): reuse the command buffer (initialize once and store). @@ -423,8 +440,9 @@ buildBenchmarkModule(IREE::HAL::ExecutableOp sourceExecutableOp, // would be to generate one module per device dispatches are made on such // that users can isolate to individual devices. For now we just deal with // it. - for (auto globalOp : deviceAnalysis.getDeviceGlobals()) + for (auto globalOp : deviceAnalysis.getDeviceGlobals()) { moduleBuilder.clone(*globalOp.getOperation()); + } // Clone the executable variant into the new module. auto executableOp = moduleBuilder.create( diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp index e1437b2cf77d..221f7547d42c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp @@ -600,6 +600,18 @@ struct MaterializeInterfacesPass return signalPassFailure(); } + // If no devices were defined and there are dispatches in the program then + // error out. This provides a better error message than if we were to allow + // this pass to no-op and then fail during conversion later on. + if (layoutAnalysis.hasDispatches() && + deviceAnalysis.getDeviceGlobals().empty()) { + mlir::emitError(moduleOp.getLoc()) + << "no HAL devices defined in the module; use the module-level " + "hal.device.targets attribute, the --iree-hal-target-device= " + "flag, or provide inputs with global !hal.devices defined"; + return signalPassFailure(); + } + // Gather the required executable targets per executable and dispatch site. auto requiredExecutableTargets = buildRequiredExecutableTargetsMap( moduleOp, deviceAnalysis, layoutAnalysis); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp index 67bf38359f7f..5d70f1b1c4ff 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTargetDevices.cpp @@ -142,36 +142,32 @@ static FailureOr createDeviceGlobals(mlir::ModuleOp moduleOp, // have one set. static void assignDefaultDeviceAffinity(mlir::ModuleOp moduleOp, FlatSymbolRefAttr defaultDeviceRef) { - Builder builder(moduleOp); - auto affinityName = builder.getStringAttr("stream.affinity"); - auto affinityAttr = builder.getAttr( - defaultDeviceRef, /*queue_mask=*/-1ll); - - // TODO(benvanik): make this an interface that can be registered on types. - auto isAnnotatableType = [](Type type) { - return isa(type) || isa(type); - }; - for (auto &op : moduleOp.getOps()) { - bool shouldAnnotate = true; - if (auto globalOp = dyn_cast(op)) { - if (!isAnnotatableType(globalOp.getGlobalType())) { - shouldAnnotate = false; + auto affinityAttr = IREE::HAL::DeviceAffinityAttr::get( + moduleOp.getContext(), defaultDeviceRef, /*queue_mask=*/-1ll); + + // Default on the module that applies to any ops that don't otherwise have a + // placement. Ideally we never need this but some programs may take/return no + // tensors or have tensors come from unattributed containers (lists/dicts). + moduleOp->setAttr("stream.affinity.default", affinityAttr); + + // Set all arg/results to route through the default device unless they've + // already been assigned. + auto affinityName = StringAttr::get(moduleOp.getContext(), "stream.affinity"); + for (auto funcOp : moduleOp.getOps()) { + if (funcOp.isPublic()) { + for (auto arg : funcOp.getArguments()) { + if (isa(arg.getType())) { + if (!funcOp.getArgAttr(arg.getArgNumber(), affinityName)) { + funcOp.setArgAttr(arg.getArgNumber(), affinityName, affinityAttr); + } + } } - } else if (op.hasTrait()) { - // Symbol table ops can't reference parent symbols properly. - shouldAnnotate = false; - } - if (!shouldAnnotate) { - continue; // skip op - } - - if (auto affinityOp = dyn_cast(op)) { - if (!affinityOp.getAffinityAttr()) { - affinityOp.setAffinityAttr(affinityAttr); - } - } else { - if (!op.hasAttr(affinityName)) { - op.setAttr(affinityName, affinityAttr); + for (auto result : llvm::enumerate(funcOp.getResultTypes())) { + if (isa(result.value())) { + if (!funcOp.getResultAttr(result.index(), affinityName)) { + funcOp.setResultAttr(result.index(), affinityName, affinityAttr); + } + } } } } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index 773ed925097f..54a87b01020a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -198,11 +198,17 @@ void buildHALDeviceAssignmentPassPipeline( passManager.addPass(IREE::HAL::createAssignTargetDevicesPass( {assignmentOptions.targetDevices})); } + + // Create globals for each device (if needed). passManager.addPass(IREE::HAL::createMaterializeTargetDevicesPass( {assignmentOptions.defaultDevice})); + + // Resolve #hal.device.promise and #hal.device.alias attributes. passManager.addPass(IREE::HAL::createResolveDevicePromisesPass()); passManager.addPass( IREE::HAL::createResolveDeviceAliasesPass({&targetRegistry})); + + // Verify devices are valid. passManager.addPass(IREE::HAL::createVerifyDevicesPass({&targetRegistry})); } @@ -222,6 +228,9 @@ void buildHALConfigurationPassPipeline(OpPassManager &passManager, // and initial interface analysis (we rely on CSE and such having been run). addCleanupPatterns(passManager); + // Verify devices are valid. + passManager.addPass(IREE::HAL::createVerifyDevicesPass({&targetRegistry})); + //---------------------------------------------------------------------------- // Device-specific interface materialization //---------------------------------------------------------------------------- diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp index e1ca6247996d..6ac9cc806152 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/VerifyDevices.cpp @@ -127,13 +127,38 @@ struct VerifyDevicesPass return signalPassFailure(); } - // Must have at least one device specified. - if (deviceAnalysis.getDeviceGlobals().empty()) { + // Devices are only required if we have dialects we may lower into device + // code. For now checking for tensor types is probably sufficient though we + // may want a pluggable way to decide this (e.g. dialect/type/op + // interfaces). + auto isTensor = [](Type type) { return isa(type); }; + bool anyTensors = false; + for (auto &op : moduleOp.getOps()) { + if (op.hasTrait()) { + continue; // ignore executables + } + op.walk([&](Operation *childOp) { + if (llvm::any_of(childOp->getOperandTypes(), isTensor) || + llvm::any_of(childOp->getResultTypes(), isTensor)) { + anyTensors = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } + // TODO(multi-device): the logic above is insufficient; we only need devices + // if the program will end up requiring them but we don't know that here. + // We have to wait until we've lowered to the point where we do require a + // device _and_ we actually want one (aren't compiling a non-HAL program). + // We could probably have an op interface, better output from the pass that + // requires the devices, etc. For now we error out in HAL conversion when we + // try to resolve devices. + if (false && anyTensors && deviceAnalysis.getDeviceGlobals().empty()) { auto diagnostic = moduleOp.emitError(); diagnostic << "no HAL devices defined in the module; use the module-level " "hal.device.targets attribute, the --iree-hal-target-device= " - "flags, or provide inputs with global !hal.devices defined; "; + "flag, or provide inputs with global !hal.devices defined; "; printAvailable(diagnostic, *targetRegistry.value); return signalPassFailure(); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir index 61926d34a1d7..adce11f9c9c9 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir @@ -4,7 +4,7 @@ // RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=device_a[0],device_a[1]})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ORDINALS // RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=#hal.device.target<"local">})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ATTR // RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=#hal.device.alias<"device_a">})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-ALIAS -// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices="device_a,#hal.device.alias<"device_b">"})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-SELECT +// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices={"device_a,#hal.device.alias<"device_b">"}})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-SELECT // RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=device_a=#hal.device.alias<"device_a">,"device_bc=device_b,#hal.device.alias<"device_c">"})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-SELECT-MULTI // CHECK: module diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir index 11b39183bde5..6360abe5cb05 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_target_devices.mlir @@ -33,6 +33,7 @@ module @module { // CHECK: module @module // CHECK-NOT: hal.device.targets +// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@__device_0> module @module attributes { hal.device.targets = [ #hal.device.select<[#device_a, #device_b]> : !hal.device @@ -42,18 +43,6 @@ module @module attributes { // CHECK-SAME: #[[DEVICE_A]], // CHECK-SAME: #[[DEVICE_B]] // CHECK-SAME: ]> : !hal.device - - // CHECK: util.global private @tensor_global - // CHECK-SAME: stream.affinity = #hal.device.affinity<@__device_0> - util.global private @tensor_global : tensor<4xf32> - - // CHECK: util.global private @primitive_global - // CHECK-NOT: stream.affinity - util.global private @primitive_global : i32 - - // CHECK: util.func private @func - // CHECK-SAME: stream.affinity = #hal.device.affinity<@__device_0> - util.func private @func() -> () } // ----- @@ -69,6 +58,7 @@ module @module attributes { // CHECK: module @module // CHECK-NOT: hal.device.targets +// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@device_a> module @module attributes { hal.device.targets = { device_a = #device_a, @@ -77,10 +67,6 @@ module @module attributes { } { // CHECK: util.global private @device_a = #[[DEVICE_A]] // CHECK: util.global private @device_bc = #hal.device.select<[#[[DEVICE_B]], #[[DEVICE_C]]]> - - // CHECK: util.global private @tensor_global - // CHECK-SAME: stream.affinity = #hal.device.affinity<@device_a> - util.global private @tensor_global : tensor<4xf32> } // ----- @@ -94,6 +80,7 @@ module @module attributes { // CHECK: module @module // CHECK-NOT: hal.device.targets +// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@device_b> module @module attributes { hal.device.targets = { device_a = #device_a, @@ -103,10 +90,6 @@ module @module attributes { } { // CHECK: util.global private @device_a // CHECK: util.global private @device_b - - // CHECK: util.global private @tensor_global - // CHECK-SAME: stream.affinity = #hal.device.affinity<@device_b> - util.global private @tensor_global : tensor<4xf32> } // ----- @@ -120,6 +103,7 @@ module @module attributes { // CHECK: module @module // CHECK-NOT: hal.device.targets +// CHECK-SAME: stream.affinity.default = #hal.device.affinity<@__device_1> module @module attributes { hal.device.targets = [ #device_a, @@ -129,9 +113,4 @@ module @module attributes { } { // CHECK: util.global private @__device_0 // CHECK: util.global private @__device_1 - - // CHECK: util.global private @tensor_global - // CHECK-SAME: stream.affinity = #hal.device.affinity<@__device_1> - util.global private @tensor_global : tensor<4xf32> } - diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir index 4511a0becdcb..b4e226418fbb 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/verify_devices.mlir @@ -1,12 +1,27 @@ // RUN: iree-opt --split-input-file --iree-hal-verify-devices %s --mlir-print-local-scope --verify-diagnostics | FileCheck %s -// expected-error@+1 {{no HAL devices defined in the module}} +// Tests that modules without tensors don't need devices. + module @module { + // CHECK: util.func private @func util.func private @func() -> () } // ----- +// TODO(multi-device): find a way to verify that devices exist if they need to. +// Currently the check is disabled as it's difficult to tell if a device will be +// needed by the time we get to the HAL layer: plugins may absorb things, etc. +// NO-expected-errorx@+1 {{no HAL devices defined in the module}} +module @module { + util.func private @func() -> () { + arith.constant dense<1.0> : tensor<4xf32> + util.return + } +} + +// ----- + module @module { // expected-error@+1 {{unregistered target device "__unregistered__"}} util.global private @device = #hal.device.target<"__unregistered__"> : !hal.device diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel index fbc0e51d4463..10d1456b0e37 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/BUILD.bazel @@ -22,6 +22,7 @@ iree_compiler_cc_library( ], deps = [ "//compiler/src/iree/compiler/Dialect/Flow/IR", + "//compiler/src/iree/compiler/Dialect/Stream/Analysis", "//compiler/src/iree/compiler/Dialect/Stream/IR", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt index 05bbb79469cf..cc472aa01ec6 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/CMakeLists.txt @@ -24,6 +24,7 @@ iree_cc_library( MLIRTransformUtils MLIRTransforms iree::compiler::Dialect::Flow::IR + iree::compiler::Dialect::Stream::Analysis iree::compiler::Dialect::Stream::IR PUBLIC ) diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp index bdc5aafc6197..31d61516e3eb 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp @@ -35,41 +35,42 @@ static Value buildResultSizeOf(Location loc, Value tensorValue, } struct ConvertTensorConstantOp - : public OpConversionPattern { + : public AffinityOpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorConstantOp constantOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorConstantOp constantOp, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { // Capture the tensor constant strongly typed with constant lifetime. - Type constantType = IREE::Stream::ResourceType::get( - getContext(), IREE::Stream::Lifetime::Constant); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp); + auto constantType = rewriter.getType( + IREE::Stream::Lifetime::Constant); auto newOp = rewriter.create( constantOp.getLoc(), constantType, convertAttributeToStream(constantOp.getValue()), - TypeAttr::get(constantOp.getType()), ValueRange{}, affinityAttr); + TypeAttr::get(constantOp.getType()), ValueRange{}, + executionAffinityAttr); // Transfer to unknown lifetime. - Type unknownType = IREE::Stream::ResourceType::get(getContext()); + auto unknownType = rewriter.getType(); auto constantSize = rewriter.createOrFold( constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult()); rewriter.replaceOpWithNewOp( constantOp, unknownType, newOp.getResult(), constantSize, constantSize, - /*source_affinity=*/affinityAttr, - /*result_affinity=*/affinityAttr); + /*source_affinity=*/executionAffinityAttr, + /*result_affinity=*/executionAffinityAttr); return success(); } }; struct ConvertTensorDynamicConstantOp - : public OpConversionPattern { + : public AffinityOpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorDynamicConstantOp constantOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorDynamicConstantOp constantOp, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto attrType = dyn_cast(constantOp.getValue().getType()); if (!attrType) return failure(); @@ -91,22 +92,21 @@ struct ConvertTensorDynamicConstantOp } // Capture the tensor constant strongly typed with constant lifetime. - Type constantType = IREE::Stream::ResourceType::get( - getContext(), IREE::Stream::Lifetime::Constant); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp); + auto constantType = rewriter.getType( + IREE::Stream::Lifetime::Constant); auto newOp = rewriter.create( constantOp.getLoc(), constantType, convertAttributeToStream(constantOp.getValue()), - TypeAttr::get(resultType), dynamicDims, affinityAttr); + TypeAttr::get(resultType), dynamicDims, executionAffinityAttr); // Transfer to unknown lifetime. - Type unknownType = IREE::Stream::ResourceType::get(getContext()); + auto unknownType = rewriter.getType(); auto constantSize = rewriter.createOrFold( constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult()); rewriter.replaceOpWithNewOp( constantOp, unknownType, newOp.getResult(), constantSize, constantSize, - /*source_affinity=*/affinityAttr, - /*result_affinity=*/affinityAttr); + /*source_affinity=*/executionAffinityAttr, + /*result_affinity=*/executionAffinityAttr); return success(); } }; @@ -114,157 +114,169 @@ struct ConvertTensorDynamicConstantOp // Reshapes and bitcasts become clones here to preserve shape/element type // information (which may become actual transfers depending on source/target // shape) - they'll be elided if not needed. +// +// NOTE: we transfer to the target before cloning. This may not be optimal +// as the clone may otherwise have been able to be elided on the producer +// side but we leave that for future copy elision to determine. template -struct ConvertTensorCastLikeOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ConvertTensorCastLikeOp + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern< + CastOpTy>::AffinityAwareConversionPattern; LogicalResult matchAndRewrite(CastOpTy op, typename CastOpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto unknownType = rewriter.getType(); - auto source = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto resultAffinityAttr = this->lookupResultAffinity(op.getResult()); + auto source = this->transferTensorOperand(op.getLoc(), op.getSource(), + adaptor.getSource(), + resultAffinityAttr, rewriter); auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), - affinityAttr, rewriter); + resultAffinityAttr, rewriter); + auto unknownType = rewriter.getType(); rewriter.replaceOpWithNewOp( op, unknownType, source.resource, op.getSource().getType(), op.getSourceDims(), source.resourceSize, op.getResult().getType(), - adaptor.getResultDims(), resultSize, affinityAttr); + adaptor.getResultDims(), resultSize, resultAffinityAttr); return success(); } }; struct ConvertTensorAllocaOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorAllocaOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type unknownType = IREE::Stream::ResourceType::get(getContext()); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorAllocaOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), - affinityAttr, rewriter); + executionAffinityAttr, rewriter); + auto unknownType = rewriter.getType(); rewriter.replaceOpWithNewOp( - op, unknownType, resultSize, affinityAttr); + op, unknownType, resultSize, executionAffinityAttr); return success(); } }; struct ConvertTensorEmptyOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorEmptyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type unknownType = IREE::Stream::ResourceType::get(getContext()); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorEmptyOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), - affinityAttr, rewriter); + executionAffinityAttr, rewriter); + auto unknownType = rewriter.getType(); rewriter.replaceOpWithNewOp( op, unknownType, op.getResult().getType(), adaptor.getResultDims(), - resultSize, affinityAttr); + resultSize, executionAffinityAttr); return success(); } }; struct ConvertTensorSplatOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorSplatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto unknownType = rewriter.getType(); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorSplatOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), - affinityAttr, rewriter); + executionAffinityAttr, rewriter); + auto unknownType = rewriter.getType(); rewriter.replaceOpWithNewOp( op, unknownType, adaptor.getValue(), op.getResult().getType(), - adaptor.getResultDims(), resultSize, affinityAttr); + adaptor.getResultDims(), resultSize, executionAffinityAttr); return success(); } }; struct ConvertTensorCloneOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorCloneOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorCloneOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + auto operand = transferTensorOperand(op.getLoc(), op.getOperand(), + adaptor.getOperand(), + executionAffinityAttr, rewriter); auto unknownType = rewriter.getType(); - auto operand = - consumeTensorOperand(op.getLoc(), adaptor.getOperand(), rewriter); rewriter.replaceOpWithNewOp( op, unknownType, operand.resource, op.getOperand().getType(), op.getArgumentDims(), operand.resourceSize, op.getResult().getType(), - adaptor.getArgumentDims(), operand.resourceSize, - IREE::Stream::AffinityAttr::lookup(op)); + adaptor.getArgumentDims(), operand.resourceSize, executionAffinityAttr); return success(); } }; struct ConvertTensorTransferOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorTransferOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto targetAffinityAttr = - dyn_cast(adaptor.getTarget()); - if (!targetAffinityAttr) + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorTransferOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + if (!executionAffinityAttr) { return rewriter.notifyMatchFailure(op, "invalid stream affinity attr"); + } + auto operand = resolveTensorOperand(op.getLoc(), op.getOperand(), + adaptor.getOperand(), rewriter); auto unknownType = rewriter.getType(); - auto operand = - consumeTensorOperand(op.getLoc(), adaptor.getOperand(), rewriter); rewriter.replaceOpWithNewOp( op, unknownType, operand.resource, operand.resourceSize, operand.resourceSize, - /*source_affinity=*/IREE::Stream::AffinityAttr{}, targetAffinityAttr); + /*source_affinity=*/operand.affinity, + /*result_affinity=*/executionAffinityAttr); return success(); } }; struct ConvertTensorSliceOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorSliceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto unknownType = rewriter.getType(); + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorSliceOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto source = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); auto resultSize = buildResultSizeOf(op.getLoc(), op.getResult(), op.getResultDims(), - affinityAttr, rewriter); + executionAffinityAttr, rewriter); + auto unknownType = rewriter.getType(); rewriter.replaceOpWithNewOp( op, unknownType, source.resource, op.getSource().getType(), op.getSourceDims(), source.resourceSize, adaptor.getStartIndices(), adaptor.getLengths(), op.getResult().getType(), adaptor.getResultDims(), - resultSize, affinityAttr); + resultSize, executionAffinityAttr); return success(); } }; struct ConvertTensorUpdateOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::TensorUpdateOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto update = - consumeTensorOperand(op.getLoc(), adaptor.getUpdate(), rewriter); + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::TensorUpdateOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto target = - consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(), + executionAffinityAttr, rewriter); + auto update = + transferTensorOperand(op.getLoc(), op.getUpdate(), adaptor.getUpdate(), + executionAffinityAttr, rewriter); rewriter.replaceOpWithNewOp( op, target.resource.getType(), target.resource, op.getTarget().getType(), adaptor.getTargetDims(), target.resourceSize, adaptor.getStartIndices(), update.resource, op.getUpdate().getType(), - op.getUpdateDims(), update.resourceSize, - IREE::Stream::AffinityAttr::lookup(op)); + op.getUpdateDims(), update.resourceSize, executionAffinityAttr); return success(); } }; @@ -281,14 +293,13 @@ static bool isScalarTensor(RankedTensorType type) { } struct ConvertTensorLoadOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; LogicalResult matchAndRewrite(IREE::Flow::TensorLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto resultType = getTypeConverter()->convertType(op.getResult().getType()); - auto source = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + auto source = resolveTensorOperand(op.getLoc(), op.getSource(), + adaptor.getSource(), rewriter); // If the source is not a staging resource then we need to transfer it to // a staging resource. We slice out just what is being loaded so that we @@ -299,6 +310,7 @@ struct ConvertTensorLoadOp // If already a staging resource then we can fast-path load the value. auto stagingType = rewriter.getType( IREE::Stream::Lifetime::Staging); + auto resultType = getTypeConverter()->convertType(op.getResult().getType()); if (source.resource.getType() == stagingType) { rewriter.replaceOpWithNewOp( op, resultType, source.resource, op.getSource().getType(), @@ -306,16 +318,14 @@ struct ConvertTensorLoadOp return success(); } - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); - // Scalar tensors get transferred without slicing. auto sourceEncoding = op.getSource().getType(); if (isScalarTensor(sourceEncoding)) { auto transferOp = rewriter.create( op.getLoc(), stagingType, source.resource, source.resourceSize, source.resourceSize, - /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op), - /*result_affinity=*/IREE::Stream::AffinityAttr::lookup(op)); + /*source_affinity=*/source.affinity, + /*result_affinity=*/source.affinity); rewriter.replaceOpWithNewOp( op, resultType, transferOp.getResult(), sourceEncoding, adaptor.getSourceDims(), transferOp.getResultSize(), @@ -341,16 +351,17 @@ struct ConvertTensorLoadOp RankedTensorType::get(resultDims, sourceEncoding.getElementType(), sourceEncoding.getEncoding()); Value resultSize = rewriter.create( - op.getLoc(), resultEncoding, ValueRange{}, affinityAttr); + op.getLoc(), resultEncoding, ValueRange{}, source.affinity); auto sliceOp = rewriter.create( op.getLoc(), source.resource.getType(), source.resource, sourceEncoding, adaptor.getSourceDims(), source.resourceSize, sliceIndices, - sliceLengths, resultEncoding, ValueRange{}, resultSize, affinityAttr); + sliceLengths, resultEncoding, ValueRange{}, resultSize, + source.affinity); auto transferOp = rewriter.create( op.getLoc(), stagingType, sliceOp.getResult(), sliceOp.getResultSize(), sliceOp.getResultSize(), - /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op), - /*result_affinity=*/IREE::Stream::AffinityAttr::lookup(op)); + /*source_affinity=*/source.affinity, + /*result_affinity=*/source.affinity); rewriter.replaceOpWithNewOp( op, resultType, transferOp.getResult(), sliceOp.getResultEncoding(), sliceOp.getResultEncodingDims(), transferOp.getResultSize(), @@ -360,13 +371,13 @@ struct ConvertTensorLoadOp }; struct ConvertTensorStoreOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; LogicalResult matchAndRewrite(IREE::Flow::TensorStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto target = - consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + auto target = resolveTensorOperand(op.getLoc(), op.getTarget(), + adaptor.getTarget(), rewriter); // If the target is a staging resource then we can directly store into it // with a fast-path. Otherwise we need to stage an upload. @@ -380,34 +391,23 @@ struct ConvertTensorStoreOp return success(); } - // Scalar tensors disconnect from the original target. - auto targetEncoding = op.getTarget().getType(); - if (isScalarTensor(targetEncoding)) { - rewriter.replaceOpWithNewOp( - op, target.resource.getType(), adaptor.getValue(), targetEncoding, - adaptor.getTargetDims(), target.resourceSize, - IREE::Stream::AffinityAttr::lookup(op)); - return success(); - } - // Use fill to store the value. // TODO(benvanik): support larger buffer slices (stage + update). IndexSet indexSet(op.getLoc(), rewriter); indexSet.populate(adaptor.getIndices()); - SmallVector lengths; - for (auto index : adaptor.getIndices()) - lengths.push_back(indexSet.get(1)); + SmallVector lengths(adaptor.getIndices().size(), indexSet.get(1)); + auto targetEncoding = op.getTarget().getType(); rewriter.replaceOpWithNewOp( op, target.resource, targetEncoding, adaptor.getTargetDims(), target.resourceSize, adaptor.getIndices(), lengths, adaptor.getValue(), - IREE::Stream::AffinityAttr::lookup(op)); + target.affinity); return success(); } }; struct ConvertTensorTraceOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; LogicalResult matchAndRewrite(IREE::Flow::TensorTraceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -416,8 +416,8 @@ struct ConvertTensorTraceOp SmallVector resourceEncodings; for (auto [tensorOperand, resourceOperand] : llvm::zip_equal(op.getValues(), adaptor.getValues())) { - auto source = - consumeTensorOperand(op.getLoc(), resourceOperand, rewriter); + auto source = resolveTensorOperand(op.getLoc(), tensorOperand, + resourceOperand, rewriter); auto stagingType = rewriter.getType( IREE::Stream::Lifetime::Staging); auto traceSource = source.resource; @@ -425,13 +425,14 @@ struct ConvertTensorTraceOp traceSource = rewriter.create( op.getLoc(), stagingType, source.resource, source.resourceSize, source.resourceSize, - /*source_affinity=*/IREE::Stream::AffinityAttr::lookup(op), - /*result_affinity=*/nullptr); + /*source_affinity=*/source.affinity, + /*result_affinity=*/source.affinity); } resources.push_back(traceSource); resourceSizes.push_back(source.resourceSize); resourceEncodings.push_back(TypeAttr::get(tensorOperand.getType())); } + rewriter.replaceOpWithNewOp( op, adaptor.getKey(), resources, resourceSizes, rewriter.getArrayAttr(resourceEncodings), adaptor.getValueDims()); @@ -440,16 +441,18 @@ struct ConvertTensorTraceOp }; struct ConvertChannelDefaultOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::ChannelDefaultOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::ChannelDefaultOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( - op, /*id=*/Value{}, + op, + /*id=*/Value{}, /*group=*/adaptor.getGroupAttr(), /*rank=*/Value{}, - /*count=*/Value{}, IREE::Stream::AffinityAttr::lookup(op)); + /*count=*/Value{}, executionAffinityAttr); return success(); } }; @@ -491,164 +494,190 @@ struct ConvertChannelCountOp }; struct ConvertAllGatherOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::CollectiveAllGatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto shape = llvm::cast(op.getSource().getType()); - auto collectiveAttr = IREE::Stream::CollectiveAttr::get( - op.getContext(), IREE::Stream::CollectiveKind::AllGather, + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::CollectiveAllGatherOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + auto collectiveAttr = rewriter.getAttr( + IREE::Stream::CollectiveKind::AllGather, /*reduction=*/std::nullopt, static_cast(op.getElementType())); auto zeroOffset = rewriter.create(op.getLoc(), 0); auto elementCount = rewriter.create( - op.getLoc(), shape.getNumElements()); + op.getLoc(), op.getType().getNumElements()); auto newTargetCast = - consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(), + executionAffinityAttr, rewriter); auto newSourceCast = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); rewriter.replaceOpWithNewOp( - op, collectiveAttr, newTargetCast.resource, + op, collectiveAttr, + /*target=*/newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, + /*target_length=*/newTargetCast.resourceSize, + /*source=*/newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, - /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, - /*source_length=*/newSourceCast.resourceSize, elementCount, - adaptor.getChannel(), - /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op)); + /*source_offset=*/zeroOffset, + /*source_end=*/newSourceCast.resourceSize, + /*source_length=*/newSourceCast.resourceSize, + /*element_count=*/elementCount, + /*channel=*/adaptor.getChannel(), + /*param=*/mlir::Value(), executionAffinityAttr); return success(); } }; struct ConvertAllReduceOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::CollectiveAllReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto shape = llvm::cast(op.getType()); - auto collectiveAttr = IREE::Stream::CollectiveAttr::get( - op.getContext(), IREE::Stream::CollectiveKind::AllReduce, + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::CollectiveAllReduceOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + auto collectiveAttr = rewriter.getAttr( + IREE::Stream::CollectiveKind::AllReduce, static_cast(op.getReductionOp()), static_cast(op.getElementType())); auto zeroOffset = rewriter.create(op.getLoc(), 0); auto elementCount = rewriter.create( - op.getLoc(), shape.getNumElements()); + op.getLoc(), op.getType().getNumElements()); auto newTargetCast = - consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(), + executionAffinityAttr, rewriter); auto newSourceCast = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); rewriter.replaceOpWithNewOp( - op, collectiveAttr, newTargetCast.resource, + op, collectiveAttr, + /*target=*/newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, + /*target_length=*/newTargetCast.resourceSize, + /*source=*/newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, - /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, - /*source_length=*/newSourceCast.resourceSize, elementCount, - adaptor.getChannel(), - /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op)); + /*source_offset=*/zeroOffset, + /*source_end=*/newSourceCast.resourceSize, + /*source_length=*/newSourceCast.resourceSize, + /*element_count=*/elementCount, + /*channel=*/adaptor.getChannel(), + /*param=*/mlir::Value(), executionAffinityAttr); return success(); } }; struct ConvertAllToAllOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::CollectiveAllToAllOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto shape = llvm::cast(op.getSource().getType()); - auto collectiveAttr = IREE::Stream::CollectiveAttr::get( - op.getContext(), IREE::Stream::CollectiveKind::AllToAll, + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::CollectiveAllToAllOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + auto collectiveAttr = rewriter.getAttr( + IREE::Stream::CollectiveKind::AllToAll, /*reduction=*/std::nullopt, static_cast(op.getElementType())); auto zeroOffset = rewriter.create(op.getLoc(), 0); auto elementCount = rewriter.create( - op.getLoc(), shape.getNumElements()); + op.getLoc(), op.getType().getNumElements()); auto newTargetCast = - consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(), + executionAffinityAttr, rewriter); auto newSourceCast = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); rewriter.replaceOpWithNewOp( - op, collectiveAttr, newTargetCast.resource, + op, collectiveAttr, + /*target=*/newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, + /*target_length=*/newTargetCast.resourceSize, + /*source=*/newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, - /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, - /*source_length=*/newSourceCast.resourceSize, elementCount, - adaptor.getChannel(), - /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op)); + /*source_offset=*/zeroOffset, + /*source_end=*/newSourceCast.resourceSize, + /*source_length=*/newSourceCast.resourceSize, + /*element_count=*/elementCount, + /*channel=*/adaptor.getChannel(), + /*param=*/mlir::Value(), executionAffinityAttr); return success(); } }; -struct ConvertReduceScatterOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::CollectiveReduceScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto shape = llvm::cast(op.getType()); - auto collectiveAttr = IREE::Stream::CollectiveAttr::get( - op.getContext(), IREE::Stream::CollectiveKind::ReduceScatter, +struct ConvertReduceScatterOp : public AffinityOpConversionPattern< + IREE::Flow::CollectiveReduceScatterOp> { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::CollectiveReduceScatterOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + auto collectiveAttr = rewriter.getAttr( + IREE::Stream::CollectiveKind::ReduceScatter, static_cast(op.getReductionOp()), static_cast(op.getElementType())); auto zeroOffset = rewriter.create(op.getLoc(), 0); auto elementCount = rewriter.create( - op.getLoc(), shape.getNumElements()); + op.getLoc(), op.getType().getNumElements()); auto newTargetCast = - consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(), + executionAffinityAttr, rewriter); auto newSourceCast = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); rewriter.replaceOpWithNewOp( - op, collectiveAttr, newTargetCast.resource, + op, collectiveAttr, + /*target=*/newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, + /*target_length=*/newTargetCast.resourceSize, + /*source=*/newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, - /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, - /*source_length=*/newSourceCast.resourceSize, elementCount, - adaptor.getChannel(), - /*param=*/mlir::Value(), IREE::Stream::AffinityAttr::lookup(op)); + /*source_offset=*/zeroOffset, + /*source_end=*/newSourceCast.resourceSize, + /*source_length=*/newSourceCast.resourceSize, + /*element_count=*/elementCount, + /*channel=*/adaptor.getChannel(), + /*param=*/mlir::Value(), executionAffinityAttr); return success(); } }; struct ConvertCollectiveSendRecvOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::CollectiveSendRecvOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto shape = llvm::cast(op.getType()); - auto collectiveAttr = IREE::Stream::CollectiveAttr::get( - op.getContext(), IREE::Stream::CollectiveKind::SendRecv, + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::CollectiveSendRecvOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + auto collectiveAttr = rewriter.getAttr( + IREE::Stream::CollectiveKind::SendRecv, /*reduction=*/std::nullopt, static_cast(op.getElementType())); auto zeroOffset = rewriter.create(op.getLoc(), 0); auto elementCount = rewriter.create( - op.getLoc(), shape.getNumElements()); + op.getLoc(), op.getType().getNumElements()); auto newTargetCast = - consumeTensorOperand(op.getLoc(), adaptor.getTarget(), rewriter); + transferTensorOperand(op.getLoc(), op.getTarget(), adaptor.getTarget(), + executionAffinityAttr, rewriter); auto newSourceCast = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); // Pack send, recv into param. The values are checked to be within the // 16-bit range during lowering to Flow dialect. @@ -665,27 +694,31 @@ struct ConvertCollectiveSendRecvOp auto param = rewriter.create(op.getLoc(), hi, lo); rewriter.replaceOpWithNewOp( - op, collectiveAttr, newTargetCast.resource, + op, collectiveAttr, + /*target=*/newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, + /*target_length=*/newTargetCast.resourceSize, + /*source=*/newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, - /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, - /*source_length=*/newSourceCast.resourceSize, elementCount, - adaptor.getChannel(), - /*param=*/param, IREE::Stream::AffinityAttr::lookup(op)); + /*source_offset=*/zeroOffset, + /*source_end=*/newSourceCast.resourceSize, + /*source_length=*/newSourceCast.resourceSize, + /*element_count=*/elementCount, + /*channel=*/adaptor.getChannel(), + /*param=*/param, executionAffinityAttr); return success(); } }; -struct ConvertDispatchOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::DispatchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); - +struct ConvertDispatchOp + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::DispatchOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { // Zero is going to be used for each operand to start. auto zeroOffset = rewriter.create(op.getLoc(), 0); @@ -700,7 +733,8 @@ struct ConvertDispatchOp : public OpConversionPattern { llvm::zip_equal(op.getArguments(), adaptor.getArguments())) { if (llvm::isa(oldOperand.getType())) { auto newOperandCast = - consumeTensorOperand(op.getLoc(), newOperand, rewriter); + transferTensorOperand(op.getLoc(), oldOperand, newOperand, + executionAffinityAttr, rewriter); newOperand = newOperandCast.resource; dispatchOperandSizes.push_back(newOperandCast.resourceSize); operandSizes.push_back(newOperandCast.resourceSize); @@ -732,9 +766,9 @@ struct ConvertDispatchOp : public OpConversionPattern { } else { auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue( op.getLoc(), result.value(), rewriter); - resultSizes.push_back(buildResultSizeOf(op.getLoc(), result.value(), - resultDynamicDims, affinityAttr, - rewriter)); + resultSizes.push_back( + buildResultSizeOf(op.getLoc(), result.value(), resultDynamicDims, + executionAffinityAttr, rewriter)); resultTypes.push_back(unknownType); } } @@ -743,7 +777,7 @@ struct ConvertDispatchOp : public OpConversionPattern { op, resultTypes, adaptor.getWorkload(), adaptor.getEntryPointsAttr(), dispatchOperands, dispatchOperandSizes, dispatchOperandOffsets, dispatchOperandEnds, dispatchOperandLengths, resultSizes, - adaptor.getTiedOperandsAttr(), affinityAttr); + adaptor.getTiedOperandsAttr(), executionAffinityAttr); newOp->setDialectAttrs(op->getDialectAttrs()); return success(); } @@ -759,8 +793,8 @@ struct ConvertFuncOp : public OpConversionPattern { // Tensors become resources without sizes. The default type converter // adds the size so we bypass that here. We may want to allow the user // to override the lifetime with attributes, too. - return IREE::Stream::ResourceType::get(type.getContext(), - IREE::Stream::Lifetime::Unknown); + return rewriter.getType( + IREE::Stream::Lifetime::Unknown); } return getTypeConverter()->convertType(type); }; @@ -784,13 +818,12 @@ struct ConvertFuncOp : public OpConversionPattern { } }; -struct ConvertCallOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Flow::CallOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); - +struct ConvertCallOp : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::Flow::CallOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { // Zero is going to be used for each operand to start. auto zeroOffset = rewriter.create(op.getLoc(), 0); @@ -805,7 +838,8 @@ struct ConvertCallOp : public OpConversionPattern { llvm::zip_equal(op.getArguments(), adaptor.getArguments())) { if (llvm::isa(oldOperand.getType())) { auto newOperandCast = - consumeTensorOperand(op.getLoc(), newOperand, rewriter); + transferTensorOperand(op.getLoc(), oldOperand, newOperand, + executionAffinityAttr, rewriter); newOperand = newOperandCast.resource; callOperandSizes.push_back(newOperandCast.resourceSize); operandSizes.push_back(newOperandCast.resourceSize); @@ -837,9 +871,9 @@ struct ConvertCallOp : public OpConversionPattern { } else { auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue( op.getLoc(), result.value(), rewriter); - resultSizes.push_back(buildResultSizeOf(op.getLoc(), result.value(), - resultDynamicDims, affinityAttr, - rewriter)); + resultSizes.push_back( + buildResultSizeOf(op.getLoc(), result.value(), resultDynamicDims, + executionAffinityAttr, rewriter)); resultTypes.push_back(unknownType); } } @@ -848,7 +882,7 @@ struct ConvertCallOp : public OpConversionPattern { op, resultTypes, adaptor.getCalleeAttr(), callOperands, callOperandSizes, callOperandOffsets, callOperandEnds, callOperandLengths, resultSizes, adaptor.getTiedOperandsAttr(), - affinityAttr); + executionAffinityAttr); newOp->setDialectAttrs(op->getDialectAttrs()); return success(); } @@ -1065,9 +1099,10 @@ struct ConvertReturnOp : public OpConversionPattern { } // namespace -void populateFlowToStreamConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns) { +void populateFlowToStreamConversionPatterns( + MLIRContext *context, TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns) { patterns .insert, @@ -1075,17 +1110,19 @@ void populateFlowToStreamConversionPatterns(MLIRContext *context, ConvertTensorAllocaOp, ConvertTensorEmptyOp, ConvertTensorSplatOp, ConvertTensorCloneOp, ConvertTensorTransferOp, ConvertTensorSliceOp, ConvertTensorUpdateOp, ConvertTensorLoadOp, - ConvertTensorStoreOp, ConvertTensorTraceOp>(typeConverter, - context); - patterns.insert(typeConverter, - context); + ConvertTensorStoreOp, ConvertTensorTraceOp>( + typeConverter, context, affinityAnalysis); + patterns.insert(typeConverter, context, + affinityAnalysis); + patterns.insert(typeConverter, context); patterns .insert(typeConverter, - context); - patterns.insert(typeConverter, context); - patterns.insert(typeConverter, context); + ConvertAllToAllOp, ConvertCollectiveSendRecvOp>( + typeConverter, context, affinityAnalysis); + patterns.insert(typeConverter, context, affinityAnalysis); + patterns.insert(typeConverter, context); + patterns.insert(typeConverter, context, affinityAnalysis); patterns.insert(typeConverter, context); patterns.insert< ConvertDispatchWorkgroupInfoOp(typeConverter, context); } -void populateFlowToStreamConversionPatterns(MLIRContext *context, - ConversionTarget &conversionTarget, - TypeConverter &typeConverter, - RewritePatternSet &patterns) { +void populateFlowToStreamConversionPatterns( + MLIRContext *context, ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns) { // Disallow all flow ops besides the ones we pass through (today). // We don't have a stream-equivalent of several of the dispatch-level flow // ops as the codegen backends directly touch them and so long as we have both @@ -1111,7 +1149,8 @@ void populateFlowToStreamConversionPatterns(MLIRContext *context, conversionTarget.addLegalOp(); conversionTarget.markOpRecursivelyLegal(); - populateFlowToStreamConversionPatterns(context, typeConverter, patterns); + populateFlowToStreamConversionPatterns(context, typeConverter, + affinityAnalysis, patterns); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h index ad2c95af8284..0379be2c835f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h @@ -11,18 +11,24 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/Transforms/DialectConversion.h" +namespace mlir::iree_compiler::IREE::Stream { +class AffinityAnalysis; +} // namespace mlir::iree_compiler::IREE::Stream + namespace mlir::iree_compiler { // Populates conversion patterns that perform flow->stream conversion. // These patterns ensure that nested types are run through the provided // |typeConverter|. -void populateFlowToStreamConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns); -void populateFlowToStreamConversionPatterns(MLIRContext *context, - ConversionTarget &conversionTarget, - TypeConverter &typeConverter, - RewritePatternSet &patterns); +void populateFlowToStreamConversionPatterns( + MLIRContext *context, TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns); +void populateFlowToStreamConversionPatterns( + MLIRContext *context, ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns); } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir index 410630703f81..063389fb4dfa 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/dispatch_ops.mlir @@ -52,13 +52,15 @@ util.global private @device_b : !hal.device // CHECK-LABEL: @dispatchAffinity // CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM1:.+]]: index, %[[DIM3:.+]]: index) util.func public @dispatchAffinity(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> (tensor, tensor) { + // CHECK: %[[INPUT_A:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%[[INPUT_SIZE]]} // CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor{%[[DIM1]], %[[DIM3]]} - // CHECK: %[[RESULT0:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@entry0(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) + // CHECK: %[[RESULT0:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@entry0(%[[INPUT_A]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) %0 = flow.dispatch @ex::@entry0(%input) { stream.affinity = #hal.device.affinity<@device_a> } : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor{%dim1, %dim3} + // CHECK: %[[INPUT_B:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_b>) !stream.resource<*>{%[[INPUT_SIZE]]} // CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor{%[[DIM3]], %[[DIM1]]} - // CHECK: %[[RESULT1:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@entry1(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) + // CHECK: %[[RESULT1:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@entry1(%[[INPUT_B]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) %1 = flow.dispatch @ex::@entry1(%input) { stream.affinity = #hal.device.affinity<@device_b> } : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor{%dim3, %dim1} diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir index a755d44ad27b..ee68211df04f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir @@ -231,9 +231,9 @@ util.func public @tensorStore(%target : tensor<2x3xi32>) -> tensor<2x3xi32> { util.func public @tensorStoreScalar(%target : tensor) -> tensor { // CHECK: %[[VALUE:.+]] = arith.constant 9 %value = arith.constant 9 : i32 - // CHECK: %[[SPLAT:.+]] = stream.tensor.splat %[[VALUE]] : i32 -> tensor in !stream.resource<*>{%[[TARGET_SIZE]]} + // CHECK: %[[FILL:.+]] = stream.tensor.fill %[[VALUE]], %[[TARGET]] : i32 -> tensor in %[[TARGET]] as !stream.resource<*>{%[[TARGET_SIZE]]} %0 = flow.tensor.store %value, %target : tensor - // CHECK: util.return %[[SPLAT]] + // CHECK: util.return %[[FILL]] util.return %0 : tensor } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp index 4323473fbb02..76eef8b8e56f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp @@ -21,11 +21,12 @@ namespace { // %1 = stream.tensor.import %0 : !hal.buffer_view -> // tensor<4xf32> in !stream.resource<*> struct ConvertTensorImportOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::HAL::TensorImportOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::HAL::TensorImportOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto sourceType = op.getSource().getType(); auto targetType = op.getTargetEncoding(); if (!llvm::isa(sourceType) && @@ -49,25 +50,23 @@ struct ConvertTensorImportOp } } - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); - // Import (buffer view to stream resource). auto resultType = rewriter.getType( IREE::Stream::Lifetime::External); Value resultSize = rewriter.create( op.getLoc(), rewriter.getIndexType(), TypeAttr::get(op.getTarget().getType()), adaptor.getTargetDims(), - affinityAttr); + executionAffinityAttr); Value resource = rewriter.create( op.getLoc(), resultType, adaptor.getSource(), TypeAttr::get(targetType), - adaptor.getTargetDims(), resultSize, affinityAttr); + adaptor.getTargetDims(), resultSize, executionAffinityAttr); // Await the fence, if needed. When not specified the resource is assumed to // be immediately available. if (auto waitFence = op.getWaitFence()) { Value waitTimepoint = rewriter.create( op.getLoc(), rewriter.getType(), - ValueRange{waitFence}, affinityAttr); + ValueRange{waitFence}, executionAffinityAttr); resource = rewriter .create( op.getLoc(), ValueRange{resource}, @@ -77,8 +76,9 @@ struct ConvertTensorImportOp auto unknownType = rewriter.getType(); rewriter.replaceOpWithNewOp( - op, unknownType, resource, resultSize, resultSize, affinityAttr, - /*target_affinity=*/IREE::Stream::AffinityAttr{}); + op, unknownType, resource, resultSize, resultSize, + /*source_affinity=*/executionAffinityAttr, + /*target_affinity=*/executionAffinityAttr); return success(); } @@ -122,11 +122,12 @@ struct ConvertTensorImportOp // %1 = stream.tensor.export %0 : tensor<4xf32> in !stream.resource<*> -> // !hal.buffer_view struct ConvertTensorExportOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::HAL::TensorExportOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::HAL::TensorExportOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto sourceType = op.getSourceEncoding(); auto targetType = op.getTarget().getType(); if (!llvm::isa(targetType) && @@ -134,9 +135,9 @@ struct ConvertTensorExportOp return rewriter.notifyMatchFailure(op, "unsupported HAL cast conversion"); } - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); auto source = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); // Exporting a produced value - transfer our source value to an externally // usable resource and directly export it. This will cause an allocation. @@ -146,14 +147,14 @@ struct ConvertTensorExportOp if (source.resource.getType() != externalType) { exportSource = rewriter.create( op.getLoc(), externalType, source.resource, source.resourceSize, - source.resourceSize, /*source_affinity=*/IREE::Stream::AffinityAttr{}, - affinityAttr); + source.resourceSize, /*source_affinity=*/source.affinity, + /*target_affinity=*/executionAffinityAttr); } // Export (stream resource to buffer view). rewriter.replaceOpWithNewOp( op, targetType, exportSource, TypeAttr::get(sourceType), - adaptor.getSourceDims(), source.resourceSize, affinityAttr); + adaptor.getSourceDims(), source.resourceSize, executionAffinityAttr); return success(); } }; @@ -170,29 +171,23 @@ struct ConvertTensorExportOp // %update = stream.async.update %0, %storage[...] // %2 = stream.async.slice %update[...] struct ConvertTensorAliasOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::HAL::TensorAliasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + IREE::HAL::TensorAliasOp op, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { auto sourceType = op.getSource().getType(); auto source = - consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); - - // All operations (if any) will happen on the device specified by the alias - // as that indicates the affinity of the storage. - auto affinityAttr = - dyn_cast_if_present(op.getAffinityAttr()); - if (!affinityAttr) { - affinityAttr = IREE::Stream::AffinityAttr::lookup(op); - } + transferTensorOperand(op.getLoc(), op.getSource(), adaptor.getSource(), + executionAffinityAttr, rewriter); // Query the target storage buffer length; we will only populate up to // what is required for the output. Value storageSize = rewriter.create( op.getLoc(), rewriter.getIndexType(), TypeAttr::get(op.getSource().getType()), adaptor.getSourceDims(), - affinityAttr); + executionAffinityAttr); // Import the target storage as a resource that we can use as an update // target. We overwrite the contents and just cast the storage to the @@ -202,7 +197,7 @@ struct ConvertTensorAliasOp auto importOp = rewriter.create( op.getLoc(), externalType, adaptor.getStorage(), TypeAttr::get(sourceType), adaptor.getSourceDims(), storageSize, - affinityAttr); + executionAffinityAttr); // Await the fence, if needed. When not specified the storage is assumed to // be immediately available. @@ -210,7 +205,7 @@ struct ConvertTensorAliasOp if (auto waitFence = op.getWaitFence()) { Value waitTimepoint = rewriter.create( op.getLoc(), rewriter.getType(), - ValueRange{waitFence}, affinityAttr); + ValueRange{waitFence}, executionAffinityAttr); storage = rewriter .create( op.getLoc(), ValueRange{storage}, @@ -223,7 +218,7 @@ struct ConvertTensorAliasOp auto updateOp = rewriter.create( op.getLoc(), externalType, storage, storageSize, zeroOffset, source.resourceSize, source.resource, source.resourceSize, - affinityAttr); + executionAffinityAttr); // Slice out the value from the updated tensor. // This preserves the use-def chain but is almost always elided by aliasing @@ -231,14 +226,14 @@ struct ConvertTensorAliasOp auto sliceOp = rewriter.create( op.getLoc(), externalType, updateOp.getResult(), updateOp.getTargetSize(), zeroOffset, source.resourceSize, - source.resourceSize, affinityAttr); + source.resourceSize, executionAffinityAttr); // Transfer to match original lifetime (if needed). Value result = sliceOp.getResult(); if (source.resource.getType() != result.getType()) { result = rewriter.create( op.getLoc(), source.resource.getType(), result, source.resourceSize, - source.resourceSize, affinityAttr, affinityAttr); + source.resourceSize, executionAffinityAttr, executionAffinityAttr); } rewriter.replaceOp(op, result); @@ -256,28 +251,38 @@ struct ConvertTensorAliasOp // %t01 = stream.timepoint.join max(%t0, %t1) // stream.timepoint.export %t01 => %fence struct ConvertTensorBarrierOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; LogicalResult matchAndRewrite(IREE::HAL::TensorBarrierOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); auto timepointType = rewriter.getType(); + IREE::Stream::AffinityAttr anyAffinityAttr; SmallVector signaledResources; SmallVector signaledTimepoints; - for (auto sourceResource : adaptor.getSources()) { - auto source = consumeTensorOperand(op.getLoc(), sourceResource, rewriter); + for (auto [sourceTensor, sourceResource] : + llvm::zip_equal(op.getSources(), adaptor.getSources())) { + auto source = resolveTensorOperand(op.getLoc(), sourceTensor, + sourceResource, rewriter); auto barrierOp = rewriter.create( sourceResource.getLoc(), source.resource.getType(), timepointType, - source.resource, source.resourceSize, affinityAttr); + source.resource, source.resourceSize, source.affinity); signaledResources.push_back(barrierOp.getResult()); signaledTimepoints.push_back(barrierOp.getResultTimepoint()); + + // When joining from multiple affinities we need to pick one to perform + // the chain. For now we do the affinity of the last tensor with the hope + // that we can perform the final signal on the affinity that is running. + // We should instead probably change this to be set after timepoint + // propagation such that we ensure it happens on the final signal when not + // acting as a join. + anyAffinityAttr = source.affinity; } Value joinedTimepoint = IREE::Stream::TimepointJoinOp::join( op.getLoc(), signaledTimepoints, rewriter); rewriter.create( op.getLoc(), joinedTimepoint, ValueRange{adaptor.getSignalFence()}, - affinityAttr); + anyAffinityAttr); rewriter.replaceOp(op, signaledResources); return success(); } @@ -285,21 +290,27 @@ struct ConvertTensorBarrierOp } // namespace -void populateHALToStreamConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns) { +void populateHALToStreamConversionPatterns( + MLIRContext *context, TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns) { typeConverter.addConversion( [](IREE::HAL::BufferViewType type) { return type; }); - patterns.insert(typeConverter, context); - patterns.insert(typeConverter, context); - patterns.insert(typeConverter, context); - patterns.insert(typeConverter, context); + patterns.insert(typeConverter, context, + affinityAnalysis); + patterns.insert(typeConverter, context, + affinityAnalysis); + patterns.insert(typeConverter, context, + affinityAnalysis); + patterns.insert(typeConverter, context, + affinityAnalysis); } -void populateHALToStreamConversionPatterns(MLIRContext *context, - ConversionTarget &conversionTarget, - TypeConverter &typeConverter, - RewritePatternSet &patterns) { +void populateHALToStreamConversionPatterns( + MLIRContext *context, ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns) { // Allow executables through without modification. conversionTarget.addLegalOp(); conversionTarget.markOpRecursivelyLegal(); @@ -315,7 +326,8 @@ void populateHALToStreamConversionPatterns(MLIRContext *context, typeConverter.isLegal(op.getTarget().getType()); }); - populateHALToStreamConversionPatterns(context, typeConverter, patterns); + populateHALToStreamConversionPatterns(context, typeConverter, + affinityAnalysis, patterns); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h index ed2a3c055305..f3e955d3d4b1 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h @@ -11,18 +11,24 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/Transforms/DialectConversion.h" +namespace mlir::iree_compiler::IREE::Stream { +class AffinityAnalysis; +} // namespace mlir::iree_compiler::IREE::Stream + namespace mlir::iree_compiler { // Populates conversion patterns that perform hal->stream conversion. // These patterns ensure that nested types are run through the provided // |typeConverter|. -void populateHALToStreamConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns); -void populateHALToStreamConversionPatterns(MLIRContext *context, - ConversionTarget &conversionTarget, - TypeConverter &typeConverter, - RewritePatternSet &patterns); +void populateHALToStreamConversionPatterns( + MLIRContext *context, TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns); +void populateHALToStreamConversionPatterns( + MLIRContext *context, ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns); } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp index 6bb26f1644e1..fee06f2df4cb 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h" #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" +#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h" #include "iree/compiler/Dialect/Stream/IR/StreamOps.h" #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" @@ -23,13 +24,59 @@ TypedAttr convertAttributeToStream(TypedAttr attr) { return attr; } +IREE::Stream::AffinityAttr +tryLookupGlobalAffinity(Operation *op, + IREE::Stream::AffinityAnalysis *affinityAnalysis) { + return affinityAnalysis->lookupGlobalAffinity(op); +} + +IREE::Stream::AffinityAttr +tryLookupExecutionAffinity(Operation *op, + IREE::Stream::AffinityAnalysis *affinityAnalysis) { + assert(llvm::isa(op) && + "must be an affinity op"); + return affinityAnalysis->lookupExecutionAffinity(op); +} + +IREE::Stream::AffinityAttr +tryLookupResultAffinity(Value value, + IREE::Stream::AffinityAnalysis *affinityAnalysis) { + return affinityAnalysis->lookupResourceAffinity(value); +} + +static std::pair +resolveTensorOperand(Location loc, Value convertedOperand, OpBuilder &builder) { + auto operandType = convertedOperand.getType(); + if (llvm::isa(operandType)) { + // Prior to https://reviews.llvm.org/D111620 this is the path we'd take; + // the tensor operands would be remapped into their new resource types. + // This is still possible during rewriting if we ourselves produce a new + // resource type, but the automatic materialization will go down the + // unrealized_conversion_cast path below. + return std::make_pair(convertedOperand, + builder.createOrFold( + loc, builder.getIndexType(), convertedOperand)); + } else if (auto castOp = + convertedOperand + .getDefiningOp()) { + // We only have a single tensor type conversion and it expands to (resource, + // size) so that's all we look for here. + assert(castOp.getNumOperands() == 2 && "expected (resource, size)"); + return std::make_pair(castOp.getOperand(0), castOp.getOperand(1)); + } + assert(false && + "unexpected operand; expected either a IREE::Stream::ResourceType or " + "the result of a mlir::UnrealizedConversionCastOp"); + return std::make_pair(Value{}, Value{}); +} + void expandResourceOperand(Location loc, Value operand, SmallVectorImpl &newOperands, OpBuilder &builder) { if (llvm::isa(operand.getType())) { - auto value = consumeTensorOperand(loc, operand, builder); - newOperands.push_back(value.resource); - newOperands.push_back(value.resourceSize); + auto [resource, resourceSize] = resolveTensorOperand(loc, operand, builder); + newOperands.push_back(resource); + newOperands.push_back(resourceSize); } else if (llvm::isa(operand.getType())) { newOperands.push_back(operand); newOperands.push_back( @@ -49,34 +96,28 @@ SmallVector expandResourceOperands(Location loc, ValueRange operands, return expandedOperands; } -ConvertedTensor consumeTensorOperand(Location loc, Value operand, - OpBuilder &builder) { - auto operandType = operand.getType(); - if (llvm::isa(operandType)) { - // Prior to https://reviews.llvm.org/D111620 this is the path we'd take; - // the tensor operands would be remapped into their new resource types. - // This is still possible during rewriting if we ourselves produce a new - // resource type, but the automatic materialization will go down the - // unrealized_conversion_cast path below. - return { - operand, - builder.createOrFold( - loc, builder.getIndexType(), operand), - }; - } else if (auto castOp = - operand.getDefiningOp()) { - // We only have a single tensor type conversion and it expands to (resource, - // size) so that's all we look for here. - assert(castOp.getNumOperands() == 2 && "expected (resource, size)"); - return { - castOp.getOperand(0), - castOp.getOperand(1), - }; +ConvertedTensor resolveTensorOperand( + Location loc, Value originalOperand, Value convertedOperand, + IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder) { + auto [resource, resourceSize] = + resolveTensorOperand(loc, convertedOperand, builder); + auto affinityAttr = affinityAnalysis->lookupResourceAffinity(originalOperand); + return {affinityAttr, resource, resourceSize}; +} + +ConvertedTensor transferTensorOperand( + Location loc, Value originalOperand, Value convertedOperand, + IREE::Stream::AffinityAttr requiredAffinityAttr, + IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder) { + auto [resource, resourceSize] = + resolveTensorOperand(loc, convertedOperand, builder); + auto affinityAttr = affinityAnalysis->lookupResourceAffinity(originalOperand); + if (affinityAttr != requiredAffinityAttr) { + resource = builder.create( + loc, resource.getType(), resource, resourceSize, resourceSize, + affinityAttr, requiredAffinityAttr); } - assert(false && - "unexpected operand; expected either a IREE::Stream::ResourceType or " - "the result of a mlir::UnrealizedConversionCastOp"); - return ConvertedTensor(); + return {requiredAffinityAttr, resource, resourceSize}; } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h index fd9249e0801e..43cfbb073494 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.h @@ -7,37 +7,123 @@ #ifndef IREE_COMPILER_DIALECT_STREAM_CONVERSION_PATTERN_UTILS_H_ #define IREE_COMPILER_DIALECT_STREAM_CONVERSION_PATTERN_UTILS_H_ +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +namespace mlir::iree_compiler::IREE::Stream { +class AffinityAnalysis; +} // namespace mlir::iree_compiler::IREE::Stream + namespace mlir::iree_compiler { // Converts a supported attribute type to the corresponding stream dialect // value. Returns the provided value if it is natively supported. TypedAttr convertAttributeToStream(TypedAttr attr); -void expandResourceOperand(Location loc, Value operand, - SmallVectorImpl &newOperands, - OpBuilder &builder); - -SmallVector expandResourceOperands(Location loc, ValueRange operands, - ConversionPatternRewriter &rewriter); +IREE::Stream::AffinityAttr +tryLookupGlobalAffinity(Operation *op, + IREE::Stream::AffinityAnalysis *affinityAnalysis); +IREE::Stream::AffinityAttr +tryLookupExecutionAffinity(Operation *op, + IREE::Stream::AffinityAnalysis *affinityAnalysis); +IREE::Stream::AffinityAttr +tryLookupResultAffinity(Value value, + IREE::Stream::AffinityAnalysis *affinityAnalysis); -// https://reviews.llvm.org/D111620 broke 1->N type expansion during dialect -// conversion. It inserts unrealized_conversion_casts but then passes the -// illegal source dialect types for pattern operands, meaning that even though -// we say tensors are illegal the patterns get the new remapped values as -// tensors. This, naturally, breaks everything. To work around this we have this -// helper that tries to peek through the unrealized_conversion_casts and get out -// the actual values we expected to see from the conversion (and did before that -// change). struct ConvertedTensor { + // Optional affinity of the resource at the time it is consumed. + // May be nullptr if the affinity could not be determined. + IREE::Stream::AffinityAttr affinity; + // Resource storing the tensor. Value resource; + // Size of the resource in bytes. Value resourceSize; }; -ConvertedTensor consumeTensorOperand(Location loc, Value operand, - OpBuilder &builder); + +void expandResourceOperand(Location loc, Value convertedOperand, + SmallVectorImpl &newOperands, + OpBuilder &builder); +SmallVector expandResourceOperands(Location loc, + ValueRange convertedOperands, + ConversionPatternRewriter &rewriter); + +ConvertedTensor resolveTensorOperand( + Location loc, Value originalOperand, Value convertedOperand, + IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder); +ConvertedTensor transferTensorOperand( + Location loc, Value originalOperand, Value convertedOperand, + IREE::Stream::AffinityAttr requiredAffinityAttr, + IREE::Stream::AffinityAnalysis *affinityAnalysis, OpBuilder &builder); + +template +struct AffinityAwareConversionPattern : public OpConversionPattern { +public: + AffinityAwareConversionPattern( + const TypeConverter &typeConverter, MLIRContext *context, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + affinityAnalysis(affinityAnalysis) {} + + IREE::Stream::AffinityAnalysis *getAffinityAnalysis() const { + return affinityAnalysis; + } + +protected: + ConvertedTensor resolveTensorOperand(Location loc, Value originalOperand, + Value convertedOperand, + OpBuilder &builder) const { + return mlir::iree_compiler::resolveTensorOperand( + loc, originalOperand, convertedOperand, affinityAnalysis, builder); + } + + ConvertedTensor + transferTensorOperand(Location loc, Value originalOperand, + Value convertedOperand, + IREE::Stream::AffinityAttr requiredAffinityAttr, + OpBuilder &builder) const { + return mlir::iree_compiler::transferTensorOperand( + loc, originalOperand, convertedOperand, requiredAffinityAttr, + affinityAnalysis, builder); + } + + IREE::Stream::AffinityAttr lookupResultAffinity(Value originalResult) const { + return mlir::iree_compiler::tryLookupResultAffinity(originalResult, + affinityAnalysis); + } + + IREE::Stream::AffinityAnalysis *affinityAnalysis = nullptr; +}; + +template +struct AffinityOpConversionPattern + : public AffinityAwareConversionPattern { +public: + AffinityOpConversionPattern(const TypeConverter &typeConverter, + MLIRContext *context, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + PatternBenefit benefit = 1) + : AffinityAwareConversionPattern(typeConverter, context, + affinityAnalysis, benefit) {} + +protected: + virtual LogicalResult matchAndRewriteOnAffinity( + OpT op, typename OpConversionPattern::OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const = 0; + +private: + LogicalResult + matchAndRewrite(OpT op, typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override final { + auto executionAffinityAttr = + tryLookupExecutionAffinity(op, this->getAffinityAnalysis()); + return matchAndRewriteOnAffinity(op, adaptor, executionAffinityAttr, + rewriter); + } +}; } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel index 646d55e890f8..38ad9eea6afa 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/BUILD.bazel @@ -15,8 +15,6 @@ package( iree_compiler_cc_library( name = "StandardToStream", srcs = [ - "ConvertConstantOps.cpp", - "ConvertStructuralOps.cpp", "Patterns.cpp", ], hdrs = [ diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt index b910c60c1feb..3def71691e94 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/CMakeLists.txt @@ -16,8 +16,6 @@ iree_cc_library( HDRS "Patterns.h" SRCS - "ConvertConstantOps.cpp" - "ConvertStructuralOps.cpp" "Patterns.cpp" DEPS LLVMSupport diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp deleted file mode 100644 index 5ff99f7f84ab..000000000000 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertConstantOps.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h" -#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h" -#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" -#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir::iree_compiler { - -namespace { - -struct ConvertTensorConstantOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Only handle tensor types - other arith.constant types (like i32) are - // ignored. - if (!llvm::isa(constantOp.getType())) - return failure(); - - Type constantType = IREE::Stream::ResourceType::get( - getContext(), IREE::Stream::Lifetime::Constant); - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(constantOp); - auto newOp = rewriter.create( - constantOp.getLoc(), constantType, - convertAttributeToStream(constantOp.getValue()), - TypeAttr::get(constantOp.getType()), - /*result_encoding_dims=*/ValueRange{}, affinityAttr); - - Type unknownType = IREE::Stream::ResourceType::get(getContext()); - auto constantSize = rewriter.createOrFold( - constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult()); - rewriter.replaceOpWithNewOp( - constantOp, unknownType, newOp.getResult(), constantSize, constantSize, - /*source_affinity=*/affinityAttr, - /*result_affinity=*/affinityAttr); - return success(); - } -}; - -} // namespace - -void populateStandardConstantToStreamPatterns( - MLIRContext *context, ConversionTarget &conversionTarget, - TypeConverter &typeConverter, RewritePatternSet &patterns) { - conversionTarget.addDynamicallyLegalOp( - [](arith::ConstantOp op) { - return !llvm::isa(op.getType()); - }); - - patterns.insert(typeConverter, context); -} - -} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp deleted file mode 100644 index 5b29504cb177..000000000000 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStructuralOps.cpp +++ /dev/null @@ -1,406 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h" -#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h" -#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" -#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir::iree_compiler { - -namespace { - -struct BranchOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::cf::BranchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Expand any resource operands to resource + size. - auto expandedOperands = expandResourceOperands( - op.getLoc(), adaptor.getDestOperands(), rewriter); - rewriter.replaceOpWithNewOp(op, op.getDest(), - expandedOperands); - return success(); - } -}; - -struct CondBranchOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Expand any resource operands to resource + size. - auto trueDestOperands = expandResourceOperands( - op.getLoc(), adaptor.getTrueDestOperands(), rewriter); - auto falseDestOperands = expandResourceOperands( - op.getLoc(), adaptor.getFalseDestOperands(), rewriter); - rewriter.replaceOpWithNewOp( - op, adaptor.getCondition(), op.getTrueDest(), trueDestOperands, - op.getFalseDest(), falseDestOperands); - return success(); - } -}; - -static ValueRange asValueRange(ArrayRef values) { return values; } - -struct SwitchOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::cf::SwitchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Expand any resource operands to resource + size. - auto defaultOperands = expandResourceOperands( - op.getLoc(), adaptor.getDefaultOperands(), rewriter); - auto caseOperands = llvm::to_vector( - llvm::map_range(adaptor.getCaseOperands(), [&](ValueRange operands) { - return expandResourceOperands(op.getLoc(), operands, rewriter); - })); - rewriter.replaceOpWithNewOp( - op, adaptor.getFlag(), op.getDefaultDestination(), defaultOperands, - op.getCaseValuesAttr(), op.getCaseDestinations(), - llvm::to_vector(llvm::map_range(caseOperands, asValueRange))); - return success(); - } -}; - -struct SelectOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::arith::SelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Only handle selects where the operands are tensors (resources). - if (!llvm::isa(op.getTrueValue().getType())) - return failure(); - auto trueOperand = - consumeTensorOperand(op.getLoc(), adaptor.getTrueValue(), rewriter); - auto falseOperand = - consumeTensorOperand(op.getLoc(), adaptor.getFalseValue(), rewriter); - auto resourceSelectOp = rewriter.create( - op.getLoc(), adaptor.getCondition(), trueOperand.resource, - falseOperand.resource); - auto sizeSelectOp = rewriter.create( - op.getLoc(), adaptor.getCondition(), trueOperand.resourceSize, - falseOperand.resourceSize); - rewriter.replaceOpWithNewOp( - op, adaptor.getTrueValue().getType(), - ValueRange{resourceSelectOp.getResult(), sizeSelectOp.getResult()}); - return success(); - } -}; - -struct ScfIfOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::scf::IfOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Expand any resource operands to resource + size. - auto expandedOperands = - expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter); - - // Expand any resource results to resource + size. - SmallVector expandedTypes; - struct Result { - size_t originalIndex; - size_t newIndex; - Type newType; - }; - SmallVector resultMap; - for (auto originalType : llvm::enumerate(op.getResultTypes())) { - SmallVector newTypes; - if (failed(getTypeConverter()->convertType(originalType.value(), - newTypes))) { - return rewriter.notifyMatchFailure(op, - "unable to convert result types"); - } - resultMap.push_back( - Result{originalType.index(), expandedTypes.size(), newTypes.front()}); - expandedTypes.append(newTypes); - } - - // Create a new call that takes the expanded input operands and returns the - // expanded output results. We can't directly replace the original call as - // the result counts differ. - auto ifOp = rewriter.create(op.getLoc(), expandedTypes, - op.getCondition()); - - ifOp.getThenRegion().getBlocks().clear(); - rewriter.inlineRegionBefore(op.getThenRegion(), ifOp.getThenRegion(), - ifOp.getThenRegion().end()); - - ifOp.getElseRegion().getBlocks().clear(); - rewriter.inlineRegionBefore(op.getElseRegion(), ifOp.getElseRegion(), - ifOp.getElseRegion().end()); - - // Tie all resource results together so we end up with 1:1 results with the - // original op. - SmallVector results; - for (auto result : resultMap) { - if (llvm::isa(result.newType)) { - auto oldType = op.getResult(result.originalIndex).getType(); - auto resource = ifOp.getResult(result.newIndex + 0); - auto resourceSize = ifOp.getResult(result.newIndex + 1); - results.push_back(rewriter - .create( - op.getLoc(), TypeRange{oldType}, - ValueRange{resource, resourceSize}) - .getResult(0)); - } else { - results.push_back(ifOp.getResult(result.newIndex)); - } - } - rewriter.replaceOp(op, results); - return success(); - } -}; - -struct ScfForOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::scf::ForOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto &typeConverter = *getTypeConverter(); - // Expand any resource operands to resource + size. - auto expandedOperands = - expandResourceOperands(op.getLoc(), adaptor.getInitArgs(), rewriter); - - // Expand any resource results to resource + size. - SmallVector expandedTypes; - struct Result { - size_t originalIndex; - size_t newIndex; - Type newType; - }; - SmallVector resultMap; - for (auto originalType : llvm::enumerate(op.getResultTypes())) { - SmallVector newTypes; - if (failed(getTypeConverter()->convertType(originalType.value(), - newTypes))) { - return rewriter.notifyMatchFailure(op, - "unable to convert result types"); - } - resultMap.push_back( - Result{originalType.index(), expandedTypes.size(), newTypes.front()}); - expandedTypes.append(newTypes); - } - - auto &block = op.getRegion().front(); - TypeConverter::SignatureConversion newSignature(block.getNumArguments()); - for (auto arg : llvm::enumerate(block.getArgumentTypes())) { - if (failed(typeConverter.convertSignatureArg(arg.index(), arg.value(), - newSignature))) { - return failure(); - } - } - - // Create a new loop that takes the expanded input operands and returns the - // expanded output results. We can't directly replace the original loop as - // the result counts differ. - auto forOp = rewriter.create( - op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), - adaptor.getStep(), expandedOperands); - - // Inline the block and update the block arguments. - rewriter.eraseBlock(forOp.getBody()); - rewriter.inlineRegionBefore(op.getRegion(), forOp.getRegion(), - forOp.getRegion().end()); - if (failed(rewriter.convertRegionTypes(&forOp.getRegion(), typeConverter, - &newSignature))) { - return failure(); - } - - // Tie all resource results together so we end up with 1:1 results with the - // original op. - SmallVector results; - for (auto result : resultMap) { - if (llvm::isa(result.newType)) { - auto oldType = op.getResult(result.originalIndex).getType(); - auto resource = forOp.getResult(result.newIndex + 0); - auto resourceSize = forOp.getResult(result.newIndex + 1); - results.push_back(rewriter - .create( - op.getLoc(), TypeRange{oldType}, - ValueRange{resource, resourceSize}) - .getResult(0)); - } else { - results.push_back(forOp.getResult(result.newIndex)); - } - } - rewriter.replaceOp(op, results); - return success(); - } -}; - -struct ScfWhileOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::scf::WhileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto &typeConverter = *getTypeConverter(); - // Expand any resource operands to resource + size. - auto expandedOperands = - expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter); - - // Expand any resource results to resource + size. - SmallVector expandedTypes; - struct Result { - size_t originalIndex; - size_t newIndex; - Type newType; - }; - SmallVector resultMap; - for (auto originalType : llvm::enumerate(op.getResultTypes())) { - SmallVector newTypes; - if (failed(getTypeConverter()->convertType(originalType.value(), - newTypes))) { - return rewriter.notifyMatchFailure(op, - "unable to convert result types"); - } - resultMap.push_back( - Result{originalType.index(), expandedTypes.size(), newTypes.front()}); - expandedTypes.append(newTypes); - } - - TypeConverter::SignatureConversion newSignature(op.getNumOperands()); - for (auto argType : llvm::enumerate(op.getOperandTypes())) { - if (failed(typeConverter.convertSignatureArg( - argType.index(), argType.value(), newSignature))) { - return failure(); - } - } - - // Create a new call that takes the expanded input operands and returns the - // expanded output results. We can't directly replace the original call as - // the result counts differ. - auto whileOp = rewriter.create( - op.getLoc(), expandedTypes, expandedOperands); - - // Inline the `before` block and update the block arguments. - whileOp.getBefore().getBlocks().clear(); - rewriter.inlineRegionBefore(op.getBefore(), whileOp.getBefore(), - whileOp.getBefore().end()); - if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), typeConverter, - &newSignature))) { - return failure(); - } - - // Inline the `after` block and update the block arguments. - whileOp.getAfter().getBlocks().clear(); - rewriter.inlineRegionBefore(op.getAfter(), whileOp.getAfter(), - whileOp.getAfter().end()); - if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(), typeConverter, - &newSignature))) { - return failure(); - } - - // Tie all resource results together so we end up with 1:1 results with the - // original op. - SmallVector results; - for (auto result : resultMap) { - if (llvm::isa(result.newType)) { - auto oldType = op.getResult(result.originalIndex).getType(); - auto resource = whileOp.getResult(result.newIndex + 0); - auto resourceSize = whileOp.getResult(result.newIndex + 1); - results.push_back(rewriter - .create( - op.getLoc(), TypeRange{oldType}, - ValueRange{resource, resourceSize}) - .getResult(0)); - } else { - results.push_back(whileOp.getResult(result.newIndex)); - } - } - rewriter.replaceOp(op, results); - return success(); - } -}; - -struct ScfConditionOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::scf::ConditionOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Expand any resource operands to resource + size. - auto expandedOperands = - expandResourceOperands(op.getLoc(), adaptor.getArgs(), rewriter); - rewriter.replaceOpWithNewOp( - op, adaptor.getCondition(), expandedOperands); - return success(); - } -}; - -struct ScfYieldOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::scf::YieldOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Expand any resource operands to resource + size. - auto expandedOperands = - expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter); - rewriter.replaceOpWithNewOp(op, expandedOperands); - return success(); - } -}; - -} // namespace - -template -static inline void addGenericLegalOp(ConversionTarget &conversionTarget, - TypeConverter &typeConverter) { - conversionTarget.addDynamicallyLegalOp([&](OpT op) { - return llvm::all_of( - op->getOperandTypes(), - [&typeConverter](Type t) { return typeConverter.isLegal(t); }) && - llvm::all_of(op->getResultTypes(), [&typeConverter](Type t) { - return typeConverter.isLegal(t); - }); - }); -} - -void populateStandardStructuralToStreamPatterns( - MLIRContext *context, ConversionTarget &conversionTarget, - TypeConverter &typeConverter, RewritePatternSet &patterns) { - conversionTarget.addLegalOp(); - - // We need to rewrite certain types on operands/results so use the default - // dynamic legality checker to force any ops using such types to run through - // our patterns. - - addGenericLegalOp(conversionTarget, typeConverter); - addGenericLegalOp(conversionTarget, typeConverter); - addGenericLegalOp(conversionTarget, typeConverter); - patterns - .insert( - typeConverter, context); - - addGenericLegalOp(conversionTarget, typeConverter); - patterns.insert(typeConverter, context); - - addGenericLegalOp(conversionTarget, typeConverter); - addGenericLegalOp(conversionTarget, typeConverter); - addGenericLegalOp(conversionTarget, typeConverter); - addGenericLegalOp(conversionTarget, typeConverter); - addGenericLegalOp(conversionTarget, typeConverter); - patterns - .insert(typeConverter, - context); -} - -} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp index 1725fb0ddb5f..9924fd2edf1c 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp @@ -6,26 +6,419 @@ #include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h" +#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h" #include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" #include "iree/compiler/Dialect/Stream/IR/StreamOps.h" #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir::iree_compiler { -void populateStandardConstantToStreamPatterns( - MLIRContext *context, ConversionTarget &conversionTarget, - TypeConverter &typeConverter, RewritePatternSet &patterns); +namespace { -void populateStandardStructuralToStreamPatterns( - MLIRContext *context, ConversionTarget &conversionTarget, - TypeConverter &typeConverter, RewritePatternSet &patterns); +struct ConvertTensorConstantOp + : public AffinityOpConversionPattern { + using AffinityOpConversionPattern::AffinityOpConversionPattern; + LogicalResult matchAndRewriteOnAffinity( + arith::ConstantOp constantOp, OpAdaptor adaptor, + IREE::Stream::AffinityAttr executionAffinityAttr, + ConversionPatternRewriter &rewriter) const override { + // Only handle tensor types - other arith.constant types (like i32) are + // ignored. + if (!llvm::isa(constantOp.getType())) { + return failure(); + } + + auto constantType = rewriter.getType( + IREE::Stream::Lifetime::Constant); + auto newOp = rewriter.create( + constantOp.getLoc(), constantType, + convertAttributeToStream(constantOp.getValue()), + TypeAttr::get(constantOp.getType()), + /*result_encoding_dims=*/ValueRange{}, executionAffinityAttr); + + auto unknownType = rewriter.getType(); + auto constantSize = rewriter.createOrFold( + constantOp.getLoc(), rewriter.getIndexType(), newOp.getResult()); + rewriter.replaceOpWithNewOp( + constantOp, unknownType, newOp.getResult(), constantSize, constantSize, + /*source_affinity=*/executionAffinityAttr, + /*result_affinity=*/executionAffinityAttr); + return success(); + } +}; + +struct BranchOpConversion + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::cf::BranchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expand any resource operands to resource + size. + auto expandedOperands = expandResourceOperands( + op.getLoc(), adaptor.getDestOperands(), rewriter); + rewriter.replaceOpWithNewOp(op, op.getDest(), + expandedOperands); + return success(); + } +}; + +struct CondBranchOpConversion + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expand any resource operands to resource + size. + auto trueDestOperands = expandResourceOperands( + op.getLoc(), adaptor.getTrueDestOperands(), rewriter); + auto falseDestOperands = expandResourceOperands( + op.getLoc(), adaptor.getFalseDestOperands(), rewriter); + rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), op.getTrueDest(), trueDestOperands, + op.getFalseDest(), falseDestOperands); + return success(); + } +}; + +static ValueRange asValueRange(ArrayRef values) { return values; } + +struct SwitchOpConversion + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::cf::SwitchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expand any resource operands to resource + size. + auto defaultOperands = expandResourceOperands( + op.getLoc(), adaptor.getDefaultOperands(), rewriter); + auto caseOperands = llvm::to_vector( + llvm::map_range(adaptor.getCaseOperands(), [&](ValueRange operands) { + return expandResourceOperands(op.getLoc(), operands, rewriter); + })); + rewriter.replaceOpWithNewOp( + op, adaptor.getFlag(), op.getDefaultDestination(), defaultOperands, + op.getCaseValuesAttr(), op.getCaseDestinations(), + llvm::to_vector(llvm::map_range(caseOperands, asValueRange))); + return success(); + } +}; + +struct SelectOpConversion + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only handle selects where the operands are tensors (resources). + if (!llvm::isa(op.getTrueValue().getType())) + return failure(); + auto trueOperand = resolveTensorOperand(op.getLoc(), op.getTrueValue(), + adaptor.getTrueValue(), rewriter); + auto falseOperand = resolveTensorOperand(op.getLoc(), op.getFalseValue(), + adaptor.getFalseValue(), rewriter); + auto resourceSelectOp = rewriter.create( + op.getLoc(), adaptor.getCondition(), trueOperand.resource, + falseOperand.resource); + auto sizeSelectOp = rewriter.create( + op.getLoc(), adaptor.getCondition(), trueOperand.resourceSize, + falseOperand.resourceSize); + rewriter.replaceOpWithNewOp( + op, adaptor.getTrueValue().getType(), + ValueRange{resourceSelectOp.getResult(), sizeSelectOp.getResult()}); + return success(); + } +}; + +struct ScfIfOpConversion + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::scf::IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expand any resource results to resource + size. + SmallVector expandedTypes; + struct Result { + size_t originalIndex; + size_t newIndex; + Type newType; + }; + SmallVector resultMap; + for (auto originalType : llvm::enumerate(op.getResultTypes())) { + SmallVector newTypes; + if (failed(getTypeConverter()->convertType(originalType.value(), + newTypes))) { + return rewriter.notifyMatchFailure(op, + "unable to convert result types"); + } + resultMap.push_back( + Result{originalType.index(), expandedTypes.size(), newTypes.front()}); + expandedTypes.append(newTypes); + } + + // Create a new call that takes the expanded input operands and returns the + // expanded output results. We can't directly replace the original call as + // the result counts differ. + auto ifOp = rewriter.create(op.getLoc(), expandedTypes, + op.getCondition()); + + ifOp.getThenRegion().getBlocks().clear(); + rewriter.inlineRegionBefore(op.getThenRegion(), ifOp.getThenRegion(), + ifOp.getThenRegion().end()); + + ifOp.getElseRegion().getBlocks().clear(); + rewriter.inlineRegionBefore(op.getElseRegion(), ifOp.getElseRegion(), + ifOp.getElseRegion().end()); + + // Tie all resource results together so we end up with 1:1 results with the + // original op. + SmallVector results; + for (auto result : resultMap) { + if (llvm::isa(result.newType)) { + auto oldType = op.getResult(result.originalIndex).getType(); + auto resource = ifOp.getResult(result.newIndex + 0); + auto resourceSize = ifOp.getResult(result.newIndex + 1); + results.push_back(rewriter + .create( + op.getLoc(), TypeRange{oldType}, + ValueRange{resource, resourceSize}) + .getResult(0)); + } else { + results.push_back(ifOp.getResult(result.newIndex)); + } + } + rewriter.replaceOp(op, results); + return success(); + } +}; + +struct ScfForOpConversion + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto &typeConverter = *getTypeConverter(); + + // Expand any resource operands to resource + size. + auto expandedOperands = + expandResourceOperands(op.getLoc(), adaptor.getInitArgs(), rewriter); + + // Expand any resource results to resource + size. + SmallVector expandedTypes; + struct Result { + size_t originalIndex; + size_t newIndex; + Type newType; + }; + SmallVector resultMap; + for (auto originalType : llvm::enumerate(op.getResultTypes())) { + SmallVector newTypes; + if (failed(getTypeConverter()->convertType(originalType.value(), + newTypes))) { + return rewriter.notifyMatchFailure(op, + "unable to convert result types"); + } + resultMap.push_back( + Result{originalType.index(), expandedTypes.size(), newTypes.front()}); + expandedTypes.append(newTypes); + } + + auto &block = op.getRegion().front(); + TypeConverter::SignatureConversion newSignature(block.getNumArguments()); + for (auto arg : llvm::enumerate(block.getArgumentTypes())) { + if (failed(typeConverter.convertSignatureArg(arg.index(), arg.value(), + newSignature))) { + return failure(); + } + } + + // Create a new loop that takes the expanded input operands and returns the + // expanded output results. We can't directly replace the original loop as + // the result counts differ. + auto forOp = rewriter.create( + op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), + adaptor.getStep(), expandedOperands); + + // Inline the block and update the block arguments. + rewriter.eraseBlock(forOp.getBody()); + rewriter.inlineRegionBefore(op.getRegion(), forOp.getRegion(), + forOp.getRegion().end()); + if (failed(rewriter.convertRegionTypes(&forOp.getRegion(), typeConverter, + &newSignature))) { + return failure(); + } + + // Tie all resource results together so we end up with 1:1 results with the + // original op. + SmallVector results; + for (auto result : resultMap) { + if (llvm::isa(result.newType)) { + auto oldType = op.getResult(result.originalIndex).getType(); + auto resource = forOp.getResult(result.newIndex + 0); + auto resourceSize = forOp.getResult(result.newIndex + 1); + results.push_back(rewriter + .create( + op.getLoc(), TypeRange{oldType}, + ValueRange{resource, resourceSize}) + .getResult(0)); + } else { + results.push_back(forOp.getResult(result.newIndex)); + } + } + rewriter.replaceOp(op, results); + return success(); + } +}; + +struct ScfWhileOpConversion + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::scf::WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto &typeConverter = *getTypeConverter(); + + // Expand any resource operands to resource + size. + auto expandedOperands = + expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter); + + // Expand any resource results to resource + size. + SmallVector expandedTypes; + struct Result { + size_t originalIndex; + size_t newIndex; + Type newType; + }; + SmallVector resultMap; + for (auto originalType : llvm::enumerate(op.getResultTypes())) { + SmallVector newTypes; + if (failed(getTypeConverter()->convertType(originalType.value(), + newTypes))) { + return rewriter.notifyMatchFailure(op, + "unable to convert result types"); + } + resultMap.push_back( + Result{originalType.index(), expandedTypes.size(), newTypes.front()}); + expandedTypes.append(newTypes); + } + + TypeConverter::SignatureConversion newSignature(op.getNumOperands()); + for (auto argType : llvm::enumerate(op.getOperandTypes())) { + if (failed(typeConverter.convertSignatureArg( + argType.index(), argType.value(), newSignature))) { + return failure(); + } + } + + // Create a new call that takes the expanded input operands and returns the + // expanded output results. We can't directly replace the original call as + // the result counts differ. + auto whileOp = rewriter.create( + op.getLoc(), expandedTypes, expandedOperands); + + // Inline the `before` block and update the block arguments. + whileOp.getBefore().getBlocks().clear(); + rewriter.inlineRegionBefore(op.getBefore(), whileOp.getBefore(), + whileOp.getBefore().end()); + if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), typeConverter, + &newSignature))) { + return failure(); + } + + // Inline the `after` block and update the block arguments. + whileOp.getAfter().getBlocks().clear(); + rewriter.inlineRegionBefore(op.getAfter(), whileOp.getAfter(), + whileOp.getAfter().end()); + if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(), typeConverter, + &newSignature))) { + return failure(); + } + + // Tie all resource results together so we end up with 1:1 results with the + // original op. + SmallVector results; + for (auto result : resultMap) { + if (llvm::isa(result.newType)) { + auto oldType = op.getResult(result.originalIndex).getType(); + auto resource = whileOp.getResult(result.newIndex + 0); + auto resourceSize = whileOp.getResult(result.newIndex + 1); + results.push_back(rewriter + .create( + op.getLoc(), TypeRange{oldType}, + ValueRange{resource, resourceSize}) + .getResult(0)); + } else { + results.push_back(whileOp.getResult(result.newIndex)); + } + } + rewriter.replaceOp(op, results); + return success(); + } +}; + +struct ScfConditionOpConversion + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::scf::ConditionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expand any resource operands to resource + size. + auto expandedOperands = + expandResourceOperands(op.getLoc(), adaptor.getArgs(), rewriter); + rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), expandedOperands); + return success(); + } +}; + +struct ScfYieldOpConversion + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(mlir::scf::YieldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expand any resource operands to resource + size. + auto expandedOperands = + expandResourceOperands(op.getLoc(), adaptor.getOperands(), rewriter); + rewriter.replaceOpWithNewOp(op, expandedOperands); + return success(); + } +}; + +template +static inline void addGenericLegalOp(ConversionTarget &conversionTarget, + TypeConverter &typeConverter) { + conversionTarget.addDynamicallyLegalOp([&](OpT op) { + return llvm::all_of( + op->getOperandTypes(), + [&typeConverter](Type t) { return typeConverter.isLegal(t); }) && + llvm::all_of(op->getResultTypes(), [&typeConverter](Type t) { + return typeConverter.isLegal(t); + }); + }); +} + +} // namespace void populateStandardToStreamConversionPatterns( MLIRContext *context, ConversionTarget &conversionTarget, - TypeConverter &typeConverter, RewritePatternSet &patterns) { + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns) { typeConverter.addConversion([](IndexType type) { return type; }); typeConverter.addConversion([](IntegerType type) { return type; }); typeConverter.addConversion([](FloatType type) { return type; }); @@ -35,10 +428,38 @@ void populateStandardToStreamConversionPatterns( conversionTarget.addIllegalOp(); - populateStandardConstantToStreamPatterns(context, conversionTarget, - typeConverter, patterns); - populateStandardStructuralToStreamPatterns(context, conversionTarget, - typeConverter, patterns); + conversionTarget.addDynamicallyLegalOp( + [](arith::ConstantOp op) { + return !llvm::isa(op.getType()); + }); + patterns.insert(typeConverter, context, + affinityAnalysis); + + conversionTarget.addLegalOp(); + + // We need to rewrite certain types on operands/results so use the default + // dynamic legality checker to force any ops using such types to run through + // our patterns. + + addGenericLegalOp(conversionTarget, typeConverter); + addGenericLegalOp(conversionTarget, typeConverter); + addGenericLegalOp(conversionTarget, typeConverter); + patterns + .insert( + typeConverter, context, affinityAnalysis); + + addGenericLegalOp(conversionTarget, typeConverter); + patterns.insert(typeConverter, context, affinityAnalysis); + + addGenericLegalOp(conversionTarget, typeConverter); + addGenericLegalOp(conversionTarget, typeConverter); + addGenericLegalOp(conversionTarget, typeConverter); + addGenericLegalOp(conversionTarget, typeConverter); + addGenericLegalOp(conversionTarget, typeConverter); + patterns + .insert( + typeConverter, context, affinityAnalysis); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h index 3314dcfb3461..112e60270feb 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.h @@ -10,6 +10,10 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +namespace mlir::iree_compiler::IREE::Stream { +class AffinityAnalysis; +} // namespace mlir::iree_compiler::IREE::Stream + namespace mlir::iree_compiler { // Populates conversion patterns that perform standard/builtin->stream @@ -17,7 +21,9 @@ namespace mlir::iree_compiler { // provided |typeConverter|. void populateStandardToStreamConversionPatterns( MLIRContext *context, ConversionTarget &conversionTarget, - TypeConverter &typeConverter, RewritePatternSet &patterns); + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns); } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp index 4d1fa5f8a677..35e1ca8760a8 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp @@ -67,8 +67,9 @@ struct FuncOpSignatureConversion } }; -struct CallOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct CallOpConversion + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; LogicalResult matchAndRewrite(IREE::Util::CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -122,8 +123,9 @@ struct CallOpConversion : public OpConversionPattern { } }; -struct ReturnOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ReturnOpConversion + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; LogicalResult matchAndRewrite(IREE::Util::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -142,6 +144,7 @@ struct ReturnOpConversion : public OpConversionPattern { struct ExpandedGlobalResource { IREE::Util::GlobalOp resourceOp; IREE::Util::GlobalOp resourceSizeOp; + IREE::Stream::AffinityAttr affinityAttr; }; struct GlobalExpansionState { @@ -163,13 +166,16 @@ class BaseGlobalConversionPattern : public OpConversionPattern { public: BaseGlobalConversionPattern( std::shared_ptr expansionState, - TypeConverter &typeConverter, MLIRContext *context, + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), - expansionState(std::move(expansionState)) {} + expansionState(std::move(expansionState)), + affinityAnalysis(affinityAnalysis) {} protected: mutable std::shared_ptr expansionState; + IREE::Stream::AffinityAnalysis *affinityAnalysis; }; struct GlobalOpExpansion @@ -230,9 +236,13 @@ struct GlobalOpExpansion globalOp.getIsMutable(), indexType, std::optional{}); resourceSizeOp.setVisibility(globalOp.getVisibility()); + // Resolve the affinity of the global. + // We require this to be a single value today that is usually chosen from + // consumers (we take the hit on transfer from producers if needed). + auto affinityAttr = tryLookupGlobalAffinity(globalOp, affinityAnalysis); + // Materialize the initializer if we need to setup a tensor-like constant. if (tensorInitializerRequired) { - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(globalOp); auto initializerOp = rewriter.create(globalOp.getLoc()); auto *entryBlock = rewriter.createBlock(&initializerOp.getBody()); @@ -265,6 +275,7 @@ struct GlobalOpExpansion expansionState->globalMap[globalOp.getSymName()] = ExpandedGlobalResource{ resourceOp, resourceSizeOp, + affinityAttr, }; return success(); @@ -289,7 +300,7 @@ struct GlobalLoadOpExpansion auto &expandedGlobal = expandedGlobalIt->getSecond(); // Insert a load/transfer to the unknown resource lifetime. - auto unknownType = IREE::Stream::ResourceType::get(rewriter.getContext()); + auto unknownType = rewriter.getType(); auto resource = rewriter .create( @@ -303,8 +314,8 @@ struct GlobalLoadOpExpansion .getResult(); rewriter.replaceOpWithNewOp( loadOp, unknownType, resource, resourceSize, resourceSize, - /*source_affinity=*/nullptr, - /*result_affinity=*/nullptr); + /*source_affinity=*/expandedGlobal.affinityAttr, + /*result_affinity=*/expandedGlobal.affinityAttr); return success(); } @@ -330,12 +341,14 @@ struct GlobalStoreOpExpansion // Insert a transfer/store to the global with unknown lifetime. Lifetime // refinement will make this go away if possible. auto value = - consumeTensorOperand(storeOp.getLoc(), adaptor.getValue(), rewriter); + resolveTensorOperand(storeOp.getLoc(), storeOp.getValue(), + adaptor.getValue(), affinityAnalysis, rewriter); assert(expandedGlobal.resourceOp && "Missing resource op"); auto transferOp = rewriter.create( storeOp.getLoc(), expandedGlobal.resourceOp.getType(), value.resource, - value.resourceSize, value.resourceSize, /*source_affinity=*/nullptr, - /*result_affinity=*/nullptr); + value.resourceSize, value.resourceSize, + /*source_affinity=*/value.affinity, + /*result_affinity=*/expandedGlobal.affinityAttr); rewriter.replaceOpWithNewOp( storeOp, transferOp.getResult(), expandedGlobal.resourceOp.getSymName()); @@ -347,30 +360,59 @@ struct GlobalStoreOpExpansion } }; +struct OptimizationBarrierOpConversion + : public AffinityAwareConversionPattern { + using AffinityAwareConversionPattern::AffinityAwareConversionPattern; + LogicalResult + matchAndRewrite(IREE::Util::OptimizationBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector newOperands; + for (auto [originalOperand, convertedOperand] : + llvm::zip_equal(op.getOperands(), adaptor.getOperands())) { + if (isa(convertedOperand.getType())) { + newOperands.push_back(resolveTensorOperand(op.getLoc(), originalOperand, + convertedOperand, rewriter) + .resource); + } else { + newOperands.push_back(convertedOperand); + } + } + rewriter.replaceOpWithNewOp(op, + newOperands); + return success(); + } +}; + } // namespace -void populateUtilToStreamConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns - .insert( - typeConverter, context); +void populateUtilToStreamConversionPatterns( + MLIRContext *context, TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns) { + patterns.insert(typeConverter, context); + patterns.insert(typeConverter, context, + affinityAnalysis); auto expansionState = std::make_shared(); // TODO(#7432): add indirect global expansion support to streams. patterns .insert( - expansionState, typeConverter, context); + expansionState, typeConverter, affinityAnalysis, context); patterns.add, GenericConvertTypesPattern, GenericConvertTypesPattern>( typeConverter, context); + + patterns.insert(typeConverter, context, + affinityAnalysis, + /*benefit=*/2); } -void populateUtilToStreamConversionPatterns(MLIRContext *context, - ConversionTarget &conversionTarget, - TypeConverter &typeConverter, - RewritePatternSet &patterns) { +void populateUtilToStreamConversionPatterns( + MLIRContext *context, ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns) { typeConverter.addConversion([=](IREE::Util::PtrType type, SmallVectorImpl &resultTypes) { // Expand pointers to tensors to [resource, sizeof resource] pointers. @@ -432,7 +474,8 @@ void populateUtilToStreamConversionPatterns(MLIRContext *context, return typeConverter.isLegal(op.getResultTypes()); }); - populateUtilToStreamConversionPatterns(context, typeConverter, patterns); + populateUtilToStreamConversionPatterns(context, typeConverter, + affinityAnalysis, patterns); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h index 5673c74a7f98..56fcca297924 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.h @@ -11,18 +11,24 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +namespace mlir::iree_compiler::IREE::Stream { +class AffinityAnalysis; +} // namespace mlir::iree_compiler::IREE::Stream + namespace mlir::iree_compiler { // Populates conversion patterns that perform util->stream conversion. // These patterns ensure that nested types are run through the provided // |typeConverter|. -void populateUtilToStreamConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns); -void populateUtilToStreamConversionPatterns(MLIRContext *context, - ConversionTarget &conversionTarget, - TypeConverter &typeConverter, - RewritePatternSet &patterns); +void populateUtilToStreamConversionPatterns( + MLIRContext *context, TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns); +void populateUtilToStreamConversionPatterns( + MLIRContext *context, ConversionTarget &conversionTarget, + TypeConverter &typeConverter, + IREE::Stream::AffinityAnalysis *affinityAnalysis, + RewritePatternSet &patterns); } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp index b0b66ac08667..91f5c0ffff3f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" +#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h" #include "iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.h" #include "iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h" #include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h" @@ -39,79 +40,6 @@ namespace mlir::iree_compiler::IREE::Stream { namespace { -// Builds a stream.tensor.import op that imports an external tensor value into -// a stream resource. |consumingOps| will be populated with all ops that consume -// the original |sourceTensor| and that should not be replaced with the returned -// value. -static Value buildTensorImportOp(Location loc, Value sourceTensor, - Type targetType, - SmallPtrSetImpl &consumingOps, - IREE::Stream::AffinityAttr affinityAttr, - OpBuilder &builder) { - // Gather dynamic dimensions from the input value. - auto dynamicDims = - IREE::Util::buildDynamicDimsForValue(loc, sourceTensor, builder); - - // Compute the size of the tensor once in the stream resource. - // This may differ from the external encoding of the tensor as imports are - // a transfer operation that may need to reformat the tensor. - auto encodingAttr = TypeAttr::get(sourceTensor.getType()); - Value resultSize = builder.create( - loc, builder.getIndexType(), encodingAttr, dynamicDims, affinityAttr); - - // Associate the external SSA value, encoding, and shape information with the - // stream resource. When lowering we'll then have all the metadata required - // even after erasing it all on the resource. - auto externalType = builder.getType( - IREE::Stream::Lifetime::External); - auto importOp = builder.create( - loc, externalType, sourceTensor, encodingAttr, dynamicDims, resultSize, - affinityAttr); - consumingOps.insert(importOp); - - // If needed insert a transfer to the target lifetime. - Value result = importOp.getResult(); - if (targetType != externalType) { - result = builder - .create( - loc, targetType, result, resultSize, resultSize, - /*source_affinity=*/affinityAttr, - /*result_affinity=*/affinityAttr) - .getResult(); - } - - auto castOp = builder.create( - loc, sourceTensor.getType(), ValueRange{result, resultSize}); - consumingOps.insert(castOp); - return castOp.getResult(0); -} - -// Builds a stream.tensor.export op that exports a stream resource into an -// external tensor value. -static Value buildTensorExportOp(Location loc, Value sourceValue, - TensorType targetType, ValueRange dynamicDims, - IREE::Stream::AffinityAttr affinityAttr, - OpBuilder &builder) { - auto source = consumeTensorOperand(loc, sourceValue, builder); - - // If needed insert a transfer to external resource lifetime. - auto externalType = builder.getType( - IREE::Stream::Lifetime::External); - if (source.resource.getType() != externalType) { - source.resource = builder.create( - loc, externalType, source.resource, source.resourceSize, - source.resourceSize, - /*source_affinity=*/nullptr, - /*result_affinity=*/affinityAttr); - } - - // Associate the stream resource and external encoding and shape information. - auto newOp = builder.create( - loc, targetType, source.resource, TypeAttr::get(targetType), dynamicDims, - source.resourceSize, affinityAttr); - return newOp.getResult(); -} - // Returns true if |op| has tensor I/O that is not yet imported/exported using // the stream ops that capture encodings and shapes. static bool doesOperationNeedWrapping(Operation *op) { @@ -123,8 +51,9 @@ static bool doesOperationNeedWrapping(Operation *op) { operand.getDefiningOp()); }) || llvm::any_of(op->getResults(), [](Value result) { - if (!isa(result.getType())) + if (!isa(result.getType())) { return false; + } return !llvm::all_of(result.getUsers(), llvm::IsaPred); }); @@ -133,15 +62,19 @@ static bool doesOperationNeedWrapping(Operation *op) { // Fallback handler for unknown ops taking/returning tensors that need to be // marshaled into/outof stream resource types. struct GenericResourcePattern : public ConversionPattern { - GenericResourcePattern(MLIRContext *context, TypeConverter &converter) - : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {} + GenericResourcePattern(MLIRContext *context, TypeConverter &converter, + IREE::Stream::AffinityAnalysis *affinityAnalysis) + : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context), + affinityAnalysis(affinityAnalysis) {} + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!doesOperationNeedWrapping(op)) + if (!doesOperationNeedWrapping(op)) { return failure(); + } - auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + auto executionAffinityAttr = affinityAnalysis->inferExecutionAffinity(op); // Export resources into tensor operands for the op to consume. SmallVector newOperands; @@ -156,11 +89,14 @@ struct GenericResourcePattern : public ConversionPattern { auto tensorType = dyn_cast(oldOperand.getType()); assert(tensorType && "must have a tensor type to map to a resource"); + auto exportAffinityAttr = + affinityAnalysis->lookupResourceAffinity(oldOperand); auto dynamicDims = IREE::Util::buildDynamicDimsForValue( op->getLoc(), oldOperand, rewriter); - newOperands.push_back(buildTensorExportOp(op->getLoc(), newOperand, - tensorType, dynamicDims, - affinityAttr, rewriter)); + newOperands.push_back(buildTensorExportOp( + op->getLoc(), oldOperand, newOperand, tensorType, dynamicDims, + exportAffinityAttr ? exportAffinityAttr : executionAffinityAttr, + rewriter)); } rewriter.modifyOpInPlace(op, [&]() { op->setOperands(newOperands); }); @@ -168,46 +104,107 @@ struct GenericResourcePattern : public ConversionPattern { rewriter.setInsertionPointAfter(op); for (auto result : op->getResults()) { auto tensorType = dyn_cast(result.getType()); - if (!tensorType) + if (!tensorType) { continue; + } + auto importAffinityAttr = + affinityAnalysis->lookupResourceAffinity(result); auto dynamicDims = IREE::Util::buildDynamicDimsForValue(op->getLoc(), result, rewriter); SmallPtrSet consumingOps; auto importedValue = buildTensorImportOp( op->getLoc(), result, rewriter.getType(), - consumingOps, affinityAttr, rewriter); + consumingOps, + importAffinityAttr ? importAffinityAttr : executionAffinityAttr, + rewriter); result.replaceAllUsesExcept(importedValue, consumingOps); } return success(); } -}; -struct OptimizationBarrierOpConversion - : public OpConversionPattern { - using OpConversionPattern< - IREE::Util::OptimizationBarrierOp>::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Util::OptimizationBarrierOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector newOperands; - for (Value v : adaptor.getOperands()) { - if (isa(v.getType())) { - newOperands.push_back( - consumeTensorOperand(op.getLoc(), v, rewriter).resource); - } else { - newOperands.push_back(v); - } + // Builds a stream.tensor.export op that exports a stream resource into an + // external tensor value. + Value buildTensorExportOp(Location loc, Value originalValue, + Value convertedValue, TensorType targetType, + ValueRange dynamicDims, + IREE::Stream::AffinityAttr executionAffinityAttr, + OpBuilder &builder) const { + auto source = + transferTensorOperand(loc, originalValue, convertedValue, + executionAffinityAttr, affinityAnalysis, builder); + + // If needed insert a transfer to external resource lifetime. + auto externalType = builder.getType( + IREE::Stream::Lifetime::External); + if (source.resource.getType() != externalType) { + source.resource = builder.create( + loc, externalType, source.resource, source.resourceSize, + source.resourceSize, + /*source_affinity=*/source.affinity, + /*result_affinity=*/executionAffinityAttr); } - rewriter.replaceOpWithNewOp(op, - newOperands); - return success(); + + // Associate the stream resource and external encoding and shape + // information. + auto newOp = builder.create( + loc, targetType, source.resource, TypeAttr::get(targetType), + dynamicDims, source.resourceSize, executionAffinityAttr); + return newOp.getResult(); + } + + // Builds a stream.tensor.import op that imports an external tensor value into + // a stream resource. |consumingOps| will be populated with all ops that + // consume the original |sourceTensor| and that should not be replaced with + // the returned value. + Value buildTensorImportOp(Location loc, Value sourceTensor, Type targetType, + SmallPtrSetImpl &consumingOps, + IREE::Stream::AffinityAttr executionAffinityAttr, + OpBuilder &builder) const { + // Gather dynamic dimensions from the input value. + auto dynamicDims = + IREE::Util::buildDynamicDimsForValue(loc, sourceTensor, builder); + + // Compute the size of the tensor once in the stream resource. + // This may differ from the external encoding of the tensor as imports are + // a transfer operation that may need to reformat the tensor. + auto encodingAttr = TypeAttr::get(sourceTensor.getType()); + Value resultSize = builder.create( + loc, builder.getIndexType(), encodingAttr, dynamicDims, + executionAffinityAttr); + + // Associate the external SSA value, encoding, and shape information with + // the stream resource. When lowering we'll then have all the metadata + // required even after erasing it all on the resource. + auto externalType = builder.getType( + IREE::Stream::Lifetime::External); + auto importOp = builder.create( + loc, externalType, sourceTensor, encodingAttr, dynamicDims, resultSize, + executionAffinityAttr); + consumingOps.insert(importOp); + + // If needed insert a transfer to the target lifetime. + Value result = importOp.getResult(); + if (targetType != externalType) { + result = builder + .create( + loc, targetType, result, resultSize, resultSize, + /*source_affinity=*/executionAffinityAttr, + /*result_affinity=*/executionAffinityAttr) + .getResult(); + } + + auto castOp = builder.create( + loc, sourceTensor.getType(), ValueRange{result, resultSize}); + consumingOps.insert(castOp); + return castOp.getResult(0); } + + IREE::Stream::AffinityAnalysis *affinityAnalysis = nullptr; }; static void stripAffinityAttrs(ModuleOp moduleOp) { - moduleOp->removeAttr("stream.affinity.default"); auto affinityName = StringAttr::get(moduleOp.getContext(), "stream.affinity"); for (auto &op : moduleOp.getOps()) { op.removeDiscardableAttr(affinityName); @@ -223,6 +220,13 @@ struct ConvertToStreamPass final void runOnOperation() override { auto *context = &getContext(); + // Run affinity analysis so that the required producer/consumer affinities + // for all SSA values we'll use during conversion are available. + AffinityAnalysis affinityAnalysis(getOperation()); + if (failed(affinityAnalysis.run())) { + return signalPassFailure(); + } + TypeConverter typeConverter; ConversionTarget conversionTarget(getContext()); RewritePatternSet patterns(&getContext()); @@ -235,10 +239,9 @@ struct ConvertToStreamPass final // Allow unknown types to pass through; these come from custom dialects that // may be mixed into the IR we are converting. typeConverter.addConversion([=](Type type) -> Type { - // convert flow.channel into stream.channel - if (llvm::isa(type)) + if (llvm::isa(type)) { return IREE::Stream::ChannelType::get(context); - + } return !llvm::isa(type) ? type : Type{}; }); @@ -275,21 +278,20 @@ struct ConvertToStreamPass final populateUtilConversionPatterns(context, conversionTarget, typeConverter, patterns); - populateUtilToStreamConversionPatterns(context, conversionTarget, - typeConverter, patterns); + populateUtilToStreamConversionPatterns( + context, conversionTarget, typeConverter, &affinityAnalysis, patterns); - populateStandardToStreamConversionPatterns(context, conversionTarget, - typeConverter, patterns); - populateFlowToStreamConversionPatterns(context, conversionTarget, - typeConverter, patterns); - populateHALToStreamConversionPatterns(context, conversionTarget, - typeConverter, patterns); + populateStandardToStreamConversionPatterns( + context, conversionTarget, typeConverter, &affinityAnalysis, patterns); + populateFlowToStreamConversionPatterns( + context, conversionTarget, typeConverter, &affinityAnalysis, patterns); + populateHALToStreamConversionPatterns( + context, conversionTarget, typeConverter, &affinityAnalysis, patterns); conversionTarget.markUnknownOpDynamicallyLegal( [&](Operation *op) -> bool { return !doesOperationNeedWrapping(op); }); - patterns.insert(context, typeConverter); - patterns.insert(typeConverter, context, - /*benefit=*/2); + patterns.insert(context, typeConverter, + &affinityAnalysis); // NOTE: we allow ops that we don't know about to allow custom dialects // that don't need anything Stream-specific to pass through. diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp index 7e9c2314f8d1..6732b9775cf0 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp @@ -399,7 +399,8 @@ static bool isSafeToElideCloneOp(IREE::Stream::AsyncCloneOp cloneOp, if (sourceType != targetType && sourceType.getLifetime() == IREE::Stream::Lifetime::Constant) { LLVM_DEBUG(llvm::dbgs() - << " - clone source is a constant; cannot elide\n"); + << " - clone is a resource lifetime cast (" << sourceType + << " to " << targetType << "); cannot elide\n"); return false; } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp index f3cd8271a0c1..2db2cd69f0d9 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/EmplaceAllocations.cpp @@ -33,7 +33,9 @@ namespace { // Emplacement //===----------------------------------------------------------------------===// -static void replaceUsesAndTransfer(Value oldValue, Value newValue) { +static void +replaceUsesAndTransfer(Value oldValue, Value newValue, + IREE::Stream::AffinityAttr usageAffinityAttr) { assert(isa(oldValue.getType())); assert(isa(newValue.getType())); if (oldValue.getType() == newValue.getType()) { @@ -44,8 +46,8 @@ static void replaceUsesAndTransfer(Value oldValue, Value newValue) { builder.setInsertionPointAfterValue(newValue); Value newValueSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( newValue.getLoc(), newValue, builder); - IREE::Stream::AffinityAttr sourceAffinity; - IREE::Stream::AffinityAttr resultAffinity; + IREE::Stream::AffinityAttr sourceAffinity = usageAffinityAttr; + IREE::Stream::AffinityAttr resultAffinity = usageAffinityAttr; Value transferValue = builder.create( newValue.getLoc(), oldValue.getType(), newValue, newValueSize, newValueSize, sourceAffinity, resultAffinity); @@ -74,7 +76,7 @@ static bool tryEmplaceDispatchOp(IREE::Stream::AsyncDispatchOp dispatchOp) { break; } - // Find potential. + // Find potential update to place the dispatch result into. Value targetResource; Value targetResourceSize; Value targetOffset; @@ -82,12 +84,22 @@ static bool tryEmplaceDispatchOp(IREE::Stream::AsyncDispatchOp dispatchOp) { Value targetLength; Value targetResult; Value targetResultSize; + Attribute targetAffinityAttr; Operation *userOp = *result.user_begin(); if (auto updateOp = dyn_cast(userOp)) { if (updateOp.getUpdate() != result) { // TODO(#14566): continue if sparse emplacement on multiple results. break; } + + // Currently only allow exactly matching affinities. + // TODO(multi-device): memory compatibility - if compatible then allow. + if (updateOp.getAffinityAttr() != dispatchOp.getAffinityAttr()) { + continue; + } + + // Try to move all SSA values required into the appropriate place. + // TODO(benvanik): undo this if there's a failure (or record/roll-back). if (!IREE::Util::tryMoveProducerBefore(updateOp.getUpdateSize(), dispatchOp) || !IREE::Util::tryMoveProducerBefore(updateOp.getTargetSize(), @@ -102,6 +114,7 @@ static bool tryEmplaceDispatchOp(IREE::Stream::AsyncDispatchOp dispatchOp) { // TODO(#14566): continue if sparse emplacement on multiple results. break; } + targetResource = updateOp.getTarget(); if (targetResource.getDefiningOp() == dispatchOp) { // NOTE: we may have already replaced the update target with one of our @@ -115,6 +128,7 @@ static bool tryEmplaceDispatchOp(IREE::Stream::AsyncDispatchOp dispatchOp) { targetLength = updateOp.getUpdateSize(); targetResult = updateOp.getResult(); targetResultSize = updateOp.getTargetSize(); + targetAffinityAttr = updateOp.getAffinityAttr(); } if (!targetResource) { // TODO(#14566): continue if sparse emplacement on multiple results. @@ -136,7 +150,7 @@ static bool tryEmplaceDispatchOp(IREE::Stream::AsyncDispatchOp dispatchOp) { dispatchOp.getResultSizesMutable().assign(resultSizes); // Replace users with the result of the dispatch op. - replaceUsesAndTransfer(targetResult, result); + replaceUsesAndTransfer(targetResult, result, dispatchOp.getAffinityAttr()); userOp->erase(); didChange = true; diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir index a3eae92ad9b5..b11839b4139b 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-hal-device-assignment-pipeline --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s #executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}> #map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)> @@ -57,7 +57,7 @@ module attributes {hal.device.targets = [#device_target_vulkan]} { } } -// vulkan uses default materialization patterns which unsets the encodings. +// Vulkan uses default materialization patterns which unsets the encodings. // CHECK-LABEL: util.func public @lhs_encoding // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK: util.return %[[ARG0]] diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir index 4842e51df003..5354b8277f4e 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/parameter_ops.mlir @@ -1,5 +1,7 @@ // RUN: iree-opt --split-input-file --iree-hal-conversion --canonicalize %s | FileCheck %s +util.global private @device : !hal.device + // CHECK-LABEL: @parameterLoad // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence) -> (!hal.buffer, !hal.buffer, !hal.fence) util.func public @parameterLoad(%wait: !stream.timepoint) -> (!stream.resource, !stream.resource, !stream.timepoint) { @@ -7,7 +9,7 @@ util.func public @parameterLoad(%wait: !stream.timepoint) -> (!stream.resource affinity(%[[AFFINITY]]) @@ -15,7 +17,7 @@ util.func public @parameterLoad(%wait: !stream.timepoint) -> (!stream.resource { + %results:2, %result_timepoint = stream.parameter.load on(#hal.device.affinity<@device>) await(%wait) => { "scope"::"key0"[%c50_i64] : !stream.resource{%c100}, "scope"::"key1"[%c51_i64] : !stream.resource{%c101} } => !stream.timepoint @@ -25,19 +27,21 @@ util.func public @parameterLoad(%wait: !stream.timepoint) -> (!stream.resource (!hal.buffer, !hal.fence) util.func public @parameterLoadNoScope(%wait: !stream.timepoint) -> (!stream.resource, !stream.timepoint) { %c50_i64 = arith.constant 50 : i64 %c100 = arith.constant 100 : index - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device) // CHECK: %[[BUFFER:.+]] = io_parameters.load<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]]) // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]]) // CHECK-SAME: type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage|SharingImmutable") // CHECK-NEXT: "key"[%c50_i64] : !hal.buffer{%c100} - %result, %result_timepoint = stream.parameter.load await(%wait) => { + %result, %result_timepoint = stream.parameter.load on(#hal.device.affinity<@device>) await(%wait) => { "key"[%c50_i64] : !stream.resource{%c100} } => !stream.timepoint // CHECK: return %[[BUFFER]], %[[SIGNAL]] @@ -46,6 +50,8 @@ util.func public @parameterLoadNoScope(%wait: !stream.timepoint) -> (!stream.res // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @parameterRead // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence util.func public @parameterRead(%wait: !stream.timepoint, %target: !stream.resource) -> !stream.timepoint { @@ -53,19 +59,21 @@ util.func public @parameterRead(%wait: !stream.timepoint, %target: !stream.resou %c100 = arith.constant 100 : index %c200 = arith.constant 200 : index %c300 = arith.constant 300 : index - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device) // CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]]) // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]]) // CHECK-NEXT: "scope"::"key"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !hal.buffer - %timepoint = stream.parameter.read await(%wait) => "scope"::"key"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource{%c300} => !stream.timepoint + %timepoint = stream.parameter.read on(#hal.device.affinity<@device>) await(%wait) => "scope"::"key"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource{%c300} => !stream.timepoint // CHECK: return %[[SIGNAL]] util.return %timepoint : !stream.timepoint } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @parameterWrite // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[SOURCE:.+]]: !hal.buffer) -> !hal.fence util.func public @parameterWrite(%wait: !stream.timepoint, %source: !stream.resource) -> !stream.timepoint { @@ -73,19 +81,21 @@ util.func public @parameterWrite(%wait: !stream.timepoint, %source: !stream.reso %c100 = arith.constant 100 : index %c200 = arith.constant 200 : index %c300 = arith.constant 300 : index - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device) // CHECK: io_parameters.scatter<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]]) // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]]) // CHECK-NEXT: %[[SOURCE]][%c100 for %c200] : !hal.buffer -> "scope"::"key"[%c50_i64] - %timepoint = stream.parameter.write await(%wait) => %source[%c100 for %c200] : !stream.resource{%c300} -> "scope"::"key"[%c50_i64] => !stream.timepoint + %timepoint = stream.parameter.write on(#hal.device.affinity<@device>) await(%wait) => %source[%c100 for %c200] : !stream.resource{%c300} -> "scope"::"key"[%c50_i64] => !stream.timepoint // CHECK: return %[[SIGNAL]] util.return %timepoint : !stream.timepoint } // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @parameterGather // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence util.func public @parameterGather(%wait: !stream.timepoint, %target: !stream.resource) -> !stream.timepoint { @@ -99,7 +109,7 @@ util.func public @parameterGather(%wait: !stream.timepoint, %target: !stream.res %c201 = arith.constant 201 : index %c202 = arith.constant 202 : index %c300 = arith.constant 300 : index - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device) // CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]]) @@ -107,7 +117,7 @@ util.func public @parameterGather(%wait: !stream.timepoint, %target: !stream.res // CHECK-NEXT: "scope"::"key0"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !hal.buffer, // CHECK-NEXT: "scope"::"key1"[%c51_i64] -> %[[TARGET]][%c101 for %c201] : !hal.buffer, // CHECK-NEXT: "scope"::"key2"[%c52_i64] -> %[[TARGET]][%c102 for %c202] : !hal.buffer - %timepoint = stream.parameter.gather await(%wait) => { + %timepoint = stream.parameter.gather on(#hal.device.affinity<@device>) await(%wait) => { "scope"::"key0"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource{%c300}, "scope"::"key1"[%c51_i64] -> %target[%c101 for %c201] : !stream.resource{%c300}, "scope"::"key2"[%c52_i64] -> %target[%c102 for %c202] : !stream.resource{%c300} @@ -118,6 +128,8 @@ util.func public @parameterGather(%wait: !stream.timepoint, %target: !stream.res // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @parameterGatherNoScope // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[TARGET:.+]]: !hal.buffer) -> !hal.fence util.func public @parameterGatherNoScope(%wait: !stream.timepoint, %target: !stream.resource) -> !stream.timepoint { @@ -128,14 +140,14 @@ util.func public @parameterGatherNoScope(%wait: !stream.timepoint, %target: !str %c200 = arith.constant 200 : index %c201 = arith.constant 201 : index %c300 = arith.constant 300 : index - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device) // CHECK: io_parameters.gather<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]]) // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]]) // CHECK-NEXT: "key0"[%c50_i64] -> %[[TARGET]][%c100 for %c200] : !hal.buffer, // CHECK-NEXT: "key1"[%c51_i64] -> %[[TARGET]][%c101 for %c201] : !hal.buffer - %timepoint = stream.parameter.gather await(%wait) => { + %timepoint = stream.parameter.gather on(#hal.device.affinity<@device>) await(%wait) => { "key0"[%c50_i64] -> %target[%c100 for %c200] : !stream.resource{%c300}, "key1"[%c51_i64] -> %target[%c101 for %c201] : !stream.resource{%c300} } => !stream.timepoint @@ -145,6 +157,8 @@ util.func public @parameterGatherNoScope(%wait: !stream.timepoint, %target: !str // ----- +util.global private @device : !hal.device + // CHECK-LABEL: @parameterScatter // CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[SOURCE:.+]]: !hal.buffer) -> !hal.fence util.func public @parameterScatter(%wait: !stream.timepoint, %source: !stream.resource) -> !stream.timepoint { @@ -158,7 +172,7 @@ util.func public @parameterScatter(%wait: !stream.timepoint, %source: !stream.re %c201 = arith.constant 201 : index %c202 = arith.constant 202 : index %c300 = arith.constant 300 : index - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 // CHECK-DAG: %[[SIGNAL:.+]] = hal.fence.create device(%[[DEVICE]] : !hal.device) // CHECK: io_parameters.scatter<%[[DEVICE]] : !hal.device> affinity(%[[AFFINITY]]) @@ -167,7 +181,7 @@ util.func public @parameterScatter(%wait: !stream.timepoint, %source: !stream.re // CHECK-NEXT: %[[SOURCE]][%c101 for %c201] : !hal.buffer -> "scope"::"key1"[%c51_i64], // CHECK-NEXT: %[[SOURCE]][%c102 for %c202] : !hal.buffer -> "scope"::"key2"[%c52_i64] // CHECK-NEXT: } - %timepoint = stream.parameter.scatter await(%wait) => { + %timepoint = stream.parameter.scatter on(#hal.device.affinity<@device>) await(%wait) => { %source[%c100 for %c200] : !stream.resource{%c300} -> "scope"::"key0"[%c50_i64], %source[%c101 for %c201] : !stream.resource{%c300} -> "scope"::"key1"[%c51_i64], %source[%c102 for %c202] : !stream.resource{%c300} -> "scope"::"key2"[%c52_i64] diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp index 0d730cd82a24..bd56dc85f19f 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp @@ -136,18 +136,6 @@ void buildIREEPrecompileTransformPassPipeline( if (compileTo == IREEVMPipelinePhase::Input) return; // early-exit - // If the user specified a set of target devices we attach them to the module - // IR so that they are available for all passes that may want to use this - // information. If trying to compile in a generic mode the user should omit - // specifying targets. - IREE::HAL::AssignmentOptions halAssignmentOptions; - halAssignmentOptions.legacyTargetBackends = - halTargetOptions.legacyTargetBackends; - halAssignmentOptions.targetDevices = halTargetOptions.targetDevices; - halAssignmentOptions.defaultDevice = halTargetOptions.defaultDevice; - IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry, - halAssignmentOptions); - // Now that inputs are legalized, generate wrapper for entry functions. if (compileFrom < IREEVMPipelinePhase::ABI) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "ABI"); @@ -172,6 +160,18 @@ void buildIREEPrecompileTransformPassPipeline( if (compileTo == IREEVMPipelinePhase::ABI) return; // early-exit + // If the user specified a set of target devices we attach them to the module + // IR so that they are available for all passes that may want to use this + // information. If trying to compile in a generic mode the user should omit + // specifying targets. + IREE::HAL::AssignmentOptions halAssignmentOptions; + halAssignmentOptions.legacyTargetBackends = + halTargetOptions.legacyTargetBackends; + halAssignmentOptions.targetDevices = halTargetOptions.targetDevices; + halAssignmentOptions.defaultDevice = halTargetOptions.defaultDevice; + IREE::HAL::buildHALDeviceAssignmentPassPipeline(passManager, targetRegistry, + halAssignmentOptions); + GlobalOptimization::TransformOptions globalTransformOptions; globalTransformOptions.options = globalOptimizationOptions; diff --git a/tools/test/compile_pipelines.mlir b/tools/test/compile_pipelines.mlir index 2fd4a6c90beb..fb6dbbe05abf 100644 --- a/tools/test/compile_pipelines.mlir +++ b/tools/test/compile_pipelines.mlir @@ -1,10 +1,10 @@ // RUN: iree-opt --iree-common-input-transformation-pipeline %s | \ // RUN: iree-opt --iree-abi-transformation-pipeline - | \ -// RUN: iree-opt --iree-common-input-transformation-pipeline - | \ +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-device-assignment-pipeline{target-devices=local})' --iree-hal-local-target-device-backends=vmvx - | \ // RUN: iree-opt --iree-global-optimization-transformation-pipeline - | \ // RUN: iree-opt --iree-flow-transformation-pipeline - | \ // RUN: iree-opt --iree-stream-transformation-pipeline - | \ -// RUN: iree-opt --iree-hal-transformation-pipeline --iree-hal-target-backends=vmvx - | \ +// RUN: iree-opt --iree-hal-transformation-pipeline - | \ // RUN: iree-opt --iree-vm-transformation-pipeline - | \ // RUN: FileCheck %s diff --git a/tools/test/compile_to_continuation.mlir b/tools/test/compile_to_continuation.mlir index 9c78153890c0..5476462df2df 100644 --- a/tools/test/compile_to_continuation.mlir +++ b/tools/test/compile_to_continuation.mlir @@ -1,79 +1,89 @@ // RUN: iree-compile --compile-to=input %s | \ -// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \ // RUN: FileCheck %s --check-prefix=INPUT-PHASE // INPUT-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref // RUN: iree-compile --compile-to=abi %s | \ -// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \ // RUN: FileCheck %s --check-prefix=ABI-PHASE // ABI-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref -// RUN: iree-compile --compile-to=flow %s | \ -// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --compile-to=flow --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ +// RUN: iree-compile --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=FLOW-PHASE // FLOW-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref -// RUN: iree-compile --compile-to=stream %s | \ -// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --compile-to=flow %s | \ +// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \ +// RUN: FileCheck %s --check-prefix=FLOW-PHASE-NO-DEVICE +// FLOW-PHASE-NO-DEVICE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref + +// RUN: iree-compile --compile-to=stream --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ +// RUN: iree-compile --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=STREAM-PHASE // STREAM-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref -// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=stream %s | \ +// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \ +// RUN: FileCheck %s --check-prefix=STREAM-PHASE-NO-DEVICE +// STREAM-PHASE-NO-DEVICE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref + +// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=EXECUTABLE-SOURCES-PHASE // EXECUTABLE-SOURCES-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref -// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=EXECUTABLE-TARGETS-PHASE // EXECUTABLE-TARGETS-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref -// RUN: iree-compile --compile-to=hal --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=hal --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=HAL-PHASE // HAL-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref -// RUN: iree-compile --compile-to=vm --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=vm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=VM-PHASE // VM-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref // RUN: iree-compile --compile-to=input %s | \ -// RUN: iree-compile --compile-from=input --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --compile-from=input --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \ // RUN: FileCheck %s --check-prefix=FROM-ABI-PHASE // FROM-INPUT-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref // RUN: iree-compile --compile-to=abi %s | \ -// RUN: iree-compile --compile-from=abi --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --compile-from=abi --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx - | \ // RUN: FileCheck %s --check-prefix=FROM-ABI-PHASE // FROM-ABI-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref -// RUN: iree-compile --compile-to=flow %s | \ -// RUN: iree-compile --compile-from=flow --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --compile-to=flow --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ +// RUN: iree-compile --compile-from=flow --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=FROM-FLOW-PHASE // FROM-FLOW-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref -// RUN: iree-compile --compile-to=stream %s | \ -// RUN: iree-compile --compile-from=stream --output-format=vm-asm --iree-hal-target-backends=vmvx - | \ +// RUN: iree-compile --compile-to=stream --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ +// RUN: iree-compile --compile-from=stream --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=FROM-STREAM-PHASE // FROM-STREAM-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref -// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --compile-from=executable-sources --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=FROM-EXECUTABLE-SOURCES-PHASE // FROM-EXECUTABLE-SOURCES-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref -// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --compile-from=executable-targets --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=FROM-EXECUTABLE-TARGETS-PHASE // FROM-EXECUTABLE-TARGETS-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref -// RUN: iree-compile --compile-to=hal --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=hal --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --compile-from=hal --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=FROM-HAL-PHASE // FROM-HAL-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref -// RUN: iree-compile --compile-to=vm --iree-hal-target-backends=vmvx %s | \ +// RUN: iree-compile --compile-to=vm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | \ // RUN: iree-compile --compile-from=vm --output-format=vm-asm - | \ // RUN: FileCheck %s --check-prefix=FROM-VM-PHASE // FROM-VM-PHASE: vm.func private @abs(%arg0: !vm.ref) -> !vm.ref diff --git a/tools/test/compile_to_phase.mlir b/tools/test/compile_to_phase.mlir index 03905649b03d..f1861a0b36c0 100644 --- a/tools/test/compile_to_phase.mlir +++ b/tools/test/compile_to_phase.mlir @@ -7,36 +7,44 @@ // ABI-PHASE: %[[INPUT:.+]] = hal.tensor.import %[[ARG0]] "input0" : !hal.buffer_view -> tensor // ABI-PHASE: math.absf %[[INPUT]] : tensor -// RUN: iree-compile --compile-to=flow %s | FileCheck %s --check-prefix=FLOW-PHASE +// RUN: iree-compile --compile-to=flow %s --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx | FileCheck %s --check-prefix=FLOW-PHASE // FLOW-PHASE: flow.executable.export public @abs_dispatch_0 // FLOW-PHASE: flow.dispatch @abs_dispatch_0 -// RUN: iree-compile --compile-to=stream %s | FileCheck %s --check-prefix=STREAM-PHASE +// RUN: iree-compile --compile-to=flow %s | FileCheck %s --check-prefix=FLOW-PHASE-NO-DEVICE +// FLOW-PHASE-NO-DEVICE: flow.executable.export public @abs_dispatch_0 +// FLOW-PHASE-NO-DEVICE: flow.dispatch @abs_dispatch_0 + +// RUN: iree-compile --compile-to=stream --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=STREAM-PHASE // STREAM-PHASE: stream.executable.export public @abs_dispatch_0 // STREAM-PHASE: stream.cmd.dispatch @abs_dispatch_0 -// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-SOURCES-PHASE +// RUN: iree-compile --compile-to=stream %s | FileCheck %s --check-prefix=STREAM-PHASE-NO-DEVICE +// STREAM-PHASE-NO-DEVICE: stream.executable.export public @abs_dispatch_0 +// STREAM-PHASE-NO-DEVICE: stream.cmd.dispatch @abs_dispatch_0 + +// RUN: iree-compile --compile-to=executable-sources --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-SOURCES-PHASE // EXECUTABLE-SOURCES-PHASE: hal.executable private @abs_dispatch_0 // EXECUTABLE-SOURCES-PHASE: hal.executable.variant // EXECUTABLE-SOURCES-PHASE: linalg.generic // EXECUTABLE-SOURCES-PHASE: math.absf -// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-TARGETS-PHASE +// RUN: iree-compile --compile-to=executable-targets --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=EXECUTABLE-TARGETS-PHASE // EXECUTABLE-TARGETS-PHASE: hal.executable private @abs_dispatch_0 // EXECUTABLE-TARGETS-PHASE: hal.executable.variant // EXECUTABLE-TARGETS-PHASE: vm.abs.f32 -// RUN: iree-compile --compile-to=hal --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=HAL-PHASE +// RUN: iree-compile --compile-to=hal --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=HAL-PHASE // HAL-PHASE: hal.executable private @abs_dispatch_0 // HAL-PHASE: hal.executable.binary // HAL-PHASE: hal.command_buffer.dispatch -// RUN: iree-compile --compile-to=vm --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=VM-PHASE +// RUN: iree-compile --compile-to=vm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=VM-PHASE // VM-PHASE: vm.rodata private @abs_dispatch_0 // VM-PHASE: vm.call @hal.command_buffer.dispatch -// RUN: iree-compile --output-format=vm-asm --compile-to=end --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE -// RUN: iree-compile --output-format=vm-asm --iree-hal-target-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE +// RUN: iree-compile --output-format=vm-asm --compile-to=end --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE +// RUN: iree-compile --output-format=vm-asm --iree-hal-target-device=local --iree-hal-local-target-device-backends=vmvx %s | FileCheck %s --check-prefix=END-PHASE // END-PHASE: vm.rodata private @abs_dispatch_0 // END-PHASE: vm.call @hal.command_buffer.dispatch From 7f5d847a079c9efc38e3391e354c3df84e79c302 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 15 Jul 2024 22:00:42 -0700 Subject: [PATCH 23/25] Adding `vm.select.ref` lowering to emitc. --- .../Conversion/VMToEmitC/ConvertVMToEmitC.cpp | 111 ++++++++++++++++-- .../VM/Conversion/VMToEmitC/EmitCBuilders.cpp | 7 +- runtime/src/iree/vm/test/assignment_ops.mlir | 2 +- 3 files changed, 111 insertions(+), 9 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index 5eef577e5694..9df179fa5d19 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -1063,7 +1063,6 @@ LogicalResult createAPIFunctions(IREE::VM::ModuleOp moduleOp, if (typeName[0] == '!') { typeName = typeName.substr(1); } - typeName = std::string("\"") + typeName + std::string("\""); Value stringView = emitc_builders::ireeMakeCstringView(builder, loc, typeName); @@ -2947,6 +2946,107 @@ class CompareRefNotZeroOpConversion } }; +class SelectRefOpConversion + : public EmitCConversionPattern { + using Adaptor = typename IREE::VM::SelectRefOp::Adaptor; + using EmitCConversionPattern::EmitCConversionPattern; + + LogicalResult + matchAndRewrite(IREE::VM::SelectRefOp selectOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto ctx = selectOp.getContext(); + auto loc = selectOp.getLoc(); + + auto moduleOp = + selectOp.getOperation()->template getParentOfType(); + auto funcOp = selectOp.getOperation() + ->template getParentOfType(); + auto &funcAnalysis = getModuleAnalysis().lookupFunction(funcOp); + + const BlockArgument moduleArg = funcOp.getArgument(CCONV_ARGUMENT_MODULE); + auto resultTypePtr = + createVmTypeDefPtr(rewriter, loc, this->getModuleAnalysis(), moduleOp, + moduleArg, selectOp.getType()); + if (!resultTypePtr.has_value()) { + return selectOp->emitError() << "generating iree_vm_type_def_t* failed"; + } + auto resultTypeAsRef = + rewriter + .create( + /*location=*/loc, + /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), + /*callee=*/StringAttr::get(ctx, "iree_vm_type_def_as_ref"), + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef{resultTypePtr.value()}) + .getResult(0); + + bool moveTrue = + funcAnalysis.isMove(selectOp.getTrueValue(), selectOp.getOperation()); + bool moveFalse = + funcAnalysis.isMove(selectOp.getFalseValue(), selectOp.getOperation()); + + Value refTrue = + this->getModuleAnalysis().lookupRef(selectOp.getTrueValue()); + Value refFalse = + this->getModuleAnalysis().lookupRef(selectOp.getFalseValue()); + Value refResult = this->getModuleAnalysis().lookupRef(selectOp.getResult()); + + Type boolType = rewriter.getI1Type(); + auto condition = rewriter.create( + loc, rewriter.getI32Type(), selectOp.getCondition()); + auto conditionI1 = rewriter.create( + /*location=*/loc, + /*type=*/boolType, + /*operand=*/condition.getResult()); + + auto *continueBlock = + rewriter.splitBlock(selectOp->getBlock(), Block::iterator(selectOp)); + + Block *trueBlock = nullptr; + { + OpBuilder::InsertionGuard guard(rewriter); + trueBlock = rewriter.createBlock(continueBlock); + returnIfError( + /*rewriter=*/rewriter, + /*location=*/loc, + /*callee=*/StringAttr::get(ctx, "iree_vm_ref_retain_or_move_checked"), + /*args=*/ + ArrayAttr::get( + ctx, {rewriter.getBoolAttr(moveTrue), rewriter.getIndexAttr(0), + rewriter.getIndexAttr(1), rewriter.getIndexAttr(2)}), + /*operands=*/ + ArrayRef{refTrue, resultTypeAsRef, refResult}, + this->getModuleAnalysis()); + rewriter.create(loc, continueBlock); + } + + Block *falseBlock = nullptr; + { + OpBuilder::InsertionGuard guard(rewriter); + falseBlock = rewriter.createBlock(continueBlock); + returnIfError( + /*rewriter=*/rewriter, + /*location=*/loc, + /*callee=*/StringAttr::get(ctx, "iree_vm_ref_retain_or_move_checked"), + /*args=*/ + ArrayAttr::get( + ctx, {rewriter.getBoolAttr(moveFalse), rewriter.getIndexAttr(0), + rewriter.getIndexAttr(1), rewriter.getIndexAttr(2)}), + /*operands=*/ + ArrayRef{refFalse, resultTypeAsRef, refResult}, + this->getModuleAnalysis()); + rewriter.create(loc, continueBlock); + } + + rewriter.setInsertionPointAfterValue(conditionI1); + rewriter.create(loc, conditionI1.getResult(), + trueBlock, falseBlock); + rewriter.replaceOp(selectOp, refResult); + + return success(); + } +}; + template class ConstOpConversion : public EmitCConversionPattern { using Adaptor = typename OpTy::Adaptor; @@ -3429,12 +3529,8 @@ class FailOpConversion : public EmitCConversionPattern { releaseRefs(rewriter, loc, funcOp, getModuleAnalysis()); - std::string messageStr = std::string("\"") + - op.getMessage().value_or("").str() + - std::string("\""); - - Value message = - emitc_builders::ireeMakeCstringView(rewriter, loc, messageStr); + Value message = emitc_builders::ireeMakeCstringView( + rewriter, loc, op.getMessage().value_or("").str()); auto messageSizeOp = emitc_builders::structMember( rewriter, loc, @@ -4430,6 +4526,7 @@ void populateVMToEmitCPatterns(ConversionTarget &conversionTarget, CallOpConversion, CallOpConversion, CompareRefNotZeroOpConversion, + SelectRefOpConversion, CondBranchOpConversion, BranchTableOpConversion, ConstOpConversion, diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp index e817f9e5c870..3076f2d361e7 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp @@ -299,6 +299,11 @@ void structPtrMemberAssign(OpBuilder builder, Location location, Value ireeMakeCstringView(OpBuilder builder, Location location, std::string str) { + std::string escapedStr; + llvm::raw_string_ostream os(escapedStr); + os.write_escaped(str); + auto quotedStr = std::string("\"") + escapedStr + std::string("\""); + auto ctx = builder.getContext(); return builder .create( @@ -306,7 +311,7 @@ Value ireeMakeCstringView(OpBuilder builder, Location location, /*type=*/emitc::OpaqueType::get(ctx, "iree_string_view_t"), /*callee=*/StringAttr::get(ctx, "iree_make_cstring_view"), /*args=*/ - ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, str)}), + ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, quotedStr)}), /*templateArgs=*/ArrayAttr{}, /*operands=*/ArrayRef{}) .getResult(0); diff --git a/runtime/src/iree/vm/test/assignment_ops.mlir b/runtime/src/iree/vm/test/assignment_ops.mlir index 1388c1e600b1..891165da8bc3 100644 --- a/runtime/src/iree/vm/test/assignment_ops.mlir +++ b/runtime/src/iree/vm/test/assignment_ops.mlir @@ -17,7 +17,7 @@ vm.module @assignment_ops { vm.return } - vm.export @test_select_ref attributes {emitc.exclude} + vm.export @test_select_ref vm.func private @test_select_ref() { %c0 = vm.const.i32 0 %list0 = vm.list.alloc %c0 : (i32) -> !vm.list From e9006927d10635ea475d34d545dba5175b2ac2b8 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 16 Jul 2024 11:00:47 -0700 Subject: [PATCH 24/25] Updating various tests to the latest changes. --- .../input/Torch/InputConversion/Passes.cpp | 1 + .../plugins/target/LLVMCPU/test/BUILD.bazel | 1 + .../target/LLVMCPU/test/CMakeLists.txt | 1 + .../materialize_homogeneous_encodings.mlir | 30 ++++++++++++++++++ compiler/plugins/target/ROCM/ROCMTarget.cpp | 8 ++--- .../plugins/target/ROCM/test/smoketest.mlir | 4 +-- .../ROCM/test/target_device_features.mlir | 8 ++--- .../target/VulkanSPIRV/test/BUILD.bazel | 1 + .../target/VulkanSPIRV/test/CMakeLists.txt | 1 + .../materialize_homogeneous_encodings.mlir | 31 ------------------- .../Dialect/Stream/Analysis/Affinity.cpp | 7 +++-- .../GlobalOptimization/test/BUILD.bazel | 1 - .../GlobalOptimization/test/CMakeLists.txt | 1 - .../shark-test-suite-models/sd3/test_clip.py | 2 +- .../shark-test-suite-models/sd3/test_mmdit.py | 2 +- .../shark-test-suite-models/sd3/test_vae.py | 2 +- .../shark-test-suite-models/sdxl/test_clip.py | 2 +- .../shark-test-suite-models/sdxl/test_unet.py | 2 +- .../shark-test-suite-models/sdxl/test_vae.py | 2 +- .../src/iree/modules/check/test/success.mlir | 1 - samples/simple_embedding/device_vmvx_sync.c | 2 +- samples/static_library/static_library_demo.c | 2 +- tools/testing/e2e/iree-e2e-conv2d-test.cc | 7 +++-- tools/testing/e2e/iree-e2e-matmul-test.cc | 7 +++-- tools/testing/e2e/test_utils.c | 2 +- tools/testing/e2e/test_utils.h | 2 +- 26 files changed, 69 insertions(+), 61 deletions(-) create mode 100644 compiler/plugins/target/LLVMCPU/test/materialize_homogeneous_encodings.mlir rename compiler/{src/iree/compiler/GlobalOptimization => plugins/target/VulkanSPIRV}/test/materialize_homogeneous_encodings.mlir (69%) diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/InputConversion/Passes.cpp index 293921892b26..8f51c617a357 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.cpp +++ b/compiler/plugins/input/Torch/InputConversion/Passes.cpp @@ -56,6 +56,7 @@ void createTorchToIREEPipeline( TorchInput::createConvertTMTensorToLinalgExtPass()); pm.addNestedPass(torch::createConvertTorchToTensorPass()); pm.addNestedPass(torch::createConvertTorchToLinalgPass()); + pm.addNestedPass(createCSEPass()); pm.addNestedPass(torch::createConvertTorchToSCFPass()); pm.addNestedPass(torch::createConvertTorchToArithPass()); pm.addPass(torch::createConvertTorchConversionToMLProgramPass()); diff --git a/compiler/plugins/target/LLVMCPU/test/BUILD.bazel b/compiler/plugins/target/LLVMCPU/test/BUILD.bazel index 14b13f9c1888..2332f86919db 100644 --- a/compiler/plugins/target/LLVMCPU/test/BUILD.bazel +++ b/compiler/plugins/target/LLVMCPU/test/BUILD.bazel @@ -16,6 +16,7 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "materialize_homogeneous_encodings.mlir", "smoketest_embedded.mlir", "smoketest_system.mlir", ], diff --git a/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt b/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt index dde56185d802..5eee1f402ef3 100644 --- a/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt +++ b/compiler/plugins/target/LLVMCPU/test/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "materialize_homogeneous_encodings.mlir" "smoketest_embedded.mlir" "smoketest_system.mlir" TOOLS diff --git a/compiler/plugins/target/LLVMCPU/test/materialize_homogeneous_encodings.mlir b/compiler/plugins/target/LLVMCPU/test/materialize_homogeneous_encodings.mlir new file mode 100644 index 000000000000..5d5b591a81fc --- /dev/null +++ b/compiler/plugins/target/LLVMCPU/test/materialize_homogeneous_encodings.mlir @@ -0,0 +1,30 @@ +// RUN: iree-opt --split-input-file --iree-hal-device-assignment-pipeline --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s + +#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}> +#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> : !hal.device +module attributes {hal.device.targets = [#device_target_llvm_cpu]} { + util.func public @lhs_encoding(%arg0: tensor) -> tensor { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim = tensor.dim %arg0, %c0 : tensor + %dim_0 = tensor.dim %arg0, %c1 : tensor + %0:2 = iree_encoding.upper_bound_tile_size tensor> -> index, index + %1 = affine.apply #map()[%0#0, %dim] + %2 = affine.apply #map()[%0#1, %dim_0] + %padded = tensor.pad %arg0 low[0, 0] high[%1, %2] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %cst : f32 + } : tensor to tensor + %3 = iree_encoding.set_encoding %padded : tensor -> tensor> + %4 = iree_encoding.unset_encoding %3 : tensor> -> tensor + util.return %4 : tensor + } +} +// CHECK-LABEL: util.func public @lhs_encoding +// CHECK: tensor.pack +// CHECK: tensor.unpack diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index 5f97a97b6162..75e4bbdb66c0 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -226,7 +226,7 @@ class ROCMTargetDevice final : public TargetDevice { targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets( context, "rocm", configAttr, executableTargetAttrs); - return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("rocm"), + return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("hip"), configAttr, executableTargetAttrs); } @@ -238,7 +238,7 @@ class ROCMTargetBackend final : public TargetBackend { public: ROCMTargetBackend(const ROCmOptions &options) : options(options) {} - std::string getLegacyDefaultDeviceID() const override { return "rocm"; } + std::string getLegacyDefaultDeviceID() const override { return "hip"; } void getDefaultExecutableTargets( MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr, @@ -702,8 +702,8 @@ struct ROCMSession final : PluginSession { void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) { - // #hal.device.target<"rocm", ... - targets.add("rocm", + // #hal.device.target<"hip", ... + targets.add("hip", [&]() { return std::make_shared(options); }); } void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) { diff --git a/compiler/plugins/target/ROCM/test/smoketest.mlir b/compiler/plugins/target/ROCM/test/smoketest.mlir index b446ea2fd93c..1afe688467ee 100644 --- a/compiler/plugins/target/ROCM/test/smoketest.mlir +++ b/compiler/plugins/target/ROCM/test/smoketest.mlir @@ -2,7 +2,7 @@ module attributes { hal.device.targets = [ - #hal.device.target<"rocm", [ + #hal.device.target<"hip", [ #hal.executable.target<"rocm", "rocm-hsaco-fb"> ]> : !hal.device ] @@ -46,7 +46,7 @@ stream.executable public @add_dispatch_0 { #loc = loc(unknown) module attributes { hal.device.targets = [ - #hal.device.target<"rocm", [ + #hal.device.target<"hip", [ #hal.executable.target<"rocm", "rocm-hsaco-fb"> ]> : !hal.device ] diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir index ae7676e17e52..15240f92ce33 100644 --- a/compiler/plugins/target/ROCM/test/target_device_features.mlir +++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir @@ -1,7 +1,7 @@ -// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=mi300x %s | FileCheck %s --check-prefix=GFX942 -// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=GFX940 -// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=rx7900xtx %s | FileCheck %s --check-prefix=GFX1100 -// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx941 --iree-rocm-target-features=+sramecc,-xnack %s | FileCheck %s --check-prefix=GFX941 +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=mi300x %s | FileCheck %s --check-prefix=GFX942 +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=GFX940 +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=rx7900xtx %s | FileCheck %s --check-prefix=GFX1100 +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=hip},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx941 --iree-rocm-target-features=+sramecc,-xnack %s | FileCheck %s --check-prefix=GFX941 // GFX942: target = #iree_gpu.target -#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)> -#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> -#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> : !hal.device -module attributes {hal.device.targets = [#device_target_llvm_cpu]} { - util.func public @lhs_encoding(%arg0: tensor) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %dim = tensor.dim %arg0, %c0 : tensor - %dim_0 = tensor.dim %arg0, %c1 : tensor - %0:2 = iree_encoding.upper_bound_tile_size tensor> -> index, index - %1 = affine.apply #map()[%0#0, %dim] - %2 = affine.apply #map()[%0#1, %dim_0] - %padded = tensor.pad %arg0 low[0, 0] high[%1, %2] { - ^bb0(%arg1: index, %arg2: index): - tensor.yield %cst : f32 - } : tensor to tensor - %3 = iree_encoding.set_encoding %padded : tensor -> tensor> - %4 = iree_encoding.unset_encoding %3 : tensor> -> tensor - util.return %4 : tensor - } -} -// CHECK-LABEL: util.func public @lhs_encoding -// CHECK: tensor.pack -// CHECK: tensor.unpack - -// ----- - #executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb"> #map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)> #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp index ac3c1660e2d3..8351c9134e31 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp @@ -565,9 +565,10 @@ ChangeStatus ValueProducerAffinityPVS::updateValue(Value value, if (auto affinityOp = dyn_cast_if_present( result.getDefiningOp())) { - auto &opPVS = solver.getElementFor( - *this, Position::forOperation(result.getOwner()), - DFX::Resolution::OPTIONAL); + auto &opPVS = solver.getOrCreateElementFor( + Position::forOperation(result.getOwner()), *this, + DFX::Resolution::OPTIONAL, /*forceUpdate=*/false, + /*updateAfterInit=*/false); LLVM_DEBUG({ llvm::dbgs() << "[ValueProducerAffinityPVS] value "; value.printAsOperand(llvm::dbgs(), solver.getAsmState()); diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel index bbb78675b2dc..027c626dbe62 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel +++ b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel @@ -29,7 +29,6 @@ iree_lit_test_suite( "global_loop_invariant_code_motion.mlir", "hoist_into_globals.mlir", "infer_numeric_narrowing.mlir", - "materialize_homogeneous_encodings.mlir", "optimize_numerics.mlir", "propagate_linalg_transpose.mlir", "raise_special_ops.mlir", diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt index 79c75b3dd87d..b6823fc7e16f 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt @@ -27,7 +27,6 @@ iree_lit_test_suite( "global_loop_invariant_code_motion.mlir" "hoist_into_globals.mlir" "infer_numeric_narrowing.mlir" - "materialize_homogeneous_encodings.mlir" "optimize_numerics.mlir" "propagate_linalg_transpose.mlir" "raise_special_ops.mlir" diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py index 27253226762f..c08b4347ae1f 100644 --- a/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py +++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py @@ -113,7 +113,7 @@ def SD3_CLIP_COMMON_RUN_FLAGS( "--iree-opt-aggressively-propagate-transposes=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", "--iree-execution-model=async-external", - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics{pad-target-type=conv})", ] ############################################################################### diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py index f328211de45c..2e5b18979f9a 100644 --- a/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py +++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py @@ -97,7 +97,7 @@ def SD3_MMDIT_COMMON_RUN_FLAGS( "--iree-codegen-llvmgpu-use-vector-distribution", "--iree-rocm-waves-per-eu=2", "--iree-execution-model=async-external", - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)", ] ############################################################################### diff --git a/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py b/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py index 6d9ab660dffc..881d93dbe46d 100644 --- a/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py +++ b/experimental/regression_suite/shark-test-suite-models/sd3/test_vae.py @@ -68,7 +68,7 @@ def SD3_VAE_COMMON_RUN_FLAGS( "--iree-flow-enable-aggressive-fusion=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", "--iree-execution-model=async-external", - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)", ] ############################################################################### diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py index 41b2e61ad312..207ddafd023a 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_clip.py @@ -99,7 +99,7 @@ def SDXL_CLIP_COMMON_RUN_FLAGS( "--iree-opt-aggressively-propagate-transposes=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", "--iree-execution-model=async-external", - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics{pad-target-type=conv})", "--iree-scheduling-dump-statistics-format=json", "--iree-scheduling-dump-statistics-file=compilation_info.json", ] diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py index 4e1bc70dcb4c..9d7f942a1226 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py @@ -103,7 +103,7 @@ def SDXL_UNET_COMMON_RUN_FLAGS( "--iree-codegen-llvmgpu-use-vector-distribution", "--iree-rocm-waves-per-eu=2", "--iree-execution-model=async-external", - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)", "--iree-scheduling-dump-statistics-format=json", "--iree-scheduling-dump-statistics-file=compilation_info.json", ] diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py index 49e49d3aec3e..5b9ab15340c6 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py @@ -68,7 +68,7 @@ def SDXL_VAE_COMMON_RUN_FLAGS( "--iree-flow-enable-aggressive-fusion=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", "--iree-execution-model=async-external", - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)", "--iree-scheduling-dump-statistics-format=json", "--iree-scheduling-dump-statistics-file=compilation_info.json", ] diff --git a/runtime/src/iree/modules/check/test/success.mlir b/runtime/src/iree/modules/check/test/success.mlir index 7c5012e97b81..c2a310ea1152 100644 --- a/runtime/src/iree/modules/check/test/success.mlir +++ b/runtime/src/iree/modules/check/test/success.mlir @@ -73,7 +73,6 @@ func.func @floats() { %p8 = arith.addf %p7, %cp1 : tensor %p9 = arith.addf %p8, %cp1 : tensor %approximately_1 = arith.addf %p9, %cp1 : tensor - check.expect_almost_eq(%approximately_1, %c1) : tensor return } diff --git a/samples/simple_embedding/device_vmvx_sync.c b/samples/simple_embedding/device_vmvx_sync.c index fa5981c10bac..f1f633fb914e 100644 --- a/samples/simple_embedding/device_vmvx_sync.c +++ b/samples/simple_embedding/device_vmvx_sync.c @@ -34,7 +34,7 @@ iree_status_t create_sample_device(iree_allocator_t host_allocator, iree_vm_instance_release(instance); // Use the default host allocator for buffer allocations. - iree_string_view_t identifier = iree_make_cstring_view("vmvx"); + iree_string_view_t identifier = iree_make_cstring_view("local-sync"); iree_hal_allocator_t* device_allocator = NULL; if (iree_status_is_ok(status)) { status = iree_hal_allocator_create_heap(identifier, host_allocator, diff --git a/samples/static_library/static_library_demo.c b/samples/static_library/static_library_demo.c index 76a0b6c5b53e..e8670c5baf29 100644 --- a/samples/static_library/static_library_demo.c +++ b/samples/static_library/static_library_demo.c @@ -42,7 +42,7 @@ iree_status_t create_device_with_static_loader(iree_allocator_t host_allocator, &library_loader); // Use the default host allocator for buffer allocations. - iree_string_view_t identifier = iree_make_cstring_view("sync"); + iree_string_view_t identifier = iree_make_cstring_view("local-sync"); iree_hal_allocator_t* device_allocator = NULL; if (iree_status_is_ok(status)) { status = iree_hal_allocator_create_heap(identifier, host_allocator, diff --git a/tools/testing/e2e/iree-e2e-conv2d-test.cc b/tools/testing/e2e/iree-e2e-conv2d-test.cc index 31d02e953523..c4158fdc73c9 100644 --- a/tools/testing/e2e/iree-e2e-conv2d-test.cc +++ b/tools/testing/e2e/iree-e2e-conv2d-test.cc @@ -549,14 +549,17 @@ int main(int argc, char** argv) { return EXIT_FAILURE; } + // Run the tests. Note that some modules may be compiled for other platforms + // and not have the required architectures for execution within them - to keep + // the test runner dumber we gracefully fail those cases by returning success. iree_status_t status = iree_test_utils_load_and_run_e2e_tests( iree_allocator_system(), conv2d_test_module_create); int exit_code = EXIT_SUCCESS; if (!iree_status_is_ok(status)) { iree_status_fprint(stderr, status); - bool is_unavailable = iree_status_is_unavailable(status); + bool is_device_unavailable = iree_status_is_not_found(status); iree_status_free(status); - exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE; + exit_code = is_device_unavailable ? EXIT_SUCCESS : EXIT_FAILURE; } IREE_TRACE_APP_EXIT(exit_code); diff --git a/tools/testing/e2e/iree-e2e-matmul-test.cc b/tools/testing/e2e/iree-e2e-matmul-test.cc index c9c82f90f4aa..f2773f048e79 100644 --- a/tools/testing/e2e/iree-e2e-matmul-test.cc +++ b/tools/testing/e2e/iree-e2e-matmul-test.cc @@ -725,14 +725,17 @@ int main(int argc, char** argv) { return EXIT_FAILURE; } + // Run the tests. Note that some modules may be compiled for other platforms + // and not have the required architectures for execution within them - to keep + // the test runner dumber we gracefully fail those cases by returning success. iree_status_t status = iree_test_utils_load_and_run_e2e_tests( iree_allocator_system(), matmul_test_module_create); int exit_code = EXIT_SUCCESS; if (!iree_status_is_ok(status)) { iree_status_fprint(stderr, status); - bool is_unavailable = iree_status_is_unavailable(status); + bool is_device_unavailable = iree_status_is_not_found(status); iree_status_free(status); - exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE; + exit_code = is_device_unavailable ? EXIT_SUCCESS : EXIT_FAILURE; } IREE_TRACE_APP_EXIT(exit_code); diff --git a/tools/testing/e2e/test_utils.c b/tools/testing/e2e/test_utils.c index 926b0ea0583a..29811482de5a 100644 --- a/tools/testing/e2e/test_utils.c +++ b/tools/testing/e2e/test_utils.c @@ -413,7 +413,7 @@ iree_status_t iree_test_utils_check_module_requirements( return iree_make_status( // The error status matters. We distinguish "feature not supported" // which is a normal thing to happen from actual errors. - IREE_STATUS_UNAVAILABLE, + IREE_STATUS_NOT_FOUND, "target device does not have the required feature '%.*s'", (int)required_feature.size, required_feature.data); } diff --git a/tools/testing/e2e/test_utils.h b/tools/testing/e2e/test_utils.h index f3a18d2e4a3e..f095537112e9 100644 --- a/tools/testing/e2e/test_utils.h +++ b/tools/testing/e2e/test_utils.h @@ -133,7 +133,7 @@ iree_status_t iree_test_utils_run_all_test_functions( iree_allocator_t host_allocator); // Returns OK if there are declared requirements on |module| and they are all -// met and otherwise UNAVAILABLE indicating that the module should not be run. +// met and otherwise NOT_FOUND indicating that the module should not be run. iree_status_t iree_test_utils_check_module_requirements( iree_vm_module_t* module); From f721fd0b612742d6dba50dd2f954321cd5d64aff Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Fri, 26 Jul 2024 09:02:22 -0700 Subject: [PATCH 25/25] Adding test for iree-run-module with multiple devices. --- tools/test/BUILD.bazel | 1 + tools/test/CMakeLists.txt | 1 + tools/test/iree-run-module-multi.mlir | 43 +++++++++++++++++++++++++++ 3 files changed, 45 insertions(+) create mode 100644 tools/test/iree-run-module-multi.mlir diff --git a/tools/test/BUILD.bazel b/tools/test/BUILD.bazel index cf5878c2419d..46709aab8162 100644 --- a/tools/test/BUILD.bazel +++ b/tools/test/BUILD.bazel @@ -31,6 +31,7 @@ iree_lit_test_suite( "iree-run-mlir.mlir", "iree-run-module-expected.mlir", "iree-run-module-inputs.mlir", + "iree-run-module-multi.mlir", "iree-run-module-outputs.mlir", "iree-run-module.mlir", "multiple_args.mlir", diff --git a/tools/test/CMakeLists.txt b/tools/test/CMakeLists.txt index 75dde6670d4d..a866548ddb5a 100644 --- a/tools/test/CMakeLists.txt +++ b/tools/test/CMakeLists.txt @@ -27,6 +27,7 @@ iree_lit_test_suite( "iree-run-mlir.mlir" "iree-run-module-expected.mlir" "iree-run-module-inputs.mlir" + "iree-run-module-multi.mlir" "iree-run-module-outputs.mlir" "iree-run-module.mlir" "multiple_args.mlir" diff --git a/tools/test/iree-run-module-multi.mlir b/tools/test/iree-run-module-multi.mlir new file mode 100644 index 000000000000..341259652818 --- /dev/null +++ b/tools/test/iree-run-module-multi.mlir @@ -0,0 +1,43 @@ +// Tests that multiple devices are supported through iree-run-module by +// providing two local thread pools. This is not optimal and not an intended +// route for multi-device CPU workloads but requires no additional hardware +// resources for the test and still verifies the compiler/runtime tooling +// rendezvous of devices as specified on the command line. + +// RUN: (iree-compile %s \ +// RUN: --iree-execution-model=async-external \ +// RUN: --iree-hal-target-device=device_a=local[0] \ +// RUN: --iree-hal-target-device=device_b=local[1] \ +// RUN: --iree-hal-local-target-device-backends=vmvx | \ +// RUN: iree-run-module \ +// RUN: --module=- \ +// RUN: --function=mutli_device_mul \ +// RUN: --input=4xf32=10,11,12,13 \ +// RUN: --device=local-task \ +// RUN: --device=local-task \ +// RUN: --task_topology_group_count=1) | \ +// RUN: FileCheck %s + +// CHECK: EXEC @mutli_device_mul +// CHECK-NEXT: result[0]: hal.buffer_view +// CHECK-NEXT: 4xf32=0 55 144 273 +func.func public @mutli_device_mul( + // Input argument is resident on device_a (tooling default to first device). + %input_a: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>} +) -> ( + // Output result is expected to be on device_a (though not required). + tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>} +) { + // Compute on device_a (input is there). + %constant_a = arith.constant dense<[0.0, 1.0, 2.0, 3.0]> : tensor<4xf32> + %transient_a = arith.mulf %input_a, %constant_a : tensor<4xf32> + // Transfer the result from device_a -> device_b. + %transient_b = flow.tensor.transfer %transient_a : tensor<4xf32> to #hal.device.promise<@device_b> + // Compute on device_b. + %constant_b = arith.constant dense<[4.0, 5.0, 6.0, 7.0]> : tensor<4xf32> + %result_b = arith.mulf %transient_b, %constant_b : tensor<4xf32> + // Transfer the result from device_b -> device_a. + %result_a = flow.tensor.transfer %result_b : tensor<4xf32> to #hal.device.promise<@device_a> + // Return the result on device_a (as required by ABI attr). + func.return %result_a : tensor<4xf32> +}