Skip to content

Commit

Permalink
[mlir] [arith] add shl overflow flag in Arith and lower to SPIR-V and…
Browse files Browse the repository at this point in the history
… LLVMIR (#79828)

There is no `SHL` used in canonicalization in `arith`

---------

Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Tobias Gysi <[email protected]>
  • Loading branch information
3 people authored Jan 30, 2024
1 parent e5054fb commit f7ef73e
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 5 deletions.
12 changes: 9 additions & 3 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def Arith_XOrIOp : Arith_TotalIntBinaryOp<"xori", [Commutative]> {
// ShLIOp
//===----------------------------------------------------------------------===//

def Arith_ShLIOp : Arith_TotalIntBinaryOp<"shli"> {
def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
let summary = "integer left-shift";
let description = [{
The `shli` operation shifts the integer value of the first operand to the left
Expand All @@ -791,12 +791,18 @@ def Arith_ShLIOp : Arith_TotalIntBinaryOp<"shli"> {
operand is greater than the bitwidth of the first operand, then the
operation returns poison.

This op supports `nuw`/`nsw` overflow flags which stands stand for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.

Example:

```mlir
%1 = arith.constant 5 : i8 // %1 is 0b00000101
%1 = arith.constant 5 : i8 // %1 is 0b00000101
%2 = arith.constant 3 : i8
%3 = arith.shli %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000
%3 = arith.shli %1, %2 : i8 // %3 is 0b00101000
%4 = arith.shli %1, %2 overflow<nsw, nuw> : i8
```
}];
let hasFolder = 1;
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ using RemUIOpLowering =
VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
using SelectOpLowering =
VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>;
using ShLIOpLowering =
VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
arith::AttrConvertOverflowToLLVM>;
using ShRSIOpLowering =
VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
using ShRUIOpLowering =
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
XOrIOpLogicalPattern, XOrIOpBooleanPattern,
spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
Expand Down
2 changes: 2 additions & 0 deletions mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -586,5 +586,7 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
// CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
// CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}
4 changes: 4 additions & 0 deletions mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1422,6 +1422,8 @@ func.func @ops_flags(%arg0: i64, %arg1: i64) {
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
// CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap} : i64
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
// CHECK: %{{.*}} = spirv.ShiftLeftLogical %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap} : i64
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}

Expand All @@ -1443,6 +1445,8 @@ func.func @ops_flags(%arg0: i64, %arg1: i64) {
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
// CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
// CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
%3 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}

Expand Down
2 changes: 2 additions & 0 deletions mlir/test/Dialect/Arith/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1147,5 +1147,7 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) {
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
// CHECK: %{{.*}} = arith.muli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
// CHECK: %{{.*}} = arith.shli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}

0 comments on commit f7ef73e

Please sign in to comment.