-
Notifications
You must be signed in to change notification settings - Fork 653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU][DT] Improve vector lowering for mmt4d #15621
Conversation
I just realized that benchmark results are not meaningful because they are all with ukernels. Removing the tags.. (I will review the PR later) |
I was wrong, the GEMM codegen is using the expert. So it does impact pixel performance. The regressions are fine after we flip data-tiling on. A follow-up can be moving non mmt4d ops to not use this expert, e.g., ARM GEMM codegen with SVE features. |
I didn't expect any impact either... Let me take a look at the codegen ones. We need this to enable DT by default. It's really weird that we are seeing 50% improvements for experimental flags... This should have no impact there... |
They could be noise.. I suggest to look at models that have regressed total dispatch sizes. |
There is a big spike on MobileBertSquad_fp32 https://perf.iree.dev/serie?IREE?04f958179d9bc04eca09f2ad518a3cb494931445f23fcc2791b2d9fcee5cf1bc ... so I think it is the noise. It seems that something happened on mobile phones for the baseline run. I triggered the benchmark again to pick up the new baseline. |
One of the big difference I found is that with this change we are creating %13 = vector.load %subview_1[%c0, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%14 = vector.load %subview_1[%c1, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%15 = vector.load %subview_1[%c2, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%16 = vector.load %subview_1[%c3, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%17 = vector.load %subview_1[%c4, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%18 = vector.load %subview_1[%c5, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%19 = vector.load %subview_1[%c6, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%20 = vector.load %subview_1[%c7, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%21 = vector.load %subview_1[%c8, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%22 = vector.load %subview_1[%c9, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%23 = vector.load %subview_1[%c10, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%24 = vector.load %subview_1[%c11, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%25 = vector.load %subview_1[%c12, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%26 = vector.load %subview_1[%c13, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%27 = vector.load %subview_1[%c14, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%28 = vector.load %subview_1[%c15, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%29 = vector.load %subview_1[%c16, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%30 = vector.load %subview_1[%c17, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%31 = vector.load %subview_1[%c18, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%32 = vector.load %subview_1[%c19, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%33 = vector.load %subview_1[%c20, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%34 = vector.load %subview_1[%c21, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%35 = vector.load %subview_1[%c22, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%36 = vector.load %subview_1[%c23, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%37 = vector.load %subview_1[%c24, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%38 = vector.load %subview_1[%c25, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%39 = vector.load %subview_1[%c26, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%40 = vector.load %subview_1[%c27, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%41 = vector.load %subview_1[%c28, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%42 = vector.load %subview_1[%c29, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%43 = vector.load %subview_1[%c30, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%44 = vector.load %subview_1[%c31, %arg3] : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
%45 = memref.load %subview[%arg2, %c0] : memref<64x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%46 = vector.broadcast %45 : f32 to vector<4xf32>
%47 = vector.fma %46, %13, %cst : vector<4xf32>
%48 = affine.apply affine_map<()[s0] -> (s0 + 1)>()[%arg2]
%49 = memref.load %subview[%48, %c0] : memref<64x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%50 = vector.broadcast %49 : f32 to vector<4xf32>
%51 = vector.fma %50, %13, %cst : vector<4xf32>
%52 = affine.apply affine_map<()[s0] -> (s0 + 2)>()[%arg2]
%53 = memref.load %subview[%52, %c0] : memref<64x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%54 = vector.broadcast %53 : f32 to vector<4xf32>
%55 = vector.fma %54, %13, %cst : vector<4xf32>
%56 = affine.apply affine_map<()[s0] -> (s0 + 3)>()[%arg2]
%57 = memref.load %subview[%56, %c0] : memref<64x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%58 = vector.broadcast %57 : f32 to vector<4xf32>
%59 = vector.fma %58, %13, %cst : vector<4xf32> Before we were loading 4x4 vector at once and using %13 = vector.transfer_read %subview[%arg2, %c0], %cst_0 {in_bounds = [true, true]} : memref<64x32xf32, strided<[32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<16x32xf32>
%14 = vector.transfer_read %subview_2[%c0, %arg3], %cst_0 {in_bounds = [true, true]} : memref<32x64xf32, strided<[384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<32x4xf32>
%15 = vector.extract_strided_slice %13 {offsets = [0, 0], sizes = [4, 4], strides = [1, 1]} : vector<16x32xf32> to vector<4x4xf32>
%16 = vector.extract_strided_slice %cst {offsets = [0, 0], sizes = [4, 4], strides = [1, 1]} : vector<16x4xf32> to vector<4x4xf32>
%17 = vector.transpose %15, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
%18 = vector.extract %17[0] : vector<4xf32> from vector<4x4xf32>
%19 = vector.extract %14[0] : vector<4xf32> from vector<32x4xf32>
%20 = vector.outerproduct %18, %19, %16 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%21 = vector.extract %17[1] : vector<4xf32> from vector<4x4xf32>
%22 = vector.extract %14[1] : vector<4xf32> from vector<32x4xf32>
%23 = vector.outerproduct %21, %22, %20 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%24 = vector.extract %17[2] : vector<4xf32> from vector<4x4xf32>
%25 = vector.extract %14[2] : vector<4xf32> from vector<32x4xf32>
%26 = vector.outerproduct %24, %25, %23 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32>
%27 = vector.extract %17[3] : vector<4xf32> from vector<4x4xf32>
%28 = vector.extract %14[3] : vector<4xf32> from vector<32x4xf32>
%29 = vector.outerproduct %27, %28, %26 {kind = #vector.kind<add>} : vector<4xf32>, vector<4xf32> |
Interesting findings! Thank you so much! Hmm... the code is kind of the expected, though:
I'm more concerned about the 32 (!) vector loads. I think it's a bit too much to unroll M by 32 for a f32 mmt4d. My guess is that this aggressive unrolling + the changes above are blowing up performance in general. Could we try to use less aggressive tile sizes? For Neon and f32... let's say (MxNxK) 8x8x1 or 4x16x1? You may have to do something similar to #15421 |
This PR removes outdated vector lowering patterns for mmt4d and reuses the ones we have for the generic cases. This is needed to improve mmt4d code generation of vecmat/matvec cases.
6bad84b
to
9581296
Compare
So the new benchmark results confirmed that the regression is due to non-optimal matmul reduction tile sizes for legacy non-DT ARM64 tile size selection: https://github.com/openxla/iree/blob/7fdc31a3c41365f0baf4ce5da37a9c1e03ad5e3f/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp#L954 This path I believe is not the priority now as we just enabled the DT by default. This is why the regression disappeared after the rebase. So I think we can merge this. Maybe create a bug to clean up the legacy ARM64 tile size selection? |
I think we can merge it. It can help us getting out of suboptimal codegen issue. |
Thanks a lot for the investigation! Great! Let's do that then! Let me open an issue for that |
This reverts commit c85865d.
This reverts commit c85865d.
This PR removes outdated vector lowering patterns for mmt4d and reuses the ones we have for the generic cases. This is needed to improve mmt4d code generation of vecmat/matvec cases. Co-authored-by: Jerry Wu <[email protected]>
This PR removes outdated vector lowering patterns for mmt4d and reuses the ones we have for the generic cases. This is needed to improve mmt4d code generation of vecmat/matvec cases.