Skip to content

Commit

Permalink
[checkpoint] Always on_device let-bound values
Browse files Browse the repository at this point in the history
Realized both dead_code.cc and fold_constant.cc would
happily move values into a different lexical virtual
device context since device_planner.cc was being
'clever' and eliding on_devices for let-bound values
when there's no change. Fixed so that every let-bound
value has an on_device.

Will be much better after apache/tvm-rfcs#45
  • Loading branch information
mbs-octoml committed Dec 2, 2021
1 parent 93e57f6 commit edf801e
Show file tree
Hide file tree
Showing 15 changed files with 136 additions and 96 deletions.
3 changes: 3 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ class Var : public Expr {
*/
TVM_DLL Var(Id vid, Type type_annotation, Span span = Span());


static Var GenSym(Type type_annotation = {}, Span span = {});

TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode);
};
Expand Down
6 changes: 0 additions & 6 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,6 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
Doc doc;
doc << "@" << op->name_hint;
#if TVM_LOG_DEBUG
if (op->checked_type_.defined()) {
doc << " /* type=" << PrintType(op->checked_type_, /*meta=*/false) << " */";
}
doc << " /* id=" << reinterpret_cast<uint64_t>(op) << " */";
#endif
return doc;
}

Expand Down
56 changes: 39 additions & 17 deletions src/printer/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,49 +36,71 @@ static const char* kSemVer = "0.0.5";
Doc TextPrinter::PrintMod(const IRModule& mod) {
Doc doc;
int counter = 0;

// We'll print in alphabetical order to make a/b diffs easier to work with.

// type definitions
std::vector<GlobalTypeVar> tyvars;
for (const auto& kv : mod->type_definitions) {
tyvars.emplace_back(kv.first);
}
std::sort(tyvars.begin(), tyvars.end(),
[](const GlobalTypeVar& left, const GlobalTypeVar& right) {
return left->name_hint < right->name_hint;
});
for (const auto& tyvar : tyvars) {
if (counter++ != 0) {
doc << Doc::NewLine();
}
doc << relay_text_printer_.Print(kv.second);
doc << relay_text_printer_.Print(mod->type_definitions[tyvar]);
doc << Doc::NewLine();
}

// functions
std::vector<GlobalVar> vars;
for (const auto& kv : mod->functions) {
if (kv.second.as<relay::FunctionNode>()) {
vars.emplace_back(kv.first);
}
std::sort(vars.begin(), vars.end(), [](const GlobalVar& left, const GlobalVar& right) {
return left->name_hint < right->name_hint;
});
for (const auto& var : vars) {
const BaseFunc& base_func = mod->functions[var];
if (base_func.as<relay::FunctionNode>()) {
relay_text_printer_.dg_ =
relay::DependencyGraph::Create(&relay_text_printer_.arena_, kv.second);
relay::DependencyGraph::Create(&relay_text_printer_.arena_, base_func);
}
if (counter++ != 0) {
doc << Doc::NewLine();
}
if (kv.second.as<relay::FunctionNode>()) {
if (base_func.as<relay::FunctionNode>()) {
std::ostringstream os;
os << "def @" << kv.first->name_hint;
#if TVM_LOG_DEBUG
os << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second);
} else if (kv.second.as<tir::PrimFuncNode>()) {
doc << "@" << kv.first->name_hint;
#if TVM_LOG_DEBUG
doc << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
os << "def @" << var->name_hint;
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), base_func);
} else if (base_func.as<tir::PrimFuncNode>()) {
doc << "@" << var->name_hint;
doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(base_func));
}
doc << Doc::NewLine();
}

#if TVM_LOG_DEBUG
// attributes
// TODO(mbs): Make this official, including support from parser.
if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
doc << "attributes {" << Doc::NewLine();
std::vector<String> keys;
for (const auto& kv : mod->attrs->dict) {
doc << " '" << kv.first << "' = " << PrettyPrint(kv.second) << Doc::NewLine();
keys.emplace_back(kv.first);
}
std::sort(keys.begin(), keys.end());
doc << "attributes {" << Doc::NewLine();
for (const auto& key : keys) {
doc << " '" << key << "' = " << PrettyPrint(mod->attrs->dict[key]) << Doc::NewLine();
}
doc << "}" << Doc::NewLine();
}
#endif

return doc;
}

Expand Down
20 changes: 13 additions & 7 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
size_t NewRegister() { return registers_num_++; }

inline void Emit(const Instruction& instr) {
VLOG(2) << "VMCompiler::Emit: instr=" << instr;
size_t instruction_index = instructions_.size();
VLOG(2) << "instruction[" << instruction_index << "] = " << instr;
ICHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op;
switch (instr.op) {
case Opcode::AllocADT:
Expand Down Expand Up @@ -336,10 +337,9 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
* in emitted code. Note that the host device is always at index 0.
*/
Index GetDeviceIndex(const SEScope& se_scope) {
VLOG(2) << "getting device index for " << se_scope;
ICHECK(!se_scope->IsFullyUnconstrained());
auto itr = std::find(context_->se_scopes_.begin(), context_->se_scopes_.end(), se_scope);
if (itr != context_->se_scopes_.end()) {
VLOG(2) << "reusing existing scope";
return std::distance(context_->se_scopes_.begin(), itr);
}

Expand Down Expand Up @@ -367,7 +367,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {

ICHECK(se_scope != host_se_scope_);
Index index = context_->se_scopes_.size();
VLOG(2) << "adding new scope";
VLOG(2) << "se_scope[" << index << "] = " << se_scope;
context_->se_scopes_.push_back(se_scope);

return index;
Expand All @@ -378,11 +378,13 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
void VisitExpr_(const ConstantNode* const_node) final {
// Check the shape is valid
NDArray data = const_node->data;
size_t konst_idx = context_->constants.size();
size_t const_index = context_->constants.size();
auto con = GetRef<Constant>(const_node);
context_->const_device_indexes.push_back(GetDeviceIndex(GetSEScope(con)));
Index device_index = GetDeviceIndex(GetSEScope(con));
VLOG(2) << "constant[" << const_index << "] on device[" << device_index << "]";
context_->const_device_indexes.push_back(device_index);
context_->constants.push_back(const_node->data);
Emit(Instruction::LoadConst(konst_idx, NewRegister()));
Emit(Instruction::LoadConst(const_index, NewRegister()));
}

void VisitExpr_(const VarNode* var_node) final {
Expand Down Expand Up @@ -872,6 +874,7 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host)

// The first device is always for the host.
CHECK(context_.se_scopes_.empty());
VLOG(2) << "se_scope[0] = " << config_->host_se_scope << " (host)";
context_.se_scopes_.push_back(config_->host_se_scope);

// Run the optimizations necessary to target the VM.
Expand Down Expand Up @@ -1085,9 +1088,12 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
// let-bound functions.
pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));

// #####################################
#if 1
// Now that we have PrimFuncs, flow and solve SEScope constraints again to account for
// any memory scopes which lowering has settled on.
pass_seqs.push_back(transform::PlanDevices(config_));
#endif

// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
Expand Down
7 changes: 7 additions & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ Var::Var(Id vid, Type type_annotation, Span span) {
data_ = std::move(n);
}

/* static */ Var Var::GenSym(Type type_annotation, Span span) {
static size_t next_id = 0;
std::ostringstream os;
os << "x_" << next_id++;
return Var(os.str(), std::move(type_annotation), std::move(span));
}

Var WithFields(Var var, Optional<Id> opt_vid, Optional<Type> opt_type_annotation,
Optional<Span> opt_span) {
Id vid = opt_vid.value_or(var->vid);
Expand Down
7 changes: 4 additions & 3 deletions src/relay/transforms/device_aware_visitors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ namespace transform {

LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional<IRModule>& maybe_mod) {
if (maybe_mod) {
for (const auto& pair : maybe_mod.value()->functions) {
if (const auto* function_node = pair.second.as<FunctionNode>()) {
for (const auto& kv : maybe_mod.value()->functions) {
if (const auto* function_node = kv.second.as<FunctionNode>()) {
SEScope se_scope = GetFunctionResultSEScope(function_node);
if (!se_scope->IsFullyUnconstrained()) {
global_var_se_scopes_.emplace(pair.first, se_scope);
VLOG(2) << "global '" << kv.first->name_hint << "' has scope " << se_scope;
global_var_se_scopes_.emplace(kv.first, se_scope);
}
}
}
Expand Down
15 changes: 13 additions & 2 deletions src/relay/transforms/device_aware_visitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,19 @@ class DeviceAwareExprFunctor<void(const Expr& n)> : public ExprFunctor<void(cons
PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i));
}
// Entering scope of function body.
PushSEScope(GetFunctionResultSEScope(function_node));
SEScope se_scope = GetFunctionResultSEScope(function_node);
VLOG(2) << "entering " << se_scope << " for function:" << std::endl
<< PrettyPrint(GetRef<Function>(function_node));
PushSEScope(se_scope);
EnterFunctionBody();

DeviceAwareVisitExpr_(function_node);

// Leaving scope of function body.
ExitFunctionBody();
PopSEScope();
VLOG(2) << "leaving " << se_scope << " for function:" << std::endl
<< PrettyPrint(GetRef<Function>(function_node));
// Function parameters go out of scope.
for (size_t i = 0; i < function_node->params.size(); ++i) {
PopBoundVar(function_node->params[i]);
Expand All @@ -168,7 +173,9 @@ class DeviceAwareExprFunctor<void(const Expr& n)> : public ExprFunctor<void(cons
while (const auto* inner_let_node = expr.as<LetNode>()) {
// Let-bound var (in pre visited version) goes into scope.
// (We'll just assume this is a letrec.)
PushBoundVar(inner_let_node->var, GetSEScope(inner_let_node->value));
SEScope se_scope = GetSEScope(inner_let_node->value);
VLOG(2) << "var '" << inner_let_node->var->name_hint() << "' has scope " << se_scope;
PushBoundVar(inner_let_node->var, se_scope);
PreVisitLetBinding_(inner_let_node->var, inner_let_node->value);
bindings.emplace_back(inner_let_node);
expr = inner_let_node->body;
Expand All @@ -189,10 +196,14 @@ class DeviceAwareExprFunctor<void(const Expr& n)> : public ExprFunctor<void(cons
OnDeviceProps props = GetOnDeviceProps(call_node);
if (props.body.defined() && props.is_fixed) {
// Entering lexical scope of fixed "on_device" call.
VLOG(2) << "entering " << props.se_scope << " for on_device:" << std::endl
<< PrettyPrint(GetRef<Call>(call_node));
PushSEScope(props.se_scope);
VisitExpr(props.body);
// Leaving lexical scope of "on_device" call.
PopSEScope();
VLOG(2) << "leaving " << props.se_scope << " for on_device:" << std::endl
<< PrettyPrint(GetRef<Call>(call_node));
} else {
DeviceAwareVisitExpr_(call_node);
}
Expand Down
20 changes: 12 additions & 8 deletions src/relay/transforms/device_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,15 @@
* the function's parameters and the result.
* - Additional "device_copy" CallNodes where a copy is required in order to respect the
* intent of the original "on_device" CallNodes.
* - Additional "on_device" CallNodes where the device type of an expression does not match
* that of the lexically enclosing "on_device" CallNode or function attribute. In practice
* - Additional "on_device" CallNodes where the device type of an expression is not trivially
* implied by the lexically enclosing "on_device" CallNode or function attribute. In practice
* this means "on_device" CallNodes may appear in two places:
* - On a let-bound expression if its device differs from the overall let expression.
* - On let-bound expressions. It is tempting to elide the "on_device" if the let-bound value
* has the same device as the overall let expression. However this would mean passes which
* inline let-bound values, such as FoldConstant and DeadCodeElimination, would need to us
* a DeviceAware visitor which in turn requires the expression to be in ANF to avoid
* deep recursion. To minimize disruption we always include the "on_device" so that it
* can follow the inline.
* - On a call argument if its device differs from the call result. In particular, the
* argument to a "device_copy" call will always be wrapped in an "on_device". (That may
* seem pedantic but simplifies downstream handling.)
Expand Down Expand Up @@ -1033,11 +1038,11 @@ class DeviceCapturer : public ExprMutator {
// We have a device transition which needs to be handled.
break;
}
// The let-bound value can be on a different device than the overall let. However if those
// devices don't agree wrap the let-bound value in an "on_device" to help downstream
// transforms track devices lexically.
// The let-bound value can be on a different device than the overall let.
// By using the fully-unconstrained SEScope for the 'lexical' scope we'll force the let-bound
// value to *always* be wrapped by an "on_device" (see introductory comment for motivation.)
Expr value =
VisitChild(/*lexical_se_scope=*/let_se_scope,
VisitChild(/*lexical_se_scope=*/SEScope::FullyUnconstrained(),
/*expected_se_scope=*/GetSEScope(inner_let_node->var),
/*child_se_scope=*/GetSEScope(inner_let_node->value), inner_let_node->value);
bindings.emplace_back(inner_let_node->var, value, inner_let_node->span);
Expand Down Expand Up @@ -1141,7 +1146,6 @@ class DeviceCapturer : public ExprMutator {
*/
Expr VisitChild(const SEScope& lexical_se_scope, const SEScope& expected_se_scope,
const SEScope& child_se_scope, const Expr& child) {
ICHECK(!lexical_se_scope->IsFullyUnconstrained());
ICHECK(!expected_se_scope->IsFullyUnconstrained());
if (child->IsInstance<OpNode>() || child->IsInstance<ConstructorNode>()) {
// Primitive operators and contructors don't need to be rewritten and can have a
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/let_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class LetList {
*
* \return a Var that hold the inserted expr.
*/
Var Push(Expr expr, Type ty) { return Push(Var("x", ty), expr); }
Var Push(Expr expr, Type ty) { return Push(Var::GenSym(ty), expr); }

/*!
* \brief insert a binding.
Expand Down
Loading

0 comments on commit edf801e

Please sign in to comment.