diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 17530380e665..60973577ac92 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -459,6 +459,57 @@ bool ContainsNode(const Stmt& stmt) { return visitor.contains_node; } +/*! + * \brief Legalize the data types of expressions to make sure they are consistent with other + * parts of the program. + * + * It enforces the following rules: + * - The data type of the index variable in a loop must be consistent with the data type of the loop + * bounds. + * - The data type of the binary and ternary expressions must be consistent with the data types of + * each of their operands. + * - The data type of the bounds and binding values of block iter vars must be consistent with the + * data type of the block iter vars. + * + * Usually we enforce the consistency of data types when constructing the IR nodes. However, such + * inconsistency may happen as a result of IR mutation in some passes. This class can be used as + * base class of such passes to ensure the consistency of data types. + */ +class DataTypeLegalizer : public StmtExprMutator { + public: + Stmt VisitStmt_(const ForNode* op) override; + + Stmt VisitStmt_(const AttrStmtNode* op) override; + Stmt VisitStmt_(const BlockRealizeNode* op) override; + Stmt VisitStmt_(const BlockNode* op) override; + PrimExpr VisitExpr_(const SelectNode* op) override; + PrimExpr VisitExpr_(const RampNode* op) override; + PrimExpr VisitExpr_(const AddNode* op) override; + PrimExpr VisitExpr_(const SubNode* op) override; + PrimExpr VisitExpr_(const MulNode* op) override; + PrimExpr VisitExpr_(const DivNode* op) override; + PrimExpr VisitExpr_(const ModNode* op) override; + PrimExpr VisitExpr_(const FloorDivNode* op) override; + PrimExpr VisitExpr_(const FloorModNode* op) override; + PrimExpr VisitExpr_(const MinNode* op) override; + PrimExpr VisitExpr_(const MaxNode* op) override; + PrimExpr VisitExpr_(const EQNode* op) override; + PrimExpr VisitExpr_(const NENode* op) override; + PrimExpr VisitExpr_(const LTNode* op) override; + PrimExpr VisitExpr_(const LENode* op) override; + PrimExpr VisitExpr_(const GTNode* op) override; + PrimExpr VisitExpr_(const GENode* op) override; + PrimExpr VisitExpr_(const CallNode* op) override; + + using StmtExprMutator::VisitExpr_; + using StmtExprMutator::VisitStmt_; + + protected: + // a map from IterVar before rewrite to that after rewrite, + // ensures one old IterVar maps to exactly one new IterVar + std::unordered_map ivmap_; +}; + } // namespace tir } // namespace tvm diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc new file mode 100644 index 000000000000..afa28d92589f --- /dev/null +++ b/src/tir/ir/data_type_rewriter.cc @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file data_type_rewriter.cc + * \brief Rewrite the data type of expressions. + */ + +#include +#include +#include + +#include "./functor_common.h" + +namespace tvm { +namespace tir { + +Stmt DataTypeLegalizer::VisitStmt_(const ForNode* op) { + Stmt s = StmtExprMutator::VisitStmt_(op); + op = s.as(); + ICHECK(op != nullptr) << "Expected type to be ForNode, but get " << s->GetTypeKey(); + PrimExpr e = VisitExpr(op->loop_var); + Var var = Downcast(e); + return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->kind, op->body, + op->thread_binding, op->annotations); +} + +Stmt DataTypeLegalizer::VisitStmt_(const BlockRealizeNode* op) { + BlockRealize realize = Downcast(StmtExprMutator::VisitStmt_(op)); + Array new_iter_values; + bool changed = false; + for (int i = 0; i < static_cast(op->iter_values.size()); ++i) { + auto dtype = realize->block->iter_vars[i]->var->dtype; + if (op->iter_values[i]->dtype != dtype) { + new_iter_values.push_back(cast(dtype, realize->iter_values[i])); + changed = true; + } else { + new_iter_values.push_back(realize->iter_values[i]); + } + } + if (changed) { + realize.CopyOnWrite()->iter_values = std::move(new_iter_values); + } + return std::move(realize); +} + +Stmt DataTypeLegalizer::VisitStmt_(const BlockNode* op) { + Block new_block = Downcast(StmtExprMutator::VisitStmt_(op)); + Array new_iter_vars = MutateArray(new_block->iter_vars, [this](const IterVar& iter) { + auto dtype = iter->var.dtype(); + if (iter->dom->min->dtype != dtype || iter->dom->extent->dtype != dtype) { + IterVar new_iter = iter; + new_iter.CopyOnWrite()->dom = + Range(cast(dtype, iter->dom->min), cast(dtype, iter->dom->extent)); + return new_iter; + } else { + return iter; + } + }); + if (!op->iter_vars.same_as(new_iter_vars)) { + new_block.CopyOnWrite()->iter_vars = std::move(new_iter_vars); + } + return std::move(new_block); +} + +Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { + Stmt s = StmtExprMutator::VisitStmt_(op); + op = s.as(); + ICHECK(op != nullptr) << "Expected type to be AttrStmtNode" + << ", but get " << s->GetTypeKey(); + const IterVarNode* iv = op->node.as(); + ICHECK(iv != nullptr) << "Expected type to be IterVarNode" + << ", but get " << op->node->GetTypeKey(); + PrimExpr e = VisitExpr(iv->var); + Var var = Downcast(e); + if (ivmap_.find(iv) == ivmap_.end()) { + Range dom = iv->dom; + if (dom.defined()) { + PrimExpr extend = dom->extent; + ICHECK(extend.dtype().is_int() && var.dtype().is_int()); + if (var.dtype().bits() != extend.dtype().bits()) { + DataType dtype = var.dtype(); + dom = Range(cast(dtype, dom->min), cast(dtype, extend), dom->span); + } + } + ivmap_[iv] = IterVar(dom, var, iv->iter_type, iv->thread_tag); + } + return AttrStmt(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body); + } + return StmtExprMutator::VisitStmt_(op); +} + +PrimExpr DataTypeLegalizer::VisitExpr_(const SelectNode* op) { + PrimExpr condition = this->VisitExpr(op->condition); + PrimExpr true_value = this->VisitExpr(op->true_value); + PrimExpr false_value = this->VisitExpr(op->false_value); + if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && + false_value.same_as(op->false_value) && true_value.dtype() == false_value.dtype()) { + return GetRef(op); + } else { + int bits = std::max(true_value.dtype().bits(), false_value.dtype().bits()); + DataType dtype = true_value.dtype().with_bits(bits); + if (true_value.dtype() != dtype) true_value = cast(dtype, true_value); + if (false_value.dtype() != dtype) false_value = cast(dtype, false_value); + return Select(condition, true_value, false_value); + } +} + +PrimExpr DataTypeLegalizer::VisitExpr_(const RampNode* op) { + PrimExpr base = VisitExpr(op->base); + PrimExpr stride = VisitExpr(op->stride); + if (base.same_as(op->base) && stride.same_as(op->stride) && base.dtype() == stride.dtype()) { + return GetRef(op); + } else { + ICHECK(base.dtype().is_int() && stride.dtype().is_int()); + int bits = std::max(base.dtype().bits(), stride.dtype().bits()); + DataType dtype = base.dtype().with_bits(bits); + if (base.dtype() != dtype) base = cast(dtype, base); + if (stride.dtype() != dtype) stride = cast(dtype, stride); + return Ramp(base, stride, op->lanes); + } +} + +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr DataTypeLegalizer::VisitExpr_(const OP* op) { \ + PrimExpr a = this->VisitExpr(op->a); \ + PrimExpr b = this->VisitExpr(op->b); \ + if (op->a.same_as(a) && op->b.same_as(b) && a.dtype() == b.dtype()) { \ + return GetRef(op); \ + } else { \ + return FUNC(a, b); \ + } \ + } + +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); + +#undef DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH + +PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { + PrimExpr e = StmtExprMutator::VisitExpr_(op); + op = e.as(); + static const Op& builtin_pow_ = Op::Get("tir.pow"); + ICHECK(op != nullptr) << "Expected type to be CallNode" + << ", but get " << e->GetTypeKey(); + if (op->op.same_as(builtin::shift_right())) { + return op->args[0] >> op->args[1]; + } else if (op->op.same_as(builtin::shift_left())) { + return op->args[0] << op->args[1]; + } else if (op->op.same_as(builtin::bitwise_and())) { + return op->args[0] & op->args[1]; + } else if (op->op.same_as(builtin::bitwise_or())) { + return op->args[0] | op->args[1]; + } else if (op->op.same_as(builtin::bitwise_xor())) { + return op->args[0] ^ op->args[1]; + } else if (op->op.same_as(builtin_pow_)) { + return pow(op->args[0], op->args[1]); + } else if (op->op.same_as(builtin::if_then_else())) { + return if_then_else(op->args[0], op->args[1], op->args[2]); + } + return e; +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 047295180712..7f9c76f5257d 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -187,7 +187,9 @@ class DataTypeVisitor final : public StmtExprVisitor { arith::ConstIntBoundAnalyzer::BoundMapType bound_; }; -class DataTypeRewriter : public StmtExprMutator { +class DataTypeRewriter : public DataTypeLegalizer { + using Parent = DataTypeLegalizer; + public: explicit DataTypeRewriter(int target_bits) : visitor_(target_bits) {} @@ -253,19 +255,8 @@ class DataTypeRewriter : public StmtExprMutator { return indices; } - Stmt VisitStmt_(const ForNode* op) final { - Stmt s = StmtExprMutator::VisitStmt_(op); - op = s.as(); - ICHECK(op != nullptr) << "Expected type to be ForNode" - << ", but get " << s->GetTypeKey(); - PrimExpr e = VisitExpr(op->loop_var); - Var var = Downcast(e); - return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->kind, op->body, - op->thread_binding, op->annotations); - } - Stmt VisitStmt_(const IfThenElseNode* op) final { - IfThenElse updated = Downcast(StmtExprMutator::VisitStmt_(op)); + IfThenElse updated = Downcast(Parent::VisitStmt_(op)); is_condition_ = true; PrimExpr cond = VisitExpr(op->condition); is_condition_ = false; @@ -275,34 +266,6 @@ class DataTypeRewriter : public StmtExprMutator { return std::move(updated); } - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { - Stmt s = StmtExprMutator::VisitStmt_(op); - op = s.as(); - ICHECK(op != nullptr) << "Expected type to be AttrStmtNode" - << ", but get " << s->GetTypeKey(); - const IterVarNode* iv = op->node.as(); - ICHECK(iv != nullptr) << "Expected type to be IterVarNode" - << ", but get " << op->node->GetTypeKey(); - PrimExpr e = VisitExpr(iv->var); - Var var = Downcast(e); - if (ivmap_.find(iv) == ivmap_.end()) { - Range dom = iv->dom; - if (dom.defined()) { - PrimExpr extend = dom->extent; - if (extend.dtype().is_int() && var.dtype().is_int() && - var.dtype().bits() != extend.dtype().bits()) { - DataType dtype = var.dtype(); - dom = Range(cast(dtype, dom->min), cast(dtype, extend), dom->span); - } - } - ivmap_[iv] = IterVar(dom, var, iv->iter_type, iv->thread_tag); - } - return AttrStmt(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body); - } - return StmtExprMutator::VisitStmt_(op); - } - PrimExpr VisitExpr_(const VarNode* op) final { if (visitor_.vmap.find(op) != visitor_.vmap.end()) { if (vmap_.find(op) == vmap_.end()) { @@ -310,42 +273,7 @@ class DataTypeRewriter : public StmtExprMutator { } return vmap_[op]; } - return StmtExprMutator::VisitExpr_(op); - } - - PrimExpr VisitExpr_(const SelectNode* op) final { - PrimExpr condition = this->VisitExpr(op->condition); - PrimExpr true_value = this->VisitExpr(op->true_value); - PrimExpr false_value = this->VisitExpr(op->false_value); - if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && - false_value.same_as(op->false_value)) { - return GetRef(op); - } else { - if (op->true_value.dtype().is_int() && op->false_value.dtype().is_int()) { - int bits = std::max(true_value.dtype().bits(), false_value.dtype().bits()); - DataType dtype = true_value.dtype().with_bits(bits); - if (true_value.dtype() != dtype) true_value = cast(dtype, true_value); - if (false_value.dtype() != dtype) false_value = cast(dtype, false_value); - } - return Select(condition, true_value, false_value); - } - } - - PrimExpr VisitExpr_(const RampNode* op) final { - PrimExpr base = VisitExpr(op->base); - PrimExpr stride = VisitExpr(op->stride); - if (base.same_as(op->base) && stride.same_as(op->stride)) { - return GetRef(op); - } else { - if (base.dtype().is_int()) { - ICHECK(stride.dtype().is_int()) << "Ramp base is int but stride is " << stride.dtype(); - int bits = std::max(base.dtype().bits(), stride.dtype().bits()); - DataType dtype = base.dtype().with_bits(bits); - if (base.dtype() != dtype) base = cast(dtype, base); - if (stride.dtype() != dtype) stride = cast(dtype, stride); - } - return Ramp(base, stride, op->lanes); - } + return Parent::VisitExpr_(op); } PrimExpr VisitExpr_(const SizeVarNode* op) final { @@ -355,7 +283,7 @@ class DataTypeRewriter : public StmtExprMutator { } return vmap_[op]; } - return StmtExprMutator::VisitExpr_(op); + return Parent::VisitExpr_(op); } PrimExpr VisitExpr_(const IntImmNode* op) final { @@ -364,29 +292,20 @@ class DataTypeRewriter : public StmtExprMutator { return IntImm(visitor_.vmap[op], op->value); } } - return StmtExprMutator::VisitExpr_(op); + return Parent::VisitExpr_(op); } PrimExpr VisitExpr_(const CastNode* op) final { if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) { - PrimExpr e = StmtExprMutator::VisitExpr_(op); + PrimExpr e = Parent::VisitExpr_(op); const CastNode* new_op = e.as(); ICHECK(new_op != nullptr) << "Expected type to be CastNode" << ", but get " << e->GetTypeKey(); return Cast(visitor_.vmap[op], new_op->value); } - return StmtExprMutator::VisitExpr_(op); + return Parent::VisitExpr_(op); } - PrimExpr VisitExpr_(const AddNode* op) final; - PrimExpr VisitExpr_(const SubNode* op) final; - PrimExpr VisitExpr_(const MulNode* op) final; - PrimExpr VisitExpr_(const DivNode* op) final; - PrimExpr VisitExpr_(const ModNode* op) final; - PrimExpr VisitExpr_(const FloorDivNode* op) final; - PrimExpr VisitExpr_(const FloorModNode* op) final; - PrimExpr VisitExpr_(const MinNode* op) final; - PrimExpr VisitExpr_(const MaxNode* op) final; PrimExpr VisitExpr_(const EQNode* op) final; PrimExpr VisitExpr_(const NENode* op) final; PrimExpr VisitExpr_(const LTNode* op) final; @@ -401,28 +320,12 @@ class DataTypeRewriter : public StmtExprMutator { // a map from Var before rewrite to that after rewrite, // ensures one old Var maps to exactly one new Var std::unordered_map vmap_; - // a map from IterVar before rewrite to that after rewrite, - // ensures one old IterVar maps to exactly one new IterVar - std::unordered_map ivmap_; // indicator of index expr to rewrite bool is_index_{false}; // indicator of condition bool is_condition_{false}; - // cached ops - const Op& builtin_pow_ = Op::Get("tir.pow"); }; -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ - PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \ - PrimExpr a = this->VisitExpr(op->a); \ - PrimExpr b = this->VisitExpr(op->b); \ - if (a.same_as(op->a) && b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - return FUNC(a, b); \ - } \ - } - #define DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \ bool is_index = is_index_; \ @@ -430,25 +333,11 @@ class DataTypeRewriter : public StmtExprMutator { if (rewrite) { \ is_index_ = true; \ } \ - PrimExpr a = this->VisitExpr(op->a); \ - PrimExpr b = this->VisitExpr(op->b); \ + auto result = Parent::VisitExpr_(op); \ is_index_ = is_index; \ - if (a.same_as(op->a) && b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - return FUNC(a, b); \ - } \ + return std::move(result); \ } -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max); DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); @@ -465,26 +354,7 @@ PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { is_condition_ = is_condition; return if_then_else(cond, VisitExpr(op->args[1]), VisitExpr(op->args[2])); } - - PrimExpr e = StmtExprMutator::VisitExpr_(op); - op = e.as(); - ICHECK(op != nullptr) << "Expected type to be CallNode" - << ", but get " << e->GetTypeKey(); - if (op->op.same_as(builtin::shift_right())) { - return op->args[0] >> op->args[1]; - } else if (op->op.same_as(builtin::shift_left())) { - return op->args[0] << op->args[1]; - } else if (op->op.same_as(builtin::bitwise_and())) { - return op->args[0] & op->args[1]; - } else if (op->op.same_as(builtin::bitwise_or())) { - return op->args[0] | op->args[1]; - } else if (op->op.same_as(builtin::bitwise_xor())) { - return op->args[0] ^ op->args[1]; - } else if (op->op.same_as(builtin_pow_)) { - return pow(op->args[0], op->args[1]); - } - - return e; + return Parent::VisitExpr_(op); } Stmt NarrowDataType(Stmt stmt, int target_bits) { return DataTypeRewriter(target_bits)(stmt); } diff --git a/tests/cpp/data_type_rewriter_test.cc b/tests/cpp/data_type_rewriter_test.cc new file mode 100644 index 000000000000..d1ac9d782ce5 --- /dev/null +++ b/tests/cpp/data_type_rewriter_test.cc @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +using namespace tvm; +using namespace tvm::tir; +using namespace tvm::runtime; + +using BinaryOpTypes = + ::testing::Types; + +template +class DataTypeLegalizerBinaryOp : public ::testing::Test {}; + +TYPED_TEST_SUITE(DataTypeLegalizerBinaryOp, BinaryOpTypes); + +TYPED_TEST(DataTypeLegalizerBinaryOp, Basic) { + using RefType = TypeParam; + using NodeType = typename RefType::ContainerType; + auto node = make_object(); + node->a = Var("a", DataType::Int(32)); + node->b = IntImm(DataType::Int(64), 2); + DataTypeLegalizer legalizer; + auto new_expr = Downcast(legalizer(RefType(node))); + auto target_dtype = DataType::Int(64); + ASSERT_EQ(new_expr->a.dtype(), target_dtype); + ASSERT_EQ(new_expr->b.dtype(), target_dtype); +} + +TEST(DataTypeLegalizer, Select) { + auto node = make_object(); + node->condition = Var("cond", DataType::Bool()); + node->true_value = Var("a", DataType::Int(64)); + node->false_value = IntImm(DataType::Int(32), 2); + DataTypeLegalizer legalizer; + Select new_select = Downcast