From c2cb11a5fcca5147f30bbd5d5b06a1c511f0736c Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Fri, 14 Feb 2025 15:05:23 -0800 Subject: [PATCH] [Stream] Add support for executable duplication for tied operand cases. (#19953) Previously, it did not consider the case that results have tied operands. In the cases that have tied operands, we have to skip the result encodings. Because the binding is shared between the tied operand and the result. The function does not have the argument (i.e., duplicated binding) for the result. In the revision, we tighten the definition of `stream.tensor.dispatch` op. In the past, it allows the type of result being different from the tied operand. Now we require any tied result must have a matching encoding. Because they share the same bindings, and it is hard to track which encoding is used in which binding today. Adding the check to the verifier allows us to catch it early when the op is lowered from flow. --------- Signed-off-by: hanhanW --- .../compiler/Dialect/Stream/IR/StreamOps.cpp | 35 ++++++++++++ .../Dialect/Stream/IR/test/tensor_ops.mlir | 17 +++++- .../Stream/Transforms/SpecializeEncodings.cpp | 53 +++++++++++++++++-- .../Transforms/test/encode_host_tensors.mlir | 2 +- .../Transforms/test/specialize_encodings.mlir | 38 +++++++++++++ 5 files changed, 138 insertions(+), 7 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp index 7f50ee2f355d..86c74236430e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp @@ -76,6 +76,37 @@ verifyDispatchWorkload(Operation *op, IREE::Stream::ExecutableExportOp exportOp, return success(); } +// Verifies the tied operand types are as the same as the result types. +static LogicalResult verifyTiedOperandEncodings(Operation *op, + ArrayAttr operandEncodingsAttr, + ArrayAttr resultEncodingsAttr) { + auto tiedOp = dyn_cast(op); + if (!tiedOp) { + return op->emitOpError() + << "the op does not implement IREE::Util::TiedOpInterface"; + } + + ArrayRef operandEncodings = operandEncodingsAttr.getValue(); + unsigned tiedOperandBase = tiedOp.getTiedOperandsIndexAndLength().first; + for (auto [idx, resultEncoding] : + llvm::enumerate(resultEncodingsAttr.getValue())) { + auto tiedOperand = tiedOp.getTiedResultOperandIndex(idx); + if (!tiedOperand.has_value()) { + continue; + } + auto operandIndex = tiedOperand.value() - tiedOperandBase; + if (operandEncodings[operandIndex] != resultEncoding) { + return op->emitError() + << "the " << operandIndex << "-th operandEncoding (" + << operandEncodings[operandIndex] + << ") does not match the resultEncoding (" << resultEncoding + << ")"; + } + } + + return success(); +} + // Verifies that |dynamicDims| contains the appropriate number of dims for all // the dynamic dimensions in |type|. static LogicalResult verifyOpDynamicDims(Operation *op, TypeRange types, @@ -2112,6 +2143,10 @@ LogicalResult TensorDispatchOp::verify() { op.getResultEncodingDims()))) { return failure(); } + if (failed(verifyTiedOperandEncodings(op, op.getOperandEncodings(), + op.getResultEncodings()))) { + return failure(); + } return success(); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/tensor_ops.mlir index c6d20e13acca..e82b3bf6872f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/tensor_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/tensor_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s +// RUN: iree-opt --split-input-file %s --verify-diagnostics | FileCheck %s // CHECK-LABEL: @tensorImport util.func private @tensorImport(%arg0: !hal.buffer_view, %arg1: index) -> !stream.resource { @@ -162,7 +162,20 @@ util.func private @tensorDispatch(%arg0: !stream.resource<*>, %arg1: index, %arg %c3 = arith.constant 3 : index %c4 = arith.constant 4 : index // CHECK: = stream.tensor.dispatch @executable::@dispatch[%c1, %c2, %c3](%arg0, %c4) : - // CHECK-SAME: (tensor<4x?xf32>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor{%arg2} in %arg0{%arg1} + // CHECK-SAME: (tensor<4x?xf32>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor<4x?xf32>{%arg2} in %arg0{%arg1} + %0 = stream.tensor.dispatch @executable::@dispatch[%c1, %c2, %c3](%arg0, %c4) : (tensor<4x?xf32>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor<4x?xf32>{%arg2} in %arg0{%arg1} + util.return %0 : !stream.resource<*> +} + +// ----- + +util.func private @tensorDispatchMismatch(%arg0: !stream.resource<*>, %arg1: index, %arg2: index) -> !stream.resource<*> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + // expected-error @+1 {{the 0-th operandEncoding (tensor<4x?xf32>) does not match the resultEncoding (tensor)}} %0 = stream.tensor.dispatch @executable::@dispatch[%c1, %c2, %c3](%arg0, %c4) : (tensor<4x?xf32>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor{%arg2} in %arg0{%arg1} util.return %0 : !stream.resource<*> } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp index 8eb2bc14e569..82a4e14459ac 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp @@ -146,6 +146,53 @@ updateBindingEncodings(FunctionOpInterface funcOp, return success(); } +/// Returns the operands encodings and result encodings from the `dispatchOp` in +/// |operands| + |results| order, i.e., it returns the stripped concatenated +/// operand encodings and result encodings. If a result is tied to an operand, +/// the result encoding is skipped. Because it shares the same binding with the +/// tied operands. +/// +/// Example 1: +/// +/// %0 = stream.tensor.dispatch ...(%arg0, %c4) +/// : (tensor<4x?xf32, #encoding> in !resource, index) +/// -> tensor<4x?xf32, #encoding> in !resource +/// +/// The above dispatch op does not have tied operands. Thus, it returns +/// |#resolved_encoding, whatever_without_encoding, #resolved_encoding| +/// +/// Example 2: +/// +/// %0 = stream.tensor.dispatch ...(%arg0, %c4) : tensor<4x?xf32, #encoding> +/// -> tensor<4x?xf32, #encoding> in %arg0 +/// +/// The above dispatch op ties the result to the first operand. Thus, the result +/// encoding is stripped. It returns +/// |#resolved_encoding, whatever_without_encoding| +static SmallVector +getBindingLayoutAttrs(IREE::Stream::TensorDispatchOp dispatchOp) { + SmallVector tiedOperands(dispatchOp.getNumResults(), + IREE::Util::TiedOpInterface::kUntiedIndex); + if (std::optional tiedOperandsAttr = + dispatchOp.getTiedOperands()) { + tiedOperands = + llvm::map_to_vector(tiedOperandsAttr.value(), [](Attribute intAttr) { + return llvm::cast(intAttr).getInt(); + }); + } + + SmallVector result(dispatchOp.getOperandEncodings().getValue()); + for (auto [resultEncoding, tiedOperand] : llvm::zip_equal( + dispatchOp.getResultEncodings().getValue(), tiedOperands)) { + if (tiedOperand != IREE::Util::TiedOpInterface::kUntiedIndex) { + continue; + } + result.push_back(resultEncoding); + } + + return result; +} + /// Duplicates stream.executables based on the operand encodings and result /// encodings of stream.tensor.dispatch ops. Some executables can be launched by /// different devices. It can produce wrong codegen artifacts when bindings @@ -183,10 +230,8 @@ duplicateExecutablesPerLayoutVariant(ModuleOp moduleOp, SymbolTable symbolTable, llvm::MapVector> dispatchOpBindingLayouts; for (auto dispatchOp : candidates) { - SmallVector bindingLayoutAttrs( - dispatchOp.getOperandEncodings().getValue()); - llvm::append_range(bindingLayoutAttrs, - dispatchOp.getResultEncodings().getValue()); + SmallVector bindingLayoutAttrs = + getBindingLayoutAttrs(dispatchOp); dispatchOpBindingLayouts[dispatchOp] = bindingLayoutAttrs; dispatchOp.forEachEntryPointAttr([&](SymbolRefAttr entryPoint) { auto exportOp = cast( diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/encode_host_tensors.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/encode_host_tensors.mlir index 9a97b9ec8323..af60eeebe6e3 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/encode_host_tensors.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/encode_host_tensors.mlir @@ -282,7 +282,7 @@ util.func public @denseTensorDispatch( // CHECK-SAME: %[[RESOURCE1]][%[[ZERO]] to %[[RESOURCE1_SIZE]] for %[[RESOURCE1_SIZE]]]) // CHECK-SAME: (!stream.resource{%[[RESOURCE0_SIZE]]}, !stream.resource{%[[RESOURCE1_SIZE]]}) -> // CHECK-SAME: (!stream.resource{%[[RESOURCE1_SIZE]]}, %[[RESOURCE1]]{%[[RESOURCE1_SIZE]]}) - %results:2 = stream.tensor.dispatch @ex::@entry(%resource0, %resource1) : (tensor<4x?xf32>{%tensor0_dim} in !stream.resource{%resource0_size}, tensor{%tensor1_dim} in !stream.resource{%resource1_size}) -> (tensor{%tensor1_dim} in !stream.resource{%resource1_size}, tensor{%tensor1_dim} in %resource1{%resource1_size}) + %results:2 = stream.tensor.dispatch @ex::@entry(%resource0, %resource1) : (tensor<4x?xf32>{%tensor0_dim} in !stream.resource{%resource0_size}, tensor{%tensor1_dim} in !stream.resource{%resource1_size}) -> (tensor<4x?xf32>{%tensor0_dim} in !stream.resource{%resource1_size}, tensor{%tensor1_dim} in %resource1{%resource1_size}) // CHECK: util.return %[[RESULTS]]#0, %[[RESULTS]]#1 util.return %results#0, %results#1 : !stream.resource, !stream.resource } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir index 80a13ad8133b..8bde336344b7 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir @@ -215,6 +215,44 @@ util.global private @device_a = #device_target_local_0_ // ----- +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +#encoding = #iree_encoding.testing_encoding<> + +util.global private @device_a = #device_target_local_0_ +stream.executable private @executable { + stream.executable.export public @dispatch + builtin.module { + func.func @dispatch(%arg0: !stream.binding, %arg1: index) { + %c0 = arith.constant 0 : index + %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor> + return + } + } +} +util.func public @tensor_dispatch_with_tied_operands(%arg0: !stream.resource, %arg1: index, %arg2: index) -> !stream.resource<*> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %0 = stream.async.transfer %arg0 : !stream.resource{%arg2} from(#hal.device.affinity<@device_a>) -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%arg2} + %1 = stream.tensor.dispatch on(#hal.device.affinity<@device_a>) @executable::@dispatch[%c1, %c2, %c3](%0, %c4) : (tensor<4x?xf32, #encoding>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor<4x?xf32, #encoding>{%arg2} in %0{%arg1} + util.return %1 : !stream.resource<*> +} +// CHECK-DAG: #[[$ENCODING:.+]] = #iree_encoding.testing_encoding<[#iree_encoding.specialized_encoding<123, tensor<4x?xf32>>]> +// CHECK: #[[TARGET:.+]] = #hal.device.target +// CHECK: util.global private @[[$DEVICE:.+]] = #[[TARGET]] +// CHECK-LABEL: util.func public @tensor_dispatch_with_tied_operands +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: stream.tensor.dispatch on(#hal.device.affinity<@[[$DEVICE]]>) +// CHECK-SAME: tensor<4x?xf32, #[[$ENCODING]]>{%[[ARG2]]} +// CHECK-SAME: tensor<4x?xf32, #[[$ENCODING]]>{%[[ARG2]]} + +// ----- + #executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> #device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device #device_target_local_1_ = #hal.device.target<"local", {ordinal = 1 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device