Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbol binding for sparse tensor development. #31

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
1e0d38b
Initial checkin
reminisce Apr 5, 2017
34611f3
Add init functions for simple bind in graph_executor
reminisce Apr 6, 2017
45abe93
Add simple_bind c_api
reminisce Apr 6, 2017
f292157
Add simple bind c-api
reminisce Apr 8, 2017
965e839
Assign zeros to in_args, arg_grads, and aux_states
reminisce Apr 8, 2017
fe933a4
Add simple_bind2 python interface
reminisce Apr 8, 2017
1c1b054
Fix python interface bugs
reminisce Apr 8, 2017
bdeee94
Interface changes
reminisce Apr 10, 2017
6e8c352
Fix
reminisce Apr 11, 2017
08428eb
Fix core dump
reminisce Apr 11, 2017
16079d8
Add bind_ith_exec c_api
reminisce Apr 14, 2017
2c37443
Change simple_bind2
reminisce Apr 15, 2017
bf04f60
Fix seg fault
reminisce Apr 15, 2017
e7f43f2
Finish simple_bind
reminisce Apr 15, 2017
c732ceb
Change _bind_ith_exec
reminisce Apr 15, 2017
4114306
Refactor simple_bind initialization flow for bind
reminisce Apr 16, 2017
ba2c78d
Consolidate bind and simple_bind graph init flow
reminisce Apr 16, 2017
30ebf0b
Fix bug
reminisce Apr 16, 2017
6fcc965
Clean up
reminisce Apr 16, 2017
8f2b2fb
Add comments
reminisce Apr 16, 2017
4487e3b
Clean up
reminisce Apr 16, 2017
394f392
Clean up
reminisce Apr 16, 2017
db7eefd
Minor correction
reminisce Apr 17, 2017
d312d84
Rename APIs in graph executor
reminisce May 3, 2017
b663a99
Refactor
reminisce May 4, 2017
8f5684c
Rebase
reminisce May 4, 2017
b6375ad
Delete deprecated functions
reminisce May 5, 2017
b6d281e
Move more front-end work to backend
reminisce May 8, 2017
2d7fa8a
Bug fix
reminisce May 10, 2017
9b10ad5
Fix failed tests
reminisce May 11, 2017
5ef1401
Minor fix
reminisce May 12, 2017
591523b
Fix lint
reminisce May 13, 2017
db05235
Fix lint
reminisce May 13, 2017
0dd1d94
Revert unnecessary changes
reminisce May 13, 2017
53b99c9
Revert
reminisce May 13, 2017
4743a0d
Revert
reminisce May 13, 2017
b9854e4
Clean up
reminisce May 14, 2017
60f060c
Fix lint
reminisce May 14, 2017
a8d7a46
Merge branch 'master' into improve_symbol_bind
reminisce May 14, 2017
432156c
Merge branch 'master' into improve_symbol_bind
reminisce May 15, 2017
2b2a44c
Merge branch 'master' into improve_symbol_bind
reminisce May 15, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,36 @@ MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle,
NDArrayHandle *aux_states,
ExecutorHandle shared_exec,
ExecutorHandle *out);

MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
const mx_uint num_g2c_keys,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const mx_uint provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
mx_uint* shared_buffer_len,
const char*** shared_buffer_name_list,
NDArrayHandle** shared_buffer_handle_list,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
mx_uint* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);
/*!
* \brief set a call back to notify the completion of operation
*/
Expand Down
32 changes: 32 additions & 0 deletions include/mxnet/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ class Executor {
* \return array of outputs in the executor.
*/
virtual const std::vector<NDArray> &outputs() const = 0;
/*!
* \brief get input argument map, key is arg name, value is arg's NDArray.
* \return input argument map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& in_arg_map() const = 0;
/*!
* \brief get input argument graident map, key is arg name, value is gradient's NDArray.
* \return input argument gradient map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& arg_grad_map() const = 0;
/*!
* \brief get aux state map, key is arg name, value is aux state's NDArray.
* \return aux state map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& aux_state_map() const = 0;
/*!
* \brief Create an operator by bind symbol with context and arguments.
* If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp.
Expand All @@ -91,6 +106,23 @@ class Executor {
const std::vector<OpReqType> &grad_req_type,
const std::vector<NDArray> &aux_states,
Executor* shared_exec = NULL);

static Executor* SimpleBind(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& group2ctx,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_map,
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& param_names,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grads,
std::vector<NDArray>* aux_states,
std::unordered_map<std::string, NDArray>*
shared_data_arrays = nullptr,
Executor* shared_exec = nullptr);
/*!
* \brief the prototype of user-defined monitor callback
*/
Expand Down
81 changes: 5 additions & 76 deletions python/mxnet/module/executor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
from collections import OrderedDict

import numpy as np

from .. import context as ctx
Expand Down Expand Up @@ -559,6 +558,7 @@ def update_metric(self, eval_metric, labels):

def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
"""Internal utility function to bind the i-th executor.
This function utilizes simple_bind python interface.
"""
shared_exec = None if shared_group is None else shared_group.execs[i]
context = self.contexts[i]
Expand All @@ -568,85 +568,14 @@ def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
if label_shapes is not None:
input_shapes.update(dict(label_shapes))

arg_shapes, _, aux_shapes = self.symbol.infer_shape(**input_shapes)
assert arg_shapes is not None, "shape inference failed"

input_types = {x.name: x.dtype for x in data_shapes}
if label_shapes is not None:
input_types.update({x.name: x.dtype for x in label_shapes})
arg_types, _, aux_types = self.symbol.infer_type(**input_types)
assert arg_types is not None, "type inference failed"

arg_arrays = []
grad_arrays = {} if self.for_training else None

def _get_or_reshape(name, shared_data_arrays, arg_shape, arg_type, context, logger):
"""Internal helper to get a memory block or re-use by re-shaping."""
if name in shared_data_arrays:
arg_arr = shared_data_arrays[name]

if np.prod(arg_arr.shape) >= np.prod(arg_shape):
# nice, we can directly re-use this data blob
assert arg_arr.dtype == arg_type
arg_arr = arg_arr.reshape(arg_shape)
else:
logger.warning(('bucketing: data "%s" has a shape %s' % (name, arg_shape)) +
(', which is larger than already allocated ') +
('shape %s' % (arg_arr.shape,)) +
('. Need to re-allocate. Consider putting ') +
('default_bucket_key to') +
(' be the bucket taking the largest input for better ') +
('memory sharing.'))
arg_arr = nd.zeros(arg_shape, context, dtype=arg_type)

# replace existing shared array because the new one is bigger
shared_data_arrays[name] = arg_arr
else:
arg_arr = nd.zeros(arg_shape, context, dtype=arg_type)
shared_data_arrays[name] = arg_arr

return arg_arr

# create or borrow arguments and gradients
for j in range(len(self.arg_names)):
name = self.arg_names[j]
if name in self.param_names: # model parameters
if shared_exec is None:
arg_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
if self.grad_req[name] != 'null':
grad_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
grad_arrays[name] = grad_arr
else:
arg_arr = shared_exec.arg_dict[name]
assert arg_arr.shape == arg_shapes[j]
assert arg_arr.dtype == arg_types[j]
if self.grad_req[name] != 'null':
grad_arrays[name] = shared_exec.grad_dict[name]
else: # data, label, or states
arg_arr = _get_or_reshape(name, shared_data_arrays, arg_shapes[j], arg_types[j],
context, self.logger)

# data might also need grad if inputs_need_grad is True
if self.grad_req[name] != 'null':
grad_arrays[name] = _get_or_reshape('grad of ' + name, shared_data_arrays,
arg_shapes[j], arg_types[j], context,
self.logger)

arg_arrays.append(arg_arr)

# create or borrow aux variables
if shared_exec is None:
aux_arrays = [nd.zeros(s, context, dtype=t) for s, t in zip(aux_shapes, aux_types)]
else:
for j, arr in enumerate(shared_exec.aux_arrays):
assert aux_shapes[j] == arr.shape
assert aux_types[j] == arr.dtype
aux_arrays = shared_exec.aux_arrays[:]

executor = self.symbol.bind(ctx=context, args=arg_arrays,
args_grad=grad_arrays, aux_states=aux_arrays,
grad_req=self.grad_req, shared_exec=shared_exec)
# Get the total bytes allocated for this executor
executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req,
type_dict=input_types, param_names=self.param_names,
shared_exec=shared_exec,
shared_data_arrays=shared_data_arrays, **input_shapes)
self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1])
return executor

Expand Down
Loading