Skip to content

Commit

Permalink
Fix clipByValue gradient for values less than min or greater than max (
Browse files Browse the repository at this point in the history
  • Loading branch information
Lewuathe authored and Nikhil Thorat committed Jun 5, 2018
1 parent 0f81255 commit e916556
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/ops/unary_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,9 @@ export class UnaryOps {

const grad = (dy: T) => {
return {
// TODO(cais): Fix gradients for the case where x = min or x
// = max.
x: () => dy.where(
x.greater(ops.scalar(clipValueMin))
.logicalAnd(x.less(ops.scalar(clipValueMax))),
x.greaterEqual(ops.scalar(clipValueMin))
.logicalAnd(x.lessEqual(ops.scalar(clipValueMax))),
zerosLike(dy)) as T,
};
};
Expand Down
12 changes: 12 additions & 0 deletions src/ops/unary_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2090,6 +2090,18 @@ describeWithFlags('clip', ALL_ENVS, () => {
expectArraysClose(gradients, [0, 0, 500]);
});

it('derivative: 1D tensor with max or min value', () => {
const min = -1;
const max = 2;
const x = tf.tensor1d([-1, 1, 2, 3]);
const dy = tf.tensor1d([1, 10, 100, 1000]);
const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy);

expect(gradients.shape).toEqual(x.shape);
expect(gradients.dtype).toEqual('float32');
expectArraysClose(gradients, [1, 10, 100, 0]);
});

it('derivative: scalar', () => {
const min = -1;
const max = 2;
Expand Down

0 comments on commit e916556

Please sign in to comment.