diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 18d21cb4d4ba..b939ea712c3c 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -118,6 +118,28 @@ class WellDefinedEraser : public StructInfoMutator, std::function(const Var& var)> f_var_map, arith::Analyzer* ana) : f_shape_var_map_(f_shape_var_map), f_var_map_(f_var_map), ana_(ana) {} + StructInfo VisitStructInfo_(const PrimStructInfoNode* op) final { + bool has_undefined = false; + Optional value; + + if (op->value.defined()) { + std::swap(has_undefined_, has_undefined); + value = VisitPrimExpr(op->value.value()); + std::swap(has_undefined_, has_undefined); + } + + // erase symbolic shape if we have undefined. + if (!has_undefined) { + if (value.same_as(op->value)) { + return GetRef(op); + } else { + return PrimStructInfo(value.value(), op->span); + } + } else { + return PrimStructInfo(op->dtype, op->span); + } + } + StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) final { bool has_undefined = false; Optional> values; @@ -295,7 +317,15 @@ class StructInfoBaseChecker if (other.as()) return BaseCheckResult::kFailL1; return BaseCheckResult::kFailL0; } - return lhs->dtype == rhs->dtype ? BaseCheckResult::kPass : BaseCheckResult::kFailL0; + + if (lhs->dtype != rhs->dtype) { + return BaseCheckResult::kFailL0; + } + + if (!lhs->value.defined()) return BaseCheckResult::kPass; + if (!rhs->value.defined()) return BaseCheckResult::kFailL2; + + return PrimValueMatchCheck(lhs->value.value(), rhs->value.value()); } BaseCheckResult VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index b445bde6f583..f74434bd7453 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -437,13 +437,22 @@ class BlockBuilderImpl : public BlockBuilderNode { void VisitStructInfo_(const ShapeStructInfoNode* op) final { for (const PrimExpr& s : op->values.value_or(Array())) { - // Only collect single var defined shape. Ignore something like `R.Tensor((m + 1, n + 1)) + // Only collect single var defined shape. Ignore something like `R.Shape((m + 1, n + 1)) if (const auto* var = s.as()) { shape_var_map_.Set(GetRef(var), s); } } } + void VisitStructInfo_(const PrimStructInfoNode* op) final { + // Only collect single var defined shape. Ignore something like `R.Prim(value=m + 1)` + if (op->value.defined()) { + if (auto var = op->value.as()) { + shape_var_map_.Set(var.value(), op->value.value()); + } + } + } + private: Map shape_var_map_; }; diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 9e91e0759248..efb2d0220481 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -155,6 +155,23 @@ tvm::Map InferSymbolicVarMap( } }; + auto bind_from_prim_value = [&bind_from_prim_expr](const StructInfo& var, + const StructInfo& expr) { + auto var_sinfo = var.as(); + if (!var_sinfo) return; + + auto expr_sinfo = expr.as(); + CHECK(expr_sinfo) << "Cannot bind expression with struct type " << expr + << " to variable with struct type " << var; + CHECK_EQ(var_sinfo->dtype, expr_sinfo->dtype) + << "Cannot bind expression with struct type " << expr << " to variable with struct type " + << var << ", due to conflicting PrimExpr DataType"; + + if (!var_sinfo->value.defined() || !expr_sinfo->value.defined()) return; + + bind_from_prim_expr(var_sinfo->value.value(), expr_sinfo->value.value()); + }; + auto bind_from_shape = [&bind_from_prim_expr](const StructInfo& var, const StructInfo& expr) { auto var_shape = var.as(); if (!var_shape) return; @@ -195,6 +212,7 @@ tvm::Map InferSymbolicVarMap( bind_from_tensor(var_sinfo, expr_sinfo); bind_from_shape(var_sinfo, expr_sinfo); + bind_from_prim_value(var_sinfo, expr_sinfo); } return tir_var_remap; diff --git a/tests/python/relax/test_bind_params.py b/tests/python/relax/test_bind_params.py index 189a44303d6c..bed44c4a6ac2 100644 --- a/tests/python/relax/test_bind_params.py +++ b/tests/python/relax/test_bind_params.py @@ -112,11 +112,14 @@ def expected() -> R.Shape([16]): def test_bind_prim_value(prim_value_dtype): + if prim_value_dtype != "int64": + pytest.xfail(reason="Currently, only support int64 as known symbolic value") + N = tir.Var("N", prim_value_dtype) value = tir.const(16, prim_value_dtype) @R.function - def before(A: R.Prim(value=N)): + def before(A: R.Prim(value=N)) -> R.Prim(value=N): R.func_attr({"global_symbol": "main"}) B: R.Prim(value=N) = A return B diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index b45c3c6e4a93..ce6fd8e04219 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1162,7 +1162,7 @@ def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): return w -def test_erase_to_well_defined(): +def test_erase_to_well_defined_removes_internal_vars(): @R.function def foo(x: R.Tensor): q = x @@ -1172,9 +1172,101 @@ def foo(x: R.Tensor): return w tvm.ir.assert_structural_equal(foo.ret_struct_info, R.Tensor(ndim=2)) + assert foo.ret_struct_info.shape is None _check(foo) +def test_erase_to_well_defined_keeps_variables_exposed_by_tensor_shape(): + @R.function + def foo(x: R.Tensor(["m", "n"])): + q = x + m, n = T.int64(), T.int64() + z = R.match_cast(q, R.Tensor((m, n))) + w = z + return w + + assert foo.ret_struct_info.shape is not None + _check(foo) + + +def test_erase_to_well_defined_keeps_variants_exposed_by_shape_expr(): + @R.function + def foo(x: R.Tensor, _: R.Shape(["m", "n"])): + q = x + m, n = T.int64(), T.int64() + z = R.match_cast(q, R.Tensor((m, n))) + w = z + return w + + assert foo.ret_struct_info.shape is not None + _check(foo) + + +def test_erase_to_well_defined_keeps_variants_exposed_by_prim_value(): + @R.function + def foo(x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n")): + q = x + m, n = T.int64(), T.int64() + z = R.match_cast(q, R.Tensor((m, n))) + w = z + return w + + assert foo.ret_struct_info.shape is not None + _check(foo) + + +def test_erase_to_well_defined_infers_from_shape_expr(): + @I.ir_module + class Module: + # The subroutine's symbolic variables are only in-scope for the subroutine. + @R.function + def subroutine(x: R.Tensor, _: R.Shape(["m", "n"])) -> R.Tensor(["m", "n"]): + q = x + m, n = T.int64(), T.int64() + z = R.match_cast(q, R.Tensor((m, n))) + w = z + return w + + # However, struct inference can make the symbolic variables in + # the main function to the symbolic variables in the + # subroutine. Therefore, the shape of the tensor returned + # from main can have a well-defined shape. + @R.function + def main(x: R.Tensor, shape: R.Shape(["m", "n"])): + output = Module.subroutine(x, shape) + return output + + assert Module["main"].ret_struct_info.shape is not None + _check(Module) + + +def test_erase_to_well_defined_infers_from_prim_value(): + @I.ir_module + class Module: + # The subroutine's symbolic variables are only in-scope for the subroutine. + @R.function + def subroutine( + x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n") + ) -> R.Tensor(["m", "n"]): + q = x + m, n = T.int64(), T.int64() + z = R.match_cast(q, R.Tensor((m, n))) + w = z + return w + + # However, struct inference can make the symbolic variables in + # the main function to the symbolic variables in the + # subroutine. Therefore, the shape of the tensor returned + # from main can have a well-defined shape. + @R.function + def main(x: R.Tensor, relax_m: R.Prim(value="m"), relax_n: R.Prim(value="n")): + output = Module.subroutine(x, relax_m, relax_n) + return output + + assert Module["main"].ret_struct_info.shape is not None + _check(Module) + + def test_empty_tuple(): @R.function def foo(x: R.Tuple()):