Skip to content

Commit

Permalink
[Relay] Switch the graph, VM and AOT executors to use the merged
Browse files Browse the repository at this point in the history
device_planner.cc from apache#9038, and finally remove DeviceMap from the
LowerTE Pass.

- We retire analysis/context_analysis.cc and
  transforms/device_annotation.cc (and their tests). That
  includes the CollectDeviceInfo, CollectDeviceAnnotationOps and
  ContextAnalysis entry points. These are all subsumed by the
  PlanDevices pass and the device aware visitors.
- The following passes now use the new 'Device Aware' visitors to
  recover the device for every Relay sub-expression:
     - backend/aot_executor_codegen.cc (AOTOnDemandAllocator)
     - backend/graph_plan_memory.cc (StorageAllocaBaseVisitor etc)
     - backend/te_compiler.cc (LowerTensorExprMutator)
     - transforms/memory_alloc.cc (DialectRewriter)
     - backend/vm/compiler.cc (VMFunctionCompiler)
- The following passes/utils must maintain the device information
  encoded by the device planner within "on_device" annotations and
  "param_device_types"/"result_device_type" function attributes:
     - backend/vm/lambda_lift.cc (LambdaLifter)
     - transforms/to_a_normal_form.cc (Fill)
     - ir/expr_functior.cc (Bind)
- Remove a lot ad-hoc 'homogeneous' vs 'hetrogeneous' conditionals
  in favor of just asking for the device. Also removed a lot of ad-doc
  encodings of the 'default' device.
- We no longer need to run device-planning twice (before and after
  lowering). Device planning is also decoupled from memory planning.
- The LowerTE Pass no longer needs an expression-to-device side table
  (which was the problem which kicked this series of PRs off in the first place).
  • Loading branch information
mbs-octoml committed Oct 2, 2021
1 parent f962220 commit 0aa6d79
Show file tree
Hide file tree
Showing 53 changed files with 1,300 additions and 3,346 deletions.
29 changes: 0 additions & 29 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,24 +211,6 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod);
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const IRModule& mod);

/*!
* \brief Collect the device mapping information of each expression.
*
* \param expr The expression.
*
* \return The device mapping.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);

/*!
* \brief Collect the device anntation operators.
*
* \param expr The expression.
*
* \return The annotated expression to device type mapping for annotation ops.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);

/*!
* \brief Finds cases that the given match expression does not catch, if any.
*
Expand Down Expand Up @@ -268,17 +250,6 @@ TVM_DLL IRModule GetCalibrateModule(IRModule mod);
*/
TVM_DLL Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& mod);

/*!
* \brief Analyze the device context of each IR node in a given relay module.
*
* \param mod The module for analysis.
* \param default_device The default device used by unassigned IR nodes.
*
* \return The mapping between an IR node and its associated device.
*/
TVM_DLL std::unordered_map<Expr, Device, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>
ContextAnalysis(const IRModule& mod, const Device& default_device);

} // namespace relay
} // namespace tvm

Expand Down
14 changes: 14 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,21 @@ class Call : public Expr {
TVM_DLL Call(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
Array<Type> type_args = Array<Type>(), Span span = Span());

/*!
* \brief Returns a copy of this with given properties. A null property denotes 'no change'.
* Returns this if all properties are unchanged. Returns a modified this if this is the only
* reference to the underlying node.
*
* TODO(mbs): Extend to all node types.
*/
Call CopyWith(Optional<Expr> opt_op = Optional<Expr>(),
Optional<Array<Expr>> opt_args = Optional<Array<Expr>>(nullptr),
Optional<Attrs> opt_attrs = Optional<Attrs>(nullptr),
Optional<Array<Type>> opt_type_args = Optional<Array<Type>>(nullptr),
Optional<Span> opt_span = Optional<Span>(nullptr));

TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
};

/*!
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,12 @@ TVM_DLL Pass ToANormalForm();
/*!
* \brief ToANormalForm but on incomplete graph.
*
* \param maybe_mod optional module holding definitions for global vars in \p expr
* \param expr the graph.
*
* \return The transformed program.
*/
TVM_DLL Expr ToANormalForm(const Expr& expr);
TVM_DLL Expr ToANormalForm(const Optional<IRModule>& maybe_mod, const Expr& expr);

/*!
* \brief Turn an expression into continuation passing style(CPS).
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ struct VMFunction {
/*! \brief The size of the frame for this function */
Index register_file_size;
/*! \brief The device type of each parameter for this function. */
std::vector<Index> params_device_type;
std::vector<DLDeviceType> params_device_type;

VMFunction(const std::string& name, std::vector<std::string> params,
const std::vector<Instruction>& instructions, Index register_file_size,
const std::vector<Index> params_device_type = {})
const std::vector<DLDeviceType> params_device_type = {})
: name(name),
params(params),
instructions(instructions),
Expand Down
49 changes: 0 additions & 49 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,6 @@
from .feature import Feature


def context_analysis(mod, default_device):
"""Analyze the device context information of each IR node in a Relay
program.
Parameters
----------
mod : tvm.IRModule
The input module.
default_device : tvm.runtime.Device
The default context allocated to an IR node.
"""
return _ffi_api.ContextAnalysis(mod, default_device)


def post_order_visit(expr, fvisit):
"""Recursively visit the ir in post DFS order node,
apply fvisit. Each node is guaranteed to be visited
Expand Down Expand Up @@ -268,40 +253,6 @@ def all_dtypes(expr):
return set(_ffi_api.all_dtypes(expr))


def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret : Dict[tvm.relay.ir.expr, int]
A dictionary mapping tvm.relay.Expr to device type.
"""
return _ffi_api.CollectDeviceInfo(expr)


def collect_device_annotation_ops(expr):
"""Collect the device annotation ops for the given expression.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret : Dict[tvm.relay.Expr, int]
A dictionary mapping tvm.relay.Expr to device type where the keys are
annotation expressions.
"""
return _ffi_api.CollectDeviceAnnotationOps(expr)


def get_total_mac_number(expr):
"""
Count the number of MACs (multiply-accumulate) of a model
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@
register_broadcast_schedule("fast_erf")
# a fake on_device schedule.
# this will not be used in actual computation
# as on_device will be removed during DeviceAnnotation pass
register_injective_schedule("on_device")


Expand Down
21 changes: 0 additions & 21 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,27 +544,6 @@ def MergeCompilerRegions():
return _ffi_api.MergeCompilerRegions()


def RewriteAnnotatedOps(fallback_device):
"""Rewrite the annotated program where annotation operators, e.g.
`on_device`, mark which device an expression should be scheduled to.
This pass helps heterogeneous execution where different operators may need
to be allocated on various devices.
Parameters
----------
fallback_device : int
The fallback device type. It is also used as the default device for
operators with no annotated device.
Returns
-------
ret: tvm.transform.Pass
The registered pass that rewrites an expression with annotated
`on_device` operators.
"""
return _ffi_api.RewriteDeviceAnnotation(fallback_device)


def ToANormalForm():
"""Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
Expand Down
Loading

0 comments on commit 0aa6d79

Please sign in to comment.