Skip to content

Commit

Permalink
Revert "HybridBlock.export() to return created filenames (apache#17758)"
Browse files Browse the repository at this point in the history
This reverts commit dfb1b88.
  • Loading branch information
rondogency authored Mar 18, 2020
1 parent dfb1b88 commit 2b75e7f
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 20 deletions.
7 changes: 2 additions & 5 deletions perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Block.pm
Original file line number Diff line number Diff line change
Expand Up @@ -1254,8 +1254,7 @@ method export(Str $path, :$epoch=0)
);
}
my $sym = $self->_cached_graph->[1];
my $sym_filename = "$path-symbol.json";
$sym->save($sym_filename);
$sym->save("$path-symbol.json");

my %arg_names = map { $_ => 1 } @{ $sym->list_arguments };
my %aux_names = map { $_ => 1 } @{ $sym->list_auxiliary_states };
Expand All @@ -1274,9 +1273,7 @@ method export(Str $path, :$epoch=0)
$arg_dict{ "aux:$name" } = $param->_reduce;
}
}
my $params_filename = sprintf('%s-%04d.params', $path, $epoch);
AI::MXNet::NDArray->save($params_filename, \%arg_dict);
return ($sym_filename, $params_filename);
AI::MXNet::NDArray->save(sprintf('%s-%04d.params', $path, $epoch), \%arg_dict);
}

__PACKAGE__->register('AI::MXNet::Gluon');
Expand Down
14 changes: 2 additions & 12 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,21 +1133,13 @@ def export(self, path, epoch=0, remove_amp_cast=True):
will be created, where xxxx is the 4 digits epoch number.
epoch : int
Epoch number of saved model.
Returns
-------
symbol_filename : str
Filename to which model symbols were saved, including `path` prefix.
params_filename : str
Filename to which model parameters were saved, including `path` prefix.
"""
if not self._cached_graph:
raise RuntimeError(
"Please first call block.hybridize() and then run forward with "
"this block at least once before calling export.")
sym = self._cached_graph[1]
sym_filename = '%s-symbol.json'%path
sym.save(sym_filename, remove_amp_cast=remove_amp_cast)
sym.save('%s-symbol.json'%path, remove_amp_cast=remove_amp_cast)

arg_names = set(sym.list_arguments())
aux_names = set(sym.list_auxiliary_states())
Expand All @@ -1159,9 +1151,7 @@ def export(self, path, epoch=0, remove_amp_cast=True):
assert name in aux_names
arg_dict['aux:%s'%name] = param._reduce()
save_fn = _mx_npx.save if is_np_array() else ndarray.save
params_filename = '%s-%04d.params'%(path, epoch)
save_fn(params_filename, arg_dict)
return (sym_filename, params_filename)
save_fn('%s-%04d.params'%(path, epoch), arg_dict)

def register_op_hook(self, callback, monitor_all=False):
"""Install op hook for block recursively.
Expand Down
4 changes: 1 addition & 3 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,9 +1191,7 @@ def test_export():
data = mx.nd.random.normal(shape=(1, 3, 32, 32))
out = model(data)

symbol_filename, params_filename = model.export('gluon')
assert symbol_filename == 'gluon-symbol.json'
assert params_filename == 'gluon-0000.params'
model.export('gluon')

module = mx.mod.Module.load('gluon', 0, label_names=None, context=ctx)
module.bind(data_shapes=[('data', data.shape)])
Expand Down

0 comments on commit 2b75e7f

Please sign in to comment.