-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
Very interesting! Our of curiosity, is this operator going to be entirely parallelized since it basically splits the graph into multiple subgraphs or how is the approach here? |
we potentially can parallelize among iterations. most likely, there is dependency between iterations. so parallelization among iterations may not be very effective. |
e75405a
to
1fa88a7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very interesting work! I have some comments and questions. Thanks.
python/mxnet/symbol/contrib.py
Outdated
"the number of output states (%d) should be the same as input states (%d)" \ | ||
% (len(sym_out[1]), len(init_states)) | ||
|
||
if (isinstance(sym_out[0], list)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No parentheses needed. It would result in coding style error in PyCharm.
python/mxnet/symbol/contrib.py
Outdated
@@ -91,3 +98,99 @@ def rand_zipfian(true_classes, num_sampled, range_max): | |||
expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range | |||
expected_count_sampled = expected_prob_sampled * num_sampled | |||
return sampled_classes, expected_count_true, expected_count_sampled | |||
|
|||
def _get_graph_inputs(subg, name, prefix): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is prefix
used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the prefix is used for pruning. this part hasn't been implemented yet. I'll probably implement it in the next PR. This PR is already very large.
python/mxnet/symbol/contrib.py
Outdated
syms.append(s) | ||
return syms | ||
|
||
def foreach(func, input, init_states, back_prop=False, name="foreach"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input
is a keyword in python. Does it make sense to call it data
which can be both singular and plural?
python/mxnet/symbol/contrib.py
Outdated
for in_name in g.list_inputs(): | ||
assert in_name in gin_names, "The input variable %s can't be found in graph inputs: %s" \ | ||
% (in_name, str(gin_names)) | ||
if (in_name in state_names): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No parentheses.
python/mxnet/symbol/contrib.py
Outdated
if (in_name in state_names): | ||
ordered_ins.append(states_map[in_name]) | ||
in_state_locs.append(len(ordered_ins) - 1) | ||
elif (in_name in data_names): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. No parentheses.
src/operator/nn/control_flow.cc
Outdated
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", ForeachComputeExCPU) | ||
.set_attr<std::string>("key_var_num_args", "num_args") | ||
.add_argument("fn", "Symbol", "Input graph.") | ||
.add_argument("inputs", "NDArray-or-Symbol[]", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's called data
by default. Does it make sense to keep the naming aligned?
src/operator/nn/control_flow.cc
Outdated
this->params = params; | ||
} | ||
|
||
void Forward(std::vector<NDArray> cinputs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why copy std::vector<NDArray>
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this part. I'll figure it out tomorrow.
src/operator/nn/control_flow.cc
Outdated
}) | ||
.set_attr<nnvm::FListInputNames>("FListInputNames", | ||
[](const NodeAttrs& attrs) { | ||
return std::vector<std::string>{"fn", "data1", "data2"}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be generated by param.num_args
?
src/operator/nn/control_flow.cc
Outdated
// in, state0, state1, ... | ||
// We need to reorder them in the same order as the input nodes of the subgraph. | ||
template<typename T> | ||
static std::vector<T> ReorderInputs(const std::vector<T> &in, const nnvm::IndexedGraph& idx) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this used?
src/operator/nn/control_flow.cc
Outdated
} | ||
}; | ||
|
||
void ForeachState::Forward(std::vector<NDArray> cinputs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why copy std::vector<NDArray>
?
src/operator/nn/control_flow.cc
Outdated
shape_inputs[loc] = TShape(in_shape->at(loc).begin() + 1, in_shape->at(loc).end()); | ||
} | ||
CHECK_EQ(attrs.subgraphs.size(), 1U); | ||
auto g = std::make_shared<nnvm::Graph>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems shared_ptr<Graph>
is not necessary here. Graph g
would suffice, right?
src/operator/nn/control_flow.cc
Outdated
const auto& idx = g->indexed_graph(); | ||
CHECK_EQ(idx.input_nodes().size(), in_shape->size()); | ||
CHECK_EQ(idx.outputs().size(), out_shape->size()); | ||
imperative::CheckAndInferShape(g.get(), std::move(shape_inputs), true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may return false
and the return value should be saved and server as the return value of ForeachShape
.
src/operator/nn/control_flow.cc
Outdated
auto eid = idx.entry_id(input_nids[i], 0); | ||
// If the input shape is none, we should update them. | ||
if ((*in_shape)[i].ndim() == 0 || (*in_shape)[i].Size() == 0) | ||
(*in_shape)[i] = shapes[eid]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be more correct, use SHAPE_ASSIGN_CHECK
. Same for the other places of assigning shapes to in_shape and out_shape.
src/operator/nn/control_flow.cc
Outdated
const auto& idx = g->indexed_graph(); | ||
CHECK_EQ(idx.input_nodes().size(), in_type->size()); | ||
CHECK_EQ(idx.outputs().size(), out_type->size()); | ||
imperative::CheckAndInferType(g.get(), std::move(dtype_inputs), true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Save return value and return in the end.
src/operator/nn/control_flow.cc
Outdated
CHECK_EQ(input_nids.size(), in_type->size()); | ||
for (size_t i = 0; i < in_type->size(); i++) { | ||
auto eid = idx.entry_id(input_nids[i], 0); | ||
(*in_type)[i] = dtypes[eid]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TYPE_ASSIGN_CHECK.
src/operator/nn/control_flow.cc
Outdated
CHECK_EQ(idx.outputs().size(), out_attrs->size()); | ||
exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask); | ||
StorageTypeVector storage_type_inputs = *in_attrs; | ||
imperative::CheckAndInferStorageType(g.get(), std::move(dev_masks), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Save return value and return in the end.
src/operator/nn/control_flow.cc
Outdated
CHECK_EQ(input_nids.size(), in_attrs->size()); | ||
for (size_t i = 0; i < in_attrs->size(); i++) { | ||
auto eid = idx.entry_id(input_nids[i], 0); | ||
(*in_attrs)[i] = stypes[eid]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
STORAGE_TYPE_ASSING_CHECK.
@piiswrong @eric-haibin-lin @reminisce @tqchen Could you please review this PR? |
include/mxnet/imperative.h
Outdated
const nnvm::NodeAttrs& attrs, | ||
const std::vector<NDArray*>& inputs, | ||
const std::vector<NDArray*>& outputs); | ||
static OpStatePtr Invoke(const Context& default_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use Imperative::Get()
python/mxnet/ndarray/contrib.py
Outdated
@@ -96,3 +96,18 @@ def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): | |||
expected_count_sampled = expected_prob_sampled * num_sampled | |||
return sampled_classes, expected_count_true, expected_count_sampled | |||
# pylint: enable=line-too-long | |||
|
|||
def foreach(func, input, init_states, back_prop=False, name="foreach"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
back_prop -> Imperative::Get()->is_recording()
add OpContext::is_record at backend
python/mxnet/ndarray/contrib.py
Outdated
ele = input[i] | ||
outs, states = func(ele, states) | ||
outs = _as_list(outs) | ||
if (i == 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outputs.append(outs)
...
outputs = zip(*outputs)
[(a, b, c), (a2, b2, c2), ...] -> [(a, a, a, ...), (b, b, b, ...), ...]
src/operator/nn/control_flow.cc
Outdated
}) | ||
.set_attr<nnvm::FListInputNames>("FListInputNames", | ||
[](const NodeAttrs& attrs) { | ||
return std::vector<std::string>{"fn", "data1", "data2"}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
needs to be variable length
include/mxnet/c_api.h
Outdated
* \param outs The input symbols of the graph. | ||
* \param out_size the number of input symbols returned. | ||
*/ | ||
MXNET_DLL int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **outs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have an ListInput api right?
should be **inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here i need to get a list of input symbols instead of names. do you suggest merging these two APIs?
src/executor/attach_op_execs_pass.cc
Outdated
@@ -134,15 +138,16 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor { | |||
return state_.get_var(); | |||
} | |||
|
|||
explicit StatefulComputeExecutor(const OpStatePtr& state, | |||
explicit StatefulComputeExecutor(const NodeAttrs& attrs, const OpStatePtr& state, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line break
src/imperative/imperative_utils.h
Outdated
@@ -379,7 +379,8 @@ inline void PushFCompute(const FCompute& fn, | |||
&input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst, | |||
&post_temp_src, &post_temp_dst, &in_temp_idx_map, mutate_idx); | |||
// setup context | |||
OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested}; | |||
bool need_grad = Imperative::Get()->is_recording(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need_grad shouldn't be get from the worker thread. It should be set outside similar to is_train
src/imperative/imperative_utils.h
Outdated
if (exec_type == ExecType::kSync) { | ||
// For operators with subgraphs, we need to invoke them in the main thread | ||
// instead of the threaded engine. | ||
if (!attrs.subgraphs.empty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You wouldn't imperatively call an op with subgraphs right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it can happen. For example, if we hybridize a block with control flow operators, the execution of these operators will happen here.
src/operator/nn/control_flow.cc
Outdated
void Forward(std::vector<NDArray> cinputs, | ||
const std::vector<OpReqType>& req, | ||
std::vector<NDArray> coutputs, bool is_recording); | ||
void Backward(int iter_no, std::vector<NDArray> ograds, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line break between args
src/operator/nn/control_flow.cc
Outdated
std::unordered_map<std::string, std::vector<NDArray> > params; | ||
CachedOpPtr op = std::make_shared<Imperative::CachedOp>(subgraph_sym, kwargs, | ||
arg_names, params); | ||
// TODO(zhengda) we need to avoid shape inference and memory plan whenever the op is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not allocate memory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allocating memory, in general, is expensive. if we can avoid shape inference and memory allocation for each iteration, we should.
python/mxnet/symbol/contrib.py
Outdated
return syms | ||
|
||
def foreach(func, data, init_states, name="foreach"): | ||
"""Run a for loop with user-defined computation over NDArrays on dimension 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NDArrays -> Symbols?
python/mxnet/ndarray/contrib.py
Outdated
|
||
Parameters | ||
---------- | ||
func : a Python function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generic python function as an argument seems too broad. Since the interface for func is well defined, do we want to restrict it to a well-defined python class? For example,
class ForeachBody(object):
def forward(data, states):
raise NotImplementedError()
def __call__(data, states):
"""
data: NDArray or list of NDArrays
states: NDArray or list of NDArrays
...
"""
check_input(data, states)
return self.forward(data,states)
def foreach(body, data, state)
Parameters
----------
func : a ForeachBody.
Then you don't have to do check_input inside contrib.foreach
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is to follow the interface of TensorFlow. https://www.tensorflow.org/api_docs/python/tf/while_loop
Using class does make API more well defined, but it requires users to write more code to define it. I don't know what is the best way.
@piiswrong what's your opinion?
python/mxnet/symbol/contrib.py
Outdated
from ..base import _LIB, c_array, check_call | ||
from ..base import SymbolHandle, _as_list | ||
from ..attribute import AttrScope | ||
|
||
__all__ = ["rand_zipfian"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add foreach to __all__ ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise from contrib import *
won't include the foreach function
python/mxnet/symbol/contrib.py
Outdated
# the python function, we need to prune the computation graph constructed from | ||
# the function. One way of doing it is to mark the nodes in the computation graph | ||
# with AttrScope and prune the nodes without the special attribute. | ||
with AttrScope(subgraph_name=name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the alternative?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alternative of what?
src/executor/graph_executor.cc
Outdated
@@ -1537,6 +1555,9 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, | |||
OpNode& op_node = op_nodes_[nid]; | |||
if (op_node.skip_exec_node) continue; | |||
if (inode.source->is_variable()) continue; | |||
// We shouldn't add control flow operators to a segment. | |||
// We can't execute these operators in the engine. | |||
if (op_node.exec->HasSubgraph()) continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not return ret
instead of continue
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using return ret
means breaking the graph into two pieces?
#if MXNET_USE_MKLDNN == 1 | ||
InvalidateOutputs(outputs, req); | ||
#endif | ||
fcompute_ex(state, opctx, inputs, req, outputs); | ||
if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I think we need to check
bool is_gpu = rctx.get_ctx().dev_mask() == gpu::kDevMask;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?
@@ -505,12 +515,16 @@ inline void PushOperator(const OpStatePtr& state, | |||
fcompute(state, opctx, input_blobs, tmp_req, output_blobs); | |||
// post-fcompute fallback, cast to original storage type, if necessary | |||
CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu); | |||
if (is_gpu && exec_type == ExecType::kSync) { | |||
if (is_gpu && exec_type == ExecType::kSync |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is && rctx.get_stream<gpu>()
required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because subgraph operators don't run in the threaded engine and don't have gpu stream.
src/operator/nn/control_flow.cc
Outdated
if (len % 2 == 1) { | ||
for (size_t i = 1; i < subg_outputs1.size(); i++) { | ||
subg_outputs1[i] = outputs[i]; | ||
subg_outputs2[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes all NDArrays are dense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed.
src/operator/nn/control_flow.cc
Outdated
ograds[i] = inputs[i]; | ||
std::vector<OpReqType> iter_req(req.size()); | ||
for (auto r : req) | ||
CHECK_NE(r, kWriteInplace); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is req guaranteed not to be equal to kWriteInplace?
src/operator/nn/subgraph_op_common.h
Outdated
* under the License. | ||
*/ | ||
|
||
#ifndef MXNET_OPERATOR_NN_SUBGRAPH_OP_COMMON_H_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is subgraph op inside nn/
folder? Isn't it more general?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right. i probably should move control flow op out as well.
src/operator/control_flow.cc
Outdated
|
||
struct ForeachParam : public dmlc::Parameter<ForeachParam> { | ||
int num_args; | ||
int dim; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the field "int dim" used anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not. i can remove it for now.
src/operator/control_flow.cc
Outdated
const std::vector<NDArray>& outputs) { | ||
ForeachState &state = state_ptr.get_state<ForeachState>(); | ||
const ForeachParam& params = state.params; | ||
size_t iter_dim = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would you mind adding a "constexpr" specifier?
src/operator/control_flow.cc
Outdated
DMLC_DECLARE_PARAMETER(ForeachParam) { | ||
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) | ||
.describe("Number of inputs."); | ||
DMLC_DECLARE_FIELD(dim).set_default(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this default to 1 or 0?
src/operator/subgraph_op_common.cc
Outdated
std::vector<std::pair<std::string, std::string> > kwargs; | ||
kwargs.push_back(std::pair<std::string, std::string>("inline_limit", "0")); | ||
// Get input names. | ||
const auto& idx = subgraph.indexed_graph(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems the "idx" here is unused
python/mxnet/symbol/contrib.py
Outdated
is_NDArray_or_list = isinstance(inputs, in_type) | ||
assert is_NDArray_or_list, msg | ||
|
||
check_data(data, symbol.Symbol, "data should be an NDArray or a list of NDArrays") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should use symbol instead of ndarray in the error message
python/mxnet/ndarray/contrib.py
Outdated
"init_states should be an NDArray or a list of NDArrays") | ||
|
||
not_data_list = isinstance(data, ndarray.NDArray) | ||
not_state_list = isinstance(init_states, ndarray.NDArray) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This boolean variable is not used, you probably should check it during the loop
One suggestion, since this ops is common known as scan, why now just use the common name instead of inventing a new API name? |
@tqchen Originally, I consider this as a control flow operator, so I use foreach because a lot of languages use |
python/mxnet/symbol/contrib.py
Outdated
from ..base import _LIB, c_array, check_call | ||
from ..base import SymbolHandle, _as_list | ||
from ..attribute import AttrScope | ||
|
||
__all__ = ["rand_zipfian"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise from contrib import *
won't include the foreach function
python/mxnet/symbol/contrib.py
Outdated
@@ -91,3 +98,154 @@ def rand_zipfian(true_classes, num_sampled, range_max): | |||
expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range | |||
expected_count_sampled = expected_prob_sampled * num_sampled | |||
return sampled_classes, expected_count_true, expected_count_sampled | |||
|
|||
def _get_graph_inputs(subg): | |||
num_handles = ctypes.c_int(1000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix?
* This contains the states for running a loop and provides methods | ||
* of running the subgraph computation for an iteration. | ||
*/ | ||
class LoopState { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this state going to be shared by while loop or just foreach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's supposed to be shared by while loop.
src/operator/subgraph_op_common.cc
Outdated
} | ||
|
||
void LoopState::Backward(int iter_no, | ||
std::vector<NDArray> ograds, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason it's not const reference is that we need a copy of the NDArray vector anyway. we need to pass their pointers to cached op, which requests pointers instead of const pointers.
src/operator/subgraph_op_common.cc
Outdated
// TODO(zhengda) we need to avoid shape inference and memory plan whenever the op is | ||
// called. Currently, CachedOp allocates memory each time Forward is called. | ||
// I need to fix this once the PR for static memory allocation in CachedOp is | ||
// merged. https://github.com/apache/incubator-mxnet/pull/10817 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR for static mem was merged
src/operator/subgraph_op_common.cc
Outdated
const auto& idx = g.indexed_graph(); | ||
CHECK_EQ(idx.input_nodes().size(), in_type->size()); | ||
CHECK_EQ(idx.outputs().size(), out_type->size()); | ||
imperative::CheckAndInferType(&g, std::move(dtype_inputs), true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if some out_type is known and we want to perform mutual inference based on outputs for elemwise ops?
CHECK_EQ(len, outputs[i].shape()[iter_dim]); | ||
for (const auto &arr : outputs) | ||
CHECK_EQ(arr.storage_type(), kDefaultStorage) | ||
<< "The for operator doesn't support the sparse format"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious: Is there anything special to handle for sparse nd in foreach??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because I created NDArrays below. To handle sparse arrays, I need to create sparse arrays explicitly. I'm not sure if foreach needs to handle sparse arrays in general. So this version will just handle dense arrays first.
src/operator/control_flow.cc
Outdated
ograds[i] = inputs[i]; | ||
std::vector<OpReqType> iter_req(req.size()); | ||
for (auto r : req) | ||
CHECK_NE(r, kWriteInplace); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this guaranteed? If plan memory introduced writeinplace in the backward, how does an user work around that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To enable WriteInplace, an operator needs to enable FInplaceOption, right? foreach doesn't have the attribute for either forward or backward.
src/imperative/imperative_utils.h
Outdated
rctx.get_stream<gpu>()->Wait(); | ||
} | ||
}; | ||
|
||
if (exec_type == ExecType::kSync) { | ||
if (!attrs.subgraphs.empty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about setting exec_type for Foreach op to kSubgraph
instead of checking attrs.subgraphs?
@@ -96,3 +98,97 @@ def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): | |||
expected_count_sampled = expected_prob_sampled * num_sampled | |||
return sampled_classes, expected_count_true, expected_count_sampled | |||
# pylint: enable=line-too-long | |||
|
|||
def foreach(body, data, init_states): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably want to also update ndarray/contrib.md and symbol/contrib.md
When two node entries refer to the same output of a node, we should create only one var node for these two node entries.
We can't get a list of variable names from a hashtable. The order can't be guaranteed. Python2 and Python3 output different orders.
Description
This PR adds a control flow operator: foreach. It takes a Python function as input and run the function over the elements in the input array. foreach is similar to scan in TensorFlow.
This PR is part of the proposal of adding a set of control flow operators to MXNet.
https://cwiki.apache.org/confluence/display/MXNET/Optimize+dynamic+neural+network+models+with+control+flow+operators
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.