Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU][DT] Improve vector lowering for mmt4d #15621

Merged
merged 2 commits into from
Nov 30, 2023

Conversation

dcaballe
Copy link
Contributor

This PR removes outdated vector lowering patterns for mmt4d and reuses the ones we have for the generic cases. This is needed to improve mmt4d code generation of vecmat/matvec cases.

@dcaballe dcaballe requested a review from bjacob November 16, 2023 20:07
@hanhanW hanhanW added benchmarks:x86_64 Run default x86_64 benchmarks benchmarks:android-cpu Run default Android CPU benchmarks and removed benchmarks:x86_64 Run default x86_64 benchmarks benchmarks:android-cpu Run default Android CPU benchmarks labels Nov 16, 2023
@hanhanW
Copy link
Contributor

hanhanW commented Nov 16, 2023

I just realized that benchmark results are not meaningful because they are all with ukernels. Removing the tags.. (I will review the PR later)

Copy link

github-actions bot commented Nov 16, 2023

@hanhanW
Copy link
Contributor

hanhanW commented Nov 16, 2023

I was wrong, the GEMM codegen is using the expert. So it does impact pixel performance. The regressions are fine after we flip data-tiling on. A follow-up can be moving non mmt4d ops to not use this expert, e.g., ARM GEMM codegen with SVE features.

https://github.com/openxla/iree/blob/acdfeae8973f1c7eda6fd1b0a1e8fe464246e850/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp#L970-L1003

@dcaballe
Copy link
Contributor Author

I didn't expect any impact either... Let me take a look at the codegen ones. We need this to enable DT by default. It's really weird that we are seeing 50% improvements for experimental flags... This should have no impact there...

@hanhanW
Copy link
Contributor

hanhanW commented Nov 16, 2023

It's really weird that we are seeing 50% improvements for experimental flags... This should have no impact there...

They could be noise.. I suggest to look at models that have regressed total dispatch sizes.

@pzread
Copy link
Contributor

pzread commented Nov 16, 2023

There is a big spike on MobileBertSquad_fp32 https://perf.iree.dev/serie?IREE?04f958179d9bc04eca09f2ad518a3cb494931445f23fcc2791b2d9fcee5cf1bc ... so I think it is the noise.

It seems that something happened on mobile phones for the baseline run. I triggered the benchmark again to pick up the new baseline.

@pzread pzread added the benchmarks:android-cpu Run default Android CPU benchmarks label Nov 16, 2023
@pzread
Copy link
Contributor

pzread commented Nov 29, 2023

One of the big difference I found is that with this change we are creating vector.fma and many memref.load for matmul in MobileBertSquart_fp32 (aarch64):

 %13 = vector.load %subview_1[%c0, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %14 = vector.load %subview_1[%c1, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %15 = vector.load %subview_1[%c2, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %16 = vector.load %subview_1[%c3, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %17 = vector.load %subview_1[%c4, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %18 = vector.load %subview_1[%c5, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %19 = vector.load %subview_1[%c6, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %20 = vector.load %subview_1[%c7, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %21 = vector.load %subview_1[%c8, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %22 = vector.load %subview_1[%c9, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %23 = vector.load %subview_1[%c10, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %24 = vector.load %subview_1[%c11, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %25 = vector.load %subview_1[%c12, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %26 = vector.load %subview_1[%c13, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %27 = vector.load %subview_1[%c14, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %28 = vector.load %subview_1[%c15, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %29 = vector.load %subview_1[%c16, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %30 = vector.load %subview_1[%c17, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %31 = vector.load %subview_1[%c18, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %32 = vector.load %subview_1[%c19, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %33 = vector.load %subview_1[%c20, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %34 = vector.load %subview_1[%c21, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %35 = vector.load %subview_1[%c22, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %36 = vector.load %subview_1[%c23, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %37 = vector.load %subview_1[%c24, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %38 = vector.load %subview_1[%c25, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %39 = vector.load %subview_1[%c26, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %40 = vector.load %subview_1[%c27, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %41 = vector.load %subview_1[%c28, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %42 = vector.load %subview_1[%c29, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %43 = vector.load %subview_1[%c30, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %44 = vector.load %subview_1[%c31, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
          %45 = memref.load %subview[%arg2, %c0] : memref<64x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
          %46 = vector.broadcast %45 : f32 to vector<4xf32>
          %47 = vector.fma %46, %13, %cst : vector<4xf32>
          %48 = affine.apply affine_map<()[s0] -> (s0 + 1)>()[%arg2]
          %49 = memref.load %subview[%48, %c0] : memref<64x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
          %50 = vector.broadcast %49 : f32 to vector<4xf32>
          %51 = vector.fma %50, %13, %cst : vector<4xf32>
          %52 = affine.apply affine_map<()[s0] -> (s0 + 2)>()[%arg2]
          %53 = memref.load %subview[%52, %c0] : memref<64x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
          %54 = vector.broadcast %53 : f32 to vector<4xf32>
          %55 = vector.fma %54, %13, %cst : vector<4xf32>
          %56 = affine.apply affine_map<()[s0] -> (s0 + 3)>()[%arg2]
          %57 = memref.load %subview[%56, %c0] : memref<64x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
          %58 = vector.broadcast %57 : f32 to vector<4xf32>
          %59 = vector.fma %58, %13, %cst : vector<4xf32>

Before we were loading 4x4 vector at once and using vector.extract to get 4xf32 vector for vector.outerproduct

 %13 = vector.transfer_read %subview[%arg2, %c0], %cst_0 {in_bounds = [true, true]} : memref<64x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16x32xf32>
          %14 = vector.transfer_read %subview_2[%c0, %arg3], %cst_0 {in_bounds = [true, true]} : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<32x4xf32>
          %15 = vector.extract_strided_slice %13 {offsets = [0, 0], sizes = [4, 4], strides = [1, 1]} : vector<16x32xf32> to vector<4x4xf32>
          %16 = vector.extract_strided_slice %cst {offsets = [0, 0], sizes = [4, 4], strides = [1, 1]} : vector<16x4xf32> to vector<4x4xf32>
          %17 = vector.transpose %15, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
          %18 = vector.extract %17[0] : vector<4xf32> from vector<4x4xf32>
          %19 = vector.extract %14[0] : vector<4xf32> from vector<32x4xf32>
          %20 = vector.outerproduct %18, %19, %16 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
          %21 = vector.extract %17[1] : vector<4xf32> from vector<4x4xf32>
          %22 = vector.extract %14[1] : vector<4xf32> from vector<32x4xf32>
          %23 = vector.outerproduct %21, %22, %20 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
          %24 = vector.extract %17[2] : vector<4xf32> from vector<4x4xf32>
          %25 = vector.extract %14[2] : vector<4xf32> from vector<32x4xf32>
          %26 = vector.outerproduct %24, %25, %23 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
          %27 = vector.extract %17[3] : vector<4xf32> from vector<4x4xf32>
          %28 = vector.extract %14[3] : vector<4xf32> from vector<32x4xf32>
          %29 = vector.outerproduct %27, %28, %26 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>

@dcaballe
Copy link
Contributor Author

Interesting findings! Thank you so much! Hmm... the code is kind of the expected, though:

  • The fma vs outerproduct is just due to this flag. It just enables the lowering of outer product into fma but that shouldn't have a performance impact.
  • There is also the unrolling of the second dimension that is happening at the new vector lowering stage (also expected).
  • The memref loads (scalar) are also expected, if you could see the IR after lowering outer product in the second case and apply the same unrolling, you would see that there are some vector loads whose elements are extracted one-by-one and then broadcasted, which it should be better.

I'm more concerned about the 32 (!) vector loads. I think it's a bit too much to unroll M by 32 for a f32 mmt4d. My guess is that this aggressive unrolling + the changes above are blowing up performance in general. Could we try to use less aggressive tile sizes? For Neon and f32... let's say (MxNxK) 8x8x1 or 4x16x1? You may have to do something similar to #15421

This PR removes outdated vector lowering patterns for mmt4d and reuses
the ones we have for the generic cases. This is needed to improve mmt4d
code generation of vecmat/matvec cases.
@pzread pzread force-pushed the mmt4d-vector-lowering branch from 6bad84b to 9581296 Compare November 30, 2023 20:34
@pzread pzread added the benchmarks:x86_64 Run default x86_64 benchmarks label Nov 30, 2023
@pzread
Copy link
Contributor

pzread commented Nov 30, 2023

So the new benchmark results confirmed that the regression is due to non-optimal matmul reduction tile sizes for legacy non-DT ARM64 tile size selection: https://github.com/openxla/iree/blob/7fdc31a3c41365f0baf4ce5da37a9c1e03ad5e3f/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp#L954

This path I believe is not the priority now as we just enabled the DT by default. This is why the regression disappeared after the rebase. So I think we can merge this. Maybe create a bug to clean up the legacy ARM64 tile size selection?

@hanhanW
Copy link
Contributor

hanhanW commented Nov 30, 2023

So the new benchmark results confirmed that the regression is due to non-optimal matmul reduction tile sizes for legacy non-DT ARM64 tile size selection:

https://github.com/openxla/iree/blob/7fdc31a3c41365f0baf4ce5da37a9c1e03ad5e3f/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp#L954

This path I believe is not the priority now as we just enabled the DT by default. This is why the regression disappeared after the rebase. So I think we can merge this. Maybe create a bug to clean up the legacy ARM64 tile size selection?

I think we can merge it. It can help us getting out of suboptimal codegen issue.

@dcaballe
Copy link
Contributor Author

Thanks a lot for the investigation! Great! Let's do that then! Let me open an issue for that

@dcaballe dcaballe merged commit c85865d into iree-org:main Nov 30, 2023
59 checks passed
@dcaballe dcaballe deleted the mmt4d-vector-lowering branch November 30, 2023 22:24
hanhanW added a commit to hanhanW/iree that referenced this pull request Dec 8, 2023
hanhanW added a commit to hanhanW/iree that referenced this pull request Dec 8, 2023
ramiro050 pushed a commit to ramiro050/iree that referenced this pull request Dec 19, 2023
This PR removes outdated vector lowering patterns for mmt4d and reuses
the ones we have for the generic cases. This is needed to improve mmt4d
code generation of vecmat/matvec cases.

Co-authored-by: Jerry Wu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmarks:android-cpu Run default Android CPU benchmarks benchmarks:x86_64 Run default x86_64 benchmarks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants