From 4636e3885743c21358123a7a1c6e0fbd13c40707 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Thu, 6 Feb 2025 09:53:28 -0800 Subject: [PATCH] [LLVMGPU] Add fixes and tests for horizontally fused gemms through GPU pipeline. Signed-off-by: MaheshRavishankar --- .../Codegen/Common/VectorLayoutAnalysis.cpp | 10 +- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 1 + .../compiler/Codegen/LLVMGPU/test/BUILD.bazel | 1 + .../Codegen/LLVMGPU/test/CMakeLists.txt | 1 + .../test/horizontal_fusion_pipeline.mlir | 239 ++++++++++++++++++ 5 files changed, 249 insertions(+), 3 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/test/horizontal_fusion_pipeline.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp index a15d397c8d2c6..93d453f7d8d3e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp @@ -203,14 +203,18 @@ ChangeResult DistributionLayout::resolveWithPossibleConflict( IRRewriter builder(opOperand.getOwner()); // Handle case where constantOp may have multiple consumers with different // layouts by creating a copy of constOp for other users. - if (!opOperand.get().hasOneUse() && !vectorLayout && + if (!opOperand.get().hasOneUse() && llvm::isa_and_nonnull( opOperand.get().getDefiningOp())) { builder.setInsertionPoint(opOperand.get().getDefiningOp()); Operation *copiedConstOp = builder.clone(*opOperand.get().getDefiningOp()); Value copiedConst = copiedConstOp->getResult(0); - builder.replaceAllUsesExcept(opOperand.get(), copiedConst, - opOperand.getOwner()); + DistributionLayout *newConstLayout = + propagation->getLatticeElement(copiedConst); + newConstLayout->subscribeEnforcement(enforcement); + (void)newConstLayout->resolve(rhs); + opOperand.set(copiedConst); + return ChangeResult::NoChange; } ResolutionResult result = doResolution(rhs); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 545a29fd4e45d..27256231c1f86 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -838,6 +838,7 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, // Set anchors at tensor level for vector distribution later and hoist out // loop invariant anchors. + funcPassManager.addPass(createDecomposeHorizontallyFusedGemmsPass()); funcPassManager.addPass(createLLVMGPUConfigureTensorLayoutsPass()); funcPassManager.addPass(createIREELoopInvariantCodeMotionPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel index 909c4ce15c18b..9c31db4f9c43b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel @@ -36,6 +36,7 @@ iree_lit_test_suite( "extract_address_computation_gpu.mlir", "gpu_set_num_workgroups.mlir", "gpu_pipeline_generalize_named_ops.mlir", + "horizontal_fusion_pipeline.mlir", "link_executables.mlir", "nvvm_extract_address_computation.mlir", "nvvm_pipeline_test.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt index 284c966535b55..72ae96f6b3217 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt @@ -33,6 +33,7 @@ iree_lit_test_suite( "extract_address_computation_gpu.mlir" "gpu_pipeline_generalize_named_ops.mlir" "gpu_set_num_workgroups.mlir" + "horizontal_fusion_pipeline.mlir" "illegal_configuration.mlir" "legalize.mlir" "linalg_transform.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/horizontal_fusion_pipeline.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/horizontal_fusion_pipeline.mlir new file mode 100644 index 0000000000000..f7992de302f3d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/horizontal_fusion_pipeline.mlir @@ -0,0 +1,239 @@ +// RUN: iree-opt --iree-gpu-test-target=gfx942 --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target))" %s --split-input-file | FileCheck %s + +func.func @fused_contraction_1(%arg0: tensor<2x4096x640xf16>, + %arg1 : tensor<10x64x640xf16>, %arg2 : tensor<10x64x640xf16>, + %arg3 : tensor<10x64x640xf16>) + -> (tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>) { + %11 = tensor.empty() : tensor<2x10x4096x64xf16> + %12 = tensor.empty() : tensor<2x10x4096x64xf32> + %cst = arith.constant 0.0: f32 + %13 = linalg.fill ins(%cst : f32) + outs(%12 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32> + %14:3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1, %arg2, %arg3 + : tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>, + tensor<10x64x640xf16>) + outs(%13, %13, %13 + : tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) { + ^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32): + %18 = arith.extf %in : f16 to f32 + %19 = arith.extf %in_0 : f16 to f32 + %20 = arith.mulf %18, %19 : f32 + %21 = arith.addf %out, %20 : f32 + %22 = arith.extf %in_1 : f16 to f32 + %23 = arith.mulf %18, %22 : f32 + %24 = arith.addf %out_3, %23 : f32 + %25 = arith.extf %in_2 : f16 to f32 + %26 = arith.mulf %18, %25 : f32 + %27 = arith.addf %out_4, %26 : f32 + linalg.yield %21, %24, %27 : f32, f32, f32 + } -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) + %15 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#0 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + %16 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#1 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + %17 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#2 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + return %15, %16, %17 + : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16> +} +// CHECK-LABEL: func @fused_contraction_1 +// CHECK-COUNT-24: amdgpu.mfma + +// ----- + +func.func @fused_contraction_2(%arg0: tensor<4096x640xf32>, + %arg1 : tensor<640x640xf32>, %arg2 : tensor<640x640xf32>, + %arg3 : tensor<640x640xf32>) + -> (tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) { + %11 = tensor.empty() : tensor<4096x640xf32> + %12 = tensor.empty() : tensor<4096x640xf32> + %cst = arith.constant 0.0: f32 + %13 = linalg.fill ins(%cst : f32) + outs(%12 : tensor<4096x640xf32>) -> tensor<4096x640xf32> + %14:3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0, %arg1, %arg2, %arg3 + : tensor<4096x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>, + tensor<640x640xf32>) + outs(%13, %13, %13 + : tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %in_2: f32, %out: f32, %out_3: f32, %out_4: f32): + %20 = arith.mulf %in, %in_0 : f32 + %21 = arith.addf %out, %20 : f32 + %23 = arith.mulf %in, %in_1 : f32 + %24 = arith.addf %out_3, %23 : f32 + %26 = arith.mulf %in, %in_2 : f32 + %27 = arith.addf %out_4, %26 : f32 + linalg.yield %21, %24, %27 : f32, f32, f32 + } -> (tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) + return %14#0, %14#1, %14#2 + : tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32> +} +// CHECK-LABEL: func @fused_contraction_2 +// CHECK-COUNT-24: amdgpu.mfma + +// ----- + +func.func @fused_contraction_3(%arg0 : tensor<2x4096x640xi8>, + %arg1 : tensor<2x640x640xi8>, %arg2 : tensor<2x640x640xi8>) + -> (tensor<2x4096x640xf16>, tensor<2x4096x640xf16>) { + %c0_i32 = arith.constant 0 : i32 + %18 = tensor.empty() : tensor<2x4096x640xf16> + %19 = tensor.empty() : tensor<2x4096x640xi32> + %20 = linalg.fill ins(%c0_i32 : i32) + outs(%19 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32> + %21:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1, %arg2 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>, tensor<2x640x640xi8>) + outs(%20, %20 : tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) { + ^bb0(%in: i8, %in_0: i8, %in_1: i8, %out: i32, %out_2: i32): + %24 = arith.extsi %in : i8 to i32 + %25 = arith.extsi %in_0 : i8 to i32 + %26 = arith.muli %24, %25 : i32 + %27 = arith.addi %out, %26 : i32 + %28 = arith.extsi %in_1 : i8 to i32 + %29 = arith.muli %24, %28 : i32 + %30 = arith.addi %out_2, %29 : i32 + linalg.yield %27, %30 : i32, i32 + } -> (tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) + %22 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%21#0 : tensor<2x4096x640xi32>) outs(%18 : tensor<2x4096x640xf16>) { + ^bb0(%in: i32, %out: f16): + %27 = arith.sitofp %in : i32 to f32 + %29 = arith.truncf %27 : f32 to f16 + linalg.yield %29 : f16 + } -> tensor<2x4096x640xf16> + %23 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%21#1 : tensor<2x4096x640xi32>) outs(%18 : tensor<2x4096x640xf16>) { + ^bb0(%in: i32, %out: f16): + %27 = arith.sitofp %in : i32 to f32 + %29 = arith.truncf %27 : f32 to f16 + linalg.yield %29 : f16 + } -> tensor<2x4096x640xf16> + return %22, %23 : tensor<2x4096x640xf16>, tensor<2x4096x640xf16> +} +// CHECK-LABEL: func @fused_contraction_3 +// CHECK-COUNT-24: amdgpu.mfma + +// ----- + +func.func @fused_contraction_4(%arg0: tensor<2x4096x640xf16>, + %arg1 : tensor<10x64x640xf16>, %arg2 : tensor<10x64x640xf16>, + %arg3 : tensor<10x64x640xf16>) + -> (tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>) { + %9 = tensor.empty() : tensor<2x10x64x4096xf16> + %10 = tensor.empty() : tensor<2x10x64x4096xf32> + %11 = tensor.empty() : tensor<2x10x4096x64xf16> + %12 = tensor.empty() : tensor<2x10x4096x64xf32> + %cst = arith.constant 0.0: f32 + %fill0 = linalg.fill ins(%cst : f32) + outs(%12 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32> + %fill1 = linalg.fill ins(%cst : f32) + outs(%10 : tensor<2x10x64x4096xf32>) -> tensor<2x10x64x4096xf32> + %14:3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1, %arg2, %arg3 + : tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>, + tensor<10x64x640xf16>) + outs(%fill0, %fill0, %fill1 + : tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x64x4096xf32>) { + ^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32): + %18 = arith.extf %in : f16 to f32 + %19 = arith.extf %in_0 : f16 to f32 + %20 = arith.mulf %18, %19 : f32 + %21 = arith.addf %out, %20 : f32 + %22 = arith.extf %in_1 : f16 to f32 + %23 = arith.mulf %18, %22 : f32 + %24 = arith.addf %out_3, %23 : f32 + %25 = arith.extf %in_2 : f16 to f32 + %26 = arith.mulf %18, %25 : f32 + %27 = arith.addf %out_4, %26 : f32 + linalg.yield %21, %24, %27 : f32, f32, f32 + } -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x64x4096xf32>) + %15 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#0 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + %16 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#1 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + %17 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#2 : tensor<2x10x64x4096xf32>) outs(%9 : tensor<2x10x64x4096xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x64x4096xf16> + return %15, %16, %17 + : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16> +} +// CHECK-LABEL: func @fused_contraction_4 +// CHECK-COUNT-24: amdgpu.mfma