-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Partitioning Gluon HybridBlocks #15969
Partitioning Gluon HybridBlocks #15969
Conversation
Should this be run prior to training or prior to exporting the HybridBlock? Could/Should it be run automatically? Edit: Based on offline discussion, automatic optimization could be run if the backend can be detected automatically. We would not want to automatically export an optimized symbol. |
Waiting on #15886 to be merged to re-use optimize_for API call on symbol |
@mxnet-label-bot add [pr-awaiting-review] |
3a64283
to
abea125
Compare
abea125
to
f90b9ab
Compare
f90b9ab
to
d286167
Compare
d286167
to
4b3d076
Compare
Thanks @guanxinq for the latest update! I think we need to call optimize for again here too when we create a SymbolBlock, otherwise the partitioning wont happen: |
@eric-haibin-lin we shouldnt partition inside |
b6636da
to
7228343
Compare
90daf39
to
80dfaed
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Good job @guanxinq @samskalicky
@leezu your comments have been addressed. can you please review again? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes in python/mxnet/gluon/block.py
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Request for documentation. Otherwise looks good to me
@@ -1040,7 +1052,12 @@ def register_child(self, block, name=None): | |||
super(HybridBlock, self).register_child(block, name) | |||
self._clear_cached_op() | |||
|
|||
def hybridize(self, active=True, **kwargs): | |||
def hybridize(self, active=True, backend=None, backend_args=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm. This is specific for hybridblock? Can we add documentation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on prior discussion with @samskalicky, documentation should describe what happens if input shapes change in subsequent forward calls. (Ie. currently no repartitioning is triggered).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the description for hybridblock hybridize().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't see it - did you push?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just pushed. Could you help review the description?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the concept of SubgraphBackendRegistry, PostPartition, etc are new and not very straightforward to users. Is it possible to also add a link to any tutorial that teaches user how to register a subgraph backend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, we plan to add a tutorial as part of our next PR and link it to the example. I have put together the TODO list for the next PR in this github issue #17532 .
if backend_args is None: | ||
self._backend_args = {} | ||
else: | ||
self._backend_args = backend_args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to enforce users to pass a dictionary (since user may pass a string), so we need to add a check below before assign it to _backend_args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets add something like
if isinstance(backend_args, dict)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we still need an else
block.
else:
self._backend_args = {}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, we don't need else
as it is initialized to {}.
71c88a7
to
5aaff0b
Compare
python/mxnet/gluon/block.py
Outdated
Whether to turn hybrid on or off. | ||
backend : str | ||
The name of backend, as registered in `SubgraphBackendRegistry`, default None | ||
backend_args : dict of optional arguments, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: optional twice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
b332edd
to
966a383
Compare
python/mxnet/gluon/block.py
Outdated
but slower. | ||
""" | ||
""" Please refer description of HybridBlock hybridize(). | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you get rid of trailing whitespaces
966a383
to
14dcc14
Compare
agree to address doc issue in a future PR
* stub for optimizing Gluon block * Init commit for Gluon hybridblocks partition(sample test included) * Added tests for Gluon and refactored tests * call optimize_for in _build_cache * Pass in 4 paras for gluon optimize_for * Fixed auxiliary state issue, args issue and added 2 tests. * Fixed auxiliary state issue, args issue and added 2 tests. * changed parameter check * refactored param init since needed for partitioning * fixed whitespace * fixed flattened args * fixed sanity & updated tests * fixed whitespace * added context support in tests * Fix python2 errors * clean code remove cargs * Add hybridblock hybridize() description Co-authored-by: guanxinq <[email protected]>
* stub for optimizing Gluon block * Init commit for Gluon hybridblocks partition(sample test included) * Added tests for Gluon and refactored tests * call optimize_for in _build_cache * Pass in 4 paras for gluon optimize_for * Fixed auxiliary state issue, args issue and added 2 tests. * Fixed auxiliary state issue, args issue and added 2 tests. * changed parameter check * refactored param init since needed for partitioning * fixed whitespace * fixed flattened args * fixed sanity & updated tests * fixed whitespace * added context support in tests * Fix python2 errors * clean code remove cargs * Add hybridblock hybridize() description Co-authored-by: guanxinq <[email protected]>
Description
Adds partitioning support for Gluon HybridBlocks. This is a continuation of the partitioning support for Symbol #15886
Design
In Gluon, a HybridBlock contains a Symbol after hybridizing and executing a forward pass. The Symbol is contained and managed within the block. The partitioning logic will be integrated into the hybridize flow.
There are many ways to create a Gluon Hybrid block and after this process, users call the
hybridize()
function to start the flow. We add two new arguments to support partitioning:backend
which is a string corresponding to thesubgraph_backend
name, andopt_args
which is a map of arguments that should be passed to thesubgraph_property
during partitioning. These values are stored until used during the first inference call. Heres an example specifying these new arguments:Notice that in the above example, the new arguments have the same value as the example in #15886. These arguments will ultimately be passed to a call to the
optimize_for
API.In the Gluon, the hybridize flow starts before the first inference. The Symbol object is created in the
_build_cache
function:https://github.com/apache/incubator-mxnet/blob/bd67723da96e6d36e72c9a42535a4fe68f234a71/python/mxnet/gluon/block.py#L933-L934
We'll add a new line of code to partition it and pass the new arguments from the hybridize call:
This supports the partitioning flow without shape/type propagation. Some backends do not need shapes and types so there is no reason to require it for all backends. Other backends will require shapes and types in order to partition the model correctly (examples being backends that only support float16 and not float32, or only support small shapes and not large ones).
For the partitioning with with shape/type propagation we can get the args to the model from the parameters in the Gluon block. By default, the initialization of Gluon parameters may be delayed. If the parameters are not initialized yet, we'll continue with the flow shown in the code snippet above that does not infer shapes/types.
In Gluon users can force initialization (see this guide) and if all parameters are initialized after calling
hybridize
and setting thebackend
name, we will pass the arguments from the Gluon parameters into theoptimize_for
API to infer shapes/types before partitioning. This gives the user the control over partitioning in the same way that they do for Symbol API. Heres a code snippet to produce the arg array and pass it tooptimize_for
:The context will be gathered from the inputs to the model like this:
Context is required to infer storage types.
Note
Partitioning is done as part of the hybridize flow, when building the cachedOp. So if shapes change between infer calls the graph is not re-partitioned.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes