-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][TOSA] Fix f16/bf16 support for MaxPool2D #69332
Conversation
Currently, the MaxPool2D operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, converse to what is stated in the [TOSA Specification](https://www.mlplatform.org/tosa/tosa_spec.html#_max_pool2d). This patch fixes the verifier to accept the two datatypes for input/output tensors, and adds related LIT test cases in Tosa/ops.mlir
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir-linalg Author: Dhruv Chauhan (dchauhan-arm) ChangesCurrently, the MaxPool2D operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, converse to what is stated in the TOSA Specification. This patch fixes the verifier to accept the two datatypes for input/output tensors, and adds related LIT test cases in Tosa/ops.mlir Full diff: https://github.com/llvm/llvm-project/pull/69332.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 4214bb57563285c..ee8f52deadbd152 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -691,7 +691,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
// Determine what the initial value needs to be for the max pool op.
TypedAttr initialAttr;
- if (resultETy.isF32())
+ if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
initialAttr = rewriter.getFloatAttr(
resultETy, APFloat::getLargest(
cast<FloatType>(resultETy).getFloatSemantics(), true));
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e62bea515d06baa..8ce8fb73f29a504 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -97,12 +97,26 @@ func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -
}
// -----
-// CHECK-LABEL: max_pool2d
-func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+// CHECK-LABEL: max_pool2d_f32
+func.func @test_max_pool2d_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
%0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}
+// -----
+// CHECK-LABEL: max_pool2d_bf16
+func.func @test_max_pool2d_bf16(%arg0: tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16> {
+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16>
+ return %0 : tensor<1x32x32x8xbf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_f16
+func.func @test_max_pool2d_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16> {
+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16>
+ return %0 : tensor<1x32x32x8xf16>
+}
+
// -----
// CHECK-LABEL: rfft2d
func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Currently, the MaxPool2D operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, converse to what is stated in the [TOSA Specification](https://www.mlplatform.org/tosa/tosa_spec.html#_max_pool2d). This patch fixes the verifier to accept the two datatypes for input/output tensors, and adds related LIT test cases in Tosa/ops.mlir
[MLIR][TOSA] Fix f16/bf16 support for MaxPool2D (llvm#69332)
Currently, the MaxPool2D operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, converse to what is stated in the TOSA Specification.
This patch fixes the verifier to accept the two datatypes for input/output tensors, and adds related LIT test cases in Tosa/ops.mlir