Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Performance tracking for llama2 with data tiling + ukernels #15566

Open
Max191 opened this issue Nov 13, 2023 · 6 comments
Open

[CPU] Performance tracking for llama2 with data tiling + ukernels #15566

Max191 opened this issue Nov 13, 2023 · 6 comments
Labels
codegen/llvm LLVM code generation compiler backend

Comments

@Max191
Copy link
Contributor

Max191 commented Nov 13, 2023

The work laid out in #15158 has been completed, and now we are moving forward with e2e testing of the llama2 7B model with the new changes. This issue will be for tracking performance and remaining e2e issues for full model testing.

As of right now, there are still a few changes that have yet to be landed, so the following branches are needed for IREE and LLVM while I work on landing the remaining changes:
LLVM: https://github.com/Max191/llvm-project/tree/quantized-matmul-v2-testing
IREE: https://github.com/Max191/iree/tree/quantized-matmul-v2-testing

Model file: https://storage.googleapis.com/shark_tank/dan/fp32_i4_cpu_llamas/llama2_7b_int4.mlir
iree-compile and iree-benchmark-module commands for performance testing on the llama2 7B model:

iree-compile --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-enable-ukernels=all --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-vm-bytecode-module-strip-source-map=true --iree-util-zero-fill-elided-attrs --iree-vm-target-truncate-unsupported-floats --iree-codegen-check-ir-before-llvm-conversion=false --iree-global-opt-enable-quantized-matmul-reassociation --iree-opt-data-tiling -o llama2_7b_int4_cpu.vmfb llama2_7b_int4.mlir
iree-benchmark-module --module=llama2_7b_int4_cpu.vmfb --function=second_vicuna_forward --device=local-task --input=1x1xi64 \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32  \
    --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32

The 64 1x32x1x128xf32 inputs are dynamic on d1 (1x32x?x128xf32). This dim is the context length, so we can increase that size to benchmark for larger context lengths.

The main thing we have left to achieve is to get ConstEval to kick in and fold the packing of the weights away.

@Max191
Copy link
Contributor Author

Max191 commented Nov 14, 2023

I have updated the test branches in the above comment with the latest changes

@Max191
Copy link
Contributor Author

Max191 commented Nov 16, 2023

Update on ConstEval folding:

I have updated the branches in the above comment with some new fixes. I am now seeing ConstEval folding on the transpose and packing of the model weights, and I am getting reasonable benchmarks, but it is still slower than the V1 approach. The repro instructions should be the same as described above:

  1. Checkout the branches listed in the above comment
  2. Download the model linked above
  3. Run the compile and benchmark commands above

With this, we get ConstEval folding of the pack ops on the model weights. Here is the tracy profile with a single thread on the benchmark module:
Tracy V2

From the profile, we can see that there is a lot of time spent in iree_task_dispatch_shard_execute_tile and iree_hal_cmd_dispatch_tile (I assume due to the high counts on the batch_mmt4d dispatches):

image

For comparison, this is the tracy profile with the old V1 (VectorContractCustomKernels) approach
Tracy V1

image

@bjacob
Copy link
Contributor

bjacob commented Nov 16, 2023

Been looking this morning with @Max191 . Main things so far:

  1. The distribution tile sizes are too small, explaining the dispatch overhead that Max points out above. Two specifics of the ops in this model are not being accounted for: (1) These matmuls are not general matmuls, they are narrow-M (in fact they are vecmats, M=1), so the RHS matrix data is traversed only once, so no cache-locality consideration needs to enter the picture here. (2) The RHS element type is particularly narrow: i4. Both (1) and (2) call for a much larger tile size than the one being selected here - as the data-tile size N0 is 32 here, and the default distribution tile size is 64, the current code at https://github.com/openxla/iree/blob/ac9548469e8e46f8ad6ea29c1c7144a774ebba19/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp#L1239 is computing 64/32 = 2 and so is choosing 2 as the distribution tile size. Should be much larger.
  2. Looking at the generated assembly for one of these batch_mmt4d, https://gist.github.com/bjacob/af9ed6ed9c704823622d0ab2c85c0438, it really is now picking up the optimized ukernel tile function, but it is unrolling it too much. As the ukernel code got fully inlined down to the architecture-optimized tile function, it was able to see all the compile-time-constant loop bounds all the way down to the inner loop, and unrolled the inner loop as a result. The general solution here would be better tuning of loop analyses that we register around here: https://github.com/openxla/iree/blob/ac9548469e8e46f8ad6ea29c1c7144a774ebba19/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMIRPasses.cpp#L68 . If that's not trivial, then we can maybe put some clang pragma in the ukernel sources to annotate some loops as not-unrollable, or we can just put __attribute__((noinline)) on the ukernel tile function.
  3. The profile shows one of the most prominent ops as a matmul_transpose_b that didn't get data-tiled. We need to add a pass in globalOptimization, conditioned on data-tiling, that rewrites linalg.matmul_transpose_b into a transpose of the RHS feeding into a linalg.matmul.

@MaheshRavishankar
Copy link
Contributor

  1. Looking at the generated assembly for one of these batch_mmt4d, https://gist.github.com/bjacob/af9ed6ed9c704823622d0ab2c85c0438, it really is now picking up the optimized ukernel tile function, but it is unrolling it too much. As the ukernel code got fully inlined down to the architecture-optimized tile function, it was able to see all the compile-time-constant loop bounds all the way down to the inner loop, and unrolled the inner loop as a result. The general solution here would be better tuning of loop analyses that we register around here: https://github.com/openxla/iree/blob/ac9548469e8e46f8ad6ea29c1c7144a774ebba19/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMIRPasses.cpp#L68

Fix for this would be to look at dispatches that have ukernels and turn off loop unrolling for LLVM compilation?

. If that's not trivial, then we can maybe put some clang pragma in the ukernel sources to annotate some loops as not-unrollable, or we can just put `__attribute__((noinline))` on the ukernel tile function.
  1. The profile shows one of the most prominent ops as a matmul_transpose_b that didn't get data-tiled. We need to add a pass in globalOptimization, conditioned on data-tiling, that rewrites linalg.matmul_transpose_b into a transpose of the RHS feeding into a linalg.matmul.

I am not sure turning off that folding based on data-tiling is a good solution. It is valid for the input program to come with matmul_transpose_b (that would be the more canonical representation here). So we need to handle this on the data-tiling path. Basically this should be handled correctly in the setencoding/materialize encoding flow

@hanhanW hanhanW added the codegen/llvm LLVM code generation compiler backend label Nov 16, 2023
@hanhanW
Copy link
Contributor

hanhanW commented Nov 16, 2023

Been looking this morning with @Max191 . Main things so far:

  1. The distribution tile sizes are too small, explaining the dispatch overhead that Max points out above. Two specifics of the ops in this model are not being accounted for: (1) These matmuls are not general matmuls, they are narrow-M (in fact they are vecmats, M=1), so the RHS matrix data is traversed only once, so no cache-locality consideration needs to enter the picture here. (2) The RHS element type is particularly narrow: i4. Both (1) and (2) call for a much larger tile size than the one being selected here - as the data-tile size N0 is 32 here, and the default distribution tile size is 64, the current code at https://github.com/openxla/iree/blob/ac9548469e8e46f8ad6ea29c1c7144a774ebba19/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp#L1239
    is computing 64/32 = 2 and so is choosing 2 as the distribution tile size. Should be much larger.

I think they are mostly batch_mmt4d kernel, so we need to look at below snippet. Looking at the profiler, 85 ns for a launch is definitely too small. We should have larger distribution tile sizes. For batch_mmt4d op, the root cause is that we forces batch_dim being 1. We are able to relax it after landing #15531. I can help on landing it.

https://github.com/openxla/iree/blob/e799ae93e6ee92eb01245d8575eea4690f0c4735/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp#L1296-L1302

  1. Looking at the generated assembly for one of these batch_mmt4d, https://gist.github.com/bjacob/af9ed6ed9c704823622d0ab2c85c0438, it really is now picking up the optimized ukernel tile function, but it is unrolling it too much. As the ukernel code got fully inlined down to the architecture-optimized tile function, it was able to see all the compile-time-constant loop bounds all the way down to the inner loop, and unrolled the inner loop as a result. The general solution here would be better tuning of loop analyses that we register around here: https://github.com/openxla/iree/blob/ac9548469e8e46f8ad6ea29c1c7144a774ebba19/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMIRPasses.cpp#L68
    . If that's not trivial, then we can maybe put some clang pragma in the ukernel sources to annotate some loops as not-unrollable, or we can just put __attribute__((noinline)) on the ukernel tile function.

The loop unrolling is enabled by default, we can disable it here:

https://github.com/openxla/iree/blob/e799ae93e6ee92eb01245d8575eea4690f0c4735/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMTargetOptions.h#L41

  1. The profile shows one of the most prominent ops as a matmul_transpose_b that didn't get data-tiled. We need to add a pass in globalOptimization, conditioned on data-tiling, that rewrites linalg.matmul_transpose_b into a transpose of the RHS feeding into a linalg.matmul.

I think we should just make SetEncoding takes ConstractionOpInterface as input argument, check if the rank is all 2, and look at affine_maps to select inner tile sizes. We can then generalize it to non 2 cases, which should cover matvec/vecmat as well.

@hanhanW
Copy link
Contributor

hanhanW commented Nov 16, 2023

There are dispatch_[0-9]* kernels, which are pack/unpack ops. We should be able to know this high-level information after we fix #15027

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
codegen/llvm LLVM code generation compiler backend
Projects
None yet
Development

No branches or pull requests

4 participants