-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Partition API adding and deleting new params to Block and Symbol #18405
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1040,41 +1040,69 @@ def _build_cache(self, *args): | |
warnings.warn("Parameter %s is not used by any computation. " | ||
"Is this intended?"%unused, stacklevel=4) | ||
|
||
data_indices = [] | ||
param_indices = [] | ||
self._cached_op_args = [] | ||
for i, name in enumerate(input_names): | ||
if name in data_names: | ||
data_indices.append(i) | ||
self._cached_op_args.append((True, data_names[name])) | ||
else: | ||
param_indices.append(i) | ||
self._cached_op_args.append((False, params[name])) | ||
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ | ||
self._flags | ||
|
||
args, _ = _flatten(args, "input") | ||
try: | ||
for is_arg, i in self._cached_op_args: | ||
if not is_arg: | ||
i.data() | ||
for name in input_names: | ||
if name in params: | ||
params[name].data() | ||
except DeferredInitializationError: | ||
self._deferred_infer_shape(*args) | ||
for is_arg, i in self._cached_op_args: | ||
if not is_arg: | ||
i._finish_deferred_init() | ||
for name in input_names: | ||
if name in params: | ||
params[name]._finish_deferred_init() | ||
|
||
arg_dict, aux_dict = dict(), dict() | ||
if self._backend: | ||
ctx = args[0].context | ||
# get list of params in the order of out.list_arguments | ||
arg_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data() | ||
for name in out.list_arguments()} | ||
aux_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data() | ||
for name in out.list_auxiliary_states()} | ||
arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() | ||
for name in out.list_arguments()}) | ||
aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() | ||
for name in out.list_auxiliary_states()}) | ||
# Partition the graph. | ||
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts) | ||
|
||
#update cached graph with partitioned graph | ||
self._cached_graph = data, out | ||
|
||
input_names = out.list_inputs() | ||
data_indices = [] | ||
param_indices = [] | ||
|
||
# In the default case, _cached_ops_args contains all the parameters from params (the sets are identical) | ||
# In the case of Partition API optimized graph _cached_ops_args might contain some parameters from params, | ||
# might contain some new parameters created during optimization and added to `arg_dict/aux_dict`, | ||
# and might not contain some parameters that were deleted during optimization. | ||
self._cached_op_args = [] | ||
for i, name in enumerate(input_names): | ||
pair = None | ||
if name in data_names: | ||
data_indices.append(i) | ||
pair = (True, data_names[name]) | ||
else: | ||
param_indices.append(i) | ||
if name in params: | ||
param = params[name] | ||
else: | ||
# The param is missing from the original params dictionary, which means the param must have | ||
# been added by the Partition API backend | ||
if name in arg_dict or name: | ||
param_data = arg_dict[name] | ||
elif name in aux_dict: | ||
param_data = aux_dict[name] | ||
else: | ||
raise RuntimeError('A parameter was added to the graph during optimization but it was not ' | ||
'added to the parameter dicts.\n' | ||
'Please check the backend.') | ||
|
||
param = Parameter(name) | ||
param._load_init(param_data, args[0].context) | ||
pair = (False, param) | ||
|
||
self._cached_op_args.append(pair) | ||
|
||
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ | ||
self._flags | ||
self._cached_op = ndarray.CachedOp(out, flags) | ||
|
||
|
||
|
@@ -1321,12 +1349,14 @@ def export(self, path, epoch=0, remove_amp_cast=True): | |
arg_names = set(sym.list_arguments()) | ||
aux_names = set(sym.list_auxiliary_states()) | ||
arg_dict = {} | ||
for param in self.collect_params().values(): | ||
if param.name in arg_names: | ||
arg_dict['arg:%s'%param.name] = param._reduce() | ||
else: | ||
assert param.name in aux_names | ||
arg_dict['aux:%s'%param.name] = param._reduce() | ||
for is_arg, param in self._cached_op_args: | ||
if not is_arg: | ||
name = param.name | ||
if name in arg_names: | ||
arg_dict['arg:{}'.format(name)] = param._reduce() | ||
else: | ||
assert name in aux_names | ||
arg_dict['aux:{}'.format(name)] = param._reduce() | ||
save_fn = _mx_npx.save if is_np_array() else ndarray.save | ||
params_filename = '%s-%04d.params'%(path, epoch) | ||
save_fn(params_filename, arg_dict) | ||
|
@@ -1437,6 +1467,23 @@ def hybrid_forward(self, F, x, *args, **kwargs): | |
# pylint: disable= invalid-name | ||
raise NotImplementedError | ||
|
||
def reset_ctx(self, ctx): | ||
"""Re-assign all Parameters to other contexts. If the Block is hybridized, it will reset the _cached_op_args. | ||
|
||
Parameters | ||
---------- | ||
ctx : Context or list of Context, default :py:meth:`context.current_context()`. | ||
Assign Parameter to given context. If ctx is a list of Context, a | ||
copy will be made for each context. | ||
""" | ||
params = self.collect_params() | ||
if self._cached_op: | ||
for p in self._cached_op_args: | ||
# resetting parameters creating by the partitioning backend | ||
if p.name not in params: | ||
p.reset_ctx(ctx) | ||
for p in params.values(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain why we need to loop over There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although i guess if we delete a param, then it will still be in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I will add some comments to clarify |
||
p.reset_ctx(ctx) | ||
|
||
class SymbolBlock(HybridBlock): | ||
"""Construct block from symbol. This is useful for using pre-trained models | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you initialize with
dict()
and then call update on an empty dictionary instead of just assigning?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arg_dict
andaux_dict
could otherwise be undefined below theassert
in line 1083. This could be a SyntaxError or linter error?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right, also Python scopes rules are sometimes a bit unsettling. So I thought that this would make it clearer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, thanks!