Skip to content

Commit

Permalink
[FIX][topi.scatter_nd] fixed shape equality assert by using analyzer …
Browse files Browse the repository at this point in the history
…to prove equality (#17537)

* fixed assert by using analyzer to the prove equality

* updated docs in Analyzer class
  • Loading branch information
PatrikPerssonInceptron authored Nov 22, 2024
1 parent 42b1e97 commit db6d205
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def int_set(self, expr, dom_map):
expr : PrimExpr
The expression.
dom_map : Dict[Var, tvm.arith.IntSet]
dom_map : Dict[tvm.tir.Var, tvm.arith.IntSet]
The domain for variables to be relaxed.
Returns
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
"""ScatterND operator"""
from tvm import te, tir # hide redefinition of min and max
from tvm.tir import expr
from tvm.arith.analyzer import Analyzer


def _verify_scatter_nd_inputs(data, indices, updates):
analyzer = Analyzer()
mdim = int(indices.shape[0])
assert mdim <= len(data.shape), (
f"The first dimension of the indices ({mdim}) must be less than or equal to "
Expand All @@ -29,7 +31,8 @@ def _verify_scatter_nd_inputs(data, indices, updates):
for i in range(len(indices.shape) - 1):
if isinstance(indices.shape[i + 1], expr.Var) or isinstance(updates.shape[i], expr.Var):
continue
assert indices.shape[i + 1] == updates.shape[i], (

assert analyzer.can_prove_equal(indices.shape[i + 1], updates.shape[i]), (
f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of "
f"updates[{i}] ({updates.shape[i]})."
)
Expand Down

0 comments on commit db6d205

Please sign in to comment.