Skip to content

Commit

Permalink
[WIP] Changing stream conversion to use a value/op affinity analysis.
Browse files Browse the repository at this point in the history
  • Loading branch information
benvanik committed Jun 28, 2024
1 parent 2fc41db commit adf845a
Show file tree
Hide file tree
Showing 96 changed files with 5,358 additions and 1,586 deletions.
89 changes: 63 additions & 26 deletions compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,17 @@ enum class TypeDisposition {
FENCE,
};

struct BarrierResult {
BlockArgument storage;
Type torchType;
int returnIndex = -1;
};

struct ConvertedAsyncFunctionInfo {
IREE::Util::FuncOp funcOp;
SmallVector<IREE::Util::ReturnOp> returnOps;
SmallVector<DictionaryAttr> torchArgAttrs;
SmallVector<DictionaryAttr> torchResultAttrs;
SmallVector<Type> torchInputTypes;
SmallVector<Type> torchResultTypes;
SmallVector<TypeDisposition> inputDispositions;
Expand All @@ -136,18 +144,33 @@ struct ConvertedAsyncFunctionInfo {
// Values that must be captured in the coarse barrier.
SmallVector<Value> barrierInputs;
// Meta data per barrier input: storage, torchType, returnIndex (or -1)
SmallVector<std::tuple<Value, Type, int>> barrierResultMeta;
SmallVector<BarrierResult> barrierResultMeta;

LogicalResult postProcess();
LogicalResult convertImmutableTensorArg(BlockArgument argValue,
Type torchType, OpBuilder &builder);
LogicalResult convertMutableTensorArg(BlockArgument argValue, Type torchType,
OpBuilder &builder);

void addBarrierInput(Value inputTensor, Value storage, Type torchType,
void addBarrierInput(Value inputTensor, BlockArgument storage, Type torchType,
int returnIndex) {
barrierInputs.push_back(inputTensor);
barrierResultMeta.emplace_back(storage, torchType, returnIndex);
barrierResultMeta.emplace_back(BarrierResult{
storage,
torchType,
returnIndex,
});
}

Attribute getTorchArgAttr(BlockArgument argValue, StringRef attrName) {
return torchArgAttrs.empty()
? Attribute{}
: torchArgAttrs[argValue.getArgNumber()].get(attrName);
}
Attribute getTorchResultAttr(int returnIndex, StringRef attrName) {
return torchResultAttrs.empty()
? Attribute{}
: torchResultAttrs[returnIndex].get(attrName);
}
};

Expand Down Expand Up @@ -232,7 +255,8 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() {
}
if (needsBarrier) {
Value source = convertToBuiltinTensor(postambleBuilder, returnValue);
addBarrierInput(source, /*storage=*/Value{}, torchType, returnIndex);
addBarrierInput(source, /*storage=*/BlockArgument{}, torchType,
returnIndex);
}
break;
}
Expand Down Expand Up @@ -276,44 +300,44 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() {
SmallVector<Value> aliasedResults;
for (auto [barrierInput, meta] :
llvm::zip_equal(barrierInputs, barrierResultMeta)) {
Value exportStorage;
Type torchType;
int returnIndex;
std::tie(exportStorage, torchType, returnIndex) = meta;
if (exportStorage) {
if (meta.storage) {
// Use the wait fence indicating when the storage is available for
// mutation. We need to ensure that no writes are made to the storage
// until it indicates it's safe to do so.
auto waitSignalFences = getEnclosingWaitSignalFences(exportStorage);
auto storageAffinityAttr =
getTorchArgAttr(meta.storage, "iree.abi.affinity");
auto waitSignalFences = getEnclosingWaitSignalFences(meta.storage);
assert(waitSignalFences && "async function missing fences");
Value waitFence = waitSignalFences->first;
auto barrierInputDims = IREE::Util::buildDynamicDimsForValue(
barrierInput.getLoc(), barrierInput, postambleBuilder);
aliasedResults.push_back(
postambleBuilder.create<IREE::HAL::TensorAliasOp>(
barrierInput.getLoc(), barrierInput.getType(), barrierInput,
barrierInputDims, exportStorage, waitFence,
/*affinity=*/nullptr));
barrierInputDims, meta.storage, waitFence,
storageAffinityAttr));
} else {
aliasedResults.push_back(barrierInput);
}
}
auto barrierOp = postambleBuilder.create<IREE::HAL::TensorBarrierOp>(
funcOp.getLoc(), aliasedResults, coarseSignalFence,
/*affinity=*/nullptr);
funcOp.getLoc(), aliasedResults, coarseSignalFence);
for (auto [barrierResult, meta] :
llvm::zip_equal(barrierOp.getResults(), barrierResultMeta)) {
Value exportStorage;
Type torchType;
int returnIndex;
std::tie(exportStorage, torchType, returnIndex) = meta;
Attribute exportAffinityAttr;
if (meta.storage) {
exportAffinityAttr = getTorchArgAttr(meta.storage, "iree.abi.affinity");
} else if (meta.returnIndex >= 0) {
exportAffinityAttr =
getTorchResultAttr(meta.returnIndex, "iree.abi.affinity");
}
Value exportedValue = postambleBuilder.create<IREE::HAL::TensorExportOp>(
funcOp.getLoc(),
postambleBuilder.getType<IREE::HAL::BufferViewType>(), barrierResult,
TypeAttr::get(barrierResult.getType()), /*name=*/nullptr,
/*affinity=*/nullptr);
if (returnIndex >= 0) {
newReturnOperands[returnIndex] = exportedValue;
exportAffinityAttr);
if (meta.returnIndex >= 0) {
newReturnOperands[meta.returnIndex] = exportedValue;
}
}
}
Expand Down Expand Up @@ -377,14 +401,16 @@ LogicalResult ConvertedAsyncFunctionInfo::convertImmutableTensorArg(
<< torchType;
}

// Propagate explicit affinities to the read.
auto affinityAttr = getTorchArgAttr(argValue, "iree.abi.affinity");

auto waitSignalFences = getEnclosingWaitSignalFences(argValue);
assert(waitSignalFences && "async function missing fences");
Value waitFence = waitSignalFences->first;
Value importedTensor = builder.create<IREE::HAL::TensorImportOp>(
loc, builtinTensorType, argValue, TypeAttr::get(builtinTensorType),
waitFence,
/*name=*/nullptr,
/*affinity=*/nullptr);
/*name=*/nullptr, affinityAttr);
if (builtinTensorType != torchType) {
importedTensor = builder.create<TorchConversion::FromBuiltinTensorOp>(
loc, torchType, importedTensor);
Expand All @@ -408,6 +434,9 @@ LogicalResult ConvertedAsyncFunctionInfo::convertMutableTensorArg(
.toBuiltinTensor();
}

// Propagate explicit affinities to the read and write.
auto affinityAttr = getTorchArgAttr(argValue, "iree.abi.affinity");

// There are only a small set of possible users of a mutable tensor.
// Handle them by operation here.
SmallVector<Operation *> users(argValue.getUsers());
Expand All @@ -419,8 +448,7 @@ LogicalResult ConvertedAsyncFunctionInfo::convertMutableTensorArg(
loc, builtinTensorType, argValue,
/*target_encoding=*/TypeAttr::get(builtinTensorType),
/*wait_fence*/ fences->first,
/*name=*/nullptr,
/*affinity=*/nullptr);
/*name=*/nullptr, affinityAttr);
rewriter.replaceOpWithNewOp<TorchConversion::FromBuiltinTensorOp>(
userOp, copyToVtOp.getResult().getType(), imported);
} else if (auto overwriteOp =
Expand All @@ -444,7 +472,6 @@ void retainFunctionAttributes(Operation *srcOp, IREE::Util::FuncOp destOp) {
// Allowlist of function attributes to retain when importing funcs.
constexpr const char *kRetainedAttributes[] = {
"iree.reflection",
"stream.affinity",
};
auto retainedAttributes = ArrayRef<const char *>(
kRetainedAttributes,
Expand Down Expand Up @@ -476,6 +503,9 @@ void createCoarseFencesSyncWrapper(StringRef syncFunctionName,
syncFuncOp.setSymVisibilityAttr(asyncFuncOp.getSymVisibilityAttr());
retainFunctionAttributes(asyncFuncOp, syncFuncOp);
syncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr());
if (auto affinityAttr = asyncFuncOp->getAttr("iree.abi.affinity")) {
syncFuncOp->setAttr("iree.abi.affinity", affinityAttr);
}
Block *entryBlock = syncFuncOp.addEntryBlock();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(entryBlock);
Expand Down Expand Up @@ -584,6 +614,10 @@ struct FuncConversionPass : public FuncConversionBase<FuncConversionPass> {
asyncFunctionName.append("$async");
}

// Stash arg/result attrs so they can be referenced during conversion.
torchFunc.getAllArgAttrs(convertedFuncInfo.torchArgAttrs);
torchFunc.getAllResultAttrs(convertedFuncInfo.torchResultAttrs);

// Convert function signature.
Type fenceType = rewriter.getType<IREE::HAL::FenceType>();
FunctionType torchFuncType = torchFunc.getFunctionType();
Expand Down Expand Up @@ -644,6 +678,9 @@ struct FuncConversionPass : public FuncConversionBase<FuncConversionPass> {
asyncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr());
asyncFuncOp->setAttr("iree.abi.model",
rewriter.getStringAttr("coarse-fences"));
if (auto affinityAttr = torchFunc->getAttr("iree.abi.affinity")) {
asyncFuncOp->setAttr("iree.abi.affinity", affinityAttr);
}
rewriter.inlineRegionBefore(
torchFunc.getBody(), asyncFuncOp.getFunctionBody(), asyncFuncOp.end());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,37 @@ func.func @main(%arg0: !torch.tensor<[5,4],f32>) -> (!torch.vtensor<[5,4],f32>)
}
}

// -----
// Tests the immutable + mutable argument case with explicit affinities.
// CHECK-LABEL: @mutable_input_overwrite_no_return
// CHECK: util.func public @main$async(
// CHECK-SAME: %arg0: !hal.buffer_view, %arg1: !hal.buffer_view,
// CHECK-SAME: %arg2: !hal.fence, %arg3: !hal.fence) -> !hal.buffer_view
// CHECK-DAG: %[[WAIT_ARG0:.+]] = hal.tensor.import on(#hal.device.promise<@dev_a>) wait(%arg2) => %arg0
// CHECK-DAG: %[[TORCH_ARG0:.+]] = torch_c.from_builtin_tensor %[[WAIT_ARG0]]
// CHECK-DAG: %[[WAIT_ARG1:.+]] = hal.tensor.import on(#hal.device.promise<@dev_b>) wait(%arg2) => %arg1
// CHECK-DAG: %[[TORCH_ARG1:.+]] = torch_c.from_builtin_tensor %[[WAIT_ARG1]]
// CHECK-DAG: %[[TORCH_RESULT0:.+]] = torch.operator "other_calc"(%[[TORCH_ARG0]])
// CHECK-DAG: %[[TORCH_RESULT1:.+]] = torch.operator "mutate_inplace"(%[[TORCH_ARG1]])
// CHECK-DAG: %[[TENSOR_ARG0:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT0]]
// CHECK-DAG: %[[TENSOR_ARG1:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT1]]
// CHECK: %[[EXPORT_ALIAS1:.+]] = hal.tensor.alias on(#hal.device.promise<@dev_b>) wait(%arg2) => %[[TENSOR_ARG1]] : tensor<5x4xf32> to %arg1 : !hal.buffer_view
// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%[[EXPORT_ALIAS1]], %[[TENSOR_ARG0]] : tensor<5x4xf32>, tensor<4x5xi32>) => %arg3 : !hal.fence
// CHECK-DAG: %[[EXPORT_RESULT0:.+]] = hal.tensor.export on(#hal.device.promise<@dev_b>) %[[BARRIER_RESULTS]]#0
// CHECK-DAG: %[[EXPORT_RESULT1:.+]] = hal.tensor.export on(#hal.device.promise<@dev_a>) %[[BARRIER_RESULTS]]#1
// CHECK: util.return %[[EXPORT_RESULT1]]
builtin.module @mutable_input_overwrite_no_return_affinities {
func.func @main(%arg0: !torch.vtensor<[4,5],si32> {iree.abi.affinity = #hal.device.promise<@dev_a>},
%arg1: !torch.tensor<[5,4],f32> {iree.abi.affinity = #hal.device.promise<@dev_b>})
-> (!torch.vtensor<[4,5],si32> {iree.abi.affinity = #hal.device.promise<@dev_a>}) {
%0 = torch.copy.to_vtensor %arg1 : !torch.vtensor<[5,4],f32>
%1 = torch.operator "mutate_inplace"(%0) : (!torch.vtensor<[5,4],f32>) -> !torch.vtensor<[5,4],f32>
%2 = torch.operator "other_calc"(%arg0) : (!torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32>
torch.overwrite.tensor.contents %1 overwrites %arg1 : !torch.vtensor<[5,4],f32>, !torch.tensor<[5,4],f32>
return %2 : !torch.vtensor<[4,5],si32>
}
}

// -----
// CHECK-LABEL: @retained_attribute_reflection
// CHECK: util.func public @main$async(
Expand Down
Loading

0 comments on commit adf845a

Please sign in to comment.