Skip to content

Commit

Permalink
[Relay] PlanDevices can run after LowerTE
Browse files Browse the repository at this point in the history
  • Loading branch information
mbs-octoml committed Dec 1, 2021
1 parent a76f648 commit 93e57f6
Show file tree
Hide file tree
Showing 11 changed files with 269 additions and 116 deletions.
5 changes: 5 additions & 0 deletions include/tvm/target/se_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ class SEScope : public ObjectRef {
return SEScope(device_type, /*virtual_device_id=*/0, std::move(target));
}

/*! \brief Returns the \p SEScope for \p memory_scope alone. */
static SEScope ForMemoryScope(MemoryScope memory_scope) {
return SEScope(kInvalidDeviceType, -1, {}, std::move(memory_scope));
}

/*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */
TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target,
MemoryScope memory_scope) {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class VarNode : public PrimExprNode {
*/
String name_hint;
/*!
* \brief type annotaion of the variable.
* \brief type annotation of the variable.
*
* It is an optional field that provides a refined type of the variable than dtype.
*
Expand Down
5 changes: 5 additions & 0 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,11 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) {
for (const Expr& arg : op->args) {
args.push_back(Print(arg));
}
#if TVM_LOG_DEBUG
for (const Type& type_arg : op->type_args) {
args.push_back(Print(type_arg));
}
#endif
for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
args.push_back(d);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ class ConvertAddToSubtract : public MixedModeMutator {
// call_lowered op.
auto call_lowered_attrs = make_object<CallLoweredAttrs>();
call_lowered_attrs->metadata.Set("relay_attrs", call->attrs);
return CallLowered(std::move(new_global_var), call->args,
std::move(Attrs(call_lowered_attrs)), call->type_args, call->span);
ICHECK(call->type_args.empty()) << "lowered functions cannot be polymorphic";
return CallLowered(std::move(new_global_var), call->args, std::move(call_lowered_attrs),
call->span);
}
}

Expand Down
16 changes: 8 additions & 8 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
* to the TIR implementation, and attributes to attach to the call to identify it as
* a TIR call.
*/
Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Array<Type> type_args, Span span,
Target target) {
Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Span span, Target target) {
CCacheKey key = CCacheKey(func, target);
CachedFunc cfunc = compiler_->Lower(key, module_name_);
ICHECK(cfunc.defined());
Expand Down Expand Up @@ -632,8 +631,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars);
}

return CallLowered(cfunc->prim_fn_var, std::move(visited_args), Attrs(call_lowered_attrs),
type_args, std::move(span));
return CallLowered(cfunc->prim_fn_var, std::move(visited_args), std::move(call_lowered_attrs),
std::move(span));
}

std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) final {
Expand Down Expand Up @@ -733,8 +732,9 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
call_lowered_attrs->metadata.Set("relay_attrs", primitive_func->attrs);

process_fn_(func_with_metadata);
return CallLowered(call_node->op, std::move(new_args), Attrs(std::move(call_lowered_attrs)),
call_node->type_args, call_node->span);
ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic";
return CallLowered(prim_func_var, std::move(new_args), std::move(call_lowered_attrs),
call_node->span);
}

// Typical case: call to fused primitive Relay Function.
Expand All @@ -754,8 +754,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {

// Lower the primitive function for that target.
Function function = Downcast<Function>(primitive_func);
return MakeLoweredCall(function, std::move(new_args), call_node->type_args, call_node->span,
target);
ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic";
return MakeLoweredCall(function, std::move(new_args), call_node->span, target);
}

IRModule module_;
Expand Down
4 changes: 4 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,10 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
// let-bound functions.
pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));

// Now that we have PrimFuncs, flow and solve SEScope constraints again to account for
// any memory scopes which lowering has settled on.
pass_seqs.push_back(transform::PlanDevices(config_));

// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
Expand Down
44 changes: 24 additions & 20 deletions src/relay/op/call/call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,24 @@ bool CallLoweredRel(const Array<Type>& types, int num_inputs, const Attrs& attrs

const Op& CallLoweredOp() { return Op::Get("call_lowered"); }

Expr CallLowered(Expr func, Array<Expr> inputs, Attrs attrs, Array<Type> type_args, Span span) {
// Right now, call_lowered only supports func being a global var pointing to the lowered
// function.
ICHECK(func.as<GlobalVarNode>())
<< "Function to call should be GlobalVarNode, but got:" << std::endl
<< PrettyPrint(func);
ICHECK(attrs.as<CallLoweredAttrs>())
<< "Expected attributes to be CallLoweredAttrs, but got " << attrs->GetTypeKey();
return Call(CallLoweredOp(), {std::move(func), Tuple(std::move(inputs))}, std::move(attrs),
std::move(type_args), std::move(span));
Call CallLowered(GlobalVar lowered_func, Array<Expr> args,
ObjectPtr<CallLoweredAttrs> call_lowered_attrs, Span span) {
return Call(CallLoweredOp(), {std::move(lowered_func), Tuple(std::move(args))},
Attrs(std::move(call_lowered_attrs)),
/*type_args=*/{}, std::move(span));
}

TVM_REGISTER_GLOBAL("relay.op.call_lowered")
.set_body_typed([](Expr func, Array<Expr> inputs, Attrs attrs, Array<Type> type_args,
Span span) {
const TupleNode* tuple_node = inputs.as<TupleNode>();
return CallLowered(func, tuple_node->fields, attrs, type_args, span);
.set_body_typed([](Expr lowered_func, Array<Expr> args, Attrs attrs, Span span) {
const auto* lowered_func_node = lowered_func.as<GlobalVarNode>();
ICHECK(lowered_func_node) << "Function to call should be GlobalVarNode, but got:" << std::endl
<< PrettyPrint(lowered_func);
const auto* call_lowered_attrs = attrs.as<CallLoweredAttrs>();
ICHECK(call_lowered_attrs) << "Expected attributes to be CallLoweredAttrs, but got "
<< attrs->GetTypeKey();
auto call_lowered_attrs_ptr = make_object<CallLoweredAttrs>(*call_lowered_attrs);
return CallLowered(GetRef<GlobalVar>(lowered_func_node), std::move(args),
std::move(call_lowered_attrs_ptr), std::move(span));
});

RELAY_REGISTER_OP("call_lowered")
Expand All @@ -105,19 +106,22 @@ CallLoweredProps GetCallLoweredProps(const CallNode* call_node) {
ICHECK(tuple_args) << "Expected second arg to call_lowered to be a Tuple of input arguments.";

ICHECK(call_node->attrs.defined()) << "Expecting call_lowered to have attributes.";
const auto* attrs = call_node->attrs.as<CallLoweredAttrs>();
ICHECK(attrs) << "Expected call_lowered op to have CallLoweredAttrs, but found "
<< call_node->attrs->GetTypeKey();
return CallLoweredProps{GetRef<GlobalVar>(function_node), tuple_args->fields, *attrs};
const auto* call_lowered_attrs = call_node->attrs.as<CallLoweredAttrs>();
ICHECK(call_lowered_attrs) << "Expected call_lowered op to have CallLoweredAttrs, but found "
<< call_node->attrs->GetTypeKey();
// If the call_node has type_args then they are for the polymorphic 'call_lowered' operator
// itself which expects the function type and argument type as parameters.
return {GetRef<GlobalVar>(function_node), tuple_args->fields, *call_lowered_attrs};
}
return {};
}

Call GetAnyCall(const CallNode* call_node) {
CallLoweredProps props = GetCallLoweredProps(call_node);
if (props.lowered_func.defined()) {
auto attrs = make_object<CallLoweredAttrs>(props.attrs);
return Call(std::move(props.lowered_func), props.arguments, Attrs(std::move(attrs)),
auto call_lowered_attrs = make_object<CallLoweredAttrs>(props.attrs);
return Call(std::move(props.lowered_func), std::move(props.arguments),
Attrs(std::move(call_lowered_attrs)),
/*type_args=*/{}, call_node->span);
} else {
return GetRef<Call>(call_node);
Expand Down
12 changes: 6 additions & 6 deletions src/relay/op/call/call.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ namespace relay {

/*!
* \brief Helper to construct a Relay call with the call_lowered op.
* \param func Lowered function to call with call_lowered.
* \param inputs Arguments to be passed to the function.
* \param attrs Function attributes, should be TIRCallAttrs.
* \param type_args Type arguments for the call.
* \param lowered_func Lowered function to call with call_lowered.
* \param args Arguments to be passed to the function.
* \param call_lowered_attrs Function attributes.
* \param span TVM span for propogating debugging info.
* \return
*/
Expr CallLowered(Expr func, Array<Expr> inputs, Attrs attrs, Array<Type> type_args, Span span);
Call CallLowered(GlobalVar lowered_func, Array<Expr> args,
ObjectPtr<CallLoweredAttrs> call_lowered_attrs, Span span);

/*!
* \brief Returns the Relay call_lowered op. Use this helper to avoid extraneous calls to
Expand All @@ -57,7 +57,7 @@ struct CallLoweredProps {
GlobalVar lowered_func;
/*! \brief Array of the arguments to call lowered_func with. */
Array<Expr> arguments;
/*! \brief Arguments from the call_lowered op. */
/*! \brief Attributes from the call_lowered op. */
CallLoweredAttrs attrs;
};

Expand Down
8 changes: 4 additions & 4 deletions src/relay/transforms/device_domains.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) {
DeviceCopyProps device_copy_props = GetDeviceCopyProps(call.get());
CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get());

if (on_device_props.body.defined()) {
if (call_lowered_props.lowered_func.defined()) {
return DomainFor(call_lowered_props.lowered_func);
} else if (on_device_props.body.defined()) {
// on_device(expr, se_scope=<t>, is_fixed=false)
// on_device : fn(<t>):?x?
//
Expand Down Expand Up @@ -286,11 +288,9 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) {
args_and_result.emplace_back(param_domain);
}
args_and_result.emplace_back(result_domain);
} else if (call_lowered_props.lowered_func.defined()) {
return DomainFor(call_lowered_props.lowered_func);
} else {
// We still need to handle the case where the function / op is not lowered
// because the device planner runs before and after lowering.
// because the device planner runs both before and after lowering.
return DomainFor(call->op);
}
auto domain = MakeHigherOrderDomain(std::move(args_and_result));
Expand Down
Loading

0 comments on commit 93e57f6

Please sign in to comment.