diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 22555e0fb3a4..f8069a717da3 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -218,7 +218,7 @@ def int_set(self, expr, dom_map): expr : PrimExpr The expression. - dom_map : Dict[Var, tvm.arith.IntSet] + dom_map : Dict[tvm.tir.Var, tvm.arith.IntSet] The domain for variables to be relaxed. Returns diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index 799b3d16733f..9cf19e2e61f4 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -18,9 +18,11 @@ """ScatterND operator""" from tvm import te, tir # hide redefinition of min and max from tvm.tir import expr +from tvm.arith.analyzer import Analyzer def _verify_scatter_nd_inputs(data, indices, updates): + analyzer = Analyzer() mdim = int(indices.shape[0]) assert mdim <= len(data.shape), ( f"The first dimension of the indices ({mdim}) must be less than or equal to " @@ -29,7 +31,8 @@ def _verify_scatter_nd_inputs(data, indices, updates): for i in range(len(indices.shape) - 1): if isinstance(indices.shape[i + 1], expr.Var) or isinstance(updates.shape[i], expr.Var): continue - assert indices.shape[i + 1] == updates.shape[i], ( + + assert analyzer.can_prove_equal(indices.shape[i + 1], updates.shape[i]), ( f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " f"updates[{i}] ({updates.shape[i]})." )