Skip to content

Commit

Permalink
Merging multi-device branch to main. (#17987)
Browse files Browse the repository at this point in the history
**TLDR**: nothing should break, `--iree-hal-target-backends=` is
deprecated, use `--iree-hal-target-device=` and appropriate
target-specific flags instead.

This reworks the target device concept in the IREE pipeline - in some
cases introducing the concept (flow and HAL) and in others replacing
placeholder mechanisms around stream affinity. This builds upon prior
work that added support for enumerating available devices via the HAL
and providing multiple devices to the runtime tools by adding the
ability to define devices, allowing for execution and storage resources
to be assigned a device, and upgrading passes to support multiple
devices. "Multi-device" here means several things and all are
accomplished with the same mechanism: a single device that may be one of
multiple types (multiple CPU/GPU archs, CPU _or_ GPU, etc), multiple
homogeneous devices (4 of the same exact GPUs accessed through the same
runtime HAL driver), multiple heterogeneous devices (a CPU and a
GPU/NPU/etc), and optional devices (a CPU with some portions offloaded
to a GPU/NPU if it's compatible and available at runtime). In this way
we can provide cross-compilation/targeting, multi-targeting, and
multiple devices with one set of flags, compiler analysis, passes
dealing with the devices, and runtime infrastructure.

Early warning: **it's strongly discouraged to use device information
prior to codegen** - any pass using such information earlier on is a red
flag that will receive pushback. IREE is designed first and foremost as
a cross-compiler with multi-targeting at its core and radically changing
program behavior near the frontend makes it nearly impossible to have
configuration control over the compilation pipeline. Consider
specializing on device prior to codegen tantamount to using C
preprocessor macros based on operating system or architecture: it means
that a problem has not been solved and a workaround has been taken.
There are exceptionally few cases that require device information early,
and those that do can do so in generic ways that do not disturb the
debuggability of the program. For example, far better than preprocessor
macros in C++ are function calls and if statements (as we can do in our
programs), and even better than that are virtual interfaces (ops that
are only lowered to one of multiple implementations later on). That
disclaimer out of the way: it's now possible to query device information
after the input pipeline (global opt/preprocessing/flow). Upstream will
push back against doing so in nearly all cases but it is a useful
mechanism for downstream projects.

The key change here is that the `--iree-hal-target-backends=` compiler
flag has been deprecated. It continues to work for now with the same
behavior as before but usage will shift to the replacement
`--iree-hal-target-device=` flag. A single instance of this flag defines
a single device within the program and repeated uses of it will define
new devices. Devices may be named ("my_device") or anonymous (in which
case they will be assigned an ordinal like 0 or 1), and each device may
be backed by one or more target devices (Vulkan, local host, HIP, etc).
Each target device in the compiler (represented by
`IREE::HAL::TargetDevice`) may have any number of backends with various
configurations (multiple archs, different deployment formats, etc
represented by one or more `IREE::HAL::ExecutableTargetAttr` values).

Example flags:
```sh
# Two devices, one the local host device and the other a Vulkan device:
--iree-hal-target-device=local --iree-hal-target-device=vulkan

# One device selecting between Vulkan if available and otherwise use the local host device:
--iree-hal-target-device=vulkan,local

# Two CUDA devices selected by runtime ordinal; at runtime two --device=
# flags are required to configure both devices:
--iree-hal-target-device=cuda[0] --iree-hal-target-device=cuda[1]

# A fully-defined target specification:
--iree-hal-target-device=#hal.device.target<"cuda", {...}, [#hal.executable.target<...>]>

# Named device for defining a reference by #hal.device.promise<@some_name>:
--iree-hal-target-device=some_name=vulkan
```

The device metadata as specified in the compiler is used to produce
enumeration code that executes at runtime and queries the available
devices to find the appropriate matches. This means that if the program
is compiled to target two CUDA devices then at runtime there must be two
CUDA devices specified - the indirection allows for the compiled
artifact to work with any two CUDA devices targeted by UUID, device
ordinal, etc and not just the first and second CUDA device in the
system. E.g. `iree-compile --iree-hal-target-device=cuda[0]
--iree-hal-target-device=cuda[1]` and `iree-run-module
--device=cuda://UUID_A --device=cuda://UUID_B`. Devices targets in the
compiler can now specify the ordinal of the device in order to
differentiate between multiple devices at runtime (the `cuda[0]` and
`cuda[1]` above indicate the first CUDA device and second CUDA device
provided to the runtime).

Major new attributes:
* `#hal.device.promise<@device>` is a reference to a device that will be
provided at a later stage. Frontends can use this as a placeholder for
devices that are specified on the command line without needing to say
what those devices are when exporting.
* `#hal.device.alias<"name">` specifies an `IREE::HAL::TargetDevice` in
the compiler (`vulkan`, `local`, `hip`, etc) and expands to a full
`#hal.device.target` based on target-specific flags.
* `#hal.device.select<[...]>` controls selection by enumerating each
device in turn and matching the first found.
* `#hal.device.fallback<@other_device>` provides a fallback reference
that the device will match if no other device matches. Note that having
two devices with the same target will create two copies at runtime - if
wanting to use the existing device then the fallback mechanism must be
used.
* `#hal.device.affinity<@device>` (and optional queue mask) is used on
ops to indicate on which device they should execute.

All of the above flags are just syntactic sugar that add the above
attributes to the program IR and it's possible for frontends to insert
these attributes or ops directly depending on use-case. In most cases
leaving placeholders in the IR such that the exact target can be
specified during compilation is ideal: this allows one output from the
frontend to be used with any number of targets and configurations.
Online compilers, though, may want to bake in their exact configuration
and can do so without the need for flags that may lose information. The
general flow of the `buildHALDeviceAssignmentPassPipeline`/`iree-opt
--iree-hal-device-assignment-pipeline` is:
1. `--iree-hal-target-device=` flags are parsed and a
`hal.device.targets` attribute is added to the module.
* `--iree-hal-device-target=cpu_device=local` becomes
`hal.device.targets = [#hal.device.alias<"local"> : !hal.device]`
* `--iree-hal-device-target=cpu_device=local
--iree-hal-device-target=gpu_device=cuda,hip` becomes
  ```mlir
  hal.device.targets = {
    cpu_device = #hal.device.alias<"local"> : !hal.device,
gpu_device = #hal.device.select<[#hal.device.alias<"cuda"> :
!hal.device, #hal.device.alias<"hip"> : !hal.device]> :
  !hal.device
  }
  ```
2. The `hal.device.targets` attribute (if any) is expanded into
`util.global` ops for each device. These globals are initialized with
one of the supported attributes which are much later turned into
enumeration/selection logic. The above multi-device example becomes:
  ```mlir
builtin.module attributes {stream.affinity.default =
#hal.device.affinity<@cpu_device>} {
util.global private @cpu_device = #hal.device.alias<"local"> :
!hal.device
util.global private @gpu_device =
#hal.device.select<[#hal.device.alias<"cuda"> : !hal.device,
#hal.device.alias<"hip"> : !hal.device]> :
  !hal.device
  }
  ```
3. Any `#hal.device.promise` attributes will be changed to reference the
globals with the same name. This allows for retargeting of inputs by
letting a frontend specify named devices prior to them having been
passed on the command line (or inserted by some other pipeline).
4. Any `#hal.device.alias` attributes are converted to full
`#hal.device.target` attributes using the appropriate
`IREE::HAL::DeviceTarget` implementation.

Upon completion of the pipeline there are globals initialized with
either a specific device target or a selection mechanism to pick between
targets. From that point onward devices are a structural part of the
program and can be referenced by symbol name via attributes like
`#hal.device.affinity`.

Programs are expected to specify the device affinity for all operations
either explicitly or implicitly. By default (as today) the first device
defined will be used but going forward we will want frontends to start
specifying devices. To that end the `flow.tensor.transfer` operation was
added to allow a tensor to have a device affinity assigned to it. A new
analysis is added that allows all tensors (or stream resources) and ops
interacting with them to be queried for which device they should be
placed on. For example, a frontend can specify multiple devices be used
in a computation by transferring the tensors used:
```mlir
util.func private @my_func(%arg0: tensor<4xi32>) -> tensor<4xi32> {
  %arg0_device_a = flow.tensor.transfer %arg0 : tensor<4xi32> to #hal.device.promise<@device_a>
  %compute_device_a = arith.addi %arg0_device_a, %arg0_device_a : tensor<4xi32>
  %transient_device_b = flow.tensor.transfer %compute_device_a : tensor<4xi32> to #hal.device.promise<@device_b>
  %compute_device_b = arith.muli %transient_device_b, %transient_device_b : tensor<4xi32>
  util.return %compute_device_b : tensor<4xi32>
}
```

To avoid copies there are also ways for frontends to indicate where
argument and result tensors are placed. The best way (in that it's most
general/powerful) is for the frontends to emit `hal.tensor.import`,
`hal.tensor.export`, and `hal.tensor.alias` ops directly as they all now
take affinities. When using the default ABI translation pass it's
possible to add arg/result attrs to public functions, e.g. `util.func
public @my_func(%arg0: tensor<2xi32> {iree.abi.affinity =
#hal.device.promise<@device_a>}) -> (tensor<2xi32> {iree.abi.affinity =
#hal.device.promise<@device_b>})`. Shorthand is provided to allow
specifying an `iree.abi.affinity` on functions themselves for when all
arguments and results are placed on the same device.

After the point devices are specified, materialized in the program as
globals, and referenced either via the magic default attribute, scoped
attributes, or explicit transfer operations most of the mechanics are
implementation details of the stream and HAL dialect lowerings.
Partitioning, allocation, and scheduling in the stream dialect were
always affinity-aware and required only minor tweaks as part of this
work while the HAL TODOs for multi-device were implemented by memoizing
resources per-device and adding the machinery to enumerate and select
devices.

This was reviewed in the following chunks and tested in a roll-up PR
#17482:
* #17915
* #17917
* #17916
* #17918
* #17919
* #17920
  • Loading branch information
benvanik authored Jul 30, 2024
2 parents 1ea9ee0 + f721fd0 commit d39c3c5
Show file tree
Hide file tree
Showing 257 changed files with 12,645 additions and 3,765 deletions.
83 changes: 63 additions & 20 deletions compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,17 @@ enum class TypeDisposition {
FENCE,
};

struct BarrierResult {
BlockArgument storage;
Type torchType;
int returnIndex = -1;
};

struct ConvertedAsyncFunctionInfo {
IREE::Util::FuncOp funcOp;
SmallVector<IREE::Util::ReturnOp> returnOps;
SmallVector<DictionaryAttr> torchArgAttrs;
SmallVector<DictionaryAttr> torchResultAttrs;
SmallVector<Type> torchInputTypes;
SmallVector<Type> torchResultTypes;
SmallVector<TypeDisposition> inputDispositions;
Expand All @@ -136,18 +144,33 @@ struct ConvertedAsyncFunctionInfo {
// Values that must be captured in the coarse barrier.
SmallVector<Value> barrierInputs;
// Meta data per barrier input: storage, torchType, returnIndex (or -1)
SmallVector<std::tuple<Value, Type, int>> barrierResultMeta;
SmallVector<BarrierResult> barrierResultMeta;

LogicalResult postProcess();
LogicalResult convertImmutableTensorArg(BlockArgument argValue,
Type torchType, OpBuilder &builder);
LogicalResult convertMutableTensorArg(BlockArgument argValue, Type torchType,
OpBuilder &builder);

void addBarrierInput(Value inputTensor, Value storage, Type torchType,
void addBarrierInput(Value inputTensor, BlockArgument storage, Type torchType,
int returnIndex) {
barrierInputs.push_back(inputTensor);
barrierResultMeta.emplace_back(storage, torchType, returnIndex);
barrierResultMeta.emplace_back(BarrierResult{
storage,
torchType,
returnIndex,
});
}

Attribute getTorchArgAttr(BlockArgument argValue, StringRef attrName) {
return torchArgAttrs.empty()
? Attribute{}
: torchArgAttrs[argValue.getArgNumber()].get(attrName);
}
Attribute getTorchResultAttr(int returnIndex, StringRef attrName) {
return torchResultAttrs.empty()
? Attribute{}
: torchResultAttrs[returnIndex].get(attrName);
}
};

Expand Down Expand Up @@ -232,7 +255,8 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() {
}
if (needsBarrier) {
Value source = convertToBuiltinTensor(postambleBuilder, returnValue);
addBarrierInput(source, /*storage=*/Value{}, torchType, returnIndex);
addBarrierInput(source, /*storage=*/BlockArgument{}, torchType,
returnIndex);
}
break;
}
Expand Down Expand Up @@ -276,23 +300,22 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() {
SmallVector<Value> aliasedResults;
for (auto [barrierInput, meta] :
llvm::zip_equal(barrierInputs, barrierResultMeta)) {
Value exportStorage;
Type torchType;
int returnIndex;
std::tie(exportStorage, torchType, returnIndex) = meta;
if (exportStorage) {
if (meta.storage) {
// Use the wait fence indicating when the storage is available for
// mutation. We need to ensure that no writes are made to the storage
// until it indicates it's safe to do so.
auto waitSignalFences = getEnclosingWaitSignalFences(exportStorage);
auto storageAffinityAttr =
getTorchArgAttr(meta.storage, "iree.abi.affinity");
auto waitSignalFences = getEnclosingWaitSignalFences(meta.storage);
assert(waitSignalFences && "async function missing fences");
Value waitFence = waitSignalFences->first;
auto barrierInputDims = IREE::Util::buildDynamicDimsForValue(
barrierInput.getLoc(), barrierInput, postambleBuilder);
aliasedResults.push_back(
postambleBuilder.create<IREE::HAL::TensorAliasOp>(
barrierInput.getLoc(), barrierInput.getType(), barrierInput,
barrierInputDims, exportStorage, waitFence));
barrierInputDims, meta.storage, waitFence,
storageAffinityAttr));
} else {
aliasedResults.push_back(barrierInput);
}
Expand All @@ -301,16 +324,20 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() {
funcOp.getLoc(), aliasedResults, coarseSignalFence);
for (auto [barrierResult, meta] :
llvm::zip_equal(barrierOp.getResults(), barrierResultMeta)) {
Value exportStorage;
Type torchType;
int returnIndex;
std::tie(exportStorage, torchType, returnIndex) = meta;
Attribute exportAffinityAttr;
if (meta.storage) {
exportAffinityAttr = getTorchArgAttr(meta.storage, "iree.abi.affinity");
} else if (meta.returnIndex >= 0) {
exportAffinityAttr =
getTorchResultAttr(meta.returnIndex, "iree.abi.affinity");
}
Value exportedValue = postambleBuilder.create<IREE::HAL::TensorExportOp>(
funcOp.getLoc(),
postambleBuilder.getType<IREE::HAL::BufferViewType>(), barrierResult,
TypeAttr::get(barrierResult.getType()), StringAttr());
if (returnIndex >= 0) {
newReturnOperands[returnIndex] = exportedValue;
TypeAttr::get(barrierResult.getType()), /*name=*/nullptr,
exportAffinityAttr);
if (meta.returnIndex >= 0) {
newReturnOperands[meta.returnIndex] = exportedValue;
}
}
}
Expand Down Expand Up @@ -374,13 +401,16 @@ LogicalResult ConvertedAsyncFunctionInfo::convertImmutableTensorArg(
<< torchType;
}

// Propagate explicit affinities to the read.
auto affinityAttr = getTorchArgAttr(argValue, "iree.abi.affinity");

auto waitSignalFences = getEnclosingWaitSignalFences(argValue);
assert(waitSignalFences && "async function missing fences");
Value waitFence = waitSignalFences->first;
Value importedTensor = builder.create<IREE::HAL::TensorImportOp>(
loc, builtinTensorType, argValue, TypeAttr::get(builtinTensorType),
waitFence,
/*name=*/StringAttr());
/*name=*/nullptr, affinityAttr);
if (builtinTensorType != torchType) {
importedTensor = builder.create<TorchConversion::FromBuiltinTensorOp>(
loc, torchType, importedTensor);
Expand All @@ -404,6 +434,9 @@ LogicalResult ConvertedAsyncFunctionInfo::convertMutableTensorArg(
.toBuiltinTensor();
}

// Propagate explicit affinities to the read and write.
auto affinityAttr = getTorchArgAttr(argValue, "iree.abi.affinity");

// There are only a small set of possible users of a mutable tensor.
// Handle them by operation here.
SmallVector<Operation *> users(argValue.getUsers());
Expand All @@ -415,7 +448,7 @@ LogicalResult ConvertedAsyncFunctionInfo::convertMutableTensorArg(
loc, builtinTensorType, argValue,
/*target_encoding=*/TypeAttr::get(builtinTensorType),
/*wait_fence*/ fences->first,
/*name=*/StringAttr());
/*name=*/nullptr, affinityAttr);
rewriter.replaceOpWithNewOp<TorchConversion::FromBuiltinTensorOp>(
userOp, copyToVtOp.getResult().getType(), imported);
} else if (auto overwriteOp =
Expand Down Expand Up @@ -470,6 +503,9 @@ void createCoarseFencesSyncWrapper(StringRef syncFunctionName,
syncFuncOp.setSymVisibilityAttr(asyncFuncOp.getSymVisibilityAttr());
retainFunctionAttributes(asyncFuncOp, syncFuncOp);
syncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr());
if (auto affinityAttr = asyncFuncOp->getAttr("iree.abi.affinity")) {
syncFuncOp->setAttr("iree.abi.affinity", affinityAttr);
}
Block *entryBlock = syncFuncOp.addEntryBlock();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(entryBlock);
Expand Down Expand Up @@ -578,6 +614,10 @@ struct FuncConversionPass : public FuncConversionBase<FuncConversionPass> {
asyncFunctionName.append("$async");
}

// Stash arg/result attrs so they can be referenced during conversion.
torchFunc.getAllArgAttrs(convertedFuncInfo.torchArgAttrs);
torchFunc.getAllResultAttrs(convertedFuncInfo.torchResultAttrs);

// Convert function signature.
Type fenceType = rewriter.getType<IREE::HAL::FenceType>();
FunctionType torchFuncType = torchFunc.getFunctionType();
Expand Down Expand Up @@ -638,6 +678,9 @@ struct FuncConversionPass : public FuncConversionBase<FuncConversionPass> {
asyncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr());
asyncFuncOp->setAttr("iree.abi.model",
rewriter.getStringAttr("coarse-fences"));
if (auto affinityAttr = torchFunc->getAttr("iree.abi.affinity")) {
asyncFuncOp->setAttr("iree.abi.affinity", affinityAttr);
}
rewriter.inlineRegionBefore(
torchFunc.getBody(), asyncFuncOp.getFunctionBody(), asyncFuncOp.end());

Expand Down
1 change: 1 addition & 0 deletions compiler/plugins/input/Torch/InputConversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ void createTorchToIREEPipeline(
TorchInput::createConvertTMTensorToLinalgExtPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToTensorPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToArithPass());
pm.addPass(torch::createConvertTorchConversionToMLProgramPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,37 @@ func.func @main(%arg0: !torch.tensor<[5,4],f32>) -> (!torch.vtensor<[5,4],f32>)
}
}

// -----
// Tests the immutable + mutable argument case with explicit affinities.
// CHECK-LABEL: @mutable_input_overwrite_no_return
// CHECK: util.func public @main$async(
// CHECK-SAME: %arg0: !hal.buffer_view, %arg1: !hal.buffer_view,
// CHECK-SAME: %arg2: !hal.fence, %arg3: !hal.fence) -> !hal.buffer_view
// CHECK-DAG: %[[WAIT_ARG0:.+]] = hal.tensor.import on(#hal.device.promise<@dev_a>) wait(%arg2) => %arg0
// CHECK-DAG: %[[TORCH_ARG0:.+]] = torch_c.from_builtin_tensor %[[WAIT_ARG0]]
// CHECK-DAG: %[[WAIT_ARG1:.+]] = hal.tensor.import on(#hal.device.promise<@dev_b>) wait(%arg2) => %arg1
// CHECK-DAG: %[[TORCH_ARG1:.+]] = torch_c.from_builtin_tensor %[[WAIT_ARG1]]
// CHECK-DAG: %[[TORCH_RESULT0:.+]] = torch.operator "other_calc"(%[[TORCH_ARG0]])
// CHECK-DAG: %[[TORCH_RESULT1:.+]] = torch.operator "mutate_inplace"(%[[TORCH_ARG1]])
// CHECK-DAG: %[[TENSOR_ARG0:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT0]]
// CHECK-DAG: %[[TENSOR_ARG1:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT1]]
// CHECK: %[[EXPORT_ALIAS1:.+]] = hal.tensor.alias on(#hal.device.promise<@dev_b>) wait(%arg2) => %[[TENSOR_ARG1]] : tensor<5x4xf32> to %arg1 : !hal.buffer_view
// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%[[EXPORT_ALIAS1]], %[[TENSOR_ARG0]] : tensor<5x4xf32>, tensor<4x5xi32>) => %arg3 : !hal.fence
// CHECK-DAG: %[[EXPORT_RESULT0:.+]] = hal.tensor.export on(#hal.device.promise<@dev_b>) %[[BARRIER_RESULTS]]#0
// CHECK-DAG: %[[EXPORT_RESULT1:.+]] = hal.tensor.export on(#hal.device.promise<@dev_a>) %[[BARRIER_RESULTS]]#1
// CHECK: util.return %[[EXPORT_RESULT1]]
builtin.module @mutable_input_overwrite_no_return_affinities {
func.func @main(%arg0: !torch.vtensor<[4,5],si32> {iree.abi.affinity = #hal.device.promise<@dev_a>},
%arg1: !torch.tensor<[5,4],f32> {iree.abi.affinity = #hal.device.promise<@dev_b>})
-> (!torch.vtensor<[4,5],si32> {iree.abi.affinity = #hal.device.promise<@dev_a>}) {
%0 = torch.copy.to_vtensor %arg1 : !torch.vtensor<[5,4],f32>
%1 = torch.operator "mutate_inplace"(%0) : (!torch.vtensor<[5,4],f32>) -> !torch.vtensor<[5,4],f32>
%2 = torch.operator "other_calc"(%arg0) : (!torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32>
torch.overwrite.tensor.contents %1 overwrites %arg1 : !torch.vtensor<[5,4],f32>, !torch.tensor<[5,4],f32>
return %2 : !torch.vtensor<[4,5],si32>
}
}

// -----
// CHECK-LABEL: @retained_attribute_reflection
// CHECK: util.func public @main$async(
Expand Down
4 changes: 3 additions & 1 deletion compiler/plugins/target/CUDA/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

module attributes {
hal.device.targets = [
#hal.device.target<"cuda", [#hal.executable.target<"cuda", "cuda-nvptx-fb">]>
#hal.device.target<"cuda", [
#hal.executable.target<"cuda", "cuda-nvptx-fb">
]> : !hal.device
]
} {

Expand Down
1 change: 1 addition & 0 deletions compiler/plugins/target/LLVMCPU/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ iree_lit_test_suite(
name = "lit",
srcs = enforce_glob(
[
"materialize_homogeneous_encodings.mlir",
"smoketest_embedded.mlir",
"smoketest_system.mlir",
],
Expand Down
1 change: 1 addition & 0 deletions compiler/plugins/target/LLVMCPU/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ iree_lit_test_suite(
NAME
lit
SRCS
"materialize_homogeneous_encodings.mlir"
"smoketest_embedded.mlir"
"smoketest_system.mlir"
TOOLS
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: iree-opt --split-input-file --iree-hal-device-assignment-pipeline --iree-global-opt-materialize-homogeneous-encodings %s | FileCheck %s

#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}>
#map = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", [#executable_target_embedded_elf_x86_64_]> : !hal.device
module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
util.func public @lhs_encoding(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%0:2 = iree_encoding.upper_bound_tile_size tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> index, index
%1 = affine.apply #map()[%0#0, %dim]
%2 = affine.apply #map()[%0#1, %dim_0]
%padded = tensor.pad %arg0 low[0, 0] high[%1, %2] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %cst : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
%3 = iree_encoding.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>>
%4 = iree_encoding.unset_encoding %3 : tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map1, #map2, #map3]>> -> tensor<?x?xf32>
util.return %4 : tensor<?x?xf32>
}
}
// CHECK-LABEL: util.func public @lhs_encoding
// CHECK: tensor.pack
// CHECK: tensor.unpack
6 changes: 4 additions & 2 deletions compiler/plugins/target/LLVMCPU/test/smoketest_embedded.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
module attributes {
hal.device.targets = [
#hal.device.target<"llvm-cpu", [
#hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", { native_vector_size = 16 : index }>
]>
#hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
native_vector_size = 16 : index
}>
]> : !hal.device
]
} {

Expand Down
6 changes: 4 additions & 2 deletions compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
module attributes {
hal.device.targets = [
#hal.device.target<"llvm-cpu", [
#hal.executable.target<"llvm-cpu", "embedded-elf-x86_64",{ native_vector_size = 16 : index } >
]>
#hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
native_vector_size = 16 : index
}>
]> : !hal.device
]
} {

Expand Down
2 changes: 1 addition & 1 deletion compiler/plugins/target/MetalSPIRV/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module attributes {
compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32],
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
}>
]>
]> : !hal.device
]
} {

Expand Down
8 changes: 4 additions & 4 deletions compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class ROCMTargetDevice final : public TargetDevice {
targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets(
context, "rocm", configAttr, executableTargetAttrs);

return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("rocm"),
return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("hip"),
configAttr, executableTargetAttrs);
}

Expand All @@ -238,7 +238,7 @@ class ROCMTargetBackend final : public TargetBackend {
public:
ROCMTargetBackend(const ROCmOptions &options) : options(options) {}

std::string getLegacyDefaultDeviceID() const override { return "rocm"; }
std::string getLegacyDefaultDeviceID() const override { return "hip"; }

void getDefaultExecutableTargets(
MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr,
Expand Down Expand Up @@ -702,8 +702,8 @@ struct ROCMSession final
: PluginSession<ROCMSession, ROCmOptions,
PluginActivationPolicy::DefaultActivated> {
void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) {
// #hal.device.target<"rocm", ...
targets.add("rocm",
// #hal.device.target<"hip", ...
targets.add("hip",
[&]() { return std::make_shared<ROCMTargetDevice>(options); });
}
void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) {
Expand Down
8 changes: 6 additions & 2 deletions compiler/plugins/target/ROCM/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

module attributes {
hal.device.targets = [
#hal.device.target<"rocm", [#hal.executable.target<"rocm", "rocm-hsaco-fb">]>
#hal.device.target<"hip", [
#hal.executable.target<"rocm", "rocm-hsaco-fb">
]> : !hal.device
]
} {

Expand Down Expand Up @@ -44,7 +46,9 @@ stream.executable public @add_dispatch_0 {
#loc = loc(unknown)
module attributes {
hal.device.targets = [
#hal.device.target<"rocm", [#hal.executable.target<"rocm", "rocm-hsaco-fb">]>
#hal.device.target<"hip", [
#hal.executable.target<"rocm", "rocm-hsaco-fb">
]> : !hal.device
]
} {

Expand Down
Loading

0 comments on commit d39c3c5

Please sign in to comment.