Skip to content

Commit

Permalink
support regex of collect_params() (apache#9348)
Browse files Browse the repository at this point in the history
* support regex of collect_params()

* fix pylint

* change default value && make select as a sigle reg

* fix if select is None, then will not do any regex matching

* update regex compile && add test

* Update block.py

* support regex of collect_params()

* fix pylint

* change default value && make select as a sigle reg

* fix if select is None, then will not do any regex matching

* update regex compile && add test
  • Loading branch information
tornadomeet authored and Nan Zhu committed Jan 16, 2018
1 parent 4be20cc commit 10cd9dd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
35 changes: 30 additions & 5 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import copy
import warnings
import re

from .. import symbol, ndarray, initializer
from ..symbol import Symbol
Expand Down Expand Up @@ -227,13 +228,38 @@ def params(self):
children's parameters)."""
return self._params

def collect_params(self):
def collect_params(self, select=None):
"""Returns a :py:class:`ParameterDict` containing this :py:class:`Block` and all of its
children's Parameters."""
children's Parameters(default), also can returns the select :py:class:`ParameterDict`
which match some given regular expressions.
For example, collect the specified parameter in ['conv1_weight', 'conv1_bias', 'fc_weight',
'fc_bias']::
model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias')
or collect all paramters which their name ends with 'weight' or 'bias', this can be done
using regular expressions::
model.collect_params('.*weight|.*bias')
Parameters
----------
select : str
regular expressions
Returns
-------
The selected :py:class:`ParameterDict`
"""
ret = ParameterDict(self._params.prefix)
ret.update(self.params)
if not select:
ret.update(self.params)
else:
pattern = re.compile(select)
ret.update({name:value for name, value in self.params.items() if pattern.match(name)})
for cld in self._children:
ret.update(cld.collect_params())
ret.update(cld.collect_params(select=select))
return ret

def save_params(self, filename):
Expand Down Expand Up @@ -261,7 +287,6 @@ def load_params(self, filename, ctx, allow_missing=False,
self.collect_params().load(filename, ctx, allow_missing, ignore_extra,
self.prefix)


def register_child(self, block):
"""Registers block as a child of self. :py:class:`Block` s assigned to self as
attributes will be registered automatically."""
Expand Down
13 changes: 12 additions & 1 deletion tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,18 @@ def __init__(self, **kwargs):
assert 'numpy.float32' in lines[1]
assert lines[2] == ')'


def test_collect_paramters():
net = nn.HybridSequential(prefix="test_")
with net.name_scope():
net.add(nn.Conv2D(10, 3))
net.add(nn.Dense(10, activation='relu'))
assert set(net.collect_params().keys()) == \
set(['test_conv0_weight', 'test_conv0_bias','test_dense0_weight','test_dense0_bias'])
assert set(net.collect_params('.*weight').keys()) == \
set(['test_conv0_weight', 'test_dense0_weight'])
assert set(net.collect_params('test_conv0_bias|test_dense0_bias').keys()) == \
set(['test_conv0_bias', 'test_dense0_bias'])

def test_basic():
model = nn.Sequential()
model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False))
Expand Down

0 comments on commit 10cd9dd

Please sign in to comment.