Skip to content

Commit

Permalink
[FIX][TOPI] Clip with IntImm/FloatImm
Browse files Browse the repository at this point in the history
Prior to this PR, TOPI clip op only accepts the min/max values with
Python native float/int type, and rejects FloatImm and IntImm.

This PR enhances the clip op to allow it accept FloatImm and IntImm.

Co-authored-by: Siyuan Feng <[email protected]>
  • Loading branch information
MasterJH5574 and Hzfengsy committed Feb 17, 2023
1 parent d12a636 commit 0e8bc14
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
18 changes: 14 additions & 4 deletions python/tvm/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# pylint: disable=redefined-builtin,unused-argument
import tvm
from tvm import te
from tvm.tir import PrimExpr

from . import tag
from . import cpp
from .utils import get_const_tuple
Expand Down Expand Up @@ -620,9 +622,9 @@ def clip(x, a_min, a_max):
----------
x : tvm.te.Tensor
Input argument.
a_min : int or float
a_min : tvm.tir.PrimExpr
Minimum value.
a_max : int or float
a_max : tvm.tir.PrimExpr
Maximum value.
Returns
Expand All @@ -633,8 +635,16 @@ def clip(x, a_min, a_max):

def _compute(*indices):
value = x(*indices)
const_min = tvm.tir.const(a_min, value.dtype)
const_max = tvm.tir.const(a_max, value.dtype)
const_min = (
tvm.tir.Cast(value.dtype, a_min)
if isinstance(a_min, PrimExpr)
else tvm.tir.const(a_min, value.dtype)
)
const_max = (
tvm.tir.Cast(value.dtype, a_max)
if isinstance(a_max, PrimExpr)
else tvm.tir.const(a_max, value.dtype)
)
return tvm.te.max(tvm.te.min(value, const_max), const_min)

return te.compute(x.shape, _compute)
Expand Down
16 changes: 13 additions & 3 deletions tests/python/topi/python/test_topi_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Test code for clip operator"""
import numpy as np
import tvm
from tvm import te
from tvm import te, tir
from tvm import topi
import tvm.testing
import tvm.topi.testing
Expand All @@ -32,12 +32,14 @@ def verify_clip(N, a_min, a_max, dtype):

# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_clip")
def get_ref_data():
def get_ref_data(a_min, a_max):
a_np = np.random.uniform(a_min * 2, a_max * 2, size=(N, N)).astype(dtype)
b_np = np.clip(a_np, a_min, a_max)
return a_np, b_np

a_np, b_np = get_ref_data()
a_min = a_min.value if isinstance(a_min, (tir.FloatImm, tir.IntImm)) else a_min
a_max = a_max.value if isinstance(a_max, (tir.FloatImm, tir.IntImm)) else a_max
a_np, b_np = get_ref_data(a_min, a_max)

def check_target(target, dev):
print("Running on target: %s" % target)
Expand All @@ -61,5 +63,13 @@ def test_clip():
verify_clip(1024, -127, 127, "int8")


@tvm.testing.uses_gpu
def test_clip_floaimm_intimm():
verify_clip(1024, tir.FloatImm("float32", -127), tir.FloatImm("float32", 127), "float32")
verify_clip(1024, tir.IntImm("int32", -127), tir.IntImm("int32", 127), "int16")
verify_clip(1024, tir.IntImm("int32", -127), tir.IntImm("int32", 127), "int8")


if __name__ == "__main__":
test_clip()
test_clip_floaimm_intimm()

0 comments on commit 0e8bc14

Please sign in to comment.