Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
HybridBlock.export() to return created filenames (#17758)
Browse files Browse the repository at this point in the history
* [Gluon] Return filenames from HybridBlock.export (#1579)

* [Gluon] Unittest filenames from HybridBlock.export (#1579)

* [Gluon] Perl HybridBlock.export return filenames (#1579)

* [Gluon] HybridBlock.export docstring grammar fix

* Linting - fix trailing whitespace
  • Loading branch information
athewsey authored Mar 18, 2020
1 parent 5996544 commit dfb1b88
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
7 changes: 5 additions & 2 deletions perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Block.pm
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,8 @@ method export(Str $path, :$epoch=0)
);
}
my $sym = $self->_cached_graph->[1];
$sym->save("$path-symbol.json");
my $sym_filename = "$path-symbol.json";
$sym->save($sym_filename);

my %arg_names = map { $_ => 1 } @{ $sym->list_arguments };
my %aux_names = map { $_ => 1 } @{ $sym->list_auxiliary_states };
Expand All @@ -1273,7 +1274,9 @@ method export(Str $path, :$epoch=0)
$arg_dict{ "aux:$name" } = $param->_reduce;
}
}
AI::MXNet::NDArray->save(sprintf('%s-%04d.params', $path, $epoch), \%arg_dict);
my $params_filename = sprintf('%s-%04d.params', $path, $epoch);
AI::MXNet::NDArray->save($params_filename, \%arg_dict);
return ($sym_filename, $params_filename);
}

__PACKAGE__->register('AI::MXNet::Gluon');
Expand Down
14 changes: 12 additions & 2 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,13 +1133,21 @@ def export(self, path, epoch=0, remove_amp_cast=True):
will be created, where xxxx is the 4 digits epoch number.
epoch : int
Epoch number of saved model.
Returns
-------
symbol_filename : str
Filename to which model symbols were saved, including `path` prefix.
params_filename : str
Filename to which model parameters were saved, including `path` prefix.
"""
if not self._cached_graph:
raise RuntimeError(
"Please first call block.hybridize() and then run forward with "
"this block at least once before calling export.")
sym = self._cached_graph[1]
sym.save('%s-symbol.json'%path, remove_amp_cast=remove_amp_cast)
sym_filename = '%s-symbol.json'%path
sym.save(sym_filename, remove_amp_cast=remove_amp_cast)

arg_names = set(sym.list_arguments())
aux_names = set(sym.list_auxiliary_states())
Expand All @@ -1151,7 +1159,9 @@ def export(self, path, epoch=0, remove_amp_cast=True):
assert name in aux_names
arg_dict['aux:%s'%name] = param._reduce()
save_fn = _mx_npx.save if is_np_array() else ndarray.save
save_fn('%s-%04d.params'%(path, epoch), arg_dict)
params_filename = '%s-%04d.params'%(path, epoch)
save_fn(params_filename, arg_dict)
return (sym_filename, params_filename)

def register_op_hook(self, callback, monitor_all=False):
"""Install op hook for block recursively.
Expand Down
4 changes: 3 additions & 1 deletion tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,7 +1191,9 @@ def test_export():
data = mx.nd.random.normal(shape=(1, 3, 32, 32))
out = model(data)

model.export('gluon')
symbol_filename, params_filename = model.export('gluon')
assert symbol_filename == 'gluon-symbol.json'
assert params_filename == 'gluon-0000.params'

module = mx.mod.Module.load('gluon', 0, label_names=None, context=ctx)
module.bind(data_shapes=[('data', data.shape)])
Expand Down

0 comments on commit dfb1b88

Please sign in to comment.