Skip to content

Commit

Permalink
New AssignTargetDevices pass to replace the legacy one.
Browse files Browse the repository at this point in the history
The legacy pass has been moved aside so that the old flags still work
but will be removed in the future.
  • Loading branch information
benvanik committed May 23, 2024
1 parent 0403062 commit 8870d72
Show file tree
Hide file tree
Showing 52 changed files with 903 additions and 249 deletions.
4 changes: 3 additions & 1 deletion compiler/plugins/target/CUDA/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

module attributes {
hal.device.targets = [
#hal.device.target<"cuda", [#hal.executable.target<"cuda", "cuda-nvptx-fb">]>
#hal.device.target<"cuda", [
#hal.executable.target<"cuda", "cuda-nvptx-fb">
]> : !hal.device
]
} {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module attributes {
#hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
native_vector_size = 16 : index
}>
]>
]> : !hal.device
]
} {

Expand Down
2 changes: 1 addition & 1 deletion compiler/plugins/target/LLVMCPU/test/smoketest_system.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ module attributes {
#hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
native_vector_size = 16 : index
}>
]>
]> : !hal.device
]
} {

Expand Down
2 changes: 1 addition & 1 deletion compiler/plugins/target/MetalSPIRV/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module attributes {
#hal.executable.target<"metal-spirv", "metal-msl-fb", {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
}>
]>
]> : !hal.device
]
} {

Expand Down
8 changes: 6 additions & 2 deletions compiler/plugins/target/ROCM/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

module attributes {
hal.device.targets = [
#hal.device.target<"rocm", [#hal.executable.target<"rocm", "rocm-hsaco-fb">]>
#hal.device.target<"rocm", [
#hal.executable.target<"rocm", "rocm-hsaco-fb">
]> : !hal.device
]
} {

Expand Down Expand Up @@ -44,7 +46,9 @@ stream.executable public @add_dispatch_0 {
#loc = loc(unknown)
module attributes {
hal.device.targets = [
#hal.device.target<"rocm", [#hal.executable.target<"rocm", "rocm-hsaco-fb">]>
#hal.device.target<"rocm", [
#hal.executable.target<"rocm", "rocm-hsaco-fb">
]> : !hal.device
]
} {

Expand Down
9 changes: 4 additions & 5 deletions compiler/plugins/target/ROCM/test/target_device_features.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=mi300x %s | FileCheck %s --check-prefix=GFX942
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=GFX940
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=rx7900xtx %s | FileCheck %s --check-prefix=GFX1100
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx941 --iree-rocm-target-features=+sramecc,-xnack %s | FileCheck %s --check-prefix=GFX941
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=mi300x %s | FileCheck %s --check-prefix=GFX942
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=GFX940
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=rx7900xtx %s | FileCheck %s --check-prefix=GFX1100
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetDevices=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx941 --iree-rocm-target-features=+sramecc,-xnack %s | FileCheck %s --check-prefix=GFX941

// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
Expand All @@ -21,7 +21,6 @@
// GFX941: target = #iree_gpu.target<arch = "gfx941",
// GFX941-SAME: features = "+sramecc,-xnack"


stream.executable public @reduce_dispatch {
stream.executable.export @reduce_dispatch workgroups(%arg0: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
Expand Down
2 changes: 1 addition & 1 deletion compiler/plugins/target/VMVX/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module attributes {
hal.device.targets = [
#hal.device.target<"local", [
#hal.executable.target<"vmvx", "vmvx-bytecode-fb">
]>
]> : !hal.device
]
} {

Expand Down
2 changes: 1 addition & 1 deletion compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module attributes {
#hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
}>
]>
]> : !hal.device
]
} {

Expand Down
2 changes: 1 addition & 1 deletion compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module attributes {
#hal.executable.target<"webgpu-spirv", "webgpu-wgsl-fb", {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
}>
]>
]> : !hal.device
]
} {

Expand Down
4 changes: 3 additions & 1 deletion compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -934,10 +934,12 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) {
if (!getCompilationPhase(compileFrom, compileTo)) {
return false;
}

// TODO: move to someplace centralized; erroring here is not great.
// InlineStatic (currently) only supports the `vmvx-inline` backend.
if (session.schedulingOptions.executionModel ==
SchedulingOptions::ExecutionModel::InlineStatic) {
for (auto target : session.halTargetOptions.targets) {
for (auto target : session.halTargetOptions.legacyTargetBackends) {
if (target != "vmvx-inline") {
parsedModule->emitError() << "InlineStatic execution model is not "
"compatible with hal target '"
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,8 @@ struct JitGlobalsPass : public JitGlobalsBase<JitGlobalsPass> {
requestedTargetDevice = resolveTargetDevice(*targetRegistry.value);
hasRequestedTargetDevice =
targetRegistry->getTargetDevice(requestedTargetDevice) != nullptr;
compileOptions->executableOptions.targets.push_back(requestedTargetDevice);
compileOptions->executableOptions.legacyTargetBackends.push_back(
requestedTargetDevice);
compileOptions->targetOptions.f32Extension = true;
compileOptions->targetOptions.f64Extension = true;
compileOptions->targetOptions.truncateUnsupportedFloats = false;
Expand Down
13 changes: 12 additions & 1 deletion compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,21 @@ void TargetOptions::bindOptions(OptionsBinder &binder) {
// initialized, so targetBackendsFlags needs to be here to be initialized
// first.
binder.list<std::string>(
"iree-hal-target-backends", targets,
"iree-hal-target-backends", legacyTargetBackends,
llvm::cl::desc("Target backends for executable compilation."),
llvm::cl::ZeroOrMore, llvm::cl::cat(halTargetOptionsCategory));

binder.list<std::string>("iree-hal-target-device", targetDevices,
llvm::cl::desc("Target device specifications."),
llvm::cl::ZeroOrMore,
llvm::cl::cat(halTargetOptionsCategory));
binder.opt<std::string>(
"iree-hal-default-device", defaultDevice,
llvm::cl::desc("Which device is considered the default when no device "
"affinity is specified. Either the device name when names "
"are specified or the numeric ordinal of the device."),
llvm::cl::cat(halTargetOptionsCategory));

binder.opt<int>(
"iree-hal-executable-debug-level", debugLevel,
llvm::cl::desc("Debug level for executable translation (0-3)"),
Expand Down
31 changes: 29 additions & 2 deletions compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,35 @@ namespace mlir::iree_compiler::IREE::HAL {
// TODO(benvanik): remove this and replace with the pass pipeline options.
// Controls executable translation targets.
struct TargetOptions {
// TODO(benvanik): multiple targets of the same type, etc.
std::vector<std::string> targets;
// TODO(benvanik): remove the legacy flag once users are switched to devices.
std::vector<std::string> legacyTargetBackends;

// Specifies target devices to assign to the program. May be omitted if the
// program already has devices assigned or no devices are required (host
// program not using the HAL).
//
// Two devices, one the local host device and the other a Vulkan device.
// `local`, `vulkan`
//
// One device selecting between Vulkan if available and otherwise use the
// local host device.
// `vulkan,local`
//
// Two CUDA devices selected by runtime ordinal; at runtime two --device=
// flags are required to configure both devices.
// `cuda[0]`, `cuda[1]`
//
// A fully-defined target specification:
// `#hal.device.target<"cuda", {...}, [#hal.executable.target<...>]>`
//
// Named device for defining a reference by #hal.device.promise<@some_name>.
// `some_name=vulkan`
std::vector<std::string> targetDevices;

// Which device is considered the default when no device affinity is specified
// on a particular operation. Accepts string names matching those specified
// in the target devices list or numeric ordinals if names were omitted.
std::string defaultDevice;

// Coarse debug level for executable translation across all targets.
// Each target backend can use this to control its own flags, with values
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <memory>
#include <utility>

#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"

namespace mlir::iree_compiler::IREE::HAL {

#define GEN_PASS_DEF_ASSIGNLEGACYTARGETDEVICESPASS
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc"

namespace {

//===----------------------------------------------------------------------===//
// --iree-hal-assign-legacy-target-devices
//===----------------------------------------------------------------------===//

struct AssignLegacyTargetDevicesPass
: public IREE::HAL::impl::AssignLegacyTargetDevicesPassBase<
AssignLegacyTargetDevicesPass> {
using IREE::HAL::impl::AssignLegacyTargetDevicesPassBase<
AssignLegacyTargetDevicesPass>::AssignLegacyTargetDevicesPassBase;

void runOnOperation() override {
auto moduleOp = getOperation();

// If no targets are specified we can't do anything - another pass earlier
// in the pipeline will have had to add the targets.
if (targetBackends.empty())
return;

// Check to see if targets are already specified and if so then no-op the
// pass so that we don't mess with whatever the user intended.
auto existingTargetsAttr =
moduleOp->getAttrOfType<ArrayAttr>("hal.device.targets");
if (existingTargetsAttr)
return;

// If there are any device globals declared then bail as it means the user
// has already materialized the devices they want.
for (auto globalOp : moduleOp.getOps<IREE::Util::GlobalOpInterface>()) {
if (isa<IREE::HAL::DeviceType>(globalOp.getGlobalType()))
return;
}

llvm::SmallDenseSet<Attribute> targetAttrSet;
SmallVector<Attribute> targetAttrs;
for (const auto &targetBackendName : targetBackends) {
auto targetBackend = targetRegistry->getTargetBackend(targetBackendName);
if (!targetBackend) {
auto diagnostic = emitError(moduleOp.getLoc())
<< "target backend '" << targetBackendName
<< "' not registered; registered backends: [";
llvm::interleaveComma(targetRegistry->getRegisteredTargetBackends(),
diagnostic);
diagnostic << "]";
return signalPassFailure();
}
auto targetDeviceName = targetBackend->getLegacyDefaultDeviceID();
auto targetDevice = targetRegistry->getTargetDevice(targetDeviceName);
if (!targetDevice) {
auto diagnostic = emitError(moduleOp.getLoc())
<< "target device '" << targetDeviceName
<< "' not registered; registered devices: [";
llvm::interleaveComma(targetRegistry->getRegisteredTargetDevices(),
diagnostic);
diagnostic << "]";
return signalPassFailure();
}

// Ask the target backend for its default device specification attribute.
auto targetAttr = targetDevice->getDefaultDeviceTarget(
moduleOp.getContext(), *targetRegistry.value);
if (!targetAttr) {
emitError(moduleOp.getLoc()) << "no default device targets available";
return signalPassFailure();
}
if (!targetAttrSet.contains(targetAttr)) {
targetAttrSet.insert(targetAttr);
targetAttrs.push_back(targetAttr);
}
}

Attribute targetsAttr;
if (targetAttrs.size() == 1) {
targetsAttr = targetAttrs.front();
} else {
targetsAttr =
IREE::HAL::DeviceSelectAttr::get(moduleOp.getContext(), targetAttrs);
}
moduleOp->setAttr("hal.device.targets",
ArrayAttr::get(moduleOp.getContext(), targetsAttr));
}
};

} // namespace

} // namespace mlir::iree_compiler::IREE::HAL
Loading

0 comments on commit 8870d72

Please sign in to comment.