Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][TOSA] Fix f16/bf16 support for MaxPool2D #69332

Merged
merged 1 commit into from
Oct 18, 2023

Conversation

dchauhan-arm
Copy link
Contributor

Currently, the MaxPool2D operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, converse to what is stated in the TOSA Specification.

This patch fixes the verifier to accept the two datatypes for input/output tensors, and adds related LIT test cases in Tosa/ops.mlir

Currently, the MaxPool2D operation in the TOSA MLIR dialect does not
accept half-precision Fp16 and Bf16 tensors, converse to what is stated
in the [TOSA Specification](https://www.mlplatform.org/tosa/tosa_spec.html#_max_pool2d).

This patch fixes the verifier to accept the two datatypes for
input/output tensors, and adds related LIT test cases in Tosa/ops.mlir
@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2023

@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Dhruv Chauhan (dchauhan-arm)

Changes

Currently, the MaxPool2D operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, converse to what is stated in the TOSA Specification.

This patch fixes the verifier to accept the two datatypes for input/output tensors, and adds related LIT test cases in Tosa/ops.mlir


Full diff: https://github.com/llvm/llvm-project/pull/69332.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+1-1)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+16-2)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 4214bb57563285c..ee8f52deadbd152 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -691,7 +691,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
 
     // Determine what the initial value needs to be for the max pool op.
     TypedAttr initialAttr;
-    if (resultETy.isF32())
+    if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
       initialAttr = rewriter.getFloatAttr(
           resultETy, APFloat::getLargest(
                          cast<FloatType>(resultETy).getFloatSemantics(), true));
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e62bea515d06baa..8ce8fb73f29a504 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -97,12 +97,26 @@ func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -
 }
 
 // -----
-// CHECK-LABEL: max_pool2d
-func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+// CHECK-LABEL: max_pool2d_f32
+func.func @test_max_pool2d_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
   %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
   return %0 : tensor<1x32x32x8xf32>
 }
 
+// -----
+// CHECK-LABEL: max_pool2d_bf16
+func.func @test_max_pool2d_bf16(%arg0: tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16> {
+  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16>
+  return %0 : tensor<1x32x32x8xbf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_f16
+func.func @test_max_pool2d_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16> {
+  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16>
+  return %0 : tensor<1x32x32x8xf16>
+}
+
 // -----
 // CHECK-LABEL: rfft2d
 func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@GeorgeARM GeorgeARM merged commit c926291 into llvm:main Oct 18, 2023
ljfitz pushed a commit to Xilinx/llvm-project that referenced this pull request Feb 22, 2024
Currently, the MaxPool2D operation in the TOSA MLIR dialect does not
accept half-precision Fp16 and Bf16 tensors, converse to what is stated
in the [TOSA
Specification](https://www.mlplatform.org/tosa/tosa_spec.html#_max_pool2d).

This patch fixes the verifier to accept the two datatypes for
input/output tensors, and adds related LIT test cases in Tosa/ops.mlir
ttjost added a commit to Xilinx/llvm-project that referenced this pull request Feb 22, 2024
[MLIR][TOSA] Fix f16/bf16 support for MaxPool2D (llvm#69332)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants