Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Partitioning Gluon HybridBlocks #15969

Merged
merged 20 commits into from
Feb 6, 2020
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,8 @@ def __init__(self, prefix=None, params=None):
self._flags = []
self._callback = None
self._monitor_all = False
self._backend = None
self._backend_args = {}

def __setattr__(self, name, value):
"""Registers parameters."""
Expand Down Expand Up @@ -935,7 +937,6 @@ def _build_cache(self, *args):
data_names = {data.name: i for i, data in enumerate(data)}
params = self.collect_params()
input_names = out.list_inputs()

param_names = set(params.keys())
expected_names = set(input_names)
for name in expected_names:
Expand Down Expand Up @@ -967,6 +968,26 @@ def _build_cache(self, *args):
self._cached_op_args.append((False, params[name]))
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags

args, _ = _flatten(args, "input")
try:
for is_arg, i in self._cached_op_args:
if not is_arg:
i.data()
except DeferredInitializationError:
self._deferred_infer_shape(*args)
for is_arg, i in self._cached_op_args:
if not is_arg:
i._finish_deferred_init()

if self._backend:
ctx = args[0].context
# get list of params in the order of out.list_arguments
arg_array = [args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()]
# Partition the graph.
out = out.optimize_for(self._backend, arg_array, ctx, **self._backend_args)

samskalicky marked this conversation as resolved.
Show resolved Hide resolved
self._cached_op = ndarray.CachedOp(out, flags)

def _deferred_infer_shape(self, *args):
Expand Down Expand Up @@ -1008,19 +1029,10 @@ def _call_cached_op(self, *args):
raise ValueError("The argument structure of HybridBlock does not match"
" the cached version. Stored format = {}, input format = {}"
.format(fmt, self._in_format))

args_without_none = [ele for ele in args if ele is not None]
try:
cargs = [args_without_none[i] if is_arg else i.data()
for is_arg, i in self._cached_op_args]
except DeferredInitializationError:
self._deferred_infer_shape(*args)
cargs = []
for is_arg, i in self._cached_op_args:
if is_arg:
cargs.append(args_without_none[i])
else:
i._finish_deferred_init()
cargs.append(i.data())
cargs = [args_without_none[i] if is_arg else i.data()
for is_arg, i in self._cached_op_args]
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
Expand All @@ -1040,7 +1052,12 @@ def register_child(self, block, name=None):
super(HybridBlock, self).register_child(block, name)
self._clear_cached_op()

def hybridize(self, active=True, **kwargs):
def hybridize(self, active=True, backend=None, backend_args=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm. This is specific for hybridblock? Can we add documentation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on prior discussion with @samskalicky, documentation should describe what happens if input shapes change in subsequent forward calls. (Ie. currently no repartitioning is triggered).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the description for hybridblock hybridize().

Copy link
Member

@eric-haibin-lin eric-haibin-lin Feb 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't see it - did you push?

Copy link
Contributor

@guanxinq guanxinq Feb 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed. Could you help review the description?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the concept of SubgraphBackendRegistry, PostPartition, etc are new and not very straightforward to users. Is it possible to also add a link to any tutorial that teaches user how to register a subgraph backend?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, we plan to add a tutorial as part of our next PR and link it to the example. I have put together the TODO list for the next PR in this github issue #17532 .

self._backend = backend
if backend_args is None:
self._backend_args = {}
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
else:
self._backend_args = backend_args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to enforce users to pass a dictionary (since user may pass a string), so we need to add a check below before assign it to _backend_args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets add something like

if isinstance(backend_args, dict)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still need an else block.

else:
    self._backend_args = {}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, we don't need else as it is initialized to {}.

self._active = active
self._flags = list(kwargs.items())
self._clear_cached_op()
Expand Down Expand Up @@ -1160,7 +1177,6 @@ def forward(self, x, *args):
params = {k: v.data(ctx) for k, v in self._reg_params.items()}

return self.hybrid_forward(ndarray, x, *args, **params)

params = {i: j.var() for i, j in self._reg_params.items()}
with self.name_scope():
return self.hybrid_forward(symbol, x, *args, **params)
Expand Down
Loading