diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index db93a51775ffcd..09ce2981d38268 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -216,33 +216,58 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [ def Bufferization_MaterializeInDestinationOp : Bufferization_Op<"materialize_in_destination", - [BufferizableOpInterface, SameOperandsAndResultType, - DestinationStyleOpInterface, + [AllShapesMatch<["source", "dest"]>, + AllElementTypesMatch<["source", "dest"]>, + BufferizableOpInterface, DestinationStyleOpInterface, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + "buildSubsetExtraction", "isEquivalentSubset"]>, + DeclareOpInterfaceMethods]> { let summary = "copy a tensor"; let description = [{ This op indicates that the data of the `source` tensor should materialize - in the future buffer of the `dest` tensors. Both tensors must have the same - shape and element type at runtime. + in `dest`, which can be a tensor or a memref. In case of a tensor, `source` + should materialize in the future buffer of `dest` and a the updated + destination tensor is returned. In case of a memref, `source` should + materialize in `dest`, which is already a buffer. The op has no results in + that case. + + `source`, `dest` and `result` (if present) must have the same shape and + element type. If the op has a result, the types of `result` and `dest` must + match exactly (e.g., including any tensor encodings). By default, this op bufferizes to a memcpy from the future buffer of the - `source` tensor to the future buffer of the `dest` tensor. However, - transformations such as "empty tensor elimination" may rewrite IR such that - a computation is performed directly in the future buffer of the `dest` - tensor and no memcpy is needed. - - Note: "tensor.insert_slice" could be used for the same purpose, but since - tensor dialect ops only indicate *what* should be computed but not *where*, - it could fold away, causing the computation to materialize in a different - buffer. + `source` tensor to the future buffer of the `dest` tensor or to the `dest` + buffer. However, transformations such as "empty tensor elimination" may + rewrite IR such that a computation is performed directly in `dest` and no + memcpy is needed. + + If `dest` is a buffer, the `restrict` and `writable` attributes must be + specified. These attributes have the same meaning as the respective + attributes of `bufferization.to_tensor`. `writable` indicates that the + `dest` buffer is considered writable. It does not make sense to materialize + a computation in a read-only buffer, so `writable` is required. `restrict` + indicates that this op is the only way for the tensor IR to access `dest` + (or an alias thereof). E.g., there must be no other `to_tensor` ops with + `dest` or with an alias of `dest`. Such IR is not supported by + One-Shot Bufferize. + + Note: `restrict` and `writable` could be removed from this op because they + must always be set for memref destinations. This op has these attributes to + make clear the requirements on the `dest` operand in the op assembly format. + Moreover, these requirements may be relaxed at some point in the future. + + Note: If `dest` is a tensor, `tensor.insert_slice` could be used for the + same purpose, but since tensor dialect ops only indicate *what* should be + computed but not *where*, it could fold away, causing the computation to + materialize in a different buffer. }]; - let arguments = (ins AnyTensor:$source, AnyTensor:$dest); - let results = (outs AnyTensor:$result); + let arguments = (ins AnyTensor:$source, AnyShaped:$dest, + UnitAttr:$restrict, UnitAttr:$writable); + let results = (outs Optional:$result); let extraClassDeclaration = [{ LogicalResult bufferize(RewriterBase &rewriter, @@ -264,10 +289,23 @@ def Bufferization_MaterializeInDestinationOp return ::llvm::cast(getResult().getType()); } - MutableOperandRange getDpsInitsMutable() { return getDestMutable(); } + MutableOperandRange getDpsInitsMutable(); + + bool isWritable(Value value, const AnalysisState &state); }]; - let assemblyFormat = "$source `in` $dest attr-dict `:` type($source)"; + let builders = [ + // Builder that materializes a source tensor in a tensor destination. + // Asserts that `dest` has tensor type. Infers the result type of this op + // from the destination tensor. + OpBuilder<(ins "Value":$source, "Value":$dest)> + ]; + + let assemblyFormat = [{ + $source `in` (`restrict` $restrict^)? (`writable` $writable^)? $dest + attr-dict `:` functional-type(operands, results) + }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -361,10 +399,15 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ thereof) will bufferize out-of-place to prevent emitting any writes to `memref` during bufferization. - If the given memref does not alias with any other memref passed to another - `to_tensor` op, the `restrict` unit attribute can be set. Only such - operations are supported by One-Shot Bufferize. (Otherwise, potential memref - aliasing relationships would have to be captured in One-Shot Bufferize.) + The `restrict` unit attribute (similar to the C `restrict` keyword) + indicates that the produced tensor result is the only way for the tensor + IR to gain access to the `memref` operand (or an alias thereof). E.g., + there must be no other `to_tensor` op with the same or with an aliasing + `memref` operand. + + Note: Only `to_tensor` ops with the `restrict` unit attribute are supported + by One-Shot Bufferize. Other IR is rejected. (To support `to_tensor` + without `restrict`, One-Shot Bufferize would have to analyze memref IR.) Example: diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 01cbacc96fd42d..1c33f444d15850 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -542,25 +542,40 @@ bool MaterializeInDestinationOp::bufferizesToMemoryRead( bool MaterializeInDestinationOp::bufferizesToMemoryWrite( OpOperand &opOperand, const AnalysisState &state) { - return &opOperand == &getDestMutable(); + if (&opOperand == &getDestMutable()) { + assert(isa(getDest().getType()) && "expected tensor type"); + return true; + } + return false; } AliasingValueList MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand, const AnalysisState &state) { - if (&opOperand == &getDestMutable()) + if (&opOperand == &getDestMutable()) { + assert(isa(getDest().getType()) && "expected tensor type"); return {{getOperation()->getResult(0), BufferRelation::Equivalent}}; + } return {}; } LogicalResult MaterializeInDestinationOp::bufferize(RewriterBase &rewriter, const BufferizationOptions &options) { - FailureOr buffer = getBuffer(rewriter, getDest(), options); - if (failed(buffer)) - return failure(); - rewriter.create(getLoc(), getSource(), *buffer); - replaceOpWithBufferizedValues(rewriter, getOperation(), *buffer); + bool tensorDest = isa(getDest().getType()); + Value buffer; + if (tensorDest) { + FailureOr maybeBuffer = getBuffer(rewriter, getDest(), options); + if (failed(maybeBuffer)) + return failure(); + buffer = *maybeBuffer; + } else { + assert(isa(getDest().getType()) && "expected memref type"); + buffer = getDest(); + } + rewriter.create(getLoc(), getSource(), buffer); + replaceOpWithBufferizedValues(rewriter, getOperation(), + tensorDest ? ValueRange(buffer) : ValueRange()); return success(); } @@ -573,15 +588,29 @@ bool MaterializeInDestinationOp::bufferizesToElementwiseAccess( LogicalResult MaterializeInDestinationOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); - reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest()); + if (getOperation()->getNumResults() == 1) { + assert(isa(getDest().getType()) && "expected tensor type"); + reifiedReturnShapes.resize(1, + SmallVector(getType().getRank())); + reifiedReturnShapes[0] = + tensor::getMixedSizes(builder, getLoc(), getDest()); + } return success(); } Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder, Location loc) { - // The subset is the entire destination tensor. - return getDest(); + if (isa(getDest().getType())) { + // The subset is the entire destination tensor. + return getDest(); + } + + // Build a bufferization.to_tensor op. + assert(isa(getDest().getType()) && "expected memref type"); + assert(getRestrict() && + "expected that ops with memrefs dest have 'restrict'"); + return builder.create(loc, getDest(), getRestrict(), + getWritable()); } bool MaterializeInDestinationOp::isEquivalentSubset( @@ -598,6 +627,51 @@ OpOperand &MaterializeInDestinationOp::getSourceOperand() { return getOperation()->getOpOperand(0) /*source*/; } +LogicalResult MaterializeInDestinationOp::verify() { + if (!isa(getDest().getType())) + return emitOpError("'dest' must be a tensor or a memref"); + if (auto destType = dyn_cast(getDest().getType())) { + if (getOperation()->getNumResults() != 1) + return emitOpError("tensor 'dest' implies exactly one tensor result"); + if (destType != getResult().getType()) + return emitOpError("result and 'dest' types must match"); + } + if (isa(getDest().getType()) && + getOperation()->getNumResults() != 0) + return emitOpError("memref 'dest' implies zero results"); + if (getRestrict() != isa(getDest().getType())) + return emitOpError("'restrict' must be specified if and only if the " + "destination is of memref type"); + if (getWritable() != isa(getDest().getType())) + return emitOpError("'writable' must be specified if and only if the " + "destination is of memref type"); + return success(); +} + +void MaterializeInDestinationOp::build(OpBuilder &builder, + OperationState &state, Value source, + Value dest) { + assert(isa(dest.getType()) && "expected tensor type"); + build(builder, state, /*result=*/dest.getType(), source, dest); +} + +bool MaterializeInDestinationOp::isWritable(Value value, + const AnalysisState &state) { + return isa(getDest().getType()) ? true : getWritable(); +} + +MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() { + return getDestMutable(); +} + +void MaterializeInDestinationOp::getEffects( + SmallVectorImpl> + &effects) { + if (isa(getDest().getType())) + effects.emplace_back(MemoryEffects::Write::get(), getDest(), + SideEffects::DefaultResource::get()); +} + //===----------------------------------------------------------------------===// // ToTensorOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp index a74a3c2c500406..e6d80a39650ccf 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -248,8 +248,10 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, LinalgPaddingOptions::CopyBackOp:: BufferizationMaterializeInDestination) { replacements.push_back( - rewriter.create( - loc, std::get<0>(it), std::get<1>(it).get())); + rewriter + .create( + loc, std::get<0>(it), std::get<1>(it).get()) + ->getResult(0)); } else { llvm_unreachable("unsupported copy back op"); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir index a2fbb06d179ebd..c3e44c426797f3 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir @@ -172,7 +172,7 @@ func.func @materialize_in_destination_aliasing(%t: tensor, %p1: index, %p %dest = tensor.extract_slice %t[%p2][5][1] : tensor to tensor<5xf32> // CHECK: bufferization.materialize_in_destination // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]} - %r = bufferization.materialize_in_destination %src in %dest : tensor<5xf32> + %r = bufferization.materialize_in_destination %src in %dest : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> return %r : tensor<5xf32> } @@ -183,6 +183,6 @@ func.func @materialize_in_destination(%t: tensor, %sz: index) -> tensor // CHECK: bufferization.materialize_in_destination // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]} - %r = bufferization.materialize_in_destination %buffer in %buffer : tensor + %r = bufferization.materialize_in_destination %buffer in %buffer : (tensor, tensor) -> tensor return %r : tensor } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir index b68682a459ed2c..99b974b9ef3c67 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -301,12 +301,25 @@ func.func @regression_multiple_insertion_points(%t1: tensor) -> tensor< func.func @materialize_in_destination(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> { %0 = tensor.empty() : tensor<5xf32> %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> - %1 = bufferization.materialize_in_destination %filled in %t : tensor<5xf32> + %1 = bufferization.materialize_in_destination %filled in %t : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> return %1 : tensor<5xf32> } // ----- +// CHECK-LABEL: func @materialize_in_destination_buffer( +// CHECK-SAME: %[[m:.*]]: memref<5xf32>, +// CHECK-NEXT: linalg.fill {{.*}} outs(%[[m]] +// CHECK-NEXT: return +func.func @materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32) { + %0 = tensor.empty() : tensor<5xf32> + %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> + bufferization.materialize_in_destination %filled in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> () + return +} + +// ----- + // CHECK-LABEL: func @linalg_copy( // CHECK-SAME: %[[m:.*]]: memref<5xf32, strided<[?], offset: ?>>, // CHECK: linalg.fill {{.*}} outs(%[[m]] diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir index 3f468750cc2840..272423de5564b0 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir @@ -218,6 +218,20 @@ func.func @tensor_copy(%arg0: tensor<5xf32>) -> tensor<5xf32> { // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] // CHECK: return %[[r]] %dest = bufferization.alloc_tensor() : tensor<5xf32> - %0 = bufferization.materialize_in_destination %arg0 in %dest : tensor<5xf32> + %0 = bufferization.materialize_in_destination %arg0 in %dest + : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> return %0 : tensor<5xf32> } + +// ----- + +// CHECK-LABEL: func @materialize_in_destination_buffer( +// CHECK-SAME: %[[t:.*]]: tensor<5xf32>, %[[m:.*]]: memref<5xf32>) +// CHECK: %[[b:.*]] = bufferization.to_memref %[[t]] : memref<5xf32, strided<[?], offset: ?>> +// CHECK: memref.copy %[[b]], %[[m]] +func.func @materialize_in_destination_buffer(%t: tensor<5xf32>, %m: memref<5xf32>) { + bufferization.materialize_in_destination %t in restrict writable %m + : (tensor<5xf32>, memref<5xf32>) -> () + return +} + diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir index 8004ec632453e8..ce56f89c1f1bbe 100644 --- a/mlir/test/Dialect/Bufferization/invalid.mlir +++ b/mlir/test/Dialect/Bufferization/invalid.mlir @@ -66,10 +66,58 @@ func.func @invalid_writable_on_op() { // ----- -// expected-note @below{{prior use here}} func.func @invalid_materialize_in_destination(%arg0: tensor, %arg1: tensor<5xf32>) { - // expected-error @below{{expects different type than prior uses: 'tensor' vs 'tensor<5xf32>'}} - bufferization.materialize_in_destination %arg0 in %arg1 : tensor + // expected-error @below{{failed to verify that all of {source, dest} have same shape}} + bufferization.materialize_in_destination %arg0 in %arg1 : (tensor, tensor<5xf32>) -> tensor<5xf32> +} + +// ----- + +func.func @invalid_materialize_in_destination_dest_type(%arg0: tensor<5xf32>, %arg1: vector<5xf32>) { + // expected-error @below{{'dest' must be a tensor or a memref}} + bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5xf32>, vector<5xf32>) -> () +} + +// ----- + +func.func @invalid_materialize_in_destination_restrict_missing(%arg0: tensor, %arg1: memref) { + // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}} + bufferization.materialize_in_destination %arg0 in %arg1 : (tensor, memref) -> () +} + +// ----- + +func.func @invalid_materialize_in_destination_result(%arg0: tensor, %arg1: memref) { + // expected-error @below{{memref 'dest' implies zero results}} + bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor, memref) -> (tensor) +} + +// ----- + +func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor, %arg1: tensor) { + // expected-error @below{{tensor 'dest' implies exactly one tensor result}} + bufferization.materialize_in_destination %arg0 in %arg1 : (tensor, tensor) -> () +} + +// ----- + +func.func @invalid_materialize_in_destination_restrict(%arg0: tensor, %arg1: tensor) { + // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}} + bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor, tensor) -> (tensor) +} + +// ----- + +func.func @invalid_materialize_in_destination_restrict(%arg0: tensor, %arg1: tensor) { + // expected-error @below{{'writable' must be specified if and only if the destination is of memref type}} + bufferization.materialize_in_destination %arg0 in writable %arg1 : (tensor, tensor) -> (tensor) +} + +// ----- + +func.func @invalid_materialize_in_destination_result_shape(%arg0: tensor, %arg1: tensor) { + // expected-error @below{{result and 'dest' types must match}} + bufferization.materialize_in_destination %arg0 in %arg1 : (tensor, tensor) -> (tensor<6xf32>) } // ----- diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir index dc53e535bfe0d5..d4bda0632189d4 100644 --- a/mlir/test/Dialect/Bufferization/ops.mlir +++ b/mlir/test/Dialect/Bufferization/ops.mlir @@ -59,10 +59,12 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) { } // CHECK-LABEL: func @test_materialize_in_destination_op -func.func @test_materialize_in_destination_op(%arg0: tensor, %arg1: tensor) +func.func @test_materialize_in_destination_op(%arg0: tensor, %arg1: tensor, %arg2: memref) -> tensor { - // CHECK: bufferization.materialize_in_destination {{.*}} : tensor - %1 = bufferization.materialize_in_destination %arg0 in %arg1 : tensor + // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor, tensor) -> tensor + %1 = bufferization.materialize_in_destination %arg0 in %arg1 : (tensor, tensor) -> tensor + // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor, memref) -> () + bufferization.materialize_in_destination %arg0 in restrict writable %arg2 : (tensor, memref) -> () return %1 : tensor }