Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
change for the new interface of InputGraph attribute.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Apr 9, 2018
1 parent 1237477 commit 7b0bcd1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/executor/attach_op_execs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class FComputeExecutor : public StorageFallbackOpExecutor {
}

bool HasSubgraph() const override {
return attrs_.g != nullptr;
return !attrs_.subgraphs.empty();
}

explicit FComputeExecutor(const NodeAttrs& attrs, FCompute fcompute,
Expand Down Expand Up @@ -217,7 +217,7 @@ class FComputeExExecutor : public OpExecutor {
void Setup() override {}

bool HasSubgraph() const override {
return attrs_.g != nullptr;
return !attrs_.subgraphs.empty();
}

ExecType exec_type() const override {
Expand Down
15 changes: 9 additions & 6 deletions src/operator/nn/control_flow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ static void ForeachComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK(attrs.g != nullptr);
nnvm::Graph &g = *attrs.g;
CHECK_EQ(attrs.subgraphs.size(), 1U);
nnvm::Graph &g = *attrs.subgraphs[0];
const auto& idx = g.indexed_graph();

// If this is inference, we only need the forward memory plan.
Expand Down Expand Up @@ -290,7 +290,8 @@ static bool ForeachShape(const nnvm::NodeAttrs& attrs,
nnvm::ShapeVector shape_inputs = *in_shape;
// foreach iterates over the first input NDArray over the first dimension.
shape_inputs[0] = TShape(in_shape->at(0).begin() + 1, in_shape->at(0).end());
auto g = attrs.g;
CHECK_EQ(attrs.subgraphs.size(), 1U);
auto g = attrs.subgraphs[0];
CHECK(g);
const auto& idx = g->indexed_graph();
CHECK_EQ(idx.input_nodes().size(), in_shape->size());
Expand Down Expand Up @@ -322,7 +323,8 @@ static bool ForeachShape(const nnvm::NodeAttrs& attrs,
static bool ForeachType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type, std::vector<int> *out_type) {
nnvm::DTypeVector dtype_inputs = *in_type;
auto g = attrs.g;
CHECK_EQ(attrs.subgraphs.size(), 1U);
auto g = attrs.subgraphs[0];
CHECK(g);
const auto& idx = g->indexed_graph();
CHECK_EQ(idx.input_nodes().size(), in_type->size());
Expand All @@ -342,7 +344,8 @@ static bool ForeachStorageType(const nnvm::NodeAttrs& attrs,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
auto g = attrs.g;
CHECK_EQ(attrs.subgraphs.size(), 1U);
auto g = attrs.subgraphs[0];
CHECK(g);
const auto& idx = g->indexed_graph();
CHECK_EQ(idx.input_nodes().size(), in_attrs->size());
Expand Down Expand Up @@ -379,7 +382,7 @@ NNVM_REGISTER_OP(_foreach)
})
.set_attr<nnvm::FInputGraph>("FInputGraph",
[](const NodeAttrs& attrs) {
return 0;
return std::vector<uint32_t>{0};
})
.set_attr<nnvm::FInferShape>("FInferShape", ForeachShape)
.set_attr<nnvm::FInferType>("FInferType", ForeachType)
Expand Down

0 comments on commit 7b0bcd1

Please sign in to comment.