Skip to content

Commit

Permalink
[Arith] Allow const folding on fp16 involving one and zero (#13631)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored Dec 16, 2022
1 parent 0eabbac commit cded048
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 19 deletions.
8 changes: 0 additions & 8 deletions src/arith/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ inline Optional<PrimExpr> TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) {
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value + fb->value);
} else {
return NullOpt;
}
}
if (fa && fa->value == 0) return b;
Expand Down Expand Up @@ -171,8 +169,6 @@ inline Optional<PrimExpr> TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) {
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value - fb->value);
} else {
return NullOpt;
}
}
if (fb && fb->value == 0) return a;
Expand Down Expand Up @@ -202,8 +198,6 @@ inline Optional<PrimExpr> TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value * fb->value);
} else {
return NullOpt;
}
}
if (fa) {
Expand Down Expand Up @@ -243,8 +237,6 @@ inline Optional<PrimExpr> TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value / fb->value);
} else {
return NullOpt;
}
}
if (fa && fa->value == 0) return a;
Expand Down
29 changes: 18 additions & 11 deletions tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import te


Expand Down Expand Up @@ -124,6 +125,22 @@ def test_div_simplify():
ck.verify(fld(17 + 47 * x, 16), fld(x * 47 + 17, 16))


def test_fp16_const_fold():
ck = CanonicalChecker()
zero = tvm.tir.const(0, "float16")
one = tvm.tir.const(1, "float16")
half = tvm.tir.const(0.5, "float16")

ck.verify(zero + half, half)
ck.verify(half - zero, half)

ck.verify(zero * half, zero)
ck.verify(half * one, half)

ck.verify(half / one, half)
ck.verify(zero / half, zero)


def test_floormod_simplify():
ck = CanonicalChecker()
flm = tvm.te.floormod
Expand Down Expand Up @@ -356,14 +373,4 @@ def test_simplify_cast():


if __name__ == "__main__":
test_floormod_simplify()
test_mul_sum_simplify()
test_simplify_if_then_else()
test_div_simplify()
test_reduce_simplify()
test_reduce_combiner_simplify()

test_split_index_simplify()
test_canonical_mixed()
test_complex_cases()
test_simplify_cast()
tvm.testing.main()

0 comments on commit cded048

Please sign in to comment.