diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index ac128d0a97bda4..9386d0fd0f04fa 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -119,7 +119,11 @@ struct CollapseShapeOpInterface tensor::CollapseShapeOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return false; + // tensor.collapse_shape may reallocate, at which point the source buffer is + // copied. I.e., there will be a memory read side effect on the bufferized + // source. This function conservatively returns "true" because whether a + // copy will be created or not is not known at this point. + return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, @@ -291,6 +295,8 @@ struct ExpandShapeOpInterface tensor::ExpandShapeOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { + // In contrast to tensor.collapse_shape, this op can always be bufferized + // without a copy. return false; } @@ -841,6 +847,7 @@ struct ReshapeOpInterface tensor::ReshapeOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { + // Depending on the layout map, the source buffer may have to be copied. auto reshapeOp = cast(op); return &opOperand == &reshapeOp.getShapeMutable(); } @@ -870,15 +877,20 @@ struct ReshapeOpInterface return failure(); // memref.reshape requires the source buffer to have an identity layout. - // If the source memref does not have an identity layout, clone the source + // If the source memref does not have an identity layout, copy the source // into a new buffer with an identity layout. auto srcType = llvm::dyn_cast(srcBuffer->getType()); if (srcType && !srcType.getLayout().isIdentity()) { - auto identityType = - MemRefType::get(srcType.getShape(), srcType.getElementType()); + FailureOr tensorAlloc = allocateTensorForShapedValue( + rewriter, op->getLoc(), reshapeOp.getSource(), options); + if (failed(tensorAlloc)) + return failure(); + auto memrefType = MemRefType::get( + srcType.getShape(), srcType.getElementType(), AffineMap(), + cast(srcBuffer->getType()).getMemorySpace()); srcBuffer = rewriter - .create(op->getLoc(), - identityType, *srcBuffer) + .create( + op->getLoc(), memrefType, *tensorAlloc) .getResult(); } diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir index 9052744a1d3f98..38c3bb8af8107d 100644 --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -384,20 +384,45 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> { // ----- // CHECK-LABEL: @reshape_with_non_identity_layout( -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>>, -// CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>) -func.func @reshape_with_non_identity_layout(%arg0: tensor<2x2xf32>, %arg1: tensor<2xi32>) -> tensor<1x2xf32> { - - // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[INPUT]][1, 0] [1, 2] [1, 1] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2xf32, strided<[?], offset: ?>> - %extracted_slice = tensor.extract_slice %arg0[1, 0] [1, 2] [1, 1] : tensor<2x2xf32> to tensor<2xf32> +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>, +// CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>, +func.func @reshape_with_non_identity_layout(%arg0: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>, %arg1: tensor<2xi32>, %idx: index) -> f32 { + %t = bufferization.to_tensor %arg0 restrict : memref<2x2xf32, strided<[?, ?], offset: ?>, 3> + + // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[INPUT]][1, 0] [1, 2] [1, 1] : memref<2x2xf32, strided<[?, ?], offset: ?>, 3> to memref<2xf32, strided<[?], offset: ?>, 3> + %extracted_slice = tensor.extract_slice %t[1, 0] [1, 2] [1, 1] : tensor<2x2xf32> to tensor<2xf32> + + // To satisify the constraints of memref.reshape, the subview must be + // reallocated a buffer with an identity layout. + // CHECK: %[[ALLOC:.+]] = memref.alloc() {{.*}} : memref<2xf32, 3> + // CHECK: memref.copy %[[SUBVIEW]], %[[ALLOC]] + // CHECK: %[[RESHAPED:.+]] = memref.reshape %[[ALLOC]](%[[LAYOUT]]) : (memref<2xf32, 3>, memref<2xi32, strided<[?], offset: ?>>) -> memref<1x2xf32, 3> + %reshape = tensor.reshape %extracted_slice(%arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<1x2xf32> - // To satisify the constraints of memref.reshape, the subview must be cloned into - // a buffer with an identity layout. - // CHECK: %[[CLONED:.+]] = bufferization.clone %[[SUBVIEW]] : memref<2xf32, strided<[?], offset: ?>> to memref<2xf32> - // CHECK: %[[RESHAPED:.+]] = memref.reshape %[[CLONED]](%[[LAYOUT]]) : (memref<2xf32>, memref<2xi32, strided<[?], offset: ?>>) -> memref<1x2xf32> + %r = tensor.extract %reshape[%idx, %idx] : tensor<1x2xf32> + return %r : f32 +} - %reshape = tensor.reshape %extracted_slice(%arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<1x2xf32> +// ----- - // CHECK: return %[[RESHAPED]] : memref<1x2xf32> - return %reshape : tensor<1x2xf32> +// CHECK-LABEL: func @collapse_shape_regression( +// CHECK-SAME: %[[t:.*]]: memref<10x20xf32, +func.func @collapse_shape_regression( + %t: tensor<10x20xf32>, %f: f32, %idx: index) { + // CHECK: %[[subview:.*]] = memref.subview %[[t]] + %0 = tensor.extract_slice %t [2, 3] [5, 6] [1, 1] + : tensor<10x20xf32> to tensor<5x6xf32> + + // Insert a copy because the original %0 is read later. + // CHECK: %[[alloc1:.*]] = memref.alloc() {{.*}} : memref<5x6xf32> + // CHECK: memref.copy %[[subview]], %[[alloc1]] + // CHECK: memref.store {{.*}}, %[[alloc1]] + tensor.insert %f into %0[%idx, %idx] : tensor<5x6xf32> + + // Insert a copy because the shape cannot be collapsed. + // CHECK: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5x6xf32> + // CHECK: memref.copy %[[subview]], %[[alloc2]] + // CHECK: memref.collapse_shape %[[alloc2]] + tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32> + return }