diff --git a/src/kernels/backend.ts b/src/kernels/backend.ts index 7ac104b9a7..a55b055839 100644 --- a/src/kernels/backend.ts +++ b/src/kernels/backend.ts @@ -84,6 +84,9 @@ export interface BackendTimer { * methods). */ export class KernelBackend implements TensorStorage, BackendTimer { + diagPart($x: Tensor): any { + throw new Error('Method not implemented.'); + } time(f: () => void): Promise { throw new Error('Not yet implemented.'); } diff --git a/src/kernels/backend_cpu.ts b/src/kernels/backend_cpu.ts index 29ca29e217..827731088e 100644 --- a/src/kernels/backend_cpu.ts +++ b/src/kernels/backend_cpu.ts @@ -2536,6 +2536,16 @@ export class MathBackendCPU implements KernelBackend { return res; } + diagPart(x: Tensor): Tensor { + const xVals = x.dataSync(); + const buffer = ops.buffer([Math.sqrt(x.size)], x.dtype); + const vals = buffer.values; + for (let i = 0; i < vals.length; i++) { + vals[i] = xVals[i * vals.length + i]; + } + return buffer.toTensor(); + } + oneHot(indices: Tensor1D, depth: number, onValue: number, offValue: number): Tensor2D { this.assertNotComplex(indices, 'oneHot'); diff --git a/src/kernels/backend_webgl.ts b/src/kernels/backend_webgl.ts index 99d21f0030..0e03481019 100644 --- a/src/kernels/backend_webgl.ts +++ b/src/kernels/backend_webgl.ts @@ -104,6 +104,7 @@ import {UnaryOpProgram} from './webgl/unaryop_gpu'; import {UnpackProgram} from './webgl/unpack_gpu'; import * as webgl_util from './webgl/webgl_util'; import {whereImpl} from './where_impl'; +import {DiagPartProgram} from './webgl/diagpart_gpu'; type KernelInfo = { name: string; query: Promise; @@ -1628,6 +1629,12 @@ export class MathBackendWebGL implements KernelBackend { return this.compileAndRun(program, [probs], output, customSetup); } + diagPart(x: Tensor): Tensor { + const size = Math.sqrt(x.size); + const program = new DiagPartProgram(size); + return this.compileAndRun(program, [x.reshape([size, size])]); + } + oneHot(indices: Tensor1D, depth: number, onValue: number, offValue: number): Tensor2D { const program = new OneHotProgram(indices.size, depth, onValue, offValue); diff --git a/src/kernels/webgl/diagpart_gpu.ts b/src/kernels/webgl/diagpart_gpu.ts new file mode 100644 index 0000000000..8a9490c5e8 --- /dev/null +++ b/src/kernels/webgl/diagpart_gpu.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {GPGPUProgram} from './gpgpu_math'; +export class DiagPartProgram implements GPGPUProgram { + variableNames = ['X']; + outputShape: number[]; + userCode: string; + constructor(size: number) { + this.outputShape = [size]; + this.userCode = ` + void main() { + int coord = getOutputCoords(); + setOutput(getX(coord,coord )); + } + `; + } +} diff --git a/src/ops/diagpart.ts b/src/ops/diagpart.ts new file mode 100644 index 0000000000..a747db89f8 --- /dev/null +++ b/src/ops/diagpart.ts @@ -0,0 +1,48 @@ +import {util} from '..'; +import {ENV} from '../environment'; +import {Tensor} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {op} from './operation'; + +/** + * Returns the diagonal part of the tensor. + * + * Given a tensor, this operation returns a tensor with the diagonal part of the + * input. + * + * Assume the input has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output + * is a tensor of rank k with dimensions `[D1,..., Dk]` + * + * ```js + * const x = tf.tensor2d([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, + * 4]]); + * + * tf.diagpart(x).print() + * ``` + * ```js + * const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 6, 8, 3, 4, 6, 8, 7, 7, 3, 3], [4, + * 1, 2, 2]) + * + * tf.diagpart(x).print() + * ``` + * @param x The input tensor. + */ + +function diagPart_(x: Tensor): Tensor { + util.assert( + x.rank !== 0 && x.rank % 2 === 0, + `diagpart expects a tensor of even and non zero rank, but got a rank-${ + x.rank} tensor`); + const mid = x.rank / 2; + const dim1 = x.shape.slice(0, mid); + const dim2 = x.shape.slice(mid, x.shape.length); + util.assert( + util.arraysEqual(dim1, dim2), + `diagPart expects ${dim1.toString()} to be equal to ${dim2.toString()}`); + const $x = convertToTensor(x, 'x', 'diagpart').flatten(); + const outShape = dim1; + return ENV.engine.runKernel(backend => backend.diagPart($x), {$x}) + .reshape(outShape); +} + +export const diagPart = op({diagPart_}); diff --git a/src/ops/diagpart_test.ts b/src/ops/diagpart_test.ts new file mode 100644 index 0000000000..da9cc08257 --- /dev/null +++ b/src/ops/diagpart_test.ts @@ -0,0 +1,60 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import * as tf from '../index'; +import {describeWithFlags} from '../jasmine_util'; +import {ALL_ENVS, expectArraysEqual} from '../test_util'; + +describeWithFlags('diagpart', ALL_ENVS, () => { + it('diag1d', () => { + const m = tf.tensor1d([1]); + expect(() => tf.diagPart(m)).toThrowError(); + }); + it('diag2d 2x2', () => { + const m = tf.tensor2d([[5, 4], [5, 6]]); + const d = tf.diagPart(m); + expectArraysEqual(d.shape, [2]); + expectArraysEqual(d, tf.tensor1d([5, 6])); + }); + it('diag2d 2*1', () => { + const m = tf.tensor2d([[5], [6]]); + expect(() => tf.diagPart(m)).toThrowError(); + }); + it('diag2d 3x3', () => { + const m = tf.tensor2d([[5, 4, 5], [5, 6, 3], [5, 4, 3]]); + const d = tf.diagPart(m); + expectArraysEqual(d.shape, [3]); + expectArraysEqual(d, tf.tensor1d([5, 6, 3])); + }); + it('diag3d 3*3*4', () => { + const m = tf.tensor(Array.from(Array(36).keys()), [3, 3, 4]); + expect(() => tf.diagPart(m)).toThrowError(); + }); + it('diag4d 3*2*3*2 int32', () => { + const m = tf.tensor(Array.from(Array(36).keys()), [3, 2, 3, 2], 'int32'); + const d = tf.diagPart(m); + expectArraysEqual(d.shape, [3, 2]); + expect(d.dtype).toBe('int32'); + expectArraysEqual(d, tf.tensor([0, 7, 14, 21, 28, 35], [3, 2], 'int32')); + }); + it('diag4d 3*2*3*2 bool', () => { + const m = tf.tensor(Array.from(Array(36).keys()), [3, 2, 3, 2], 'bool'); + const d = tf.diagPart(m); + expectArraysEqual(d.shape, [3, 2]); + expect(d.dtype).toBe('bool'); + expectArraysEqual(d, tf.tensor([0, 1, 1, 1, 1, 1], [3, 2], 'bool')); + }); +}); diff --git a/src/ops/ops.ts b/src/ops/ops.ts index 56ef0f3464..a7fec2ecf7 100644 --- a/src/ops/ops.ts +++ b/src/ops/ops.ts @@ -44,6 +44,7 @@ export * from './scatter_nd'; export * from './spectral_ops'; export * from './sparse_to_dense'; export * from './gather_nd'; +export * from './diagpart'; export {op} from './operation';