Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][bufferization] MaterializeInDestinationOp: Support memref destinations #68074

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Oct 3, 2023

Extend bufferization.materialize_in_destination to support memref destinations. This op can now be used to indicate that a tensor computation should materialize in a given buffer (that may have been allocated by another component/runtime). The op still participates in "empty tensor elimination".

Example:

func.func @test(%out: memref<10xf32>) {
  %t = tensor.empty() : tensor<10xf32>
  %c = linalg.generic ... outs(%t: tensor<10xf32>) -> tensor<10xf32>
  bufferization.materialize_in_destination %c in restrict writable %out : (tensor<10xf32>, memref<10xf32>) -> ()
  return
}

After "empty tensor elimination", the above IR can bufferize without an allocation:

func.func @test(%out: memref<10xf32>) {
  linalg.generic ... outs(%out: memref<10xf32>)
  return
}

This change also clarifies the meaning of the restrict unit attribute on bufferization.to_tensor ops.

@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2023

@llvm/pr-subscribers-mlir-bufferization
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Changes

Extend bufferization.materialize_in_destination to support memref destinations. This op can now be used to indicate that a tensor computation should materialize in a given buffer (that may have been allocated by another component/runtime). The op still participates in "empty tensor elimination".

Example:

func.func @<!-- -->test(%out: memref&lt;10xf32&gt;) {
  %t = tensor.empty() : tensor&lt;10xf32&gt;
  %c = linalg.generic ... outs(%t: tensor&lt;10xf32&gt;) -&gt; tensor&lt;10xf32&gt;
  bufferization.materialize_in_destination %c in %out : (tensor&lt;10xf32&gt;, memref&lt;10xf32&gt;) -&gt; ()

After "empty tensor elimination", the above IR can bufferize without an allocation. The "linalg.generic" is computed directly on %out.


Patch is 20.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68074.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+54-18)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+81-11)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp (+1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Padding.cpp (+4-2)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir (+2-2)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+14-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+14-1)
  • (modified) mlir/test/Dialect/Bufferization/invalid.mlir (+44-3)
  • (modified) mlir/test/Dialect/Bufferization/ops.mlir (+5-3)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 9761ab12134ad28..68d64e685eeabcb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -216,33 +216,56 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
 
 def Bufferization_MaterializeInDestinationOp
     : Bufferization_Op<"materialize_in_destination",
-        [BufferizableOpInterface, SameOperandsAndResultType,
-         DestinationStyleOpInterface,
+        [AllShapesMatch<["source", "dest"]>,
+         AllElementTypesMatch<["source", "dest"]>,
+         BufferizableOpInterface, DestinationStyleOpInterface,
          DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
          DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
             ["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
-             "buildSubsetExtraction", "isEquivalentSubset"]>]> {
+             "buildSubsetExtraction", "isEquivalentSubset"]>,
+         DeclareOpInterfaceMethods<MemoryEffectsOpInterface, ["getEffects"]>]> {
   let summary = "copy a tensor";
 
   let description = [{
     This op indicates that the data of the `source` tensor should materialize
-    in the future buffer of the `dest` tensors. Both tensors must have the same
-    shape and element type at runtime.
+    in `dest`, which can be a tensor or a memref. In case of a tensor, `source`
+    should materialize in the future buffer of `dest` and a the updated
+    destination tensor is returned. In case of a memref, `source` should
+    materialize in `dest`, which is already a buffer. The op has no results in
+    that case.
+
+    `source`, `dest` and `result` (if present) must have the same shape and
+    element type. If the op has a result, the types of `result` and `dest` must
+    match exactly (e.g., including any tensor encodings).
 
     By default, this op bufferizes to a memcpy from the future buffer of the
-    `source` tensor to the future buffer of the `dest` tensor. However,
-    transformations such as "empty tensor elimination" may rewrite IR such that
-    a computation is performed directly in the future buffer of the `dest`
-    tensor and no memcpy is needed.
-
-    Note: "tensor.insert_slice" could be used for the same purpose, but since
-    tensor dialect ops only indicate *what* should be computed but not *where*,
-    it could fold away, causing the computation to materialize in a different
-    buffer.
+    `source` tensor to the future buffer of the `dest` tensor or to the `dest`
+    buffer. However, transformations such as "empty tensor elimination" may
+    rewrite IR such that a computation is performed directly in `dest` and no
+    memcpy is needed.
+
+    If `dest` is a buffer, the "restrict" and "writable" attributes must be
+    specified. These attributes have the same meaning as the respective
+    attributes of "bufferization.to_tensor". "writable" indicates that the
+    `dest` buffer is considered writable. It does not make sense to materialize
+    a computation in a read-only buffer, so "writable" is required. "restrict"
+    indicates that `dest` does not alias with any memref passed to a "to_tensor"
+    op. Such aliasing is not supported by One-Shot Bufferize.
+
+    Note: "restrict" and "writable" could be removed from this op because they
+    must always be set for memref destinations. This op has these attributes to
+    make clear the requirements on the `dest` operand in the op assembly format.
+    Moreover, these requirements may be relaxed at some point in the future.
+
+    Note: If `dest` is a tensor, "tensor.insert_slice" could be used for the
+    same purpose, but since tensor dialect ops only indicate *what* should be
+    computed but not *where*, it could fold away, causing the computation to
+    materialize in a different buffer.
   }];
 
-  let arguments = (ins AnyTensor:$source, AnyTensor:$dest);
-  let results = (outs AnyTensor:$result);
+  let arguments = (ins AnyTensor:$source, AnyShaped:$dest,
+                       UnitAttr:$restrict, UnitAttr:$writable);
+  let results = (outs Optional<AnyTensor>:$result);
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
@@ -264,10 +287,23 @@ def Bufferization_MaterializeInDestinationOp
       return ::llvm::cast<RankedTensorType>(getResult().getType());
     }
 
-    MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
+    MutableOperandRange getDpsInitsMutable();
+
+    bool isWritable(Value value, const AnalysisState &state);
   }];
 
-  let assemblyFormat = "$source `in` $dest attr-dict `:` type($source)";
+  let builders = [
+    // Builder that materializes a source tensor in a tensor destination.
+    // Asserts that `dest` has tensor type. Infers the result type of this op
+    // from the destination tensor.
+    OpBuilder<(ins "Value":$source, "Value":$dest)>
+  ];
+
+  let assemblyFormat = [{
+    $source `in` (`restrict` $restrict^)? (`writable` $writable^)? $dest
+        attr-dict `:` functional-type(operands, results)
+  }];
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 7c6c1be351cced1..5b88b0201f05d40 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -542,13 +542,15 @@ bool MaterializeInDestinationOp::bufferizesToMemoryRead(
 
 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
     OpOperand &opOperand, const AnalysisState &state) {
-  return &opOperand == &getDestMutable()[0];
+  return isa<TensorType>(getDest().getType()) &&
+         &opOperand == &getDestMutable()[0];
 }
 
 AliasingValueList
 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
                                               const AnalysisState &state) {
-  if (&opOperand == &getDestMutable()[0])
+  if (isa<TensorType>(getDest().getType()) &&
+      &opOperand == &getDestMutable()[0])
     return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
   return {};
 }
@@ -556,11 +558,20 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
 LogicalResult
 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
                                       const BufferizationOptions &options) {
-  FailureOr<Value> buffer = getBuffer(rewriter, getDest(), options);
-  if (failed(buffer))
-    return failure();
-  rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), *buffer);
-  replaceOpWithBufferizedValues(rewriter, getOperation(), *buffer);
+  bool tensorDest = isa<TensorType>(getDest().getType());
+  Value buffer;
+  if (tensorDest) {
+    FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
+    if (failed(maybeBuffer))
+      return failure();
+    buffer = *maybeBuffer;
+  } else {
+    assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
+    buffer = getDest();
+  }
+  rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), buffer);
+  replaceOpWithBufferizedValues(rewriter, getOperation(),
+                                tensorDest ? ValueRange(buffer) : ValueRange());
   return success();
 }
 
@@ -573,15 +584,29 @@ bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
 
 LogicalResult MaterializeInDestinationOp::reifyResultShapes(
     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
-  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
+  if (getOperation()->getNumResults() == 1) {
+    assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
+    reifiedReturnShapes.resize(1,
+                               SmallVector<OpFoldResult>(getType().getRank()));
+    reifiedReturnShapes[0] =
+        tensor::getMixedSizes(builder, getLoc(), getDest());
+  }
   return success();
 }
 
 Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
                                                         Location loc) {
-  // The subset is the entire destination tensor.
-  return getDest();
+  if (isa<TensorType>(getDest().getType())) {
+    // The subset is the entire destination tensor.
+    return getDest();
+  }
+
+  // Build a bufferization.to_tensor op.
+  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
+  assert(getRestrict() &&
+         "expected that ops with memrefs dest have 'restrict'");
+  return builder.create<ToTensorOp>(loc, getDest(), getRestrict(),
+                                    getWritable());
 }
 
 bool MaterializeInDestinationOp::isEquivalentSubset(
@@ -598,6 +623,51 @@ OpOperand &MaterializeInDestinationOp::getSourceOperand() {
   return getOperation()->getOpOperand(0) /*source*/;
 }
 
+LogicalResult MaterializeInDestinationOp::verify() {
+  if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
+    return emitOpError("'dest' must be a tensor or a memref");
+  if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
+    if (getOperation()->getNumResults() != 1)
+      return emitOpError("tensor 'dest' implies exactly one tensor result");
+    if (destType != getResult().getType())
+      return emitOpError("result and 'dest' types must match");
+  }
+  if (isa<BaseMemRefType>(getDest().getType()) &&
+      getOperation()->getNumResults() != 0)
+    return emitOpError("memref 'dest' implies zero results");
+  if (getRestrict() != isa<BaseMemRefType>(getDest().getType()))
+    return emitOpError("'restrict' must be specified if and only if the "
+                       "destination is of memref type");
+  if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
+    return emitOpError("'writable' must be specified if and only if the "
+                       "destination is of memref type");
+  return success();
+}
+
+void MaterializeInDestinationOp::build(OpBuilder &builder,
+                                       OperationState &state, Value source,
+                                       Value dest) {
+  assert(isa<TensorType>(dest.getType()) && "expected tensor type");
+  build(builder, state, /*result=*/dest.getType(), source, dest);
+}
+
+bool MaterializeInDestinationOp::isWritable(Value value,
+                                            const AnalysisState &state) {
+  return isa<TensorType>(getDest().getType()) ? true : getWritable();
+}
+
+MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
+  return getDestMutable();
+}
+
+void MaterializeInDestinationOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (isa<BaseMemRefType>(getDest().getType()))
+    effects.emplace_back(MemoryEffects::Write::get(), getDest(),
+                         SideEffects::DefaultResource::get());
+}
+
 //===----------------------------------------------------------------------===//
 // ToTensorOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 77ad13dacaa9838..e37c20dc68c88a6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -149,6 +149,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
           op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
       if (!replacement)
         continue;
+
       if (replacement.getType() != v.getType()) {
         rewriter.setInsertionPointAfterValue(replacement);
         replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index a74a3c2c500406f..e6d80a39650ccf0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -248,8 +248,10 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
                LinalgPaddingOptions::CopyBackOp::
                    BufferizationMaterializeInDestination) {
       replacements.push_back(
-          rewriter.create<bufferization::MaterializeInDestinationOp>(
-              loc, std::get<0>(it), std::get<1>(it).get()));
+          rewriter
+              .create<bufferization::MaterializeInDestinationOp>(
+                  loc, std::get<0>(it), std::get<1>(it).get())
+              ->getResult(0));
     } else {
       llvm_unreachable("unsupported copy back op");
     }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index a2fbb06d179ebda..c3e44c426797f39 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -172,7 +172,7 @@ func.func @materialize_in_destination_aliasing(%t: tensor<?xf32>, %p1: index, %p
   %dest = tensor.extract_slice %t[%p2][5][1] : tensor<?xf32> to tensor<5xf32>
   // CHECK: bufferization.materialize_in_destination
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
-  %r = bufferization.materialize_in_destination %src in %dest : tensor<5xf32>
+  %r = bufferization.materialize_in_destination %src in %dest : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
   return %r : tensor<5xf32>
 }
 
@@ -183,6 +183,6 @@ func.func @materialize_in_destination(%t: tensor<?xf32>, %sz: index) -> tensor<?
   %buffer = tensor.empty(%sz) : tensor<?xf32>
   // CHECK: bufferization.materialize_in_destination
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
-  %r = bufferization.materialize_in_destination %buffer in %buffer : tensor<?xf32>
+  %r = bufferization.materialize_in_destination %buffer in %buffer : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   return %r : tensor<?xf32>
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 41e43047657daff..f3c063826e31a90 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -301,12 +301,25 @@ func.func @regression_multiple_insertion_points(%t1: tensor<?x?xf32>) -> tensor<
 func.func @materialize_in_destination(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
   %0 = tensor.empty() : tensor<5xf32>
   %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
-  %1 = bufferization.materialize_in_destination %filled in %t : tensor<5xf32>
+  %1 = bufferization.materialize_in_destination %filled in %t : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
   return %1 : tensor<5xf32>
 }
 
 // -----
 
+// CHECK-LABEL: func @materialize_in_destination_buffer(
+//  CHECK-SAME:     %[[m:.*]]: memref<5xf32>,
+//  CHECK-NEXT:   linalg.fill {{.*}} outs(%[[m]]
+//  CHECK-NEXT:   return
+func.func @materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32) {
+  %0 = tensor.empty() : tensor<5xf32>
+  %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+  bufferization.materialize_in_destination %filled in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @linalg_copy(
 //  CHECK-SAME:     %[[m:.*]]: memref<5xf32, strided<[?], offset: ?>>,
 //       CHECK:   linalg.fill {{.*}} outs(%[[m]]
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index 3f468750cc28405..c8a9b0b1fefc940 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -218,6 +218,19 @@ func.func @tensor_copy(%arg0: tensor<5xf32>) -> tensor<5xf32> {
   // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
   // CHECK: return %[[r]]
   %dest = bufferization.alloc_tensor() : tensor<5xf32>
-  %0 = bufferization.materialize_in_destination %arg0 in %dest : tensor<5xf32>
+  %0 = bufferization.materialize_in_destination %arg0 in %dest
+      : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
   return %0 : tensor<5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @materialize_in_destination_buffer(
+//  CHECK-SAME:     %[[t:.*]]: tensor<5xf32>, %[[m:.*]]: memref<5xf32>)
+//       CHECK:   %[[b:.*]] = bufferization.to_memref %[[t]] : memref<5xf32, strided<[?], offset: ?>>
+//       CHECK:   memref.copy %[[b]], %[[m]]
+func.func @materialize_in_destination_buffer(%t: tensor<5xf32>, %m: memref<5xf32>) {
+  bufferization.materialize_in_destination %t in restrict writable %m
+      : (tensor<5xf32>, memref<5xf32>) -> ()
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 3dfd1eb17e8d64f..5020ab9cb7368b1 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -66,10 +66,51 @@ func.func @invalid_writable_on_op() {
 
 // -----
 
-// expected-note @below{{prior use here}}
 func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
-  // expected-error @below{{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<5xf32>'}}
-  bufferization.materialize_in_destination %arg0 in %arg1 : tensor<?xf32>
+  // expected-error @below{{failed to verify that all of {source, dest} have same shape}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_restrict_missing(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
+  // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, memref<?xf32>) -> ()
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_result(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
+  // expected-error @below{{memref 'dest' implies zero results}}
+  bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, memref<?xf32>) -> (tensor<?xf32>)
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // expected-error @below{{tensor 'dest' implies exactly one tensor result}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> ()
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
+  bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // expected-error @below{{'writable' must be specified if and only if the destination is of memref type}}
+  bufferization.materialize_in_destination %arg0 in writable %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination_result_shape(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
+  // expected-error @below{{result and 'dest' types must match}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<6xf32>)
 }
 
 // -----
diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index dc53e535bfe0d57..d4bda0632189d41 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -59,10 +59,12 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
 }
...
[truncated]

@matthias-springer matthias-springer force-pushed the materialize_in_dest_buffer branch from ea217da to 4b7d289 Compare October 4, 2023 06:51
@github-actions
Copy link

github-actions bot commented Oct 4, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@matthias-springer matthias-springer force-pushed the materialize_in_dest_buffer branch 3 times, most recently from 7dbb2b3 to 696d5c0 Compare October 5, 2023 13:02
…stinations

Extend `bufferization.materialize_in_destination` to support memref destinations. This op can now be used to indicate that a tensor computation should materialize in a given buffer (that may have been allocated by another component/runtime). The op still participates in "empty tensor elimination".

Example:
```
func.func @test(%out: memref<10xf32>) {
  %t = tensor.empty() : tensor<10xf32>
  %c = linalg.generic ... outs(%t: tensor<10xf32>) -> tensor<10xf32>
  bufferization.materialize_in_destination %c in %out : (tensor<10xf32>, memref<10xf32>) -> ()
```

After "empty tensor elimination", the above IR can bufferize without an allocation. The "linalg.generic" is computed directly on %out.
@matthias-springer matthias-springer force-pushed the materialize_in_dest_buffer branch from 696d5c0 to 037663f Compare October 6, 2023 09:48
@matthias-springer matthias-springer merged commit 0fcaca2 into llvm:main Oct 6, 2023
@@ -172,7 +172,7 @@ func.func @materialize_in_destination_aliasing(%t: tensor<?xf32>, %p1: index, %p
%dest = tensor.extract_slice %t[%p2][5][1] : tensor<?xf32> to tensor<5xf32>
// CHECK: bufferization.materialize_in_destination
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
%r = bufferization.materialize_in_destination %src in %dest : tensor<5xf32>
%r = bufferization.materialize_in_destination %src in %dest : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this actually reads better.
Using only the type of the tensor of the memref is fully unambiguous, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because the result is optional. In case of a memref, there is no result (like linalg.generic). TableGen helpers such as AllTypesMatch etc. do not work with that. I could write the printer/parser in C++ though, then we could stay with the original syntax.

a computation in a read-only buffer, so `writable` is required. `restrict`
indicates that this op is the only way for the tensor IR to access `dest`
(or an alias thereof). E.g., there must be no other `to_tensor` ops with
`dest` or with an alias of `dest`. Such IR is not supported by
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are deeper safety concerns here right ?
We should make it more explicit, here and below, that using this op improperly will result in undefined behavior.

I think the current wording here and above is potentially error prone in that the user may expect the analysis to be conservative or produce warnings and/or errors.

In reality this is a very strict directive that is easy to misuse.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this sentence here and at ToTensorOp:
Ops that have incorrect usage of restrict may bufferize incorrectly.

Moreover, these requirements may be relaxed at some point in the future.

Note: If `dest` is a tensor, `tensor.insert_slice` could be used for the
same purpose, but since tensor dialect ops only indicate *what* should be
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this true ?

If one uses tensor.insert_slice, the analysis will insert copies if required, whereas materialize_in_destination is prescriptive IIUC.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, that is the desired behavior. It was actually not like that for tensor destinations.

`memref` operand.

Note: Only `to_tensor` ops with the `restrict` unit attribute are supported
by One-Shot Bufferize. Other IR is rejected. (To support `to_tensor`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, can we be more explicit about user's responsibility / UB vs what the analysis will catch ?

Copy link
Member Author

@matthias-springer matthias-springer Oct 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this sentence:
Ops that have incorrect usage of restrict may bufferize incorrectly.

Note: If `dest` is a tensor, `tensor.insert_slice` could be used for the
same purpose, but since tensor dialect ops only indicate *what* should be
computed but not *where*, it could fold away, causing the computation to
materialize in a different buffer.
Copy link
Contributor

@nicolasvasilache nicolasvasilache Oct 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another note maybe: is there any correctness guarantee when e.g. using this op as the last op in a function ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as restrict is not used incorrectly, the IR is guaranteed to bufferize correctly. If the computation cannot materialize in the specified tensor due to a RaW conflict or a read-only tensor, the IR fails to bufferize. (Added test cases.)

matthias-springer added a commit to matthias-springer/llvm-project that referenced this pull request Oct 7, 2023
Address additional comments in llvm#68074. This should have been part of llvm#68074.
matthias-springer added a commit that referenced this pull request Oct 7, 2023
Address additional comments in #68074. This should have been part of
#68074.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants