From 8f082f1a3f604e23c536079c22e9dc62460ec9cf Mon Sep 17 00:00:00 2001 From: edoh Date: Tue, 4 Dec 2018 23:45:00 +0100 Subject: [PATCH 1/6] initial commit diagPart operator FEATURE add operator diagPart https://github.com/tensorflow/tfjs/issues/655 --- src/kernels/backend.ts | 3 +++ src/kernels/backend_cpu.ts | 10 ++++++++++ src/kernels/backend_webgl.ts | 7 +++++++ src/kernels/webgl/diagpart_gpu.ts | 31 +++++++++++++++++++++++++++++++ src/ops/diagPart.ts | 27 +++++++++++++++++++++++++++ src/ops/diagPart_test.ts | 28 ++++++++++++++++++++++++++++ src/ops/ops.ts | 1 + 7 files changed, 107 insertions(+) create mode 100644 src/kernels/webgl/diagpart_gpu.ts create mode 100644 src/ops/diagPart.ts create mode 100644 src/ops/diagPart_test.ts diff --git a/src/kernels/backend.ts b/src/kernels/backend.ts index 7ac104b9a7..a55b055839 100644 --- a/src/kernels/backend.ts +++ b/src/kernels/backend.ts @@ -84,6 +84,9 @@ export interface BackendTimer { * methods). */ export class KernelBackend implements TensorStorage, BackendTimer { + diagPart($x: Tensor): any { + throw new Error('Method not implemented.'); + } time(f: () => void): Promise { throw new Error('Not yet implemented.'); } diff --git a/src/kernels/backend_cpu.ts b/src/kernels/backend_cpu.ts index 29ca29e217..827731088e 100644 --- a/src/kernels/backend_cpu.ts +++ b/src/kernels/backend_cpu.ts @@ -2536,6 +2536,16 @@ export class MathBackendCPU implements KernelBackend { return res; } + diagPart(x: Tensor): Tensor { + const xVals = x.dataSync(); + const buffer = ops.buffer([Math.sqrt(x.size)], x.dtype); + const vals = buffer.values; + for (let i = 0; i < vals.length; i++) { + vals[i] = xVals[i * vals.length + i]; + } + return buffer.toTensor(); + } + oneHot(indices: Tensor1D, depth: number, onValue: number, offValue: number): Tensor2D { this.assertNotComplex(indices, 'oneHot'); diff --git a/src/kernels/backend_webgl.ts b/src/kernels/backend_webgl.ts index 99d21f0030..0e03481019 100644 --- a/src/kernels/backend_webgl.ts +++ b/src/kernels/backend_webgl.ts @@ -104,6 +104,7 @@ import {UnaryOpProgram} from './webgl/unaryop_gpu'; import {UnpackProgram} from './webgl/unpack_gpu'; import * as webgl_util from './webgl/webgl_util'; import {whereImpl} from './where_impl'; +import {DiagPartProgram} from './webgl/diagpart_gpu'; type KernelInfo = { name: string; query: Promise; @@ -1628,6 +1629,12 @@ export class MathBackendWebGL implements KernelBackend { return this.compileAndRun(program, [probs], output, customSetup); } + diagPart(x: Tensor): Tensor { + const size = Math.sqrt(x.size); + const program = new DiagPartProgram(size); + return this.compileAndRun(program, [x.reshape([size, size])]); + } + oneHot(indices: Tensor1D, depth: number, onValue: number, offValue: number): Tensor2D { const program = new OneHotProgram(indices.size, depth, onValue, offValue); diff --git a/src/kernels/webgl/diagpart_gpu.ts b/src/kernels/webgl/diagpart_gpu.ts new file mode 100644 index 0000000000..8a9490c5e8 --- /dev/null +++ b/src/kernels/webgl/diagpart_gpu.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {GPGPUProgram} from './gpgpu_math'; +export class DiagPartProgram implements GPGPUProgram { + variableNames = ['X']; + outputShape: number[]; + userCode: string; + constructor(size: number) { + this.outputShape = [size]; + this.userCode = ` + void main() { + int coord = getOutputCoords(); + setOutput(getX(coord,coord )); + } + `; + } +} diff --git a/src/ops/diagPart.ts b/src/ops/diagPart.ts new file mode 100644 index 0000000000..c3b666a431 --- /dev/null +++ b/src/ops/diagPart.ts @@ -0,0 +1,27 @@ +import {util} from '..'; +import {ENV} from '../environment'; +// import {tensor} from '../ops/tensor_ops'; +import {Tensor} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; + +import {op} from './operation'; + +function diagPart_(x: Tensor): Tensor { + util.assert( + x.rank % 2 === 0, + `diagpart expects a tensor of even dimension, but got a rank-${ + x.rank} tensor`); + const mid = x.rank / 2; + const dim1 = x.shape.slice(0, mid); + const dim2 = x.shape.slice(mid, x.shape.length); + util.assert( + util.arraysEqual(dim1, dim2), + `diagPart expects ${dim1.toString()} to be equal to ${ + dim2.toString()}`); + const $x = convertToTensor(x, 'x', 'diag_part').flatten(); + const outShape = dim1; + return ENV.engine.runKernel(backend => backend.diagPart($x), {$x}) + .reshape(outShape); +} + +export const diagPart = op({diagPart_}); diff --git a/src/ops/diagPart_test.ts b/src/ops/diagPart_test.ts new file mode 100644 index 0000000000..67fc0e13ed --- /dev/null +++ b/src/ops/diagPart_test.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import * as tf from '../index'; +import {describeWithFlags} from '../jasmine_util'; +import {ALL_ENVS, expectArraysEqual} from '../test_util'; + +describeWithFlags('diagonal', ALL_ENVS, () => { + it('diag2d', () => { + const m = tf.tensor2d([[5, 4], [5, 6]]); + const d = tf.diagPart(m); + expectArraysEqual(d.shape, [2]); + expectArraysEqual(d, tf.tensor1d([5, 6])); + }); +}); diff --git a/src/ops/ops.ts b/src/ops/ops.ts index 56ef0f3464..80626ea993 100644 --- a/src/ops/ops.ts +++ b/src/ops/ops.ts @@ -44,6 +44,7 @@ export * from './scatter_nd'; export * from './spectral_ops'; export * from './sparse_to_dense'; export * from './gather_nd'; +export * from './diagPart'; export {op} from './operation'; From f1037b04e94ed8eaf3af0b76455c9d5e6cf3c400 Mon Sep 17 00:00:00 2001 From: edoh Date: Wed, 5 Dec 2018 19:41:16 +0100 Subject: [PATCH 2/6] add more tests --- src/ops/diagPart_test.ts | 28 -------------- src/ops/{diagPart.ts => diagpart.ts} | 9 ++--- src/ops/diagpart_test.ts | 58 ++++++++++++++++++++++++++++ src/ops/ops.ts | 2 +- 4 files changed, 63 insertions(+), 34 deletions(-) delete mode 100644 src/ops/diagPart_test.ts rename src/ops/{diagPart.ts => diagpart.ts} (71%) create mode 100644 src/ops/diagpart_test.ts diff --git a/src/ops/diagPart_test.ts b/src/ops/diagPart_test.ts deleted file mode 100644 index 67fc0e13ed..0000000000 --- a/src/ops/diagPart_test.ts +++ /dev/null @@ -1,28 +0,0 @@ -/** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ -import * as tf from '../index'; -import {describeWithFlags} from '../jasmine_util'; -import {ALL_ENVS, expectArraysEqual} from '../test_util'; - -describeWithFlags('diagonal', ALL_ENVS, () => { - it('diag2d', () => { - const m = tf.tensor2d([[5, 4], [5, 6]]); - const d = tf.diagPart(m); - expectArraysEqual(d.shape, [2]); - expectArraysEqual(d, tf.tensor1d([5, 6])); - }); -}); diff --git a/src/ops/diagPart.ts b/src/ops/diagpart.ts similarity index 71% rename from src/ops/diagPart.ts rename to src/ops/diagpart.ts index c3b666a431..7c1222f077 100644 --- a/src/ops/diagPart.ts +++ b/src/ops/diagpart.ts @@ -8,17 +8,16 @@ import {op} from './operation'; function diagPart_(x: Tensor): Tensor { util.assert( - x.rank % 2 === 0, - `diagpart expects a tensor of even dimension, but got a rank-${ + x.rank !== 0 && x.rank % 2 === 0, + `diagpart expects a tensor of even and non zero rank, but got a rank-${ x.rank} tensor`); const mid = x.rank / 2; const dim1 = x.shape.slice(0, mid); const dim2 = x.shape.slice(mid, x.shape.length); util.assert( util.arraysEqual(dim1, dim2), - `diagPart expects ${dim1.toString()} to be equal to ${ - dim2.toString()}`); - const $x = convertToTensor(x, 'x', 'diag_part').flatten(); + `diagPart expects ${dim1.toString()} to be equal to ${dim2.toString()}`); + const $x = convertToTensor(x, 'x', 'diagpart').flatten(); const outShape = dim1; return ENV.engine.runKernel(backend => backend.diagPart($x), {$x}) .reshape(outShape); diff --git a/src/ops/diagpart_test.ts b/src/ops/diagpart_test.ts new file mode 100644 index 0000000000..db953cd6cf --- /dev/null +++ b/src/ops/diagpart_test.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import * as tf from '../index'; +import {describeWithFlags} from '../jasmine_util'; +import {ALL_ENVS, expectArraysEqual} from '../test_util'; + +describeWithFlags('diagpart', ALL_ENVS, () => { + it('diag1d', () => { + const m = tf.tensor1d([1]); + expect(() => tf.diagPart(m)).toThrowError(); + }); + it('diag2d 2x2', () => { + const m = tf.tensor2d([[5, 4], [5, 6]]); + const d = tf.diagPart(m); + expectArraysEqual(d.shape, [2]); + expectArraysEqual(d, tf.tensor1d([5, 6])); + }); + it('diag2d 2*1', () => { + const m = tf.tensor2d([[5], [6]]); + expect(() => tf.diagPart(m)).toThrowError(); + }); + it('diag2d 3x3', () => { + const m = tf.tensor2d([[5, 4, 5], [5, 6, 3], [5, 4, 3]]); + const d = tf.diagPart(m); + expectArraysEqual(d.shape, [3]); + expectArraysEqual(d, tf.tensor1d([5, 6, 3])); + }); + it('diag3d 3*3*4', () => { + const m = tf.tensor(Array.from(Array(36).keys()), [3, 3, 4]); + expect(() => tf.diagPart(m)).toThrowError(); + }); + it('diag4d 3*2*3*2', () => { + const m = tf.tensor(Array.from(Array(36).keys()), [3, 2, 3, 2], 'int32'); + const d = tf.diagPart(m); + expectArraysEqual(d.shape, [3, 2]); + expectArraysEqual(d, tf.tensor([0, 7, 14, 21, 28, 35], [3, 2], 'int32')); + }); + it('diag4d 3*2*3*2', () => { + const m = tf.tensor(Array.from(Array(36).keys()), [3, 2, 3, 2], 'bool'); + const d = tf.diagPart(m); + expectArraysEqual(d.shape, [3, 2]); + expectArraysEqual(d, tf.tensor([0, 1, 1, 1, 1, 1], [3, 2], 'bool')); + }); +}); diff --git a/src/ops/ops.ts b/src/ops/ops.ts index 80626ea993..a7fec2ecf7 100644 --- a/src/ops/ops.ts +++ b/src/ops/ops.ts @@ -44,7 +44,7 @@ export * from './scatter_nd'; export * from './spectral_ops'; export * from './sparse_to_dense'; export * from './gather_nd'; -export * from './diagPart'; +export * from './diagpart'; export {op} from './operation'; From a54fe529f2fdf763d611eb5b11348c329b951376 Mon Sep 17 00:00:00 2001 From: edoh Date: Wed, 5 Dec 2018 19:53:54 +0100 Subject: [PATCH 3/6] add jsdoc --- src/ops/diagpart.ts | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/ops/diagpart.ts b/src/ops/diagpart.ts index 7c1222f077..3eb70fa8af 100644 --- a/src/ops/diagpart.ts +++ b/src/ops/diagpart.ts @@ -6,6 +6,30 @@ import {convertToTensor} from '../tensor_util_env'; import {op} from './operation'; +/** + * Returns the diagonal part of the tensor. + * + * Given a tensor, this operation returns a tensor with the diagonal part of the + * input. + * + * Assume the input has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output + * is a tensor of rank k with dimensions `[D1,..., Dk]` + * + * ```js + * const x = tf.tensor2d([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, + * 4]]); + * + * tf.diagpart(x).print() + * ``` + * ```js + * const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 6, 8, 3, 4, 6, 8, 7, 7, 3, 3], [4, + * 1, 2, 2]) + * + * tf.diagpart(x).print() + * ``` + * @param x The input tensor. + */ + function diagPart_(x: Tensor): Tensor { util.assert( x.rank !== 0 && x.rank % 2 === 0, From ac783184f996a19af445db537dbbb4ff4cbeb79e Mon Sep 17 00:00:00 2001 From: edoh Date: Wed, 5 Dec 2018 20:04:24 +0100 Subject: [PATCH 4/6] remove unused operator --- src/ops/diagpart.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ops/diagpart.ts b/src/ops/diagpart.ts index 3eb70fa8af..e12e7f23e8 100644 --- a/src/ops/diagpart.ts +++ b/src/ops/diagpart.ts @@ -1,6 +1,5 @@ import {util} from '..'; import {ENV} from '../environment'; -// import {tensor} from '../ops/tensor_ops'; import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; From 0a3d63d8fd857f5acb17cdc059634b8353e3acc6 Mon Sep 17 00:00:00 2001 From: edoh Date: Fri, 7 Dec 2018 21:55:13 +0100 Subject: [PATCH 5/6] check dtype --- src/ops/diagpart_test.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/ops/diagpart_test.ts b/src/ops/diagpart_test.ts index db953cd6cf..da9cc08257 100644 --- a/src/ops/diagpart_test.ts +++ b/src/ops/diagpart_test.ts @@ -43,16 +43,18 @@ describeWithFlags('diagpart', ALL_ENVS, () => { const m = tf.tensor(Array.from(Array(36).keys()), [3, 3, 4]); expect(() => tf.diagPart(m)).toThrowError(); }); - it('diag4d 3*2*3*2', () => { + it('diag4d 3*2*3*2 int32', () => { const m = tf.tensor(Array.from(Array(36).keys()), [3, 2, 3, 2], 'int32'); const d = tf.diagPart(m); expectArraysEqual(d.shape, [3, 2]); + expect(d.dtype).toBe('int32'); expectArraysEqual(d, tf.tensor([0, 7, 14, 21, 28, 35], [3, 2], 'int32')); }); - it('diag4d 3*2*3*2', () => { + it('diag4d 3*2*3*2 bool', () => { const m = tf.tensor(Array.from(Array(36).keys()), [3, 2, 3, 2], 'bool'); const d = tf.diagPart(m); expectArraysEqual(d.shape, [3, 2]); + expect(d.dtype).toBe('bool'); expectArraysEqual(d, tf.tensor([0, 1, 1, 1, 1, 1], [3, 2], 'bool')); }); }); From 5d1ab549f1bda428b1dc784d8e5011e5b8f82a97 Mon Sep 17 00:00:00 2001 From: edoh Date: Fri, 7 Dec 2018 21:56:05 +0100 Subject: [PATCH 6/6] remove useless space --- src/ops/diagpart.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ops/diagpart.ts b/src/ops/diagpart.ts index e12e7f23e8..a747db89f8 100644 --- a/src/ops/diagpart.ts +++ b/src/ops/diagpart.ts @@ -2,7 +2,6 @@ import {util} from '..'; import {ENV} from '../environment'; import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; - import {op} from './operation'; /**