diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Block.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Block.pm index c4ac933ad1a1..29cacffcdd05 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Block.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Block.pm @@ -1254,7 +1254,8 @@ method export(Str $path, :$epoch=0) ); } my $sym = $self->_cached_graph->[1]; - $sym->save("$path-symbol.json"); + my $sym_filename = "$path-symbol.json"; + $sym->save($sym_filename); my %arg_names = map { $_ => 1 } @{ $sym->list_arguments }; my %aux_names = map { $_ => 1 } @{ $sym->list_auxiliary_states }; @@ -1273,7 +1274,9 @@ method export(Str $path, :$epoch=0) $arg_dict{ "aux:$name" } = $param->_reduce; } } - AI::MXNet::NDArray->save(sprintf('%s-%04d.params', $path, $epoch), \%arg_dict); + my $params_filename = sprintf('%s-%04d.params', $path, $epoch); + AI::MXNet::NDArray->save($params_filename, \%arg_dict); + return ($sym_filename, $params_filename); } __PACKAGE__->register('AI::MXNet::Gluon'); diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index a9448801a681..da76b3efcd87 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1133,13 +1133,21 @@ def export(self, path, epoch=0, remove_amp_cast=True): will be created, where xxxx is the 4 digits epoch number. epoch : int Epoch number of saved model. + + Returns + ------- + symbol_filename : str + Filename to which model symbols were saved, including `path` prefix. + params_filename : str + Filename to which model parameters were saved, including `path` prefix. """ if not self._cached_graph: raise RuntimeError( "Please first call block.hybridize() and then run forward with " "this block at least once before calling export.") sym = self._cached_graph[1] - sym.save('%s-symbol.json'%path, remove_amp_cast=remove_amp_cast) + sym_filename = '%s-symbol.json'%path + sym.save(sym_filename, remove_amp_cast=remove_amp_cast) arg_names = set(sym.list_arguments()) aux_names = set(sym.list_auxiliary_states()) @@ -1151,7 +1159,9 @@ def export(self, path, epoch=0, remove_amp_cast=True): assert name in aux_names arg_dict['aux:%s'%name] = param._reduce() save_fn = _mx_npx.save if is_np_array() else ndarray.save - save_fn('%s-%04d.params'%(path, epoch), arg_dict) + params_filename = '%s-%04d.params'%(path, epoch) + save_fn(params_filename, arg_dict) + return (sym_filename, params_filename) def register_op_hook(self, callback, monitor_all=False): """Install op hook for block recursively. diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index a02682557954..453c31ebc5de 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1191,7 +1191,9 @@ def test_export(): data = mx.nd.random.normal(shape=(1, 3, 32, 32)) out = model(data) - model.export('gluon') + symbol_filename, params_filename = model.export('gluon') + assert symbol_filename == 'gluon-symbol.json' + assert params_filename == 'gluon-0000.params' module = mx.mod.Module.load('gluon', 0, label_names=None, context=ctx) module.bind(data_shapes=[('data', data.shape)])