Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

support regex of collect_params() #9348

Merged
merged 12 commits into from
Jan 12, 2018
35 changes: 30 additions & 5 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import copy
import warnings
import re

from .. import symbol, ndarray, initializer
from ..symbol import Symbol
Expand Down Expand Up @@ -227,13 +228,38 @@ def params(self):
children's parameters)."""
return self._params

def collect_params(self):
def collect_params(self, select=None):
"""Returns a :py:class:`ParameterDict` containing this :py:class:`Block` and all of its
children's Parameters."""
children's Parameters(default), also can returns the select :py:class:`ParameterDict`
which match some given regular expressions.

For example, collect the specified parameter in ['conv1_weight', 'conv1_bias', 'fc_weight',
'fc_bias']::

model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias')

or collect all paramters which their name ends with 'weight' or 'bias', this can be done
using regular expressions::

model.collect_params('.*weight|.*bias')

Parameters
----------
select : str
regular expressions

Returns
-------
The selected :py:class:`ParameterDict`
"""
ret = ParameterDict(self._params.prefix)
ret.update(self.params)
if not select:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if select = [] or {} or '' or () then not select will be True, may be wrong here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, and it should be fine as long as both empty string and None returns true.

ret.update(self.params)
else:
pattern = re.compile(select)
ret.update({name:value for name, value in self.params.items() if pattern.match(name)})
for cld in self._children:
ret.update(cld.collect_params())
ret.update(cld.collect_params(select=select))
return ret

def save_params(self, filename):
Expand Down Expand Up @@ -261,7 +287,6 @@ def load_params(self, filename, ctx, allow_missing=False,
self.collect_params().load(filename, ctx, allow_missing, ignore_extra,
self.prefix)


def register_child(self, block):
"""Registers block as a child of self. :py:class:`Block` s assigned to self as
attributes will be registered automatically."""
Expand Down
13 changes: 12 additions & 1 deletion tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,18 @@ def __init__(self, **kwargs):
assert 'numpy.float32' in lines[1]
assert lines[2] == ')'


def test_collect_paramters():
net = nn.HybridSequential(prefix="test_")
with net.name_scope():
net.add(nn.Conv2D(10, 3))
net.add(nn.Dense(10, activation='relu'))
assert set(net.collect_params().keys()) == \
set(['test_conv0_weight', 'test_conv0_bias','test_dense0_weight','test_dense0_bias'])
assert set(net.collect_params('.*weight').keys()) == \
set(['test_conv0_weight', 'test_dense0_weight'])
assert set(net.collect_params('test_conv0_bias|test_dense0_bias').keys()) == \
set(['test_conv0_bias', 'test_dense0_bias'])

def test_basic():
model = nn.Sequential()
model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False))
Expand Down