From c507efa1aa640391e957cddcf1f4f758876642f3 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 11 Apr 2018 00:12:05 -0700 Subject: [PATCH] [NODE] GetRef helper (#31) --- relay/include/relay/node.h | 20 ++++++++++++++++++++ relay/src/relay/evaluator.cc | 2 ++ 2 files changed, 22 insertions(+) diff --git a/relay/include/relay/node.h b/relay/include/relay/node.h index c7baa5a0e00da..2f5aec36c7771 100644 --- a/relay/include/relay/node.h +++ b/relay/include/relay/node.h @@ -862,6 +862,26 @@ class FnValueNode : public ValueNode { RELAY_DEFINE_VALUE(FnValue, FnValueNode); + +/*! + * \brief Get a reference type from a Node ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the node alive beyond the scope of the function. + * + * \param ptr The node pointer + * \tparam RefType The reference type + * \tparam NodeType The node type + * \return The corresponding RefType + */ +template +RefType GetRef(const NodeType* ptr) { + static_assert(std::is_same::value, + "Can only cast to the ref of same container type"); + return RefType(const_cast(ptr)->shared_from_this()); +} + } // namespace relay } // namespace nnvm diff --git a/relay/src/relay/evaluator.cc b/relay/src/relay/evaluator.cc index e7c864642e50a..673985b688cb5 100644 --- a/relay/src/relay/evaluator.cc +++ b/relay/src/relay/evaluator.cc @@ -21,6 +21,8 @@ Evaluator::Evaluator() : env() {} Value Evaluator::Eval(const Expr& expr) { return this->operator()(expr); } Value Evaluator::VisitExpr_(const LocalIdNode* local_node) { + // will error: GetRef(local_node); + LocalId local_id = GetRef(local_node); // We should instead compile this to a stack machine with statically resolved // offsets. // LocalId local = LocalId(local_node->GetNodeRef().node_);