Skip to content

Commit

Permalink
[Stream] Add support for executable duplication for tied operand case…
Browse files Browse the repository at this point in the history
…s. (#19953)

Previously, it did not consider the case that results have tied
operands. In the cases that have tied operands, we have to skip the
result encodings. Because the binding is shared between the tied operand
and the result. The function does not have the argument (i.e.,
duplicated binding) for the result.

In the revision, we tighten the definition of `stream.tensor.dispatch`
op. In the past, it allows the type of result being different from the
tied operand. Now we require any tied result must have a matching
encoding. Because they share the same bindings, and it is hard to track
which encoding is used in which binding today. Adding the check to the
verifier allows us to catch it early when the op is lowered from flow.

---------

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW authored Feb 14, 2025
1 parent 3e424a9 commit c2cb11a
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 7 deletions.
35 changes: 35 additions & 0 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,37 @@ verifyDispatchWorkload(Operation *op, IREE::Stream::ExecutableExportOp exportOp,
return success();
}

// Verifies the tied operand types are as the same as the result types.
static LogicalResult verifyTiedOperandEncodings(Operation *op,
ArrayAttr operandEncodingsAttr,
ArrayAttr resultEncodingsAttr) {
auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op);
if (!tiedOp) {
return op->emitOpError()
<< "the op does not implement IREE::Util::TiedOpInterface";
}

ArrayRef<Attribute> operandEncodings = operandEncodingsAttr.getValue();
unsigned tiedOperandBase = tiedOp.getTiedOperandsIndexAndLength().first;
for (auto [idx, resultEncoding] :
llvm::enumerate(resultEncodingsAttr.getValue())) {
auto tiedOperand = tiedOp.getTiedResultOperandIndex(idx);
if (!tiedOperand.has_value()) {
continue;
}
auto operandIndex = tiedOperand.value() - tiedOperandBase;
if (operandEncodings[operandIndex] != resultEncoding) {
return op->emitError()
<< "the " << operandIndex << "-th operandEncoding ("
<< operandEncodings[operandIndex]
<< ") does not match the resultEncoding (" << resultEncoding
<< ")";
}
}

return success();
}

// Verifies that |dynamicDims| contains the appropriate number of dims for all
// the dynamic dimensions in |type|.
static LogicalResult verifyOpDynamicDims(Operation *op, TypeRange types,
Expand Down Expand Up @@ -2112,6 +2143,10 @@ LogicalResult TensorDispatchOp::verify() {
op.getResultEncodingDims()))) {
return failure();
}
if (failed(verifyTiedOperandEncodings(op, op.getOperandEncodings(),
op.getResultEncodings()))) {
return failure();
}
return success();
}

Expand Down
17 changes: 15 additions & 2 deletions compiler/src/iree/compiler/Dialect/Stream/IR/test/tensor_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
// RUN: iree-opt --split-input-file %s --verify-diagnostics | FileCheck %s

// CHECK-LABEL: @tensorImport
util.func private @tensorImport(%arg0: !hal.buffer_view, %arg1: index) -> !stream.resource<external> {
Expand Down Expand Up @@ -162,7 +162,20 @@ util.func private @tensorDispatch(%arg0: !stream.resource<*>, %arg1: index, %arg
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
// CHECK: = stream.tensor.dispatch @executable::@dispatch[%c1, %c2, %c3](%arg0, %c4) :
// CHECK-SAME: (tensor<4x?xf32>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor<?x4xf32>{%arg2} in %arg0{%arg1}
// CHECK-SAME: (tensor<4x?xf32>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor<4x?xf32>{%arg2} in %arg0{%arg1}
%0 = stream.tensor.dispatch @executable::@dispatch[%c1, %c2, %c3](%arg0, %c4) : (tensor<4x?xf32>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor<4x?xf32>{%arg2} in %arg0{%arg1}
util.return %0 : !stream.resource<*>
}

// -----

util.func private @tensorDispatchMismatch(%arg0: !stream.resource<*>, %arg1: index, %arg2: index) -> !stream.resource<*> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
// expected-error @+1 {{the 0-th operandEncoding (tensor<4x?xf32>) does not match the resultEncoding (tensor<?x4xf32>)}}
%0 = stream.tensor.dispatch @executable::@dispatch[%c1, %c2, %c3](%arg0, %c4) : (tensor<4x?xf32>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor<?x4xf32>{%arg2} in %arg0{%arg1}
util.return %0 : !stream.resource<*>
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,53 @@ updateBindingEncodings(FunctionOpInterface funcOp,
return success();
}

/// Returns the operands encodings and result encodings from the `dispatchOp` in
/// |operands| + |results| order, i.e., it returns the stripped concatenated
/// operand encodings and result encodings. If a result is tied to an operand,
/// the result encoding is skipped. Because it shares the same binding with the
/// tied operands.
///
/// Example 1:
///
/// %0 = stream.tensor.dispatch ...(%arg0, %c4)
/// : (tensor<4x?xf32, #encoding> in !resource, index)
/// -> tensor<4x?xf32, #encoding> in !resource
///
/// The above dispatch op does not have tied operands. Thus, it returns
/// |#resolved_encoding, whatever_without_encoding, #resolved_encoding|
///
/// Example 2:
///
/// %0 = stream.tensor.dispatch ...(%arg0, %c4) : tensor<4x?xf32, #encoding>
/// -> tensor<4x?xf32, #encoding> in %arg0
///
/// The above dispatch op ties the result to the first operand. Thus, the result
/// encoding is stripped. It returns
/// |#resolved_encoding, whatever_without_encoding|
static SmallVector<Attribute>
getBindingLayoutAttrs(IREE::Stream::TensorDispatchOp dispatchOp) {
SmallVector<int64_t> tiedOperands(dispatchOp.getNumResults(),
IREE::Util::TiedOpInterface::kUntiedIndex);
if (std::optional<ArrayAttr> tiedOperandsAttr =
dispatchOp.getTiedOperands()) {
tiedOperands =
llvm::map_to_vector(tiedOperandsAttr.value(), [](Attribute intAttr) {
return llvm::cast<IntegerAttr>(intAttr).getInt();
});
}

SmallVector<Attribute> result(dispatchOp.getOperandEncodings().getValue());
for (auto [resultEncoding, tiedOperand] : llvm::zip_equal(
dispatchOp.getResultEncodings().getValue(), tiedOperands)) {
if (tiedOperand != IREE::Util::TiedOpInterface::kUntiedIndex) {
continue;
}
result.push_back(resultEncoding);
}

return result;
}

/// Duplicates stream.executables based on the operand encodings and result
/// encodings of stream.tensor.dispatch ops. Some executables can be launched by
/// different devices. It can produce wrong codegen artifacts when bindings
Expand Down Expand Up @@ -183,10 +230,8 @@ duplicateExecutablesPerLayoutVariant(ModuleOp moduleOp, SymbolTable symbolTable,
llvm::MapVector<IREE::Stream::TensorDispatchOp, SmallVector<Attribute>>
dispatchOpBindingLayouts;
for (auto dispatchOp : candidates) {
SmallVector<Attribute> bindingLayoutAttrs(
dispatchOp.getOperandEncodings().getValue());
llvm::append_range(bindingLayoutAttrs,
dispatchOp.getResultEncodings().getValue());
SmallVector<Attribute> bindingLayoutAttrs =
getBindingLayoutAttrs(dispatchOp);
dispatchOpBindingLayouts[dispatchOp] = bindingLayoutAttrs;
dispatchOp.forEachEntryPointAttr([&](SymbolRefAttr entryPoint) {
auto exportOp = cast<IREE::Stream::ExecutableExportOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ util.func public @denseTensorDispatch(
// CHECK-SAME: %[[RESOURCE1]][%[[ZERO]] to %[[RESOURCE1_SIZE]] for %[[RESOURCE1_SIZE]]])
// CHECK-SAME: (!stream.resource<transient>{%[[RESOURCE0_SIZE]]}, !stream.resource<external>{%[[RESOURCE1_SIZE]]}) ->
// CHECK-SAME: (!stream.resource<external>{%[[RESOURCE1_SIZE]]}, %[[RESOURCE1]]{%[[RESOURCE1_SIZE]]})
%results:2 = stream.tensor.dispatch @ex::@entry(%resource0, %resource1) : (tensor<4x?xf32>{%tensor0_dim} in !stream.resource<transient>{%resource0_size}, tensor<?xi32>{%tensor1_dim} in !stream.resource<external>{%resource1_size}) -> (tensor<?xi32>{%tensor1_dim} in !stream.resource<external>{%resource1_size}, tensor<?xf32>{%tensor1_dim} in %resource1{%resource1_size})
%results:2 = stream.tensor.dispatch @ex::@entry(%resource0, %resource1) : (tensor<4x?xf32>{%tensor0_dim} in !stream.resource<transient>{%resource0_size}, tensor<?xi32>{%tensor1_dim} in !stream.resource<external>{%resource1_size}) -> (tensor<4x?xf32>{%tensor0_dim} in !stream.resource<external>{%resource1_size}, tensor<?xi32>{%tensor1_dim} in %resource1{%resource1_size})
// CHECK: util.return %[[RESULTS]]#0, %[[RESULTS]]#1
util.return %results#0, %results#1 : !stream.resource<external>, !stream.resource<external>
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,44 @@ util.global private @device_a = #device_target_local_0_

// -----

#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
#encoding = #iree_encoding.testing_encoding<>

util.global private @device_a = #device_target_local_0_
stream.executable private @executable {
stream.executable.export public @dispatch
builtin.module {
func.func @dispatch(%arg0: !stream.binding, %arg1: index) {
%c0 = arith.constant 0 : index
%0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readwrite:tensor<16xf32, #encoding>>
return
}
}
}
util.func public @tensor_dispatch_with_tied_operands(%arg0: !stream.resource<external>, %arg1: index, %arg2: index) -> !stream.resource<*> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%0 = stream.async.transfer %arg0 : !stream.resource<external>{%arg2} from(#hal.device.affinity<@device_a>) -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%arg2}
%1 = stream.tensor.dispatch on(#hal.device.affinity<@device_a>) @executable::@dispatch[%c1, %c2, %c3](%0, %c4) : (tensor<4x?xf32, #encoding>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor<4x?xf32, #encoding>{%arg2} in %0{%arg1}
util.return %1 : !stream.resource<*>
}
// CHECK-DAG: #[[$ENCODING:.+]] = #iree_encoding.testing_encoding<[#iree_encoding.specialized_encoding<123, tensor<4x?xf32>>]>
// CHECK: #[[TARGET:.+]] = #hal.device.target
// CHECK: util.global private @[[$DEVICE:.+]] = #[[TARGET]]
// CHECK-LABEL: util.func public @tensor_dispatch_with_tied_operands
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
// CHECK: stream.tensor.dispatch on(#hal.device.affinity<@[[$DEVICE]]>)
// CHECK-SAME: tensor<4x?xf32, #[[$ENCODING]]>{%[[ARG2]]}
// CHECK-SAME: tensor<4x?xf32, #[[$ENCODING]]>{%[[ARG2]]}

// -----

#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
#device_target_local_1_ = #hal.device.target<"local", {ordinal = 1 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
Expand Down

0 comments on commit c2cb11a

Please sign in to comment.