Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY] [AST] Add virtual_device as a first class field in Relay #9641

Merged
merged 5 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ namespace tvm {

using tvm::runtime::String;

// Forward-declare SEScope to avoid circular imports.
class SEScope;

/*!
* \brief Base type of all the expressions.
* \sa Expr
Expand Down Expand Up @@ -165,6 +168,29 @@ class RelayExprNode : public BaseExprNode {
template <typename TTypeNode>
inline const TTypeNode* type_as() const;

/*!
* \brief The virtual device (SEScope) for this node (the result of device planning).
* For first-order expressions (non functions), this describes where the result of evaluating the
* expression should be stored. Note that currently, all composite first-order values (tuples,
* references, ADTs) must be stored on the same virtual device. This means that it is not possible
* to store two tuple fields on different devices, so we only need one virtual device for these
* types.
*
* For expressions that have the function type, the virtual device describes where the result of
* the call to the function or closure is stored (instead of where the function itself is stored).
* The SEScope's Target field describes how the body of the function should be compiled.
*
* \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular
* import.
*/
mutable ObjectRef virtual_device_;

/*!
* \return The virtual device (SEScope).
* If the virtual device is not defined, returns SEScope::FullyUnconstrained().
*/
SEScope virtual_device() const;

static constexpr const char* _type_key = "RelayExpr";
static constexpr const uint32_t _type_child_slots = 22;
TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode);
Expand Down
93 changes: 60 additions & 33 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <tvm/ir/op.h>
#include <tvm/target/se_scope.h>

#include <functional>
#include <stack>
Expand Down Expand Up @@ -151,10 +152,14 @@ class Tuple : public Expr {
* \param tuple The tuple to copy
* \param opt_fields The (optional) fields for the copied tuple. If none, ret_tuple->fields =
* tuple->fields.
* \param opt_span The (optional) span for the copied tuple. If none, ret_tuple->span = tuple->span.
* \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none,
* ret_tuple->virtual_device = tuple->virtual_device.
* \param opt_span The (optional) span for the copied tuple. If none,
* ret_tuple->span = tuple->span.
*/
Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields = Optional<Array<Expr>>(),
Optional<Span> opt_span = Optional<Span>(nullptr));
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*!
* \brief Local variables used in the let expression.
Expand Down Expand Up @@ -240,14 +245,17 @@ class Var : public Expr {
* \param opt_vid The (optional) vid for the copied var. If none, ret_var->vid = var->vid.
* \param opt_type_annotation The (optional) type_annotation for the copied var. If none,
* ret_var->type_annotation = var->type_annotation.
* \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none,
* ret_tuple->virtual_device = tuple->virtual_device.
* \param opt_span The (optional) span for the copied var. If none, ret_var->span = var->span.
* \return If all properties are null or the same as the property in the input var
* (i.e., opt_vid is null or opt_vid.value() == var->vid, etc.), then we return var. Otherwise,
* we return a copy of call with the different fields overwritten. (i.e., if
* opt_vid.value() != var->vid, then ret_var->vid = opt_.value()).
* \return If all properties are null or the same as the property in the input var (i.e., opt_vid is
* null or opt_vid.value() == var->vid, etc.), then we return var. Otherwise, we return a copy of
* call with the different fields overwritten. (i.e., if opt_vid.value() != var->vid, then
* ret_var->vid = opt_.value()).
*/
Var WithFields(Var var, Optional<Id> opt_vid = Optional<Id>(),
Optional<Type> opt_type_annotation = Optional<Type>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*!
Expand Down Expand Up @@ -362,16 +370,19 @@ class Call : public Expr {
* call->attrs.
* \param opt_type_args The (optional) type args for the copied call. If none,
* ret_call->type_args = call->type_args.
* \param opt_virtual_device The (optional) virtual_device for the copied call. If none,
* ret_call->virtual_device = call->virtual_device.
* \param opt_span The (optional) span for the copied call. If none, ret_call->span = call->span.
* \return If all properties are null or the same as the property in the input call
* (i.e., opt_op is null or opt_op.value() == call->op, etc.), then we return call. Otherwise, we
* return a copy of call with the different fields overwritten. (i.e., if opt_op.value() !=
* call->op, then ret_call->op = opt_op.value()).
* \return If all properties are null or the same as the property in the input call (i.e., opt_op is
* null or opt_op.value() == call->op, etc.), then we return call. Otherwise, we return a copy of
* call with the different fields overwritten. (i.e., if opt_op.value() != call->op, then
* ret_call->op = opt_op.value()).
*/
Call WithFields(Call call, Optional<Expr> opt_op = Optional<Expr>(),
Optional<Array<Expr>> opt_args = Optional<Array<Expr>>(),
Optional<Attrs> opt_attrs = Optional<Attrs>(),
Optional<Array<Type>> opt_type_args = Optional<Array<Type>>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*!
Expand Down Expand Up @@ -456,6 +467,8 @@ class Let : public Expr {
* \param opt_var The (optional) var for the copied let. If none, ret_let->op = let->op.
* \param opt_value The (optional) value for the copied let. If none, ret_let->args = let->args.
* \param opt_body The (optional) body for the copied let. If none, ret_let->attrs = let->attrs.
* \param opt_virtual_device The (optional) virtual_device for the copied let. If none,
* ret_let->virtual_device = let->virtual_device.
* \param opt_span The (optional) span for the copied let. If none, ret_let->span = let->span.
* \return If all properties are null or the same as the property in the input let (i.e., opt_var is
* null or opt_var.value() == let->var, etc.), then we return let. Otherwise, we return a copy of
Expand All @@ -465,6 +478,7 @@ class Let : public Expr {
Let WithFields(Let let, Optional<Var> opt_var = Optional<Var>(),
Optional<Expr> opt_value = Optional<Expr>(),
Optional<Expr> opt_body = Optional<Expr>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*!
Expand Down Expand Up @@ -539,17 +553,19 @@ class If : public Expr {
* ret_if->true_branch = ret_if->false_branch.
* \param opt_false_branch The (optional) false_branch
* for the copied if_expr. If none, ret_if->false_branch = if_expr->false_branch.
* \param opt_span
* The (optional) span for the copied if_expr. If none, ret_if->span = if_expr->span.
* \return If all
* properties are null or the same as the property in the input if_expr (i.e., opt_cond is null or
* opt_cond.value() == if_expr->cond, etc.), then we return if_expr. Otherwise, we return a copy of
* if_expr with the different fields overwritten. (i.e., if opt_cond.value() != if_expr->cond, then
* ret_if->cond = opt_cond.value()).
* \param opt_virtual_device The (optional) virtual_device for the copied if_expr. If none,
* ret_if->virtual_device = if_expr->virtual_device.
* \param opt_span The (optional) span for the copied if_expr. If none,
* ret_if->span = if_expr->span.
* \return If all properties are null or the same as the property in
* the input if_expr (i.e., opt_cond is null or opt_cond.value() == if_expr->cond, etc.), then we
* return if_expr. Otherwise, we return a copy of if_expr with the different fields overwritten.
* (i.e., if opt_cond.value() != if_expr->cond, then ret_if->cond = opt_cond.value()).
electriclilies marked this conversation as resolved.
Show resolved Hide resolved
*/
If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(),
Optional<Expr> opt_true_branch = Optional<Expr>(),
Optional<Expr> opt_false_branch = Optional<Expr>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*! \brief Get index-th field out of a tuple. */
Expand Down Expand Up @@ -603,8 +619,9 @@ class TupleGetItem : public Expr {
* ret_tuple_get_item->tuple = tuple_get_item->tuple.
* \param opt_index The (optional) index for the copied tuple_get_item. If none,
* ret_tuple_get_item->index = tuple_get_item->index.
* \param
* opt_span The (optional) span for the copied tuple_get_item. If none,
* \param opt_virtual_device The (optional) virtual_device for the copied tuple_get_item.
* If none, ret_tuple_get_item->virtual_device = tuple_get_item->virtual_device.
* \param opt_span The (optional) span for the copied tuple_get_item. If none,
* ret_tuple_get_item->span = tuple_get_item->span.
* \return If all properties are null or the same as the property in the input tuple_get_item
* (i.e., opt_tuple is null or opt_tuple.value() == tuple_get_item->tuple, etc.), then we return
Expand All @@ -614,6 +631,7 @@ class TupleGetItem : public Expr {
*/
TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple = Optional<Expr>(),
Optional<Integer> opt_index = Optional<Integer>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*! \brief Create a new Reference out of initial value. */
Expand Down Expand Up @@ -663,6 +681,8 @@ class RefCreate : public Expr {
* \param ref_create The ref_create to copy.
* \param opt_value The (optional) value for the copied ref_create. If none,
* ret_ref_create->value = ref_create->value.
* \param opt_virtual_device The (optional) virtual_device for the copied ref_create. If none,
* ret_ref_create->virtual_device = ref_create->virtual_device.
* \param opt_span The (optional) span for the copied ref_create. If none,
* ret_ref_create->span = ref_create->span.
* \return If all properties are null or the same as the property in the input ref_create
Expand All @@ -672,6 +692,7 @@ class RefCreate : public Expr {
* ret_ref_create->value = opt_value.value()).
*/
RefCreate WithFields(RefCreate ref_create, Optional<Expr> opt_value = Optional<Expr>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*! \brief Get value out of Reference. */
Expand Down Expand Up @@ -720,15 +741,18 @@ class RefRead : public Expr {
* \param ref_read The ref_read to copy.
* \param opt_ref The (optional) ref for the copied ref_read. If none, ret_ref_read->ref =
* ref_read->ref.
* \param opt_span
* The (optional) span for the copied ref_read. If none, ret_ref_read->span = ref_read->span.
* \return If all properties are null or the same as the property in the input ref_read
* (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), then we return ref_read.
* Otherwise, we return a copy of ref_read with the different fields overwritten.
* (i.e., if opt_ref.value() != ref_read->ref, then
* ret_ref_read->ref = opt_ref.value()).
* \param opt_virtual_device
* The (optional) virtual_device for the copied ref_read. If none, ret_ref_read->virtual_device =
* ref_read->virtual_device.
* \param opt_span The (optional) span for the copied ref_read. If none, ret_ref_read->span =
* ref_read->span.
* \return If all properties are null or the same as the property in the input
* ref_read (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), then we return
* ref_read. Otherwise, we return a copy of ref_read with the different fields overwritten. (i.e.,
* if opt_ref.value() != ref_read->ref, then ret_ref_read->ref = opt_ref.value()).
*/
RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref = Optional<Expr>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
Expand Down Expand Up @@ -784,16 +808,19 @@ class RefWrite : public Expr {
* ret_ref_write->ref = ref_write->ref.
* \param opt_value The (optional) value for the copied ref_write. If none,
* ret_ref_write->value = ref_write->value.
* \param opt_span
* The (optional) span for the copied ref_write. If none, ret_ref_write->span = ref_write->span.
* \return If all properties are null or the same as the property in the input ref_write
* (i.e., opt_ref is null or opt_ref.value() == ref_write->ref, etc.), then we return ref_write.
* Otherwise, we return a copy of ref_write with the different fields overwritten.
* (i.e., if ref_write.value() != ref_write->ref, then
* ret_ref_write->ref = opt_ref.value()).
* \param opt_virtual_device
* The (optional) virtual_device for the copied ref_write. If none, ret_ref_write->virtual_device =
* ref_write->virtual_device.
* \param opt_span The (optional) span for the copied ref_write. If none, ret_ref_write->span =
* ref_write->span.
* \return If all properties are null or the same as the property in the input ref_write (i.e.,
* opt_ref is null or opt_ref.value() == ref_write->ref, etc.), then we return ref_write. Otherwise,
* we return a copy of ref_write with the different fields overwritten. (i.e., if ref_write.value()
* != ref_write->ref, then ret_ref_write->ref = opt_ref.value()).
*/
RefWrite WithFields(RefWrite ref_write, Optional<Expr> opt_ref = Optional<Expr>(),
Optional<Expr> opt_value = Optional<Expr>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*!
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class Function : public BaseFunc {
* \param opt_attrs
* The (optional) attributes for the copied function. If none,
* ret_function->attrs = function->attrs.
* \param opt_virtual_device The (optional) virtual_device for the copied function. If none,
* ret_function->virtual_device = function->virtual_device.
* \param opt_span The (optional) span for the copied function. If none,
* ret_function->span = function->span.
* \return If all properties are null or the same as the property in the input function
Expand All @@ -146,6 +148,7 @@ Function WithFields(Function function, Optional<Array<Var>> opt_params = Optiona
Optional<Type> opt_ret_type = Optional<Type>(),
Optional<Array<TypeVar>> opt_ty_params = Optional<Array<TypeVar>>(),
Optional<DictAttrs> opt_attrs = Optional<DictAttrs>(),
Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
Optional<Span> opt_span = Optional<Span>());

/*
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ enumn = "^0.1"
[build-dependencies]
bindgen = { version="0.57", default-features = false, features = ["runtime"] }
anyhow = "^1.0"
tvm-build = "0.2.1"
tvm-build = "0.2.4"
2 changes: 2 additions & 0 deletions rust/tvm/src/ir/relay/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ pub mod attrs;
pub struct ExprNode {
pub base: BaseExprNode,
pub checked_type: Type,
pub virtual_device: ObjectRef,
}

impl ExprNode {
pub fn base<T: IsObject>(span: Span) -> ExprNode {
ExprNode {
base: BaseExprNode::base::<T>(span.clone()),
checked_type: Type::null(),
virtual_device: ObjectRef::null(),
}
}
}
Expand Down
Loading