Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][bufferization] Implement BufferDeallocationopInterface for scf.forall.in_parallel #66351

Merged
merged 1 commit into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- BufferDeallocationOpInterfaceImpl.h ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
#define MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H

namespace mlir {

class DialectRegistry;

namespace scf {
void registerBufferDeallocationOpInterfaceExternalModels(
DialectRegistry &registry);
} // namespace scf
} // namespace mlir

#endif // MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
3 changes: 3 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
Expand Down Expand Up @@ -149,6 +151,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerValueBoundsOpInterfaceExternalModels(registry);
memref::registerMemorySlotExternalModels(registry);
scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerValueBoundsOpInterfaceExternalModels(registry);
shape::registerBufferizableOpInterfaceExternalModels(registry);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SCF/IR/SCF.h"

using namespace mlir;
using namespace mlir::bufferization;

namespace {
/// The `scf.forall.in_parallel` terminator is special in a few ways:
/// * It does not implement the BranchOpInterface or
/// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
/// which is not supported by BufferDeallocation.
/// * It has a graph-like region which only allows one specific tensor op
/// * After bufferization the nested region is always empty
/// For these reasons we provide custom deallocation logic via this external
/// model.
///
/// Example:
/// ```mlir
/// scf.forall (%arg1) in (%arg0) {
/// %alloc = memref.alloc() : memref<2xf32>
/// ...
/// <implicit in_parallel terminator here>
/// }
/// ```
/// gets transformed to
/// ```mlir
/// scf.forall (%arg1) in (%arg0) {
/// %alloc = memref.alloc() : memref<2xf32>
/// ...
/// bufferization.dealloc (%alloc : memref<2xf32>) if (%true)
/// <implicit in_parallel terminator here>
/// }
/// ```
struct InParallelOpInterface
: public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
scf::InParallelOp> {
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
const DeallocationOptions &options) const {
auto inParallelOp = cast<scf::InParallelOp>(op);
OpBuilder builder(op);
if (!inParallelOp.getBody()->empty())
return op->emitError("only supported when nested region is empty");

// Collect the values to deallocate and retain and use them to create the
// dealloc operation.
Block *block = op->getBlock();
SmallVector<Value> memrefs, conditions, toRetain;
if (failed(state.getMemrefsAndConditionsToDeallocate(
builder, op->getLoc(), block, memrefs, conditions)))
return failure();

state.getMemrefsToRetain(block, /*toBlock=*/nullptr, {}, toRetain);
if (memrefs.empty() && toRetain.empty())
return op;

auto deallocOp = builder.create<bufferization::DeallocOp>(
op->getLoc(), memrefs, conditions, toRetain);

// We want to replace the current ownership of the retained values with the
// result values of the dealloc operation as they are always unique.
state.resetOwnerships(deallocOp.getRetained(), block);
for (auto [retained, ownership] :
llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
state.updateOwnership(retained, ownership, block);

return op;
}
};

} // namespace

void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
});
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRSCFTransforms
BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ForToWhile.cpp
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/SCF/buffer-deallocation.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
// RUN: -buffer-deallocation-simplification -split-input-file %s | FileCheck %s

func.func @parallel_insert_slice(%arg0: index) {
%c0 = arith.constant 0 : index
%alloc = memref.alloc() : memref<2xf32>
scf.forall (%arg1) in (%arg0) {
%alloc0 = memref.alloc() : memref<2xf32>
%0 = memref.load %alloc[%c0] : memref<2xf32>
linalg.fill ins(%0 : f32) outs(%alloc0 : memref<2xf32>)
}
return
}

// CHECK-LABEL: func @parallel_insert_slice
// CHECK-SAME: (%arg0: index)
// CHECK: [[ALLOC0:%.+]] = memref.alloc(
// CHECK: scf.forall
// CHECK: [[ALLOC1:%.+]] = memref.alloc(
// CHECK: bufferization.dealloc ([[ALLOC1]] : memref<2xf32>) if (%true
// CHECK-NOT: retain
// CHECK: }
// CHECK: bufferization.dealloc ([[ALLOC0]] : memref<2xf32>) if (%true
// CHECK-NOT: retain