-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Gradient multiplier (contrib) operator #13632
Gradient multiplier (contrib) operator #13632
Conversation
Missing test for backwards pass
@mxnet-label-bot add[Operator, pr-awaiting-review] |
Shouldn't we have a more generic gradient multiplier operator? What d you think? |
That is certainly possible, shall I rewrite it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing the op. The forward and backward logic can utilize existing kernels such as those in identity and broadcast_scalar_mul.
@szha Thanks for the feedback, good points. However, I have a hard time finding those kernels, to me they seem to be deeply integrated into other operators. Could you please point me to the right functions? |
@szha Dumped the header file and used forward and backward from identity / scalar_mul. |
.set_attr_parser([](NodeAttrs* attrs) { | ||
attrs->parsed = std::stod(attrs->dict["scalar"]); | ||
}) | ||
.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you also plan to support sparse inputs/outputs? If not, you don't have to register FInferStorageType
and FComputeEx
(by default it infers dense storage and uses FCompute).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the operator is very simple I thought it would be easy to support sparse data as well. What do I need to change to have full support?
Thinking to rename the operator to gradient multiplier. Any thoughts? |
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method has no indentation. Is this expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does, not sure why github shows it wrong
Retrigger flaky test
[](const NodeAttrs& attrs){ | ||
return std::vector<bool>{true}; | ||
}) | ||
.add_argument("scalar", "float", "scalar input"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider making this description more informative (e.g. X multiplier)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, updated.
Improved the description of the scalar multiplier
@szha @ThomasDelteil merge? |
* Added the gradient reversal contrib operator Missing test for backwards pass * Fixed linting errors * Fixed forward test * Added random forward / backward test for gradient reversal * Update test_contrib_operator.py * Fixed typo in gradient reversal op description * Replace forward code with the identitiy implementation * Fixed typos in function docs * Changed default behavior to identity * Replaced backward code with scalar_mul * Fixed backward operator and unit test * Renamed operator to gradient multiplier * Update test_contrib_operator.py Retrigger flaky test * Update gradient_multiplier_op.cc Improved the description of the scalar multiplier
* Added the gradient reversal contrib operator Missing test for backwards pass * Fixed linting errors * Fixed forward test * Added random forward / backward test for gradient reversal * Update test_contrib_operator.py * Fixed typo in gradient reversal op description * Replace forward code with the identitiy implementation * Fixed typos in function docs * Changed default behavior to identity * Replaced backward code with scalar_mul * Fixed backward operator and unit test * Renamed operator to gradient multiplier * Update test_contrib_operator.py Retrigger flaky test * Update gradient_multiplier_op.cc Improved the description of the scalar multiplier
* Added the gradient reversal contrib operator Missing test for backwards pass * Fixed linting errors * Fixed forward test * Added random forward / backward test for gradient reversal * Update test_contrib_operator.py * Fixed typo in gradient reversal op description * Replace forward code with the identitiy implementation * Fixed typos in function docs * Changed default behavior to identity * Replaced backward code with scalar_mul * Fixed backward operator and unit test * Renamed operator to gradient multiplier * Update test_contrib_operator.py Retrigger flaky test * Update gradient_multiplier_op.cc Improved the description of the scalar multiplier
Description
Adds the gradient multiplier operator that is mostly used in unsupervised adversarial domain adaptation.
In short: on forward pass it acts as identity transform; on backwards it multiplies the gradients with a scalar constant (lambda).
See full description here: http://proceedings.mlr.press/v37/ganin15.pdf
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.