Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

support setattr for some specified paramter of ParameterDict #9323

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,10 +593,10 @@ def reset_ctx(self, ctx):
for i in self.values():
i.reset_ctx(ctx)

def setattr(self, name, value):
"""Set an attribute to a new value for all Parameters.
def setattr(self, name, value, selected_param_names="all"):
"""Set an attribute to a new value for all(default) or some selected Parameters.

For example, set grad_req to null if you don't need gradient w.r.t a
For example, set grad_req to null if you don't need gradient w.r.t all
model's Parameters::

model.collect_params().setattr('grad_req', 'null')
Expand All @@ -605,15 +605,27 @@ def setattr(self, name, value):

model.collect_params().setattr('lr_mult', 0.5)

or set grad_req to null if you don't need gradient w.r.t model's
Parameters in ['conv1_weight', 'conv1_bias']::

model.collect_params().setattr('grad_req', 'null', ['conv1_weight', 'conv1_bias'])

Parameters
----------
name : str
Name of the attribute.
value : valid type for attribute name
The new value for the attribute.
selected_param_names : str or list
The selected paramters to be setattr.
"""
for i in self.values():
setattr(i, name, value)
if selected_param_names == "all":
selected_param_names = self.keys()
elif isinstance(selected_param_names, str):
selected_param_names = [selected_param_names]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"all" can be a valid parameter name.

Copy link
Contributor Author

@tornadomeet tornadomeet Jan 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, '"all" is bad here.

for k, v in self.items():
if k in selected_param_names:
setattr(v, name, value)

def save(self, filename, strip_prefix=''):
"""Save parameters to file.
Expand Down