Skip to content

Commit

Permalink
[TIR] Modify IntImmNode deep_equal to match regardless of type
Browse files Browse the repository at this point in the history
This patch makes a small change to compare the values of IntImmNode to
see if they're equal when performing a deep_equal of expressions. This
is to try and align it with how the [`PEqualChecker<IntImm>`](https://github.com/apache/tvm/blob/b2204ae6988c7745ea9736340ccd900bc21ae821/src/arith/pattern_match.h#L166)
works where we only compare the values if both are IntImm.

This caused some simplifications to be inconsistent based on whether we
used IntImmNode or PrimExpr to pass an integer between different passes,
and it seemed to make more sense to say that if the values are equal,
then we can conclude the immediates are equal.
  • Loading branch information
quic-sanirudh committed Mar 26, 2024
1 parent b2204ae commit 0675ed4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/tir/analysis/deep_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
if (lhs->type_index() != rhs->type_index()) return false;
if (auto* plhs = lhs.as<IntImmNode>()) {
auto* prhs = rhs.as<IntImmNode>();
return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
return plhs->value == prhs->value;
}
if (lhs.as<AnyNode>()) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,9 @@ def func2():
assert not tvm.tir.analysis.expr_deep_equal(func2(), func1())


def test_equal_ints():
assert tvm.tir.analysis.expr_deep_equal(128, tvm.tir.IntImm(dtype="int64", value=128))


if __name__ == "__main__":
test_equal_expr()

0 comments on commit 0675ed4

Please sign in to comment.