From edf801eb464b5d485afd26a0917f77d771f2c8c5 Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Wed, 1 Dec 2021 17:11:21 -0800 Subject: [PATCH] [checkpoint] Always on_device let-bound values Realized both dead_code.cc and fold_constant.cc would happily move values into a different lexical virtual device context since device_planner.cc was being 'clever' and eliding on_devices for let-bound values when there's no change. Fixed so that every let-bound value has an on_device. Will be much better after https://github.com/apache/tvm-rfcs/pull/45 --- include/tvm/relay/expr.h | 3 + src/printer/relay_text_printer.cc | 6 -- src/printer/text_printer.cc | 56 +++++++++---- src/relay/backend/vm/compiler.cc | 20 +++-- src/relay/ir/expr.cc | 7 ++ src/relay/transforms/device_aware_visitors.cc | 7 +- src/relay/transforms/device_aware_visitors.h | 15 +++- src/relay/transforms/device_planner.cc | 20 +++-- src/relay/transforms/let_list.h | 2 +- src/relay/transforms/memory_alloc.cc | 80 +++++++++---------- src/relay/transforms/partial_eval.cc | 4 +- src/relay/transforms/to_a_normal_form.cc | 2 +- src/runtime/vm/executable.cc | 2 +- src/target/se_scope.cc | 4 - src/target/target.cc | 4 - 15 files changed, 136 insertions(+), 96 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index f57b2d1a1952c..a156c402cc7bf 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -229,6 +229,9 @@ class Var : public Expr { */ TVM_DLL Var(Id vid, Type type_annotation, Span span = Span()); + + static Var GenSym(Type type_annotation = {}, Span span = {}); + TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); }; diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 9d6d57d701308..d0c2cfebbbd82 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -502,12 +502,6 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) { Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { Doc doc; doc << "@" << op->name_hint; -#if TVM_LOG_DEBUG - if (op->checked_type_.defined()) { - doc << " /* type=" << PrintType(op->checked_type_, /*meta=*/false) << " */"; - } - doc << " /* id=" << reinterpret_cast(op) << " */"; -#endif return doc; } diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index 5acb9bd3f1dc2..73fc8a3050d0e 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -36,49 +36,71 @@ static const char* kSemVer = "0.0.5"; Doc TextPrinter::PrintMod(const IRModule& mod) { Doc doc; int counter = 0; + + // We'll print in alphabetical order to make a/b diffs easier to work with. + // type definitions + std::vector tyvars; for (const auto& kv : mod->type_definitions) { + tyvars.emplace_back(kv.first); + } + std::sort(tyvars.begin(), tyvars.end(), + [](const GlobalTypeVar& left, const GlobalTypeVar& right) { + return left->name_hint < right->name_hint; + }); + for (const auto& tyvar : tyvars) { if (counter++ != 0) { doc << Doc::NewLine(); } - doc << relay_text_printer_.Print(kv.second); + doc << relay_text_printer_.Print(mod->type_definitions[tyvar]); doc << Doc::NewLine(); } + // functions + std::vector vars; for (const auto& kv : mod->functions) { - if (kv.second.as()) { + vars.emplace_back(kv.first); + } + std::sort(vars.begin(), vars.end(), [](const GlobalVar& left, const GlobalVar& right) { + return left->name_hint < right->name_hint; + }); + for (const auto& var : vars) { + const BaseFunc& base_func = mod->functions[var]; + if (base_func.as()) { relay_text_printer_.dg_ = - relay::DependencyGraph::Create(&relay_text_printer_.arena_, kv.second); + relay::DependencyGraph::Create(&relay_text_printer_.arena_, base_func); } if (counter++ != 0) { doc << Doc::NewLine(); } - if (kv.second.as()) { + if (base_func.as()) { std::ostringstream os; - os << "def @" << kv.first->name_hint; -#if TVM_LOG_DEBUG - os << " /* id=" << reinterpret_cast(kv.first.get()) << " */"; -#endif - doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second); - } else if (kv.second.as()) { - doc << "@" << kv.first->name_hint; -#if TVM_LOG_DEBUG - doc << " /* id=" << reinterpret_cast(kv.first.get()) << " */"; -#endif - doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast(kv.second)); + os << "def @" << var->name_hint; + doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), base_func); + } else if (base_func.as()) { + doc << "@" << var->name_hint; + doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast(base_func)); } doc << Doc::NewLine(); } + #if TVM_LOG_DEBUG // attributes + // TODO(mbs): Make this official, including support from parser. if (mod->attrs.defined() && !mod->attrs->dict.empty()) { - doc << "attributes {" << Doc::NewLine(); + std::vector keys; for (const auto& kv : mod->attrs->dict) { - doc << " '" << kv.first << "' = " << PrettyPrint(kv.second) << Doc::NewLine(); + keys.emplace_back(kv.first); + } + std::sort(keys.begin(), keys.end()); + doc << "attributes {" << Doc::NewLine(); + for (const auto& key : keys) { + doc << " '" << key << "' = " << PrettyPrint(mod->attrs->dict[key]) << Doc::NewLine(); } doc << "}" << Doc::NewLine(); } #endif + return doc; } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index f72b722de9e80..02f201b9ab2a5 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -301,7 +301,8 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { size_t NewRegister() { return registers_num_++; } inline void Emit(const Instruction& instr) { - VLOG(2) << "VMCompiler::Emit: instr=" << instr; + size_t instruction_index = instructions_.size(); + VLOG(2) << "instruction[" << instruction_index << "] = " << instr; ICHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op; switch (instr.op) { case Opcode::AllocADT: @@ -336,10 +337,9 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { * in emitted code. Note that the host device is always at index 0. */ Index GetDeviceIndex(const SEScope& se_scope) { - VLOG(2) << "getting device index for " << se_scope; + ICHECK(!se_scope->IsFullyUnconstrained()); auto itr = std::find(context_->se_scopes_.begin(), context_->se_scopes_.end(), se_scope); if (itr != context_->se_scopes_.end()) { - VLOG(2) << "reusing existing scope"; return std::distance(context_->se_scopes_.begin(), itr); } @@ -367,7 +367,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { ICHECK(se_scope != host_se_scope_); Index index = context_->se_scopes_.size(); - VLOG(2) << "adding new scope"; + VLOG(2) << "se_scope[" << index << "] = " << se_scope; context_->se_scopes_.push_back(se_scope); return index; @@ -378,11 +378,13 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { void VisitExpr_(const ConstantNode* const_node) final { // Check the shape is valid NDArray data = const_node->data; - size_t konst_idx = context_->constants.size(); + size_t const_index = context_->constants.size(); auto con = GetRef(const_node); - context_->const_device_indexes.push_back(GetDeviceIndex(GetSEScope(con))); + Index device_index = GetDeviceIndex(GetSEScope(con)); + VLOG(2) << "constant[" << const_index << "] on device[" << device_index << "]"; + context_->const_device_indexes.push_back(device_index); context_->constants.push_back(const_node->data); - Emit(Instruction::LoadConst(konst_idx, NewRegister())); + Emit(Instruction::LoadConst(const_index, NewRegister())); } void VisitExpr_(const VarNode* var_node) final { @@ -872,6 +874,7 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host) // The first device is always for the host. CHECK(context_.se_scopes_.empty()); + VLOG(2) << "se_scope[0] = " << config_->host_se_scope << " (host)"; context_.se_scopes_.push_back(config_->host_se_scope); // Run the optimizations necessary to target the VM. @@ -1085,9 +1088,12 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { // let-bound functions. pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false)); + // ##################################### +#if 1 // Now that we have PrimFuncs, flow and solve SEScope constraints again to account for // any memory scopes which lowering has settled on. pass_seqs.push_back(transform::PlanDevices(config_)); +#endif // Inline the functions that are lifted to the module scope. We perform this // pass after all other optimization passes but before the memory allocation diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 8998f4e1573db..fc31f4b76e5aa 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -112,6 +112,13 @@ Var::Var(Id vid, Type type_annotation, Span span) { data_ = std::move(n); } +/* static */ Var Var::GenSym(Type type_annotation, Span span) { + static size_t next_id = 0; + std::ostringstream os; + os << "x_" << next_id++; + return Var(os.str(), std::move(type_annotation), std::move(span)); +} + Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation, Optional opt_span) { Id vid = opt_vid.value_or(var->vid); diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc index e3d5a821c58e4..8650f55da921d 100644 --- a/src/relay/transforms/device_aware_visitors.cc +++ b/src/relay/transforms/device_aware_visitors.cc @@ -36,11 +36,12 @@ namespace transform { LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional& maybe_mod) { if (maybe_mod) { - for (const auto& pair : maybe_mod.value()->functions) { - if (const auto* function_node = pair.second.as()) { + for (const auto& kv : maybe_mod.value()->functions) { + if (const auto* function_node = kv.second.as()) { SEScope se_scope = GetFunctionResultSEScope(function_node); if (!se_scope->IsFullyUnconstrained()) { - global_var_se_scopes_.emplace(pair.first, se_scope); + VLOG(2) << "global '" << kv.first->name_hint << "' has scope " << se_scope; + global_var_se_scopes_.emplace(kv.first, se_scope); } } } diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h index 8cdf0db74ebd3..10a6c86827cb7 100644 --- a/src/relay/transforms/device_aware_visitors.h +++ b/src/relay/transforms/device_aware_visitors.h @@ -146,7 +146,10 @@ class DeviceAwareExprFunctor : public ExprFunctorparams[i], GetFunctionParamSEScope(function_node, i)); } // Entering scope of function body. - PushSEScope(GetFunctionResultSEScope(function_node)); + SEScope se_scope = GetFunctionResultSEScope(function_node); + VLOG(2) << "entering " << se_scope << " for function:" << std::endl + << PrettyPrint(GetRef(function_node)); + PushSEScope(se_scope); EnterFunctionBody(); DeviceAwareVisitExpr_(function_node); @@ -154,6 +157,8 @@ class DeviceAwareExprFunctor : public ExprFunctor(function_node)); // Function parameters go out of scope. for (size_t i = 0; i < function_node->params.size(); ++i) { PopBoundVar(function_node->params[i]); @@ -168,7 +173,9 @@ class DeviceAwareExprFunctor : public ExprFunctor()) { // Let-bound var (in pre visited version) goes into scope. // (We'll just assume this is a letrec.) - PushBoundVar(inner_let_node->var, GetSEScope(inner_let_node->value)); + SEScope se_scope = GetSEScope(inner_let_node->value); + VLOG(2) << "var '" << inner_let_node->var->name_hint() << "' has scope " << se_scope; + PushBoundVar(inner_let_node->var, se_scope); PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); bindings.emplace_back(inner_let_node); expr = inner_let_node->body; @@ -189,10 +196,14 @@ class DeviceAwareExprFunctor : public ExprFunctor(call_node)); PushSEScope(props.se_scope); VisitExpr(props.body); // Leaving lexical scope of "on_device" call. PopSEScope(); + VLOG(2) << "leaving " << props.se_scope << " for on_device:" << std::endl + << PrettyPrint(GetRef(call_node)); } else { DeviceAwareVisitExpr_(call_node); } diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index 8b44097aaadf7..1d00b6a9687db 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -128,10 +128,15 @@ * the function's parameters and the result. * - Additional "device_copy" CallNodes where a copy is required in order to respect the * intent of the original "on_device" CallNodes. - * - Additional "on_device" CallNodes where the device type of an expression does not match - * that of the lexically enclosing "on_device" CallNode or function attribute. In practice + * - Additional "on_device" CallNodes where the device type of an expression is not trivially + * implied by the lexically enclosing "on_device" CallNode or function attribute. In practice * this means "on_device" CallNodes may appear in two places: - * - On a let-bound expression if its device differs from the overall let expression. + * - On let-bound expressions. It is tempting to elide the "on_device" if the let-bound value + * has the same device as the overall let expression. However this would mean passes which + * inline let-bound values, such as FoldConstant and DeadCodeElimination, would need to us + * a DeviceAware visitor which in turn requires the expression to be in ANF to avoid + * deep recursion. To minimize disruption we always include the "on_device" so that it + * can follow the inline. * - On a call argument if its device differs from the call result. In particular, the * argument to a "device_copy" call will always be wrapped in an "on_device". (That may * seem pedantic but simplifies downstream handling.) @@ -1033,11 +1038,11 @@ class DeviceCapturer : public ExprMutator { // We have a device transition which needs to be handled. break; } - // The let-bound value can be on a different device than the overall let. However if those - // devices don't agree wrap the let-bound value in an "on_device" to help downstream - // transforms track devices lexically. + // The let-bound value can be on a different device than the overall let. + // By using the fully-unconstrained SEScope for the 'lexical' scope we'll force the let-bound + // value to *always* be wrapped by an "on_device" (see introductory comment for motivation.) Expr value = - VisitChild(/*lexical_se_scope=*/let_se_scope, + VisitChild(/*lexical_se_scope=*/SEScope::FullyUnconstrained(), /*expected_se_scope=*/GetSEScope(inner_let_node->var), /*child_se_scope=*/GetSEScope(inner_let_node->value), inner_let_node->value); bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); @@ -1141,7 +1146,6 @@ class DeviceCapturer : public ExprMutator { */ Expr VisitChild(const SEScope& lexical_se_scope, const SEScope& expected_se_scope, const SEScope& child_se_scope, const Expr& child) { - ICHECK(!lexical_se_scope->IsFullyUnconstrained()); ICHECK(!expected_se_scope->IsFullyUnconstrained()); if (child->IsInstance() || child->IsInstance()) { // Primitive operators and contructors don't need to be rewritten and can have a diff --git a/src/relay/transforms/let_list.h b/src/relay/transforms/let_list.h index 56875f6c16a16..f449d6c3b011e 100644 --- a/src/relay/transforms/let_list.h +++ b/src/relay/transforms/let_list.h @@ -79,7 +79,7 @@ class LetList { * * \return a Var that hold the inserted expr. */ - Var Push(Expr expr, Type ty) { return Push(Var("x", ty), expr); } + Var Push(Expr expr, Type ty) { return Push(Var::GenSym(ty), expr); } /*! * \brief insert a binding. diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 00be629eabff1..9dfaf89117142 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -69,6 +69,8 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { Function Rewrite(const Function& expr) { return Downcast(Mutate(expr)); } private: + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const TupleNode* tuple_node) final { LetList& scope = scopes_.back(); Array new_fields; @@ -77,8 +79,10 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { for (auto field : tuple_node->fields) { auto new_field = Mutate(field); if (new_field->IsInstance()) { + SEScope se_scope = GetSEScope(field); + ICHECK(!se_scope->IsFullyUnconstrained()); Var const_var("const", Type(nullptr)); - new_field = scope.Push(const_var, new_field); + new_field = scope.Push(const_var, MaybeOnDevice(new_field, se_scope, /*is_fixed=*/true)); } new_fields.push_back(new_field); } @@ -89,7 +93,9 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { Expr new_value = Mutate(value); - scopes_.back().Push(var, new_value); + SEScope se_scope = GetSEScope(value); + ICHECK(!se_scope->IsFullyUnconstrained()); + scopes_.back().Push(var, MaybeOnDevice(new_value, se_scope, /*is_fixed=*/true)); // Since we always need a let block on which to bind sub-expressions the rewritten bindings // are tracked in the current scopes. But return the rewritten binding anyway. return {var, new_value}; @@ -127,6 +133,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { VLOG(1) << "converting lowered call to DPS:" << std::endl << PrettyPrint(call); SEScope se_scope = GetSEScope(call); + ICHECK(!se_scope->IsFullyUnconstrained()); LetList& scope = scopes_.back(); std::vector new_args; @@ -176,7 +183,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { Expr invoke = InvokeTVMOp(call_lowered_props.lowered_func, ins, outs, Downcast(call_lowered_props.attrs.metadata.at("relay_attrs"))); - scope.Push(OnDevice(invoke, se_scope, /*is_fixed=*/true)); + scope.Push(MaybeOnDevice(invoke, se_scope, /*is_fixed=*/true)); return ToTupleType(ret_type, std::vector(outputs.begin(), outputs.end())); } @@ -192,8 +199,8 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { /*! Returns an \p alloc_tensor call for a tensor of \p shape and \p dtype over \p storage. */ inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype, Array assert_shape) { - Expr offset = OnDevice(MakeConstantScalar(DataType::Int(64), 0), host_se_scope_, - /*is_fixed=*/true); + Expr offset = MaybeOnDevice(MakeConstantScalar(DataType::Int(64), 0), host_se_scope_, + /*is_fixed=*/true); return tvm::relay::AllocTensor(storage, std::move(offset), std::move(shape), dtype, assert_shape); } @@ -236,22 +243,20 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { CHECK(imm) << "expect static int shape"; int_shape.push_back(imm->value); } - Expr shape = OnDevice(MakeConstant(int_shape), host_se_scope_, /*is_fixed=*/true); - Expr size = OnDevice(ComputeStorage(type), host_se_scope_, /*is_fixed=*/true); + Expr shape = MaybeOnDevice(MakeConstant(int_shape), host_se_scope_, /*is_fixed=*/true); + Expr size = MaybeOnDevice(ComputeStorage(type), host_se_scope_, /*is_fixed=*/true); // Alignment is directly captured in the instruction rather than calculated, so we // don't want to wrap it with an "on_device". Expr alignment = ComputeAlignment(type->dtype); // Run type inference later to get the correct type. Var var("storage_" + name_hint, Type(nullptr)); - Expr value = OnDevice(AllocStorage(size, alignment, se_scope, type->dtype), se_scope, - /*is_fixed=*/true); - auto sto = scope->Push(var, value); + Expr value = AllocStorage(size, alignment, se_scope, type->dtype); + auto sto = scope->Push(var, MaybeOnDevice(value, se_scope, /*is_fixed=*/true)); // TODO(@jroesch): There is a bug with typing based on the constant shape. - auto tensor = OnDevice(AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape), - se_scope, /*is_fixed=*/true); + auto tensor = AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape); Var tensor_var("tensor_" + name_hint, Type(nullptr)); - return scope->Push(tensor_var, tensor); + return scope->Push(tensor_var, MaybeOnDevice(tensor, se_scope, /*is_fixed=*/true)); } /*! @@ -287,23 +292,24 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { if (state == tec::kNeedInputShape) { std::vector exprs = FromTupleType(ty, arg); for (size_t j = 0; j < exprs.size(); ++j) { - Expr sh_of = Mutate(ShapeOf(exprs[j])); // already accounts for device + Expr sh_of = Mutate(ShapeOf(exprs[j])); Var in_shape_var("in_shape_" + std::to_string(input_pos + j), Type(nullptr)); - shape_func_ins.push_back(scope->Push(in_shape_var, sh_of)); + shape_func_ins.push_back( + scope->Push(in_shape_var, MaybeOnDevice(sh_of, host_se_scope_, /*is_fixed=*/true))); input_pos++; } } else if (state == tec::kNeedInputData) { auto new_arg = Mutate(arg); // already accounts for device SEScope arg_se_scope = GetSEScope(arg); + ICHECK(!arg_se_scope->IsFullyUnconstrained()); // The dynamic shape function is expecting its data on the host/CPU, so insert a // device_copy otherwise. (We'll need to fuse & lower these copies in the same way // we fuse & lower other operators we insert for, eg, dynamic tensor size calculation.) - if (arg_se_scope != host_se_scope_) { - new_arg = OnDevice(DeviceCopy(new_arg, arg_se_scope, host_se_scope_), host_se_scope_, - /*is_fixed=*/true); - } + new_arg = MaybeDeviceCopy(MaybeOnDevice(new_arg, arg_se_scope, /*is_fixed=*/true), + arg_se_scope, host_se_scope_); Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr)); - shape_func_ins.push_back(scope->Push(in_shape_var, new_arg)); + shape_func_ins.push_back( + scope->Push(in_shape_var, MaybeOnDevice(new_arg, host_se_scope_, /*is_fixed=*/true))); input_pos++; } else { // TODO(@jroesch): handle kNeedBoth @@ -322,20 +328,16 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { ICHECK(tensor_type_node); // Put the shape func on the host. This also ensures that everything between // shape_of and shape_func is similarly on the host. - Expr alloc = MakeStaticAllocation(scope, GetRef(tensor_type_node), host_se_scope_, - std::to_string(i)); - // TODO(mbs): Don't really need a fresh var here since alloc will always be a var. - Var shape_func_out_var("shape_func_out_" + std::to_string(i), Type(nullptr)); - alloc = scope->Push(shape_func_out_var, alloc); + Var alloc = MakeStaticAllocation(scope, GetRef(tensor_type_node), host_se_scope_, + "out_shape_" + std::to_string(i)); out_shapes.push_back(alloc); } // Represent the call in DPS form. - auto shape_call = OnDevice(InvokeTVMOp(prim_fn_var, Tuple(shape_func_ins), Tuple(out_shapes), - Downcast(attrs.metadata.at("relay_attrs"))), - host_se_scope_, /*is_fixed=*/true); + auto shape_call = InvokeTVMOp(prim_fn_var, Tuple(shape_func_ins), Tuple(out_shapes), + Downcast(attrs.metadata.at("relay_attrs"))); Var shape_func_var("shape_func", Type(nullptr)); - scope->Push(shape_func_var, shape_call); + scope->Push(shape_func_var, MaybeOnDevice(shape_call, host_se_scope_, /*is_fixed=*/true)); return out_shapes; } @@ -349,14 +351,14 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { for (size_t i = 0; i < out_shapes.size(); ++i) { auto out_shape = out_shapes[i]; auto out_type = out_types[i]; - auto size = OnDevice(ComputeStorageInRelay(out_shape, out_type), host_se_scope_, - /*is_fixed=*/true); + auto size = MaybeOnDevice(ComputeStorageInRelay(out_shape, out_type), host_se_scope_, + /*is_fixed=*/true); // Alignment is directly captured in the instruction so don't wrap in "on_device". auto alignment = ComputeAlignment(out_type->dtype); Var sto_var("storage_" + std::to_string(i), Type(nullptr)); - auto val = OnDevice(AllocStorage(size, alignment, se_scope, out_type->dtype), se_scope, - /*is_fixed=*/true); - storages.push_back(scope->Push(sto_var, val)); + auto val = AllocStorage(size, alignment, se_scope, out_type->dtype); + storages.push_back(scope->Push(sto_var, MaybeOnDevice(val, se_scope, + /*is_fixed=*/true))); } Array outs; @@ -364,17 +366,15 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { auto out_shape = out_shapes[i]; auto out_type = out_types[i]; auto storage = storages[i]; - auto alloc = OnDevice(AllocTensor(storage, out_shape, out_type->dtype, out_type->shape), - se_scope, /*is_fixed=*/true); + auto alloc = AllocTensor(storage, out_shape, out_type->dtype, out_type->shape); Var out_var("out_" + std::to_string(i), Type(nullptr)); - outs.push_back(scope->Push(out_var, alloc)); + outs.push_back(scope->Push(out_var, MaybeOnDevice(alloc, se_scope, /*is_fixed=*/true))); } Tuple tuple_outs(outs); auto call = InvokeTVMOp(func, ins, tuple_outs, Downcast(attrs.metadata.at("relay_attrs"))); - auto invoke = OnDevice(call, se_scope, /*is_fixed=*/true); - scope->Push(invoke); + scope->Push(MaybeOnDevice(call, se_scope, /*is_fixed=*/true)); return ToTupleType(ret_type, std::vector(tuple_outs->fields.begin(), tuple_outs->fields.end())); } @@ -398,7 +398,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { CHECK(imm) << "expect static int shape"; shape.push_back(imm->value); } - shape_expr = OnDevice(MakeConstant(shape), host_se_scope_, /*is_fixed=*/true); + shape_expr = MaybeOnDevice(MakeConstant(shape), host_se_scope_, /*is_fixed=*/true); } return ReshapeTensor(ins->fields[0], shape_expr, ret_ty->shape); } diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 8f5e9e146d54b..28d1aa5532bf7 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -842,7 +842,7 @@ class PartialEvaluator : public ExprFunctor }); } - PStatic VisitFunc(const Function& func, LetList* ll, const Var& name = Var("x", Type())) { + PStatic VisitFunc(const Function& func, LetList* ll, const Var& name) { Func f = VisitFuncStatic(func, name); Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func)))); // TODO(@M.K.): we seems to reduce landin knot into letrec. @@ -851,7 +851,7 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { - return VisitFunc(GetRef(op), ll); + return VisitFunc(GetRef(op), ll, Var::GenSym()); } struct ReflectError : Error { diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index f958a600551e5..b9a673da9f975 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -219,7 +219,7 @@ class Fill : ExprFunctor, private transform::Lexi // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly Expr Compound(const Expr& orig, const Expr& now, const Var& v) { Expr annotated_expr = MaybeOnDevice(now, GetSEScope(orig), /*is_fixed=*/true); - Var var = v.defined() ? v : Var(String("x"), Type()); + Var var = v.defined() ? v : Var::GenSym(); bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); if (!v.defined() && not_included) { return annotated_expr; diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index b613a03bfc5cd..44971c0bcee98 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -178,7 +178,7 @@ std::string Executable::GetConstants() const { for (size_t i = 0; i < constants.size(); ++i) { const auto& constant = constants[i]; auto ndarray = Downcast(constant); - oss << "VM Constant[" << i << "]: has shape " << ShapeString(ndarray.Shape(), ndarray->dtype) + oss << "VM Const[" << i << "]: has shape " << ShapeString(ndarray.Shape(), ndarray->dtype) << " on device index " << const_device_indexes[i] << std::endl; } return oss.str(); diff --git a/src/target/se_scope.cc b/src/target/se_scope.cc index ec6b30eab1acb..8e6c6fe7f2a26 100644 --- a/src/target/se_scope.cc +++ b/src/target/se_scope.cc @@ -62,10 +62,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "memory_scope='" << node->memory_scope << "'"; } } -#if TVM_LOG_DEBUG - // We rely on object identity of SEScopes, so include the object address to help debugging. - p->stream << ", id=" << reinterpret_cast(ref.get()); -#endif p->stream << ")"; }); diff --git a/src/target/target.cc b/src/target/target.cc index 792884061db6c..a5c493a582ab2 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -566,10 +566,6 @@ String TargetNode::ToDebugString() const { if (host.defined()) { os << ", host=" << GetHost().value()->ToDebugString(); } -#if TVM_LOG_DEBUG - // We depend on pointer equality so include that in the debug representation. - os << ", id=" << reinterpret_cast(this); -#endif os << ")"; return os.str(); }