Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Slow copy coming from tensor.insert_slice with dynamic dims #15195

Open
Max191 opened this issue Oct 16, 2023 · 8 comments
Open

[CPU] Slow copy coming from tensor.insert_slice with dynamic dims #15195

Max191 opened this issue Oct 16, 2023 · 8 comments
Assignees
Labels
bug 🐞 Something isn't working

Comments

@Max191
Copy link
Contributor

Max191 commented Oct 16, 2023

What happened?

While burning down performance on llama2 for CPU, I ran into a slow copy dispatch that came from a tensor.insert_slice op inserting on an inner dim. After doing some rewrites, the insert now looks like this:

  %inserted_slice = tensor.insert_slice %collapsed_0 into %5[0, 0, 0] [%0, 32, 128] [1, 1, 1] : tensor<?x32x128xf32> into tensor<?x32x128xf32>

This should be able to turn into a flow.tensor.update op since the inner dimensions are contiguous, but it fails to rewrite due to the dynamic dim:
https://github.com/openxla/iree/blob/ebdb098b216c3e59a9977902823ede613f553f71/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.cpp#L73-L81

I'm figuring on CPU it should be okay to have dynamic dims here since we don't have to worry about round trips to device. Should we enable dynamic dims in the flow.tensor.update rewrite here?

Steps to reproduce your issue

I have been playing with the relevant workload, and am working with this IR now:
https://gist.github.com/Max191/f9705764cc3cd650f3c547071dcc03a9

Compiling on ToM with:

iree-compile --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-llvmcpu-enable-microkernels --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-vm-bytecode-module-strip-source-map=true --iree-util-zero-fill-elided-attrs --iree-vm-target-truncate-unsupported-floats --iree-codegen-check-ir-before-llvm-conversion=false --iree-opt-const-expr-hoisting=False -o concat_inputs.vmfb concat_inputs_rewrite.mlir

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

Here is the dump after all from the above IR:
https://drive.google.com/file/d/1rIAF7zbVZ5m4IX69_XJ__-PxJ21AzWY0/view?usp=sharing

@Max191 Max191 added the bug 🐞 Something isn't working label Oct 16, 2023
@MaheshRavishankar
Copy link
Contributor

Just fix this to https://github.com/openxla/iree/blob/ebdb098b216c3e59a9977902823ede613f553f71/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.cpp#L66

for (size_t dim = offsets.size(); dim > 1; dim--) { .. }

and try. That should fix it.

@Max191
Copy link
Contributor Author

Max191 commented Oct 16, 2023

Just fix this to

https://github.com/openxla/iree/blob/ebdb098b216c3e59a9977902823ede613f553f71/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.cpp#L66

for (size_t dim = offsets.size(); dim > 1; dim--) { .. }

and try. That should fix it.

This gives us flow.tensor.update, but there is still a transpose that is materializing and causing slowdown. Since we introduced a transpose to move the dynamic dim to the outside, the batch dimension is now inner, and the transpose that moves the batch dim to the outermost dim gets its own dispatch.

This is essentially what we want to be able to fuse now:
https://gist.github.com/Max191/908486a43bd86c83d865d7d25face75f

So either we allow these to be in the same dispatch or we just add tensor.concat and flow.tensor.concat and avoid introducing the transpose altogether.

@Max191
Copy link
Contributor Author

Max191 commented Oct 19, 2023

If we generalize the batch matmul that consumes the tensor.update, we no longer have to materialize the result of the transpose, but now I am seeing a new dispatch that materializes the output in the full llama2 model:

      func.func @second_vicuna_forward_dispatch_999_generic_32x1xDx128_f32(%arg0: index, %arg1: !flow.dispatch.tensor<readonly:tensor<1x?x32x128xf32>>, %arg2: index, %arg3: index, %arg4: !flow.dispatch.tensor<writeonly:tensor<1x?x32x128xf32>>) {
        %c1_i64 = arith.constant 1 : i64
        %0 = flow.dispatch.workload.ordinal %arg2, 1 : index
        %1 = flow.dispatch.workload.ordinal %arg3, 2 : index
        %2 = flow.dispatch.tie_shape %arg1 : !flow.dispatch.tensor<readonly:tensor<1x?x32x128xf32>>{%0}
        %3 = flow.dispatch.tie_shape %arg4 : !flow.dispatch.tensor<writeonly:tensor<1x?x32x128xf32>>{%1}
        %4 = flow.dispatch.workload.ordinal %arg0, 0 : index
        %5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0], sizes = [1, %0, 32, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x?x32x128xf32>>{%0} -> tensor<1x?x32x128xf32>
        %6 = arith.index_cast %4 : index to i64
        %7 = arith.addi %6, %c1_i64 : i64
        %8 = arith.index_cast %7 : i64 to index
        %9 = tensor.empty(%8) : tensor<1x?x32x128xf32>
        %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%5 : tensor<1x?x32x128xf32>) outs(%9 : tensor<1x?x32x128xf32>) {
        ^bb0(%in: f32, %out: f32):
          linalg.yield %in : f32
        } -> tensor<1x?x32x128xf32>
        flow.dispatch.tensor.store %10, %3, offsets = [0, 0, 0, 0], sizes = [1, %1, 32, 128], strides = [1, 1, 1, 1] : tensor<1x?x32x128xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x?x32x128xf32>>{%1}
        return
      }

Is there a specific reason this wouldn't be folded away? To me, it just looks like it is a copy with some weird indexing maps (In this case %arg0=D-1, %arg2=D, %arg3=D). And the result is stored directly as an output, so I'd think this dispatch can just disappear.

On another note, assuming we can fold away the above dispatch and we don't materialize the concat result. generalizing the batch matmul causes about a 10x slowdown on the batch matmul itself based on the profile. The materialization of the concat result is a bigger hit to performance for now, but this batch matmul will ultimately be slowing down the model quite a bit too, even if the overall performance is better when we generalize it.

@Max191
Copy link
Contributor Author

Max191 commented Oct 19, 2023

I've reduced the problem into a sequence of 2 transpose operations that are inverses of each other:

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>
module {
  func.func @double_transpose(%arg0: tensor<1x?x32x128xf32>) -> tensor<1x?x32x128xf32> {
    %c1 = arith.constant 1 : index
    %dim = tensor.dim %arg0, %c1 : tensor<1x?x32x128xf32>
    %0 = tensor.empty(%dim) : tensor<1x32x?x128xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x?x32x128xf32>) outs(%0 : tensor<1x32x?x128xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x32x?x128xf32>
    %2 = tensor.empty(%dim) : tensor<1x?x32x128xf32>
    %3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1 : tensor<1x32x?x128xf32>) outs(%2 : tensor<1x?x32x128xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x?x32x128xf32>
    return %3 : tensor<1x?x32x128xf32>
  }
}

I figure we want to simply be folding this away, but instead FusionOfTensorOps turns it into:

// -----// IR Dump After CSE (cse) //----- //
func.func @double_transpose(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
  %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
  %1 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<1x?x32x128xf32>{%0}
  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<1x?x32x128xf32> into tensor<?x32x128xf32>
  %2 = tensor.empty(%0) : tensor<32x?x128xf32>
  %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed : tensor<?x32x128xf32>) outs(%2 : tensor<32x?x128xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<32x?x128xf32>
  %4 = tensor.empty(%0) : tensor<?x32x128xf32>
  %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3 : tensor<32x?x128xf32>) outs(%4 : tensor<?x32x128xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<?x32x128xf32>
  %expanded = tensor.expand_shape %5 [[0, 1], [2], [3]] : tensor<?x32x128xf32> into tensor<1x?x32x128xf32>
  %6 = hal.tensor.export %expanded "output 0" : tensor<1x?x32x128xf32>{%0} -> !hal.buffer_view
  return %6 : !hal.buffer_view
}

// -----// IR Dump After FusionOfTensorOps (iree-flow-fusion-of-tensor-ops) //----- //
func.func @double_transpose(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
  %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
  %1 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<1x?x32x128xf32>{%0}
  %2 = tensor.empty(%0) : tensor<1x?x32x128xf32>
  %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1 : tensor<1x?x32x128xf32>) outs(%2 : tensor<1x?x32x128xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<1x?x32x128xf32>
  %4 = hal.tensor.export %3 "output 0" : tensor<1x?x32x128xf32>{%0} -> !hal.buffer_view
  return %4 : !hal.buffer_view
}

Here is the dump compiled to flow for reference:
https://drive.google.com/file/d/1XAeZ8yd79NzkKTajqAVzq_arlWRaSDCZ/view?usp=sharing

@MaheshRavishankar
Copy link
Contributor

huh this should just be folded away

@MaheshRavishankar
Copy link
Contributor

https://github.com/llvm/llvm-project/blob/fb5047f5244d81aa89f68210a9cd34ddddcc8af4/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp#L1081 should be the pattern that should fold this away. Might be a bug here as to why it is missing this.

@Max191
Copy link
Contributor Author

Max191 commented Oct 19, 2023

After fixing the above bug, the generalized batch matmul is turning out to be quite expensive. I have uploaded IR dumps and objdumps of a simple example of the transposed batch matmul in question here:
https://drive.google.com/drive/folders/1odSynjyt3kiXC3Dc7CB-ufDNBfBsKMs4?usp=sharing

There are 2 versions. One with the additional experimental optimizations related to quantized vecmats enabled, and one without. It is possible that these optimizations are impacting performance, so I uploaded both.

Here is the input IR as well:

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
module {
  func.func @transposed_batch_matmul(%arg0: tensor<32x1x?xf32>, %arg1: tensor<?x32x128xf32>) -> tensor<32x1x128xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<32x1x128xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32x1x128xf32>) -> tensor<32x1x128xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<32x1x?xf32>, tensor<?x32x128xf32>) outs(%1 : tensor<32x1x128xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %3 = arith.mulf %in, %in_0 : f32
      %4 = arith.addf %out, %3 : f32
      linalg.yield %4 : f32
    } -> tensor<32x1x128xf32>
    return %2 : tensor<32x1x128xf32>
  }
}

Edit: compile commands

default:

iree-compile --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-llvmcpu-enable-microkernels --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-vm-bytecode-module-strip-source-map=true --iree-util-zero-fill-elided-attrs --iree-vm-target-truncate-unsupported-floats --iree-codegen-check-ir-before-llvm-conversion=false --iree-opt-const-expr-hoisting=False -o transposed_batch_matmul.vmfb --mlir-print-ir-after-all --mlir-disable-threading --debug-only=iree-llvmcpu-vector-lowering --iree-llvmcpu-keep-linker-artifacts=false --iree-llvmcpu-link-embedded=false transposed_batch_matmul.mlir

quantized matmul changes:

iree-compile --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-llvmcpu-enable-microkernels --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-vm-bytecode-module-strip-source-map=true --iree-util-zero-fill-elided-attrs --iree-vm-target-truncate-unsupported-floats --iree-codegen-check-ir-before-llvm-conversion=false --iree-opt-const-expr-hoisting=False --iree-llvmcpu-enable-quantized-matmul-reassociation --iree-flow-enable-quantized-matmul-reassociation -o transposed_batch_matmul.vmfb --mlir-print-ir-after-all --mlir-disable-threading --debug-only=iree-llvmcpu-vector-lowering --iree-llvmcpu-keep-linker-artifacts=false --iree-llvmcpu-link-embedded=false transposed_batch_matmul.mlir

@hanhanW hanhanW self-assigned this Oct 24, 2023
@Max191
Copy link
Contributor Author

Max191 commented Oct 25, 2023

I found why this is slow. The quantized mamtul reassociation changes enables split reduction on ops with 2 inputs, so the generalized batch matmul with a dynamic reduction dim was going through the LLVMCPUSplitReduction pass. The problem is that the pass tiles everything by to 1 before doing the splitReduction, and then splitReduction fails for due to the dynamic reduction dim, leaving the generic with all the parallel dimensions tiled to 1. For now, I just added a check ensuring static reduction dimensions in the LLVMCPUSplitReduction pass, and the batch matmuls have more expected performance.

The overall latency of the model is marginally better now, shaving off ~5-10ms for longer context lengths (original latency of 90+ ms). We could probably get a bit better performance if we add splitReduction functionality for dynamic dims too and let both generalized batch matmuls go through splitReduction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants