From e916556beb52e561856e2c8354f3ccf3c852d84d Mon Sep 17 00:00:00 2001 From: Kai Sasaki Date: Wed, 6 Jun 2018 00:18:23 +0900 Subject: [PATCH] Fix clipByValue gradient for values less than min or greater than max (#1076) BUG --- src/ops/unary_ops.ts | 6 ++---- src/ops/unary_ops_test.ts | 12 ++++++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/ops/unary_ops.ts b/src/ops/unary_ops.ts index 0150910e8d..fd7380adaf 100644 --- a/src/ops/unary_ops.ts +++ b/src/ops/unary_ops.ts @@ -357,11 +357,9 @@ export class UnaryOps { const grad = (dy: T) => { return { - // TODO(cais): Fix gradients for the case where x = min or x - // = max. x: () => dy.where( - x.greater(ops.scalar(clipValueMin)) - .logicalAnd(x.less(ops.scalar(clipValueMax))), + x.greaterEqual(ops.scalar(clipValueMin)) + .logicalAnd(x.lessEqual(ops.scalar(clipValueMax))), zerosLike(dy)) as T, }; }; diff --git a/src/ops/unary_ops_test.ts b/src/ops/unary_ops_test.ts index 5b332201fa..acda7df149 100644 --- a/src/ops/unary_ops_test.ts +++ b/src/ops/unary_ops_test.ts @@ -2090,6 +2090,18 @@ describeWithFlags('clip', ALL_ENVS, () => { expectArraysClose(gradients, [0, 0, 500]); }); + it('derivative: 1D tensor with max or min value', () => { + const min = -1; + const max = 2; + const x = tf.tensor1d([-1, 1, 2, 3]); + const dy = tf.tensor1d([1, 10, 100, 1000]); + const gradients = tf.grad(x => x.clipByValue(min, max))(x, dy); + + expect(gradients.shape).toEqual(x.shape); + expect(gradients.dtype).toEqual('float32'); + expectArraysClose(gradients, [1, 10, 100, 0]); + }); + it('derivative: scalar', () => { const min = -1; const max = 2;