Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-432] Add Foreach #10451

Closed
wants to merge 135 commits into from
Closed

[MXNET-432] Add Foreach #10451

wants to merge 135 commits into from

Conversation

zheng-da
Copy link
Contributor

@zheng-da zheng-da commented Apr 7, 2018

Description

This PR adds a control flow operator: foreach. It takes a Python function as input and run the function over the elements in the input array. foreach is similar to scan in TensorFlow.

This PR is part of the proposal of adding a set of control flow operators to MXNet.
https://cwiki.apache.org/confluence/display/MXNET/Optimize+dynamic+neural+network+models+with+control+flow+operators

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

@zheng-da zheng-da requested a review from szha as a code owner April 7, 2018 01:19
@marcoabreu
Copy link
Contributor

Very interesting! Our of curiosity, is this operator going to be entirely parallelized since it basically splits the graph into multiple subgraphs or how is the approach here?

@zheng-da
Copy link
Contributor Author

zheng-da commented Apr 7, 2018

we potentially can parallelize among iterations. most likely, there is dependency between iterations. so parallelization among iterations may not be very effective.

Copy link
Contributor

@reminisce reminisce left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting work! I have some comments and questions. Thanks.

"the number of output states (%d) should be the same as input states (%d)" \
% (len(sym_out[1]), len(init_states))

if (isinstance(sym_out[0], list)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No parentheses needed. It would result in coding style error in PyCharm.

@@ -91,3 +98,99 @@ def rand_zipfian(true_classes, num_sampled, range_max):
expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range
expected_count_sampled = expected_prob_sampled * num_sampled
return sampled_classes, expected_count_true, expected_count_sampled

def _get_graph_inputs(subg, name, prefix):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is prefix used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the prefix is used for pruning. this part hasn't been implemented yet. I'll probably implement it in the next PR. This PR is already very large.

syms.append(s)
return syms

def foreach(func, input, init_states, back_prop=False, name="foreach"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input is a keyword in python. Does it make sense to call it data which can be both singular and plural?

for in_name in g.list_inputs():
assert in_name in gin_names, "The input variable %s can't be found in graph inputs: %s" \
% (in_name, str(gin_names))
if (in_name in state_names):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No parentheses.

if (in_name in state_names):
ordered_ins.append(states_map[in_name])
in_state_locs.append(len(ordered_ins) - 1)
elif (in_name in data_names):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. No parentheses.

.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", ForeachComputeExCPU)
.set_attr<std::string>("key_var_num_args", "num_args")
.add_argument("fn", "Symbol", "Input graph.")
.add_argument("inputs", "NDArray-or-Symbol[]",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's called data by default. Does it make sense to keep the naming aligned?

this->params = params;
}

void Forward(std::vector<NDArray> cinputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why copy std::vector<NDArray>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this part. I'll figure it out tomorrow.

})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"fn", "data1", "data2"};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be generated by param.num_args?

// in, state0, state1, ...
// We need to reorder them in the same order as the input nodes of the subgraph.
template<typename T>
static std::vector<T> ReorderInputs(const std::vector<T> &in, const nnvm::IndexedGraph& idx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used?

}
};

void ForeachState::Forward(std::vector<NDArray> cinputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why copy std::vector<NDArray>?

shape_inputs[loc] = TShape(in_shape->at(loc).begin() + 1, in_shape->at(loc).end());
}
CHECK_EQ(attrs.subgraphs.size(), 1U);
auto g = std::make_shared<nnvm::Graph>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems shared_ptr<Graph> is not necessary here. Graph g would suffice, right?

const auto& idx = g->indexed_graph();
CHECK_EQ(idx.input_nodes().size(), in_shape->size());
CHECK_EQ(idx.outputs().size(), out_shape->size());
imperative::CheckAndInferShape(g.get(), std::move(shape_inputs), true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may return false and the return value should be saved and server as the return value of ForeachShape.

auto eid = idx.entry_id(input_nids[i], 0);
// If the input shape is none, we should update them.
if ((*in_shape)[i].ndim() == 0 || (*in_shape)[i].Size() == 0)
(*in_shape)[i] = shapes[eid];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be more correct, use SHAPE_ASSIGN_CHECK. Same for the other places of assigning shapes to in_shape and out_shape.

const auto& idx = g->indexed_graph();
CHECK_EQ(idx.input_nodes().size(), in_type->size());
CHECK_EQ(idx.outputs().size(), out_type->size());
imperative::CheckAndInferType(g.get(), std::move(dtype_inputs), true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Save return value and return in the end.

CHECK_EQ(input_nids.size(), in_type->size());
for (size_t i = 0; i < in_type->size(); i++) {
auto eid = idx.entry_id(input_nids[i], 0);
(*in_type)[i] = dtypes[eid];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TYPE_ASSIGN_CHECK.

CHECK_EQ(idx.outputs().size(), out_attrs->size());
exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
StorageTypeVector storage_type_inputs = *in_attrs;
imperative::CheckAndInferStorageType(g.get(), std::move(dev_masks),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Save return value and return in the end.

CHECK_EQ(input_nids.size(), in_attrs->size());
for (size_t i = 0; i < in_attrs->size(); i++) {
auto eid = idx.entry_id(input_nids[i], 0);
(*in_attrs)[i] = stypes[eid];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

STORAGE_TYPE_ASSING_CHECK.

@zheng-da zheng-da changed the title [WIP] Add Foreach [MXNET-432] Add Foreach May 18, 2018
@zheng-da
Copy link
Contributor Author

@piiswrong @eric-haibin-lin @reminisce @tqchen Could you please review this PR?

const nnvm::NodeAttrs& attrs,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs);
static OpStatePtr Invoke(const Context& default_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use Imperative::Get()

@@ -96,3 +96,18 @@ def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
expected_count_sampled = expected_prob_sampled * num_sampled
return sampled_classes, expected_count_true, expected_count_sampled
# pylint: enable=line-too-long

def foreach(func, input, init_states, back_prop=False, name="foreach"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

back_prop -> Imperative::Get()->is_recording()
add OpContext::is_record at backend

ele = input[i]
outs, states = func(ele, states)
outs = _as_list(outs)
if (i == 0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outputs.append(outs)
...

outputs = zip(*outputs)

[(a, b, c), (a2, b2, c2), ...] -> [(a, a, a, ...), (b, b, b, ...), ...]

})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"fn", "data1", "data2"};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs to be variable length

* \param outs The input symbols of the graph.
* \param out_size the number of input symbols returned.
*/
MXNET_DLL int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **outs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have an ListInput api right?
should be **inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here i need to get a list of input symbols instead of names. do you suggest merging these two APIs?

@@ -134,15 +138,16 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
return state_.get_var();
}

explicit StatefulComputeExecutor(const OpStatePtr& state,
explicit StatefulComputeExecutor(const NodeAttrs& attrs, const OpStatePtr& state,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line break

@@ -379,7 +379,8 @@ inline void PushFCompute(const FCompute& fn,
&input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst,
&post_temp_src, &post_temp_dst, &in_temp_idx_map, mutate_idx);
// setup context
OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
bool need_grad = Imperative::Get()->is_recording();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need_grad shouldn't be get from the worker thread. It should be set outside similar to is_train

if (exec_type == ExecType::kSync) {
// For operators with subgraphs, we need to invoke them in the main thread
// instead of the threaded engine.
if (!attrs.subgraphs.empty()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You wouldn't imperatively call an op with subgraphs right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can happen. For example, if we hybridize a block with control flow operators, the execution of these operators will happen here.

void Forward(std::vector<NDArray> cinputs,
const std::vector<OpReqType>& req,
std::vector<NDArray> coutputs, bool is_recording);
void Backward(int iter_no, std::vector<NDArray> ograds,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line break between args

std::unordered_map<std::string, std::vector<NDArray> > params;
CachedOpPtr op = std::make_shared<Imperative::CachedOp>(subgraph_sym, kwargs,
arg_names, params);
// TODO(zhengda) we need to avoid shape inference and memory plan whenever the op is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not allocate memory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allocating memory, in general, is expensive. if we can avoid shape inference and memory allocation for each iteration, we should.

return syms

def foreach(func, data, init_states, name="foreach"):
"""Run a for loop with user-defined computation over NDArrays on dimension 0.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NDArrays -> Symbols?


Parameters
----------
func : a Python function.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generic python function as an argument seems too broad. Since the interface for func is well defined, do we want to restrict it to a well-defined python class? For example,

class ForeachBody(object):
    def forward(data, states):
        raise NotImplementedError()

    def __call__(data, states):
        """
        data: NDArray or list of NDArrays
        states: NDArray or list of NDArrays
        ... 
        """
        check_input(data, states)
        return self.forward(data,states)

def foreach(body, data, state)

    Parameters
    ----------
    func : a ForeachBody.

Then you don't have to do check_input inside contrib.foreach

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to follow the interface of TensorFlow. https://www.tensorflow.org/api_docs/python/tf/while_loop
Using class does make API more well defined, but it requires users to write more code to define it. I don't know what is the best way.
@piiswrong what's your opinion?

from ..base import _LIB, c_array, check_call
from ..base import SymbolHandle, _as_list
from ..attribute import AttrScope

__all__ = ["rand_zipfian"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add foreach to __all__ ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this for?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise from contrib import * won't include the foreach function

# the python function, we need to prune the computation graph constructed from
# the function. One way of doing it is to mark the nodes in the computation graph
# with AttrScope and prune the nodes without the special attribute.
with AttrScope(subgraph_name=name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the alternative?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternative of what?

@@ -1537,6 +1555,9 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start,
OpNode& op_node = op_nodes_[nid];
if (op_node.skip_exec_node) continue;
if (inode.source->is_variable()) continue;
// We shouldn't add control flow operators to a segment.
// We can't execute these operators in the engine.
if (op_node.exec->HasSubgraph()) continue;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not return ret instead of continue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using return ret means breaking the graph into two pieces?

#if MXNET_USE_MKLDNN == 1
InvalidateOutputs(outputs, req);
#endif
fcompute_ex(state, opctx, inputs, req, outputs);
if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I think we need to check
bool is_gpu = rctx.get_ctx().dev_mask() == gpu::kDevMask;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

@@ -505,12 +515,16 @@ inline void PushOperator(const OpStatePtr& state,
fcompute(state, opctx, input_blobs, tmp_req, output_blobs);
// post-fcompute fallback, cast to original storage type, if necessary
CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
if (is_gpu && exec_type == ExecType::kSync) {
if (is_gpu && exec_type == ExecType::kSync
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is && rctx.get_stream<gpu>() required?

Copy link
Contributor Author

@zheng-da zheng-da May 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because subgraph operators don't run in the threaded engine and don't have gpu stream.

if (len % 2 == 1) {
for (size_t i = 1; i < subg_outputs1.size(); i++) {
subg_outputs1[i] = outputs[i];
subg_outputs2[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes all NDArrays are dense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed.

ograds[i] = inputs[i];
std::vector<OpReqType> iter_req(req.size());
for (auto r : req)
CHECK_NE(r, kWriteInplace);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is req guaranteed not to be equal to kWriteInplace?

* under the License.
*/

#ifndef MXNET_OPERATOR_NN_SUBGRAPH_OP_COMMON_H_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is subgraph op inside nn/ folder? Isn't it more general?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right. i probably should move control flow op out as well.

@zheng-da zheng-da mentioned this pull request May 25, 2018
7 tasks

struct ForeachParam : public dmlc::Parameter<ForeachParam> {
int num_args;
int dim;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the field "int dim" used anywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not. i can remove it for now.

const std::vector<NDArray>& outputs) {
ForeachState &state = state_ptr.get_state<ForeachState>();
const ForeachParam& params = state.params;
size_t iter_dim = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you mind adding a "constexpr" specifier?

DMLC_DECLARE_PARAMETER(ForeachParam) {
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
.describe("Number of inputs.");
DMLC_DECLARE_FIELD(dim).set_default(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this default to 1 or 0?

std::vector<std::pair<std::string, std::string> > kwargs;
kwargs.push_back(std::pair<std::string, std::string>("inline_limit", "0"));
// Get input names.
const auto& idx = subgraph.indexed_graph();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems the "idx" here is unused

is_NDArray_or_list = isinstance(inputs, in_type)
assert is_NDArray_or_list, msg

check_data(data, symbol.Symbol, "data should be an NDArray or a list of NDArrays")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use symbol instead of ndarray in the error message

"init_states should be an NDArray or a list of NDArrays")

not_data_list = isinstance(data, ndarray.NDArray)
not_state_list = isinstance(init_states, ndarray.NDArray)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This boolean variable is not used, you probably should check it during the loop

@tqchen
Copy link
Member

tqchen commented Jun 14, 2018

One suggestion, since this ops is common known as scan, why now just use the common name instead of inventing a new API name?

@zheng-da
Copy link
Contributor Author

@tqchen Originally, I consider this as a control flow operator, so I use foreach because a lot of languages use foreach as a keyword for this. But DL frameworks consider this operator as a high-order function and call it scan. It makes sense to follow the convention of the DL frameworks, although the definition of this operator is a little different from the one in TensorFlow (foreach here splits outputs of an iteration into outputs of the loop and loop variables).

from ..base import _LIB, c_array, check_call
from ..base import SymbolHandle, _as_list
from ..attribute import AttrScope

__all__ = ["rand_zipfian"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise from contrib import * won't include the foreach function

@@ -91,3 +98,154 @@ def rand_zipfian(true_classes, num_sampled, range_max):
expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range
expected_count_sampled = expected_prob_sampled * num_sampled
return sampled_classes, expected_count_true, expected_count_sampled

def _get_graph_inputs(subg):
num_handles = ctypes.c_int(1000)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix?

* This contains the states for running a loop and provides methods
* of running the subgraph computation for an iteration.
*/
class LoopState {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this state going to be shared by while loop or just foreach?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's supposed to be shared by while loop.

}

void LoopState::Backward(int iter_no,
std::vector<NDArray> ograds,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const reference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason it's not const reference is that we need a copy of the NDArray vector anyway. we need to pass their pointers to cached op, which requests pointers instead of const pointers.

// TODO(zhengda) we need to avoid shape inference and memory plan whenever the op is
// called. Currently, CachedOp allocates memory each time Forward is called.
// I need to fix this once the PR for static memory allocation in CachedOp is
// merged. https://github.com/apache/incubator-mxnet/pull/10817
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR for static mem was merged

const auto& idx = g.indexed_graph();
CHECK_EQ(idx.input_nodes().size(), in_type->size());
CHECK_EQ(idx.outputs().size(), out_type->size());
imperative::CheckAndInferType(&g, std::move(dtype_inputs), true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if some out_type is known and we want to perform mutual inference based on outputs for elemwise ops?

CHECK_EQ(len, outputs[i].shape()[iter_dim]);
for (const auto &arr : outputs)
CHECK_EQ(arr.storage_type(), kDefaultStorage)
<< "The for operator doesn't support the sparse format";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious: Is there anything special to handle for sparse nd in foreach??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because I created NDArrays below. To handle sparse arrays, I need to create sparse arrays explicitly. I'm not sure if foreach needs to handle sparse arrays in general. So this version will just handle dense arrays first.

ograds[i] = inputs[i];
std::vector<OpReqType> iter_req(req.size());
for (auto r : req)
CHECK_NE(r, kWriteInplace);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this guaranteed? If plan memory introduced writeinplace in the backward, how does an user work around that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To enable WriteInplace, an operator needs to enable FInplaceOption, right? foreach doesn't have the attribute for either forward or backward.

rctx.get_stream<gpu>()->Wait();
}
};

if (exec_type == ExecType::kSync) {
if (!attrs.subgraphs.empty()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about setting exec_type for Foreach op to kSubgraph instead of checking attrs.subgraphs?

@@ -96,3 +98,97 @@ def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
expected_count_sampled = expected_prob_sampled * num_sampled
return sampled_classes, expected_count_true, expected_count_sampled
# pylint: enable=line-too-long

def foreach(body, data, init_states):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably want to also update ndarray/contrib.md and symbol/contrib.md

@zheng-da zheng-da mentioned this pull request Jul 2, 2018
5 tasks
@zheng-da zheng-da closed this Jul 2, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.