Skip to content

Commit

Permalink
[NODE] GetRef helper (apache#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and jroesch committed Aug 16, 2018
1 parent 4481e48 commit c507efa
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
20 changes: 20 additions & 0 deletions relay/include/relay/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,26 @@ class FnValueNode : public ValueNode {

RELAY_DEFINE_VALUE(FnValue, FnValueNode);


/*!
* \brief Get a reference type from a Node ptr type
*
* It is always important to get a reference type
* if we want to return a value as reference or keep
* the node alive beyond the scope of the function.
*
* \param ptr The node pointer
* \tparam RefType The reference type
* \tparam NodeType The node type
* \return The corresponding RefType
*/
template<typename RefType, typename NodeType>
RefType GetRef(const NodeType* ptr) {
static_assert(std::is_same<typename RefType::ContainerType, NodeType>::value,
"Can only cast to the ref of same container type");
return RefType(const_cast<NodeType*>(ptr)->shared_from_this());
}

} // namespace relay
} // namespace nnvm

Expand Down
2 changes: 2 additions & 0 deletions relay/src/relay/evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Evaluator::Evaluator() : env() {}
Value Evaluator::Eval(const Expr& expr) { return this->operator()(expr); }

Value Evaluator::VisitExpr_(const LocalIdNode* local_node) {
// will error: GetRef<GlobalId>(local_node);
LocalId local_id = GetRef<LocalId>(local_node);
// We should instead compile this to a stack machine with statically resolved
// offsets.
// LocalId local = LocalId(local_node->GetNodeRef().node_);
Expand Down

0 comments on commit c507efa

Please sign in to comment.