Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Partitioning Gluon HybridBlocks #15969

Merged
merged 20 commits into from
Feb 6, 2020
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,8 @@ def __init__(self, prefix=None, params=None):
self._flags = []
self._callback = None
self._monitor_all = False
self._backend = None
self._backend_args = {}

def __setattr__(self, name, value):
"""Registers parameters."""
Expand Down Expand Up @@ -935,7 +937,6 @@ def _build_cache(self, *args):
data_names = {data.name: i for i, data in enumerate(data)}
params = self.collect_params()
input_names = out.list_inputs()

param_names = set(params.keys())
expected_names = set(input_names)
for name in expected_names:
Expand All @@ -954,6 +955,26 @@ def _build_cache(self, *args):
unused = ', '.join(list(param_names - set(used_param_names)))
warnings.warn("Parameter %s is not used by any computation. "
"Is this intended?"%unused, stacklevel=4)
if self._backend:
ctx = args[0].context
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
arg_array = []
# Build args list if params are initialized
try:
for name in out.list_arguments():
if name in data_names.keys():
arg_array.append(args[data_names[name]])
else:
arg_array.append(params.get(name).data())
# Exceptions are thrown, because the params are not initialized.
# In this case, we don't care and will just not use the params.
except DeferredInitializationError:
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
self._deferred_infer_shape(*args)
# To do: get arg_array.
arg_array = None
except RuntimeError:
arg_array = None
# Partition the graph.
out = out.optimize_for(self._backend, arg_array, ctx, **self._backend_args)

data_indices = []
param_indices = []
Expand Down Expand Up @@ -1040,7 +1061,12 @@ def register_child(self, block, name=None):
super(HybridBlock, self).register_child(block, name)
self._clear_cached_op()

def hybridize(self, active=True, **kwargs):
def hybridize(self, active=True, backend=None, backend_args=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm. This is specific for hybridblock? Can we add documentation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on prior discussion with @samskalicky, documentation should describe what happens if input shapes change in subsequent forward calls. (Ie. currently no repartitioning is triggered).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the description for hybridblock hybridize().

Copy link
Member

@eric-haibin-lin eric-haibin-lin Feb 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't see it - did you push?

Copy link
Contributor

@guanxinq guanxinq Feb 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed. Could you help review the description?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the concept of SubgraphBackendRegistry, PostPartition, etc are new and not very straightforward to users. Is it possible to also add a link to any tutorial that teaches user how to register a subgraph backend?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, we plan to add a tutorial as part of our next PR and link it to the example. I have put together the TODO list for the next PR in this github issue #17532 .

self._backend = backend
if backend_args is None:
self._backend_args = {}
samskalicky marked this conversation as resolved.
Show resolved Hide resolved
else:
self._backend_args = backend_args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to enforce users to pass a dictionary (since user may pass a string), so we need to add a check below before assign it to _backend_args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets add something like

if isinstance(backend_args, dict)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still need an else block.

else:
    self._backend_args = {}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, we don't need else as it is initialized to {}.

self._active = active
self._flags = list(kwargs.items())
self._clear_cached_op()
Expand Down Expand Up @@ -1160,7 +1186,6 @@ def forward(self, x, *args):
params = {k: v.data(ctx) for k, v in self._reg_params.items()}

return self.hybrid_forward(ndarray, x, *args, **params)

params = {i: j.var() for i, j in self._reg_params.items()}
with self.name_scope():
return self.hybrid_forward(symbol, x, *args, **params)
Expand Down
Loading