Skip to content

Commit

Permalink
[mlir][spirv] Use folding in IndexToSPIRV
Browse files Browse the repository at this point in the history
Allow for constant propogation when converting from index to SPIR-V for:
ceildiv[s|u] and floordivs.

Aims to improve readability of generated SPIR-V code.

Part of work llvm#70704
  • Loading branch information
inbelic committed Mar 14, 2024
1 parent 88b1575 commit 9271be8
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 26 deletions.
58 changes: 32 additions & 26 deletions mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,25 +120,28 @@ struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
loc, n_type, IntegerAttr::get(n_type, -1));

// Compute `x`.
Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
Value mPos = rewriter.createOrFold<spirv::SGreaterThanOp>(loc, m, zero);
Value x = rewriter.createOrFold<spirv::SelectOp>(loc, mPos, negOne, posOne);

// Compute the positive result.
Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
Value nPlusX = rewriter.createOrFold<spirv::IAddOp>(loc, n, x);
Value nPlusXDivM = rewriter.createOrFold<spirv::SDivOp>(loc, nPlusX, m);
Value posRes =
rewriter.createOrFold<spirv::IAddOp>(loc, nPlusXDivM, posOne);

// Compute the negative result.
Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
Value negN = rewriter.createOrFold<spirv::ISubOp>(loc, zero, n);
Value negNDivM = rewriter.createOrFold<spirv::SDivOp>(loc, negN, m);
Value negRes = rewriter.createOrFold<spirv::ISubOp>(loc, zero, negNDivM);

// Pick the positive result if `n` and `m` have the same sign and `n` is
// non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
Value nPos = rewriter.createOrFold<spirv::SGreaterThanOp>(loc, n, zero);
Value sameSign =
rewriter.createOrFold<spirv::LogicalEqualOp>(loc, nPos, mPos);
Value nNonZero = rewriter.createOrFold<spirv::INotEqualOp>(loc, n, zero);
Value cmp =
rewriter.createOrFold<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
return success();
}
Expand Down Expand Up @@ -168,12 +171,12 @@ struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
IntegerAttr::get(n_type, 1));

// Compute the non-zero result.
Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
Value minusOne = rewriter.createOrFold<spirv::ISubOp>(loc, n, one);
Value quotient = rewriter.createOrFold<spirv::UDivOp>(loc, minusOne, m);
Value plusOne = rewriter.createOrFold<spirv::IAddOp>(loc, quotient, one);

// Pick the result
Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
Value cmp = rewriter.createOrFold<spirv::IEqualOp>(loc, n, zero);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
return success();
}
Expand Down Expand Up @@ -206,24 +209,27 @@ struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
loc, n_type, IntegerAttr::get(n_type, -1));

// Compute `x`.
Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
Value mNeg = rewriter.createOrFold<spirv::SLessThanOp>(loc, m, zero);
Value x = rewriter.createOrFold<spirv::SelectOp>(loc, mNeg, posOne, negOne);

// Compute the negative result
Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
Value xMinusN = rewriter.createOrFold<spirv::ISubOp>(loc, x, n);
Value xMinusNDivM = rewriter.createOrFold<spirv::SDivOp>(loc, xMinusN, m);
Value negRes =
rewriter.createOrFold<spirv::ISubOp>(loc, negOne, xMinusNDivM);

// Compute the positive result.
Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
Value posRes = rewriter.createOrFold<spirv::SDivOp>(loc, n, m);

// Pick the negative result if `n` and `m` have different signs and `n` is
// non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
Value nNeg = rewriter.createOrFold<spirv::SLessThanOp>(loc, n, zero);
Value diffSign =
rewriter.createOrFold<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
Value nNonZero = rewriter.createOrFold<spirv::INotEqualOp>(loc, n, zero);

Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
Value cmp =
rewriter.createOrFold<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
return success();
}
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ func.func @ceildivs(%n: index, %m: index) -> index {
return %result : index
}

// CHECK-LABEL: @ceildivs_fold
func.func @ceildivs_fold() -> index {
%n = index.constant -42
%m = index.constant 5
%result = index.ceildivs %n, %m

// CHECK: %[[RESULT:.*]] = spirv.Constant -8
// %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]]
// return %[[RESULTI]]
return %result : index
}

// CHECK-LABEL: @ceildivu
// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
func.func @ceildivu(%n: index, %m: index) -> index {
Expand All @@ -117,6 +129,18 @@ func.func @ceildivu(%n: index, %m: index) -> index {
return %result : index
}

// CHECK-LABEL: @ceildivu_fold
func.func @ceildivu_fold() -> index {
%n = index.constant -42
%m = index.constant 5
%result = index.ceildivu %n, %m

// CHECK: %[[RESULT:.*]] = spirv.Constant 8
// %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]]
// return %[[RESULTI]]
return %result : index
}

// CHECK-LABEL: @floordivs
// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
func.func @floordivs(%n: index, %m: index) -> index {
Expand Down Expand Up @@ -148,6 +172,18 @@ func.func @floordivs(%n: index, %m: index) -> index {
return %result : index
}

// CHECK-LABEL: @floordivs_fold
func.func @floordivs_fold() -> index {
%n = index.constant -42
%m = index.constant 5
%result = index.floordivs %n, %m

// CHECK: %[[RESULT:.*]] = spirv.Constant -9
// %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]]
// return %[[RESULTI]]
return %result : index
}

// CHECK-LABEL: @index_cmp
func.func @index_cmp(%a : index, %b : index) {
// CHECK: spirv.IEqual
Expand Down

0 comments on commit 9271be8

Please sign in to comment.