From e64c988da913f905804fdeec1983c8a8a4f53465 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sun, 24 Sep 2017 13:31:38 -0400 Subject: [PATCH 1/6] add slice1d --- .gitignore | 3 +- src/math/math.ts | 84 ++++++++++++++++++------------ src/math/math_cpu.ts | 5 ++ src/math/math_gpu.ts | 7 +++ src/math/ndarray.ts | 4 +- src/math/webgl/slice1d_gpu.ts | 51 ++++++++++++++++++ src/math/webgl/slice1d_gpu_test.ts | 84 ++++++++++++++++++++++++++++++ src/util.ts | 54 +++++++++++++++++-- src/util_test.ts | 39 ++++++++++++++ 9 files changed, 291 insertions(+), 40 deletions(-) create mode 100644 src/math/webgl/slice1d_gpu.ts create mode 100644 src/math/webgl/slice1d_gpu_test.ts diff --git a/.gitignore b/.gitignore index 67ad111bdb..2b33fc1994 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ index.ts npm-debug.log .DS_Store dist/ -.idea/ \ No newline at end of file +.idea/ +demos/performance_rnn diff --git a/src/math/math.ts b/src/math/math.ts index 9dd3904213..a71633e8c1 100644 --- a/src/math/math.ts +++ b/src/math/math.ts @@ -273,7 +273,7 @@ export abstract class NDArrayMath { `must match inner dimension of second rank 2 input, but got ` + `rank ${matrix.rank}.`); - return this.matMul(v.as2D(1, v.size), matrix).as1D(); + return this.matMul(v.as2D(1, -1), matrix).as1D(); } /** @@ -296,7 +296,7 @@ export abstract class NDArrayMath { `must match inner dimension of second rank 2 input, but got ` + `shape ${matrix.shape}.`); - return this.matMul(matrix, v.as2D(v.size, 1)).as1D(); + return this.matMul(matrix, v.as2D(-1, 1)).as1D(); } /** @@ -313,7 +313,7 @@ export abstract class NDArrayMath { v1.size === v2.size, `Error in dotProduct: size of inputs (${v1.size}) and (` + `${v2.size}) must match.`); - return this.matMul(v1.as2D(1, v1.size), v2.as2D(v2.size, 1)).asScalar(); + return this.matMul(v1.as2D(1, -1), v2.as2D(-1, 1)).asScalar(); } /** @@ -327,7 +327,7 @@ export abstract class NDArrayMath { `Error in outerProduct: inputs must be rank 1, but got ranks ` + `${v1.rank} and ${v2.rank}.`); - return this.matMul(v1.as2D(v1.size, 1), v2.as2D(1, v2.size)); + return this.matMul(v1.as2D(-1, 1), v2.as2D(1, -1)); } /////////////// @@ -355,12 +355,31 @@ export abstract class NDArrayMath { } /** - * Extracts a slice from a matrix. The operation extraces a slice from input - * that starts at coordinates `begin` and is of size `size`. - * @param input The input matrix to slice from. - * @param begin The 2D coordinates in the input matrix to start the slice - * from. - * @param size The sice of the 2D window to slice. + * Extracts a 1D slice from 1D array starting at coordinates `begin` and is of + * length `size`. + * + * @param input The input array to slice from. + * @param begin The offset to start the slice from. + * @param size The size of the slice. + */ + slice1D(input: Array1D, begin: number, size: number): Array1D { + util.assert( + begin + size <= input.size, + `Error in slice1D: requested start ${begin} and size ${size} ` + + `would overflow input of size ${input.size}`); + return this.executeOp( + 'slice1D', () => this.slice1DInternal(input, begin, size)); + } + protected abstract slice1DInternal( + input: Array1D, begin: number, size: number): Array1D; + + /** + * Extracts a 2D slice from a 2D array starting at coordinates `begin` and is + * of size `size`. + * + * @param input The input array to slice from. + * @param begin The [row, col] 2d coordinates to start the slice from. + * @param size The size of the slice. */ slice2D(input: Array2D, begin: [number, number], size: [number, number]): Array2D { @@ -1487,36 +1506,33 @@ export abstract class NDArrayMath { data.shape[0] === 1, `Error in multiRNNCell: first dimension of data is ` + `${data.shape[0]}, but batch sizes > 1 are not yet supported.`); - // concat(inputs, h, 1) - // There is no concat1d, so reshape inputs and h to 3d, concat, then - // reshape back to 2d. - const data3D = data.as3D(1, 1, data.shape[1]); - const h3D = h.as3D(1, 1, h.shape[1]); - const combined3D = this.concat3D(data3D, h3D, 2); - const combined2D = combined3D.as2D(1, data.shape[1] + h.shape[1]); - - const weighted = this.matMul(combined2D, lstmKernel); - const res = this.add(weighted, lstmBias) as Array2D; + const combined = this.concat1D(data.as1D(), h.as1D()); + console.log('combined', combined.getValues()); + const weighted = this.vectorTimesMatrix(combined, lstmKernel); + console.log('weighted', weighted.getValues()); + const res = this.addStrict(weighted, lstmBias); + console.log('res', res.getValues()); // i = input_gate, j = new_input, f = forget_gate, o = output_gate - const i = this.slice2D(res, [0, 0], [res.shape[0], res.shape[1] / 4]); - const j = this.slice2D( - res, [0, res.shape[1] / 4 * 1], [res.shape[0], res.shape[1] / 4]); - const f = this.slice2D( - res, [0, res.shape[1] / 4 * 2], [res.shape[0], res.shape[1] / 4]); - const o = this.slice2D( - res, [0, res.shape[1] / 4 * 3], [res.shape[0], res.shape[1] / 4]); - - const newC = this.add( + const sliceSize = res.size / 4; + const i = this.slice1D(res, 0, sliceSize); + const j = this.slice1D(res, sliceSize, sliceSize); + const f = this.slice1D(res, sliceSize * 2, sliceSize); + const o = this.slice1D(res, sliceSize * 3, sliceSize); + console.log('i', i.getValues()); + console.log('j', j.getValues()); + console.log('f', f.getValues()); + console.log('o', o.getValues()); + + const newC = this.addStrict( this.multiplyStrict( - c, this.sigmoid(this.scalarPlusArray(forgetBias, f))), - this.multiplyStrict(this.sigmoid(i), this.tanh(j))) as Array2D; - const newH = - this.multiplyStrict(this.tanh(newC), this.sigmoid(o)) as Array2D; + c.as1D(), this.sigmoid(this.scalarPlusArray(forgetBias, f))), + this.multiplyStrict(this.sigmoid(i), this.tanh(j))); + const newH = this.multiplyStrict(this.tanh(newC), this.sigmoid(o)); return [newC, newH]; }); - return [res[0], res[1]]; + return [res[0].as2D(1, -1), res[1].as2D(1, -1)]; } } diff --git a/src/math/math_cpu.ts b/src/math/math_cpu.ts index bfd0566f60..55944c52d7 100644 --- a/src/math/math_cpu.ts +++ b/src/math/math_cpu.ts @@ -34,6 +34,11 @@ export class NDArrayMathCPU extends NDArrayMath { ndarray.shape, {values: new Float32Array(ndarray.getValues())}) as T; } + protected slice1DInternal(input: Array1D, begin: number, size: number): + Array1D { + throw new Error('Method not implemented.'); + } + protected slice2DInternal( input: Array2D, beginRowCol: [number, number], sizeRowCol: [number, number]): Array2D { diff --git a/src/math/math_gpu.ts b/src/math/math_gpu.ts index 073358a83f..57492caa0f 100644 --- a/src/math/math_gpu.ts +++ b/src/math/math_gpu.ts @@ -43,6 +43,7 @@ import {MatMulProgram} from './webgl/mulmat_gpu'; import {Pool2DProgram} from './webgl/pool_gpu'; import {ReduceSumProgram} from './webgl/reducesum_gpu'; import {ResizeBilinear3DProgram} from './webgl/resize_bilinear_gpu'; +import {Slice1DProgram} from './webgl/slice1d_gpu'; import {TextureManager} from './webgl/texture_manager'; import * as unary_op from './webgl/unaryop_gpu'; import {UnaryOpProgram} from './webgl/unaryop_gpu'; @@ -85,6 +86,12 @@ export class NDArrayMathGPU extends NDArrayMath { return output.reshape(a.shape) as T; } + protected slice1DInternal(input: Array1D, begin: number, size: number): + Array1D { + const program = new Slice1DProgram(size); + return this.compileAndRun(program, [input]); + } + protected slice2DInternal( input: Array2D, beginRowCol: [number, number], sizeRowCol: [number, number]): Array2D { diff --git a/src/math/ndarray.ts b/src/math/ndarray.ts index 0cc79dd90e..ea20605cf5 100644 --- a/src/math/ndarray.ts +++ b/src/math/ndarray.ts @@ -143,10 +143,10 @@ export class NDArray { /** Reshapes the current ndarray into the provided shape. */ reshape(newShape: number[]): NDArray { + newShape = util.inferFromImplicitShape(newShape, this.size); if (util.arraysEqual(this.shape, newShape)) { // No-op. - // tslint:disable-next-line:no-any - return this as any; + return this; } util.assert( diff --git a/src/math/webgl/slice1d_gpu.ts b/src/math/webgl/slice1d_gpu.ts new file mode 100644 index 0000000000..6a5ab9da8e --- /dev/null +++ b/src/math/webgl/slice1d_gpu.ts @@ -0,0 +1,51 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {GPGPUContext} from './gpgpu_context'; +import {GPGPUProgram} from './gpgpu_math'; + +export class Slice1DProgram implements GPGPUProgram { + variableNames = ['source']; + params: Array<{}>; + outputShape: number[]; + userCode: string; + + // Caching uniform location for speed. + startLoc: WebGLUniformLocation; + + constructor(destSize: number) { + this.outputShape = [destSize]; + this.params = []; + this.userCode = ` + uniform int start; + + void main() { + int sourceIndex = start + getOutputCoords(); + setOutput(getSource(sourceIndex)); + } + `; + } + + getCustomSetupFunc(start: number) { + return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => { + if (this.startLoc == null) { + this.startLoc = gpgpu.getUniformLocation(webGLProgram, 'start'); + } + gpgpu.gl.uniform1i(this.startLoc, start); + }; + } +} diff --git a/src/math/webgl/slice1d_gpu_test.ts b/src/math/webgl/slice1d_gpu_test.ts new file mode 100644 index 0000000000..8529666f2e --- /dev/null +++ b/src/math/webgl/slice1d_gpu_test.ts @@ -0,0 +1,84 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Array1D, initializeGPU} from '../ndarray'; +import {GPGPUContext} from './gpgpu_context'; +import * as gpgpu_math from './gpgpu_math'; +import {Slice1DProgram} from './slice1d_gpu'; +import {TextureManager} from './texture_manager'; +import * as webgl_util from './webgl_util'; + +describe('slice1d_gpu', () => { + let gpgpu: GPGPUContext; + let texManager: TextureManager; + + beforeAll(() => { + gpgpu = new GPGPUContext(); + texManager = new TextureManager(gpgpu); + initializeGPU(gpgpu, texManager); + }); + + afterAll(() => { + texManager.dispose(); + gpgpu.dispose(); + }); + + it('slices [1] into [1] (effectively a copy)', () => { + const a = Array1D.new([5]); + const result = doSlice(a, 0, 1); + expect(result.shape).toEqual([1]); + expect(result.get(0)).toBe(5); + }); + + it('slices [5] into [2] starting at 3', () => { + const a = Array1D.new([1, 2, 3, 4, 5]); + const result = doSlice(a, 3, 2); + expect(result.shape).toEqual([2]); + expect(result.getValues()).toEqual(new Float32Array([4, 5])); + }); + + it('slices [5] into [3] starting at 1', () => { + const a = Array1D.new([1, 2, 3, 4, 5]); + const result = doSlice(a, 1, 3); + expect(result.shape).toEqual([3]); + expect(result.getValues()).toEqual(new Float32Array([2, 3, 4])); + }); + + it('slices array that is bigger than max tex size', () => { + const maxTexSize = webgl_util.queryMaxTextureSize(gpgpu.gl); + const a = Array1D.randUniform([maxTexSize + 10], -1, 1); + const expected = a.get(a.size - 1); + const result = doSlice(a, a.size - 1, 1); + expect(result.shape).toEqual([1]); + expect(result.get(0)).toEqual(expected); + }); + + + function doSlice(a: Array1D, start: number, size: number): Array1D { + const program = new Slice1DProgram(size); + const result = Array1D.zeros([size]); + + const binary = gpgpu_math.compileProgram(gpgpu, program, [a], result); + const customSetup = program.getCustomSetupFunc(start); + gpgpu_math.runProgram(binary, [a], result, customSetup); + + a.dispose(); + gpgpu.deleteProgram(binary.webGLProgram); + + return result; + } +}); diff --git a/src/util.ts b/src/util.ts index ea3bca537d..4ed3e6a996 100644 --- a/src/util.ts +++ b/src/util.ts @@ -15,8 +15,8 @@ * ============================================================================= */ -export type Vector = - number[]|Float64Array|Float32Array|Int32Array|Int8Array|Int16Array; +export type Vector = number[] | Float64Array | Float32Array | Int32Array | + Int8Array | Int16Array; /** Shuffles the array using Fisher-Yates algorithm. */ // tslint:disable-next-line:no-any @@ -105,7 +105,8 @@ export function flatten(arr: any[], ret?: number[]): number[] { return ret; } -export type ArrayData = number|number[]|number[][]|number[][][]|number[][][][]; +export type ArrayData = + number | number[] | number[][] | number[][][] | number[][][][]; export function inferShape(arr: ArrayData): number[] { const shape: number[] = []; @@ -262,3 +263,50 @@ function decodeParam( params: {[key: string]: string}, name: string, value?: string) { params[decodeURIComponent(name)] = decodeURIComponent(value || ''); } + +/** + * Given the full size of the array and a shape that may contain -1 as the + * implicit dimension, returns the inferred shape where -1 is replaced. + * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3]. + * + * @param shape The shape, which may contain -1 in some dimension. + * @param size The full size (number of elements) of the array. + * @return The inferred shape where -1 is replaced with the inferred size. + */ +export function inferFromImplicitShape( + shape: number[], size: number): number[] { + let shapeProd = 1; + let implicitIdx = -1; + + for (let i = 0; i < shape.length; ++i) { + if (shape[i] > 0) { + shapeProd *= shape[i]; + } else if (shape[i] === -1) { + if (implicitIdx !== -1) { + throw Error( + `Shapes can only have 1 implicit size. ` + + `Found -1 at dim ${implicitIdx} and dim ${i}`); + } + implicitIdx = i; + } else if (shape[i] <= 0) { + throw Error(`Shapes can not be <= 0. Found ${shape[i]} at dim ${i}`); + } + } + + if (implicitIdx === -1) { + if (size > 0 && size !== shapeProd) { + throw Error(`Size (${size}) must match the product of shape ${shape}`); + } + return shape; + } + + if (size % shapeProd !== 0) { + throw Error( + `The implicit shape can't be a fractional number. ` + + `Got ${size} / ${shapeProd}`); + } + + const newShape = shape.slice(); + newShape[implicitIdx] = size / shapeProd; + return newShape; +} diff --git a/src/util_test.ts b/src/util_test.ts index c1b55e38bd..292574ee5b 100644 --- a/src/util_test.ts +++ b/src/util_test.ts @@ -179,3 +179,42 @@ describe('util.getQueryParams', () => { .toEqual({'a': '1', 'b': 'hi', 'f': 'animal'}); }); }); + +describe('util.inferFromImplicitShape', () => { + it('empty shape', () => { + const result = util.inferFromImplicitShape([], 0); + expect(result).toEqual([]); + }); + + it('[2, 3, 4] -> [2, 3, 4]', () => { + const result = util.inferFromImplicitShape([2, 3, 4], 24); + expect(result).toEqual([2, 3, 4]); + }); + + it('[2, -1, 4] -> [2, 3, 4], size=24', () => { + const result = util.inferFromImplicitShape([2, -1, 4], 24); + expect(result).toEqual([2, 3, 4]); + }); + + it('[-1, 3, 4] -> [2, 3, 4], size=24', () => { + const result = util.inferFromImplicitShape([-1, 3, 4], 24); + expect(result).toEqual([2, 3, 4]); + }); + + it('[2, 3, -1] -> [2, 3, 4], size=24', () => { + const result = util.inferFromImplicitShape([2, 3, -1], 24); + expect(result).toEqual([2, 3, 4]); + }); + + it('[2, -1, -1] throws error', () => { + expect(() => util.inferFromImplicitShape([2, -1, -1], 24)).toThrowError(); + }); + + it('[2, 3, -1] size=13 throws error', () => { + expect(() => util.inferFromImplicitShape([2, 3, -1], 13)).toThrowError(); + }); + + it('[2, 3, 4] size=25 (should be 24) throws error', () => { + expect(() => util.inferFromImplicitShape([2, 3, 4], 25)).toThrowError(); + }); +}); From e13cb23f614bc9cc41c43939aba674f417e49b9f Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sun, 24 Sep 2017 13:44:51 -0400 Subject: [PATCH 2/6] fix lstm bug --- src/math/math.ts | 11 ++--------- src/math/math_gpu.ts | 3 ++- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/math/math.ts b/src/math/math.ts index a71633e8c1..a60fce7120 100644 --- a/src/math/math.ts +++ b/src/math/math.ts @@ -1480,8 +1480,8 @@ export abstract class NDArrayMath { const newC: Array2D[] = []; const newH: Array2D[] = []; for (let i = 0; i < res.length; i += 2) { - newC.push(res[i] as Array2D); - newH.push(res[i + 1] as Array2D); + newC.push(res[i]); + newH.push(res[i + 1]); } return [newC, newH]; } @@ -1507,11 +1507,8 @@ export abstract class NDArrayMath { `Error in multiRNNCell: first dimension of data is ` + `${data.shape[0]}, but batch sizes > 1 are not yet supported.`); const combined = this.concat1D(data.as1D(), h.as1D()); - console.log('combined', combined.getValues()); const weighted = this.vectorTimesMatrix(combined, lstmKernel); - console.log('weighted', weighted.getValues()); const res = this.addStrict(weighted, lstmBias); - console.log('res', res.getValues()); // i = input_gate, j = new_input, f = forget_gate, o = output_gate const sliceSize = res.size / 4; @@ -1519,10 +1516,6 @@ export abstract class NDArrayMath { const j = this.slice1D(res, sliceSize, sliceSize); const f = this.slice1D(res, sliceSize * 2, sliceSize); const o = this.slice1D(res, sliceSize * 3, sliceSize); - console.log('i', i.getValues()); - console.log('j', j.getValues()); - console.log('f', f.getValues()); - console.log('o', o.getValues()); const newC = this.addStrict( this.multiplyStrict( diff --git a/src/math/math_gpu.ts b/src/math/math_gpu.ts index 57492caa0f..e0fb81666d 100644 --- a/src/math/math_gpu.ts +++ b/src/math/math_gpu.ts @@ -89,7 +89,8 @@ export class NDArrayMathGPU extends NDArrayMath { protected slice1DInternal(input: Array1D, begin: number, size: number): Array1D { const program = new Slice1DProgram(size); - return this.compileAndRun(program, [input]); + const customSetup = program.getCustomSetupFunc(begin); + return this.compileAndRun(program, [input], null, customSetup); } protected slice2DInternal( From 12aaeb6f943b3ed4b37e5a387c68791f9d96e1ab Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sun, 24 Sep 2017 21:59:09 -0400 Subject: [PATCH 3/6] generalize concat to support 1-4D and add slice[1-4]D --- src/graph/graph.ts | 6 +- src/graph/ops/concat3d.ts | 3 +- src/graph/ops/concat3d_test.ts | 9 +- src/math/concat_util.ts | 32 ++- src/math/concat_util_test.ts | 15 +- src/math/math.ts | 77 +++++-- src/math/math_cpu.ts | 74 +++++-- src/math/math_gpu.ts | 48 ++-- src/math/slice_util.ts | 40 ++++ src/math/webgl/concat1d_gpu.ts | 45 ---- src/math/webgl/concat1d_gpu_test.ts | 103 --------- src/math/webgl/concat2d_gpu.ts | 52 ----- src/math/webgl/concat3d_gpu.ts | 54 ----- src/math/webgl/concat3d_gpu_test.ts | 92 -------- src/math/webgl/concat_gpu.ts | 99 +++++++++ ...oncat2d_gpu_test.ts => concat_gpu_test.ts} | 153 ++++++++++++- src/math/webgl/slice1d_gpu.ts | 51 ----- src/math/webgl/slice1d_gpu_test.ts | 84 ------- src/math/webgl/slice_gpu.ts | 101 +++++++++ src/math/webgl/slice_gpu_test.ts | 209 ++++++++++++++++++ 20 files changed, 783 insertions(+), 564 deletions(-) create mode 100644 src/math/slice_util.ts delete mode 100644 src/math/webgl/concat1d_gpu.ts delete mode 100644 src/math/webgl/concat1d_gpu_test.ts delete mode 100644 src/math/webgl/concat2d_gpu.ts delete mode 100644 src/math/webgl/concat3d_gpu.ts delete mode 100644 src/math/webgl/concat3d_gpu_test.ts create mode 100644 src/math/webgl/concat_gpu.ts rename src/math/webgl/{concat2d_gpu_test.ts => concat_gpu_test.ts} (50%) delete mode 100644 src/math/webgl/slice1d_gpu.ts delete mode 100644 src/math/webgl/slice1d_gpu_test.ts create mode 100644 src/math/webgl/slice_gpu.ts create mode 100644 src/math/webgl/slice_gpu_test.ts diff --git a/src/graph/graph.ts b/src/graph/graph.ts index f3252029d3..24b1df5748 100644 --- a/src/graph/graph.ts +++ b/src/graph/graph.ts @@ -623,12 +623,10 @@ export class Concat3DNode extends Node { public axis: number) { super( graph, 'Concat3D', {x1, x2}, - new Tensor( - concat_util.computeConcatOutputShape(x1.shape, x2.shape, axis))); + new Tensor(concat_util.computeOutShape(x1.shape, x2.shape, axis))); } validate() { - concat_util.assertConcatShapesMatch( - this.x1.shape, this.x2.shape, 3, this.axis); + concat_util.assertParams(this.x1.shape, this.x2.shape, this.axis); } } diff --git a/src/graph/ops/concat3d.ts b/src/graph/ops/concat3d.ts index b206c78971..a545286018 100644 --- a/src/graph/ops/concat3d.ts +++ b/src/graph/ops/concat3d.ts @@ -36,8 +36,7 @@ export class Concat3D extends Operation { private x1Tensor: Tensor, private x2Tensor: Tensor, private axis: number, private yTensor: Tensor) { super(); - concat_util.assertConcatShapesMatch( - x1Tensor.shape, x2Tensor.shape, 3, axis); + concat_util.assertParams(x1Tensor.shape, x2Tensor.shape, axis); } feedForward(math: NDArrayMath, inferenceArrays: TensorArrayMap) { diff --git a/src/graph/ops/concat3d_test.ts b/src/graph/ops/concat3d_test.ts index 5d24d6f01e..41d5cd9c74 100644 --- a/src/graph/ops/concat3d_test.ts +++ b/src/graph/ops/concat3d_test.ts @@ -51,8 +51,7 @@ describe('concat3d operation', () => { x1Tensor = new Tensor(x1.shape); x2Tensor = new Tensor(x2.shape); - yTensor = new Tensor( - concat_util.computeConcatOutputShape(x1.shape, x2.shape, axis)); + yTensor = new Tensor(concat_util.computeOutShape(x1.shape, x2.shape, axis)); tensorArrayMap.set(x1Tensor, x1); tensorArrayMap.set(x2Tensor, x2); @@ -75,8 +74,7 @@ describe('concat3d operation', () => { x1Tensor = new Tensor(x1.shape); x2Tensor = new Tensor(x2.shape); - yTensor = new Tensor( - concat_util.computeConcatOutputShape(x1.shape, x2.shape, axis)); + yTensor = new Tensor(concat_util.computeOutShape(x1.shape, x2.shape, axis)); tensorArrayMap.set(x1Tensor, x1); tensorArrayMap.set(x2Tensor, x2); @@ -99,8 +97,7 @@ describe('concat3d operation', () => { x1Tensor = new Tensor(x1.shape); x2Tensor = new Tensor(x2.shape); - yTensor = new Tensor( - concat_util.computeConcatOutputShape(x1.shape, x2.shape, axis)); + yTensor = new Tensor(concat_util.computeOutShape(x1.shape, x2.shape, axis)); tensorArrayMap.set(x1Tensor, x1); tensorArrayMap.set(x2Tensor, x2); diff --git a/src/math/concat_util.ts b/src/math/concat_util.ts index f1f0fe6947..e31d32ef6f 100644 --- a/src/math/concat_util.ts +++ b/src/math/concat_util.ts @@ -17,35 +17,33 @@ import * as util from '../util'; -export function assertConcatShapesMatch( - x1Shape: number[], x2Shape: number[], rank: number, axis: number, - errorMessagePrefix = '') { +export function assertParams(aShape: number[], bShape: number[], axis: number) { + const rank = aShape.length; + const bRank = bShape.length; util.assert( - x1Shape.length === rank, - errorMessagePrefix + `x1 shape should be of rank ${rank}.`); - util.assert( - x2Shape.length === rank, - errorMessagePrefix + `x2 shape should be of rank ${rank}.`); + aShape.length === bShape.length, + `Error in concat${rank}D: rank of x1 (${rank}) and x2 (${bRank}) ` + + `must be the same.`); util.assert( - axis >= 0 && axis < rank, `axis must be between 0 and ${rank - 1}.`); + axis >= 0 && axis < rank, + `Error in concat${rank}D: axis must be ` + + `between 0 and ${rank - 1}.`); for (let i = 0; i < rank; i++) { util.assert( - (i === axis) || (x1Shape[i] === x2Shape[i]), - errorMessagePrefix + - `Shape (${x1Shape}) does not match (${x2Shape}) along ` + - `the non-concatenated axis ${i}.`); + (i === axis) || (aShape[i] === bShape[i]), + `Error in concat${rank}D: Shape (${aShape}) does not match ` + + `(${bShape}) along the non-concatenated axis ${i}.`); } } -export function computeConcatOutputShape( - x1Shape: number[], x2Shape: number[], - axis: number): [number, number, number] { +export function computeOutShape( + x1Shape: number[], x2Shape: number[], axis: number): number[] { util.assert( x1Shape.length === x2Shape.length, 'x1 and x2 should have the same rank.'); const outputShape = x1Shape.slice(); outputShape[axis] += x2Shape[axis]; - return outputShape as [number, number, number]; + return outputShape; } diff --git a/src/math/concat_util_test.ts b/src/math/concat_util_test.ts index b7f441e799..b57c539bff 100644 --- a/src/math/concat_util_test.ts +++ b/src/math/concat_util_test.ts @@ -20,7 +20,7 @@ import * as concat_util from './concat_util'; describe('concat_util.assertConcatShapesMatch rank=3D', () => { it('Non-3D tensor x1', () => { const assertFn = () => { - concat_util.assertConcatShapesMatch([1], [1, 2, 3], 3, 1); + concat_util.assertParams([1], [1, 2, 3], 1); }; expect(assertFn).toThrow(); @@ -28,7 +28,7 @@ describe('concat_util.assertConcatShapesMatch rank=3D', () => { it('Non-3D tensor x2', () => { const assertFn = () => { - concat_util.assertConcatShapesMatch([1, 2, 3], [2, 3], 3, 1); + concat_util.assertParams([1, 2, 3], [2, 3], 1); }; expect(assertFn).toThrow(); @@ -36,7 +36,7 @@ describe('concat_util.assertConcatShapesMatch rank=3D', () => { it('axis out of bound', () => { const assertFn = () => { - concat_util.assertConcatShapesMatch([1, 2, 3], [1, 2, 3], 3, 4); + concat_util.assertParams([1, 2, 3], [1, 2, 3], 4); }; expect(assertFn).toThrow(); @@ -44,7 +44,7 @@ describe('concat_util.assertConcatShapesMatch rank=3D', () => { it('non-axis shape mismatch', () => { const assertFn = () => { - concat_util.assertConcatShapesMatch([2, 3, 3], [2, 2, 4], 3, 2); + concat_util.assertParams([2, 3, 3], [2, 2, 4], 2); }; expect(assertFn).toThrow(); @@ -52,7 +52,7 @@ describe('concat_util.assertConcatShapesMatch rank=3D', () => { it('shapes line up', () => { const assertFn = () => { - concat_util.assertConcatShapesMatch([2, 3, 3], [2, 3, 4], 3, 2); + concat_util.assertParams([2, 3, 3], [2, 3, 4], 2); }; expect(assertFn).not.toThrow(); @@ -61,7 +61,8 @@ describe('concat_util.assertConcatShapesMatch rank=3D', () => { describe('concat_util.computeConcatOutputShape', () => { it('compute output shape, axis=0', () => { - expect(concat_util.computeConcatOutputShape([2, 2, 3], [1, 2, 3], 0)) - .toEqual([3, 2, 3]); + expect(concat_util.computeOutShape([2, 2, 3], [1, 2, 3], 0)).toEqual([ + 3, 2, 3 + ]); }); }); diff --git a/src/math/math.ts b/src/math/math.ts index a60fce7120..5bb962c4f0 100644 --- a/src/math/math.ts +++ b/src/math/math.ts @@ -21,6 +21,7 @@ import * as conv_util from './conv_util'; import {ConvInfo} from './conv_util'; import * as copy2d_util from './copy2d_util'; import {Array1D, Array2D, Array3D, Array4D, NDArray, Scalar} from './ndarray'; +import * as slice_util from './slice_util'; export type ScopeResult = NDArray[] | NDArray | void; @@ -363,10 +364,7 @@ export abstract class NDArrayMath { * @param size The size of the slice. */ slice1D(input: Array1D, begin: number, size: number): Array1D { - util.assert( - begin + size <= input.size, - `Error in slice1D: requested start ${begin} and size ${size} ` + - `would overflow input of size ${input.size}`); + slice_util.assertParams(input, [begin], [size]); return this.executeOp( 'slice1D', () => this.slice1DInternal(input, begin, size)); } @@ -383,17 +381,52 @@ export abstract class NDArrayMath { */ slice2D(input: Array2D, begin: [number, number], size: [number, number]): Array2D { - util.assert( - begin[0] + size[0] <= input.shape[0] && - begin[1] + size[1] <= input.shape[1], - `Error in slice2D: requested start position ${begin} and size ` + - `${size} would overflow input of shape ${input.shape}.`); + slice_util.assertParams(input, begin, size); return this.executeOp( 'slice2D', () => this.slice2DInternal(input, begin, size)); } protected abstract slice2DInternal( input: Array2D, begin: [number, number], size: [number, number]): Array2D; + /** + * Extracts a 3D slice from a 3D array starting at coordinates `begin` and is + * of size `size`. + * + * @param input The input array to slice from. + * @param begin The [row, col, depth] 3d coordinates to start the slice from. + * @param size The size of the slice. + */ + slice3D(input: Array3D, begin: [number, number, number], size: [ + number, number, number + ]): Array3D { + slice_util.assertParams(input, begin, size); + return this.executeOp( + 'slice3D', () => this.slice3DInternal(input, begin, size)); + } + protected abstract slice3DInternal( + input: Array3D, begin: [number, number, number], + size: [number, number, number]): Array3D; + + /** + * Extracts a 4D slice from a 4D array starting at coordinates `begin` and is + * of size `size`. + * + * @param input The input array to slice from. + * @param begin The [row, col, depth, depth2] 4d coordinates to start the + * slice from. + * @param size The size of the slice. + */ + slice4D(input: Array4D, begin: [number, number, number, number], size: [ + number, number, number, number + ]): Array4D { + slice_util.assertParams(input, begin, size); + return this.executeOp( + 'slice4D', () => this.slice4DInternal(input, begin, size)); + } + protected abstract slice4DInternal( + input: Array4D, begin: [number, number, number, number], + size: [number, number, number, number]): Array4D; + /** * Copies a window from the `source` matrix starting at `sourceBegin` and is * of size `sourceSize` to a window in the `dest` matrix starting at @@ -447,8 +480,7 @@ export abstract class NDArrayMath { * @return The concatenated array. */ concat1D(a: Array1D, b: Array1D): Array1D { - concat_util.assertConcatShapesMatch( - a.shape, b.shape, 1, 0, 'Error in concat1D: '); + concat_util.assertParams(a.shape, b.shape, 0); return this.executeOp('concat1D', () => this.concat1DInternal(a, b)); } protected abstract concat1DInternal(a: Array1D, b: Array1D): Array1D; @@ -482,8 +514,7 @@ export abstract class NDArrayMath { * @return The concatenated array. */ concat2D(a: Array2D, b: Array2D, axis: number): Array2D { - concat_util.assertConcatShapesMatch( - a.shape, b.shape, 2, axis, 'Error in concat2D: '); + concat_util.assertParams(a.shape, b.shape, axis); return this.executeOp('concat2D', () => this.concat2DInternal(a, b, axis)); } protected abstract concat2DInternal(a: Array2D, b: Array2D, axis: number): @@ -521,14 +552,30 @@ export abstract class NDArrayMath { * @return The concatenated array. */ concat3D(ndarray1: Array3D, ndarray2: Array3D, axis: number): Array3D { - concat_util.assertConcatShapesMatch( - ndarray1.shape, ndarray2.shape, 3, axis, 'Error in concat3D: '); + concat_util.assertParams(ndarray1.shape, ndarray2.shape, axis); return this.executeOp( 'concat3D', () => this.concat3DInternal(ndarray1, ndarray2, axis)); } protected abstract concat3DInternal( ndarray1: Array3D, ndarray2: Array3D, axis: number): Array3D; + /** + * Concatenates two 4D ndarrays along a given axis. See math.concat2D() for + * documentation. + * + * @param ndarray1 The first array to concat. + * @param ndarray2 The second array to conat. + * @param axis The axis to concate along. + * @return The concatenated array. + */ + concat4D(ndarray1: Array4D, ndarray2: Array4D, axis: number): Array4D { + concat_util.assertParams(ndarray1.shape, ndarray2.shape, axis); + return this.executeOp( + 'concat4D', () => this.concat4DInternal(ndarray1, ndarray2, axis)); + } + protected abstract concat4DInternal( + ndarray1: Array4D, ndarray2: Array4D, axis: number): Array4D; + /////////////////// // Reduction ops // /////////////////// diff --git a/src/math/math_cpu.ts b/src/math/math_cpu.ts index 55944c52d7..18c51fd332 100644 --- a/src/math/math_cpu.ts +++ b/src/math/math_cpu.ts @@ -48,6 +48,17 @@ export class NDArrayMathCPU extends NDArrayMath { return result; } + protected slice3DInternal( + input: Array3D, begin: [number, number, number], + size: [number, number, number]): Array3D { + throw new Error('Method not implemented.'); + } + protected slice4DInternal( + input: Array4D, begin: [number, number, number, number], + size: [number, number, number, number]): Array4D { + throw new Error('Method not implemented.'); + } + protected copy2DInternal( source: Array2D, sourceBeginRowCol: [number, number], sourceSizeRowCol: [number, number], dest: Array2D, @@ -69,9 +80,8 @@ export class NDArrayMathCPU extends NDArrayMath { } protected concat1DInternal(a: Array1D, b: Array1D): Array1D { - const outputShape = - concat_util.computeConcatOutputShape(a.shape, b.shape, 0); - const result = Array1D.zeros(outputShape); + const outShape = concat_util.computeOutShape(a.shape, b.shape, 0); + const result = Array1D.zeros(outShape as [number]); // Use built-in TypedArray.set() method for speed. const aVals = a.getValues(); @@ -84,9 +94,8 @@ export class NDArrayMathCPU extends NDArrayMath { } protected concat2DInternal(a: Array2D, b: Array2D, axis: number): Array2D { - const outputShape = - concat_util.computeConcatOutputShape(a.shape, b.shape, axis); - const result = Array2D.zeros(outputShape); + const outShape = concat_util.computeOutShape(a.shape, b.shape, axis); + const result = Array2D.zeros(outShape as [number, number]); if (axis === 0) { // Use built-in TypedArray.set() method for speed. @@ -98,8 +107,8 @@ export class NDArrayMathCPU extends NDArrayMath { return result; } - for (let i = 0; i < outputShape[0]; i++) { - for (let j = 0; j < outputShape[1]; j++) { + for (let i = 0; i < outShape[0]; ++i) { + for (let j = 0; j < outShape[1]; ++j) { const index: [number, number] = [i, j]; let value: number; if (index[axis] < a.shape[axis]) { @@ -117,10 +126,9 @@ export class NDArrayMathCPU extends NDArrayMath { } protected concat3DInternal(a: Array3D, b: Array3D, axis: number): Array3D { - const outputShape = - concat_util.computeConcatOutputShape(a.shape, b.shape, axis); + const outShape = concat_util.computeOutShape(a.shape, b.shape, axis); - const result = Array3D.zeros(outputShape); + const result = Array3D.zeros(outShape as [number, number, number]); if (axis === 0) { // Use built-in TypedArray.set() method for speed. @@ -132,9 +140,9 @@ export class NDArrayMathCPU extends NDArrayMath { return result; } - for (let i = 0; i < outputShape[0]; i++) { - for (let j = 0; j < outputShape[1]; j++) { - for (let k = 0; k < outputShape[2]; k++) { + for (let i = 0; i < outShape[0]; ++i) { + for (let j = 0; j < outShape[1]; ++j) { + for (let k = 0; k < outShape[2]; ++k) { // Shader begins. const index: [number, number, number] = [i, j, k]; let value: number; @@ -154,6 +162,44 @@ export class NDArrayMathCPU extends NDArrayMath { return result; } + protected concat4DInternal(a: Array4D, b: Array4D, axis: number): Array4D { + const outShape = concat_util.computeOutShape(a.shape, b.shape, axis); + const result = Array4D.zeros(outShape as [number, number, number, number]); + + if (axis === 0) { + // Use built-in TypedArray.set() method for speed. + const aVals = a.getValues(); + const bVals = b.getValues(); + const vals = result.getValues(); + vals.set(aVals, 0); + vals.set(bVals, a.size); + return result; + } + + for (let i = 0; i < outShape[0]; ++i) { + for (let j = 0; j < outShape[1]; ++j) { + for (let k = 0; k < outShape[2]; ++k) { + for (let l = 0; l < outShape[3]; ++l) { + // Shader begins. + const index: [number, number, number, number] = [i, j, k, l]; + let value: number; + if (index[axis] < a.shape[axis]) { + value = a.get(i, j, k, l); + } else { + index[axis] -= a.shape[axis]; + const [i2, j2, k2, l2] = index; + value = b.get(i2, j2, k2, l2); + } + + result.set(value, i, j, k, l); + } + } + } + } + + return result; + } + protected scaledArrayAddInternal( c1: Scalar, a: T, c2: Scalar, b: T) { const newShape = util.assertAndGetBroadcastedShape(a.shape, b.shape); diff --git a/src/math/math_gpu.ts b/src/math/math_gpu.ts index e0fb81666d..5c2a2146f7 100644 --- a/src/math/math_gpu.ts +++ b/src/math/math_gpu.ts @@ -25,9 +25,7 @@ import {ArgMinMaxProgram} from './webgl/argminmax_gpu'; import {BatchNormProgram} from './webgl/batchnorm_gpu'; import * as binaryop_gpu from './webgl/binaryop_gpu'; import {BinaryOpProgram} from './webgl/binaryop_gpu'; -import {Concat1DProgram} from './webgl/concat1d_gpu'; -import {Concat2DProgram} from './webgl/concat2d_gpu'; -import {Concat3DProgram} from './webgl/concat3d_gpu'; +import {ConcatProgram} from './webgl/concat_gpu'; // tslint:disable-next-line:max-line-length import {Conv2DDerBiasProgram, Conv2DDerInputProgram, Conv2DDerWeightsProgram} from './webgl/conv_backprop_gpu'; import {Conv2DProgram} from './webgl/conv_gpu'; @@ -43,7 +41,7 @@ import {MatMulProgram} from './webgl/mulmat_gpu'; import {Pool2DProgram} from './webgl/pool_gpu'; import {ReduceSumProgram} from './webgl/reducesum_gpu'; import {ResizeBilinear3DProgram} from './webgl/resize_bilinear_gpu'; -import {Slice1DProgram} from './webgl/slice1d_gpu'; +import {SliceProgram} from './webgl/slice_gpu'; import {TextureManager} from './webgl/texture_manager'; import * as unary_op from './webgl/unaryop_gpu'; import {UnaryOpProgram} from './webgl/unaryop_gpu'; @@ -88,18 +86,33 @@ export class NDArrayMathGPU extends NDArrayMath { protected slice1DInternal(input: Array1D, begin: number, size: number): Array1D { - const program = new Slice1DProgram(size); + const program = new SliceProgram([size]); + const customSetup = program.getCustomSetupFunc([begin]); + return this.compileAndRun(program, [input], null, customSetup); + } + + protected slice2DInternal(input: Array2D, begin: [number, number], size: [ + number, number + ]): Array2D { + const program = new SliceProgram(size); + const customSetup = program.getCustomSetupFunc(begin); + return this.compileAndRun(program, [input], null, customSetup); + } + + protected slice3DInternal( + input: Array3D, begin: [number, number, number], + size: [number, number, number]): Array3D { + const program = new SliceProgram(size); const customSetup = program.getCustomSetupFunc(begin); return this.compileAndRun(program, [input], null, customSetup); } - protected slice2DInternal( - input: Array2D, beginRowCol: [number, number], - sizeRowCol: [number, number]): Array2D { - const result = this.makeOutputArray(sizeRowCol); - this.copy2DInternal( - input, beginRowCol, sizeRowCol, result, [0, 0], sizeRowCol); - return result; + protected slice4DInternal( + input: Array4D, begin: [number, number, number, number], + size: [number, number, number, number]): Array4D { + const program = new SliceProgram(size); + const customSetup = program.getCustomSetupFunc(begin); + return this.compileAndRun(program, [input], null, customSetup); } protected copy2DInternal( @@ -114,17 +127,22 @@ export class NDArrayMathGPU extends NDArrayMath { } protected concat1DInternal(a: Array1D, b: Array1D): Array1D { - const program = new Concat1DProgram(a.size, b.size); + const program = new ConcatProgram(a.shape, b.shape, 0); return this.compileAndRun(program, [a, b]); } protected concat2DInternal(a: Array2D, b: Array2D, axis: number): Array2D { - const program = new Concat2DProgram(a.shape, b.shape, axis); + const program = new ConcatProgram(a.shape, b.shape, axis); return this.compileAndRun(program, [a, b]); } protected concat3DInternal(x1: Array3D, x2: Array3D, axis: number): Array3D { - const program = new Concat3DProgram(x1.shape, x2.shape, axis); + const program = new ConcatProgram(x1.shape, x2.shape, axis); + return this.compileAndRun(program, [x1, x2]); + } + + protected concat4DInternal(x1: Array4D, x2: Array4D, axis: number): Array4D { + const program = new ConcatProgram(x1.shape, x2.shape, axis); return this.compileAndRun(program, [x1, x2]); } diff --git a/src/math/slice_util.ts b/src/math/slice_util.ts new file mode 100644 index 0000000000..da6a2e25d9 --- /dev/null +++ b/src/math/slice_util.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as util from '../util'; + +import {NDArray} from './ndarray'; + +export function assertParams( + input: NDArray, begin: number[], size: number[]): void { + util.assert( + input.rank === begin.length, + `Error in slice${input.rank}D: Length of begin ${begin} must ` + + `match the rank of the array (${input.rank}).`); + util.assert( + input.rank === size.length, + `Error in slice${input.rank}D: Length of size ${size} must ` + + `match the rank of the array (${input.rank}).`); + + for (let i = 0; i < input.rank; ++i) { + util.assert( + begin[i] + size[i] <= input.shape[i], + `Error in slice${input.rank}D: begin[${i}] + size[${i}] ` + + `(${begin[i] + + size[i]}) would overflow input.shape[${i}] (${input.shape[i]})`); + } +} diff --git a/src/math/webgl/concat1d_gpu.ts b/src/math/webgl/concat1d_gpu.ts deleted file mode 100644 index c479384d99..0000000000 --- a/src/math/webgl/concat1d_gpu.ts +++ /dev/null @@ -1,45 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import * as concat_util from '../concat_util'; -import {GPGPUProgram} from './gpgpu_math'; - -export class Concat1DProgram implements GPGPUProgram { - variableNames = ['A', 'B']; - params: Array<{}> = []; - outputShape: number[] = []; - userCode: string; - - constructor(x1Size: number, x2Size: number) { - this.outputShape = - concat_util.computeConcatOutputShape([x1Size], [x2Size], 0); - this.userCode = ` - void main() { - int x = getOutputCoords(); - float value = 0.0; - - if (x < ${x1Size}) { - value = getA(x); - } else { - value = getB(x - ${x1Size}); - } - - setOutput(value); - } - `; - } -} diff --git a/src/math/webgl/concat1d_gpu_test.ts b/src/math/webgl/concat1d_gpu_test.ts deleted file mode 100644 index 6d8fb81910..0000000000 --- a/src/math/webgl/concat1d_gpu_test.ts +++ /dev/null @@ -1,103 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import * as test_util from '../../test_util'; -import {NDArrayMathCPU} from '../math_cpu'; -import {Array1D, initializeGPU} from '../ndarray'; -import {Concat1DProgram} from './concat1d_gpu'; -import {GPGPUContext} from './gpgpu_context'; -import * as gpgpu_math from './gpgpu_math'; -import {TextureManager} from './texture_manager'; -import * as webgl_util from './webgl_util'; - -describe('concat1d_gpu', () => { - let gpgpu: GPGPUContext; - let textureManager: TextureManager; - - beforeAll(() => { - gpgpu = new GPGPUContext(); - textureManager = new TextureManager(gpgpu); - initializeGPU(gpgpu, textureManager); - }); - - afterAll(() => { - textureManager.dispose(); - gpgpu.dispose(); - }); - - it('3 + 5', () => { - const a = Array1D.new([3]); - const b = Array1D.new([5]); - - const result = doConcat(a, b); - const expected = new Float32Array([3, 5]); - test_util.expectArraysClose(result, expected); - }); - - it('3 + [5,7]', () => { - const a = Array1D.new([3]); - const b = Array1D.new([5, 7]); - - const result = doConcat(a, b); - const expected = new Float32Array([3, 5, 7]); - test_util.expectArraysClose(result, expected); - }); - - it('[3,5] + 7', () => { - const a = Array1D.new([3, 5]); - const b = Array1D.new([7]); - - const result = doConcat(a, b); - const expected = new Float32Array([3, 5, 7]); - test_util.expectArraysClose(result, expected); - }); - - it('matches cpu with arrays bigger than max tex size', () => { - const maxTextureSize = webgl_util.queryMaxTextureSize(gpgpu.gl); - const a = Array1D.randUniform([maxTextureSize + 10], -1, 1); - const b = Array1D.randUniform([maxTextureSize + 10], -1, 1); - - const result = doConcat(a, b, false); - const expected = doCpuConcat(a, b); - a.dispose(); - b.dispose(); - - test_util.expectArraysClose(result, expected); - }); - - function doCpuConcat(a: Array1D, b: Array1D): Float32Array { - const mathCpu = new NDArrayMathCPU(); - return mathCpu.concat1D(a, b).getValues(); - } - - function doConcat(a: Array1D, b: Array1D, dispose = true): Float32Array { - const program = new Concat1DProgram(a.size, b.size); - const r = Array1D.zeros(program.outputShape as [number]); - const binary = gpgpu_math.compileProgram(gpgpu, program, [a, b], r); - gpgpu_math.runProgram(binary, [a, b], r); - const result = r.getValues(); - - if (dispose) { - a.dispose(); - b.dispose(); - } - r.dispose(); - gpgpu.deleteProgram(binary.webGLProgram); - - return result; - } -}); diff --git a/src/math/webgl/concat2d_gpu.ts b/src/math/webgl/concat2d_gpu.ts deleted file mode 100644 index 790771f6ba..0000000000 --- a/src/math/webgl/concat2d_gpu.ts +++ /dev/null @@ -1,52 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import * as concat_util from '../concat_util'; -import {GPGPUProgram} from './gpgpu_math'; - -export class Concat2DProgram implements GPGPUProgram { - variableNames = ['A', 'B']; - params: Array<{}> = []; - outputShape: number[] = []; - userCode: string; - - constructor( - x1Shape: [number, number], x2Shape: [number, number], axis: number) { - const yAxes = ['yR', 'yC']; - const concatAxis = yAxes[axis]; - this.params = [axis]; - this.outputShape = - concat_util.computeConcatOutputShape(x1Shape, x2Shape, axis); - this.userCode = ` - void main() { - ivec2 coords = getOutputCoords(); - int yR = coords.x; - int yC = coords.y; - - float value = 0.0; - if (${concatAxis} < ${x1Shape[axis]}) { - value = getA(yR, yC); - } else { - ${concatAxis} -= ${x1Shape[axis]}; - value = getB(yR, yC); - } - - setOutput(value); - } - `; - } -} diff --git a/src/math/webgl/concat3d_gpu.ts b/src/math/webgl/concat3d_gpu.ts deleted file mode 100644 index e53abcac24..0000000000 --- a/src/math/webgl/concat3d_gpu.ts +++ /dev/null @@ -1,54 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import * as concat_util from '../concat_util'; -import {GPGPUProgram} from './gpgpu_math'; - -export class Concat3DProgram implements GPGPUProgram { - variableNames = ['A', 'B']; - params: Array<{}> = []; - outputShape: number[] = []; - userCode: string; - - constructor( - x1Shape: [number, number, number], x2Shape: [number, number, number], - axis: number) { - const yAxes = ['yR', 'yC', 'yD']; - const concatAxis = yAxes[axis]; - this.params = [axis]; - this.outputShape = - concat_util.computeConcatOutputShape(x1Shape, x2Shape, axis); - this.userCode = ` - void main() { - ivec3 coords = getOutputCoords(); - int yR = coords.x; - int yC = coords.y; - int yD = coords.z; - - float value = 0.0; - if (${concatAxis} < ${x1Shape[axis]}) { - value = getA(yR, yC, yD); - } else { - ${concatAxis} -= ${x1Shape[axis]}; - value = getB(yR, yC, yD); - } - - setOutput(value); - } - `; - } -} diff --git a/src/math/webgl/concat3d_gpu_test.ts b/src/math/webgl/concat3d_gpu_test.ts deleted file mode 100644 index dbe1ef2efd..0000000000 --- a/src/math/webgl/concat3d_gpu_test.ts +++ /dev/null @@ -1,92 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import * as test_util from '../../test_util'; -import {Array3D, initializeGPU, NDArray} from '../ndarray'; -import {Concat3DProgram} from './concat3d_gpu'; -import {GPGPUContext} from './gpgpu_context'; -import * as gpgpu_math from './gpgpu_math'; -import {TextureManager} from './texture_manager'; - -describe('concat3d_gpu', () => { - it('concat axis=0', () => { - const x1 = new Float32Array([1, 11, 111, 2, 22, 222]); - const x2 = - new Float32Array([5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888]); - - const result = uploadConcat3dDownload(x1, x2, [1, 2, 3], [2, 2, 3], 0); - test_util.expectArraysClose( - result, new Float32Array([ - 1, 11, 111, 2, 22, 222, 5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888 - ]), - 1e-6); - - }); - - it('concat axis=1', () => { - const x1 = new Float32Array([1, 11, 111, 3, 33, 333]); - const x2 = - new Float32Array([5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888]); - - const result = uploadConcat3dDownload(x1, x2, [2, 1, 3], [2, 2, 3], 1); - test_util.expectArraysClose( - result, new Float32Array([ - 1, 11, 111, 5, 55, 555, 6, 66, 666, 3, 33, 333, 7, 77, 777, 8, 88, 888 - ]), - 1e-6); - }); - - it('concat axis=2', () => { - const x1 = new Float32Array([1, 11, 2, 22, 3, 33, 4, 44]); - const x2 = - new Float32Array([5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888]); - - const result = uploadConcat3dDownload(x1, x2, [2, 2, 2], [2, 2, 3], 2); - test_util.expectArraysClose( - result, new Float32Array([ - 1, 11, 5, 55, 555, 2, 22, 6, 66, 666, - 3, 33, 7, 77, 777, 4, 44, 8, 88, 888 - ]), - 1e-6); - }); -}); - -function uploadConcat3dDownload( - a: Float32Array, b: Float32Array, aShape: [number, number, number], - bShape: [number, number, number], axis: number): Float32Array { - const gpgpu = new GPGPUContext(); - gpgpu.enableAutomaticDebugValidation(true); - const textureManager = new TextureManager(gpgpu); - initializeGPU(gpgpu, textureManager); - - const program = new Concat3DProgram(aShape, bShape, axis); - const aArr = Array3D.new(aShape, a); - const bArr = Array3D.new(bShape, b); - const rArr = NDArray.zeros(program.outputShape); - const binary = gpgpu_math.compileProgram(gpgpu, program, [aArr, bArr], rArr); - gpgpu_math.runProgram(binary, [aArr, bArr], rArr); - const result = rArr.getValues(); - - aArr.dispose(); - bArr.dispose(); - rArr.dispose(); - textureManager.dispose(); - gpgpu.deleteProgram(binary.webGLProgram); - gpgpu.dispose(); - - return result; -} diff --git a/src/math/webgl/concat_gpu.ts b/src/math/webgl/concat_gpu.ts new file mode 100644 index 0000000000..a92c7ea077 --- /dev/null +++ b/src/math/webgl/concat_gpu.ts @@ -0,0 +1,99 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as concat_util from '../concat_util'; +import {GPGPUProgram} from './gpgpu_math'; + +export class ConcatProgram implements GPGPUProgram { + variableNames = ['A', 'B']; + params: Array<{}> = []; + outputShape: number[] = []; + userCode: string; + + constructor(aShape: number[], bShape: number[], axis: number) { + const yAxes = ['yR', 'yC', 'yD', 'yW']; + const concatAxis = yAxes[axis]; + this.params = [axis]; + this.outputShape = concat_util.computeOutShape(aShape, bShape, axis); + + const dType = getDataType(aShape.length); + const unpackSnippet = getUnpack(aShape.length); + const sampleCoords = getSampleCoords(aShape.length); + + this.userCode = ` + void main() { + ${dType} coords = getOutputCoords(); + ${unpackSnippet} + + float value = 0.0; + if (${concatAxis} < ${aShape[axis]}) { + value = getA(${sampleCoords}); + } else { + ${concatAxis} -= ${aShape[axis]}; + value = getB(${sampleCoords}); + } + + setOutput(value); + } + `; + } +} + +function getSampleCoords(rank: number): string { + if (rank === 1) { + return 'yR'; + } else if (rank === 2) { + return 'yR, yC'; + } else if (rank === 3) { + return 'yR, yC, yD'; + } else if (rank === 4) { + return 'yR, yC, yD, yW'; + } else { + throw Error(`Concat for rank ${rank} is not yet supported`); + } +} + +function getUnpack(rank: number): string { + let res = rank === 1 ? 'int yR = coords;' : 'int yR = coords.x;'; + if (rank > 1) { + res += '\nint yC = coords.y;'; + } + if (rank > 2) { + res += '\nint yD = coords.z;'; + } + if (rank > 3) { + res += '\nint yW = coords.w;'; + } + if (rank > 4) { + throw Error(`Concat for rank ${rank} is not yet supported`); + } + return res; +} + +function getDataType(rank: number): string { + if (rank === 1) { + return 'int'; + } else if (rank === 2) { + return 'ivec2'; + } else if (rank === 3) { + return 'ivec3'; + } else if (rank === 4) { + return 'ivec4'; + } else { + throw Error(`Concat for rank ${rank} is not yet supported`); + } +} diff --git a/src/math/webgl/concat2d_gpu_test.ts b/src/math/webgl/concat_gpu_test.ts similarity index 50% rename from src/math/webgl/concat2d_gpu_test.ts rename to src/math/webgl/concat_gpu_test.ts index fa49c6af32..70edbc5c0e 100644 --- a/src/math/webgl/concat2d_gpu_test.ts +++ b/src/math/webgl/concat_gpu_test.ts @@ -17,13 +17,91 @@ import * as test_util from '../../test_util'; import {NDArrayMathCPU} from '../math_cpu'; -import {Array2D, initializeGPU} from '../ndarray'; -import {Concat2DProgram} from './concat2d_gpu'; +import {Array1D, Array2D, Array3D, initializeGPU, NDArray} from '../ndarray'; +import {ConcatProgram} from './concat_gpu'; import {GPGPUContext} from './gpgpu_context'; import * as gpgpu_math from './gpgpu_math'; import {TextureManager} from './texture_manager'; import * as webgl_util from './webgl_util'; +describe('concat1d_gpu', () => { + let gpgpu: GPGPUContext; + let textureManager: TextureManager; + + beforeAll(() => { + gpgpu = new GPGPUContext(); + textureManager = new TextureManager(gpgpu); + initializeGPU(gpgpu, textureManager); + }); + + afterAll(() => { + textureManager.dispose(); + gpgpu.dispose(); + }); + + it('3 + 5', () => { + const a = Array1D.new([3]); + const b = Array1D.new([5]); + + const result = doConcat(a, b); + const expected = new Float32Array([3, 5]); + test_util.expectArraysClose(result, expected); + }); + + it('3 + [5,7]', () => { + const a = Array1D.new([3]); + const b = Array1D.new([5, 7]); + + const result = doConcat(a, b); + const expected = new Float32Array([3, 5, 7]); + test_util.expectArraysClose(result, expected); + }); + + it('[3,5] + 7', () => { + const a = Array1D.new([3, 5]); + const b = Array1D.new([7]); + + const result = doConcat(a, b); + const expected = new Float32Array([3, 5, 7]); + test_util.expectArraysClose(result, expected); + }); + + it('matches cpu with arrays bigger than max tex size', () => { + const maxTextureSize = webgl_util.queryMaxTextureSize(gpgpu.gl); + const a = Array1D.randUniform([maxTextureSize + 10], -1, 1); + const b = Array1D.randUniform([maxTextureSize + 10], -1, 1); + + const result = doConcat(a, b, false); + const expected = doCpuConcat(a, b); + a.dispose(); + b.dispose(); + + test_util.expectArraysClose(result, expected); + }); + + function doCpuConcat(a: Array1D, b: Array1D): Float32Array { + const mathCpu = new NDArrayMathCPU(); + return mathCpu.concat1D(a, b).getValues(); + } + + function doConcat(a: Array1D, b: Array1D, dispose = true): Float32Array { + const program = new ConcatProgram(a.shape, b.shape, 0); + const r = Array1D.zeros(program.outputShape as [number]); + const binary = gpgpu_math.compileProgram(gpgpu, program, [a, b], r); + gpgpu_math.runProgram(binary, [a, b], r); + const result = r.getValues(); + + if (dispose) { + a.dispose(); + b.dispose(); + } + r.dispose(); + gpgpu.deleteProgram(binary.webGLProgram); + + return result; + } +}); + describe('concat2d_gpu', () => { let gpgpu: GPGPUContext; let textureManager: TextureManager; @@ -130,7 +208,7 @@ describe('concat2d_gpu', () => { function doConcat( a: Array2D, b: Array2D, axis: number, dispose = true): Array2D { - const program = new Concat2DProgram(a.shape, b.shape, axis); + const program = new ConcatProgram(a.shape, b.shape, axis); const r = Array2D.zeros(program.outputShape as [number, number]); const binary = gpgpu_math.compileProgram(gpgpu, program, [a, b], r); gpgpu_math.runProgram(binary, [a, b], r); @@ -144,3 +222,72 @@ describe('concat2d_gpu', () => { return r; } }); + +describe('concat3d_gpu', () => { + it('concat axis=0', () => { + const x1 = new Float32Array([1, 11, 111, 2, 22, 222]); + const x2 = + new Float32Array([5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888]); + + const result = uploadConcat3dDownload(x1, x2, [1, 2, 3], [2, 2, 3], 0); + test_util.expectArraysClose( + result, new Float32Array([ + 1, 11, 111, 2, 22, 222, 5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888 + ]), + 1e-6); + + }); + + it('concat axis=1', () => { + const x1 = new Float32Array([1, 11, 111, 3, 33, 333]); + const x2 = + new Float32Array([5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888]); + + const result = uploadConcat3dDownload(x1, x2, [2, 1, 3], [2, 2, 3], 1); + test_util.expectArraysClose( + result, new Float32Array([ + 1, 11, 111, 5, 55, 555, 6, 66, 666, 3, 33, 333, 7, 77, 777, 8, 88, 888 + ]), + 1e-6); + }); + + it('concat axis=2', () => { + const x1 = new Float32Array([1, 11, 2, 22, 3, 33, 4, 44]); + const x2 = + new Float32Array([5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888]); + + const result = uploadConcat3dDownload(x1, x2, [2, 2, 2], [2, 2, 3], 2); + test_util.expectArraysClose( + result, new Float32Array([ + 1, 11, 5, 55, 555, 2, 22, 6, 66, 666, + 3, 33, 7, 77, 777, 4, 44, 8, 88, 888 + ]), + 1e-6); + }); +}); + +function uploadConcat3dDownload( + a: Float32Array, b: Float32Array, aShape: [number, number, number], + bShape: [number, number, number], axis: number): Float32Array { + const gpgpu = new GPGPUContext(); + gpgpu.enableAutomaticDebugValidation(true); + const textureManager = new TextureManager(gpgpu); + initializeGPU(gpgpu, textureManager); + + const program = new ConcatProgram(aShape, bShape, axis); + const aArr = Array3D.new(aShape, a); + const bArr = Array3D.new(bShape, b); + const rArr = NDArray.zeros(program.outputShape); + const binary = gpgpu_math.compileProgram(gpgpu, program, [aArr, bArr], rArr); + gpgpu_math.runProgram(binary, [aArr, bArr], rArr); + const result = rArr.getValues(); + + aArr.dispose(); + bArr.dispose(); + rArr.dispose(); + textureManager.dispose(); + gpgpu.deleteProgram(binary.webGLProgram); + gpgpu.dispose(); + + return result; +} diff --git a/src/math/webgl/slice1d_gpu.ts b/src/math/webgl/slice1d_gpu.ts deleted file mode 100644 index 6a5ab9da8e..0000000000 --- a/src/math/webgl/slice1d_gpu.ts +++ /dev/null @@ -1,51 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {GPGPUContext} from './gpgpu_context'; -import {GPGPUProgram} from './gpgpu_math'; - -export class Slice1DProgram implements GPGPUProgram { - variableNames = ['source']; - params: Array<{}>; - outputShape: number[]; - userCode: string; - - // Caching uniform location for speed. - startLoc: WebGLUniformLocation; - - constructor(destSize: number) { - this.outputShape = [destSize]; - this.params = []; - this.userCode = ` - uniform int start; - - void main() { - int sourceIndex = start + getOutputCoords(); - setOutput(getSource(sourceIndex)); - } - `; - } - - getCustomSetupFunc(start: number) { - return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => { - if (this.startLoc == null) { - this.startLoc = gpgpu.getUniformLocation(webGLProgram, 'start'); - } - gpgpu.gl.uniform1i(this.startLoc, start); - }; - } -} diff --git a/src/math/webgl/slice1d_gpu_test.ts b/src/math/webgl/slice1d_gpu_test.ts deleted file mode 100644 index 8529666f2e..0000000000 --- a/src/math/webgl/slice1d_gpu_test.ts +++ /dev/null @@ -1,84 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {Array1D, initializeGPU} from '../ndarray'; -import {GPGPUContext} from './gpgpu_context'; -import * as gpgpu_math from './gpgpu_math'; -import {Slice1DProgram} from './slice1d_gpu'; -import {TextureManager} from './texture_manager'; -import * as webgl_util from './webgl_util'; - -describe('slice1d_gpu', () => { - let gpgpu: GPGPUContext; - let texManager: TextureManager; - - beforeAll(() => { - gpgpu = new GPGPUContext(); - texManager = new TextureManager(gpgpu); - initializeGPU(gpgpu, texManager); - }); - - afterAll(() => { - texManager.dispose(); - gpgpu.dispose(); - }); - - it('slices [1] into [1] (effectively a copy)', () => { - const a = Array1D.new([5]); - const result = doSlice(a, 0, 1); - expect(result.shape).toEqual([1]); - expect(result.get(0)).toBe(5); - }); - - it('slices [5] into [2] starting at 3', () => { - const a = Array1D.new([1, 2, 3, 4, 5]); - const result = doSlice(a, 3, 2); - expect(result.shape).toEqual([2]); - expect(result.getValues()).toEqual(new Float32Array([4, 5])); - }); - - it('slices [5] into [3] starting at 1', () => { - const a = Array1D.new([1, 2, 3, 4, 5]); - const result = doSlice(a, 1, 3); - expect(result.shape).toEqual([3]); - expect(result.getValues()).toEqual(new Float32Array([2, 3, 4])); - }); - - it('slices array that is bigger than max tex size', () => { - const maxTexSize = webgl_util.queryMaxTextureSize(gpgpu.gl); - const a = Array1D.randUniform([maxTexSize + 10], -1, 1); - const expected = a.get(a.size - 1); - const result = doSlice(a, a.size - 1, 1); - expect(result.shape).toEqual([1]); - expect(result.get(0)).toEqual(expected); - }); - - - function doSlice(a: Array1D, start: number, size: number): Array1D { - const program = new Slice1DProgram(size); - const result = Array1D.zeros([size]); - - const binary = gpgpu_math.compileProgram(gpgpu, program, [a], result); - const customSetup = program.getCustomSetupFunc(start); - gpgpu_math.runProgram(binary, [a], result, customSetup); - - a.dispose(); - gpgpu.deleteProgram(binary.webGLProgram); - - return result; - } -}); diff --git a/src/math/webgl/slice_gpu.ts b/src/math/webgl/slice_gpu.ts new file mode 100644 index 0000000000..93832cac81 --- /dev/null +++ b/src/math/webgl/slice_gpu.ts @@ -0,0 +1,101 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {GPGPUContext} from './gpgpu_context'; +import {GPGPUProgram} from './gpgpu_math'; + +export class SliceProgram implements GPGPUProgram { + variableNames = ['source']; + params: Array<{}>; + outputShape: number[]; + userCode: string; + rank: number; + + // Caching uniform location for speed. + startLoc: WebGLUniformLocation; + + constructor(destSize: number[]) { + this.outputShape = destSize; + this.rank = destSize.length; + this.params = []; + + const dtype = getDataType(this.rank); + const sourceCoords = getCoords(this.rank); + + this.userCode = ` + uniform ${dtype} start; + + void main() { + ${dtype} sourceLoc = start + getOutputCoords(); + setOutput(getSource(${sourceCoords})); + } + `; + } + + getCustomSetupFunc(start: number[]) { + if (start.length !== this.rank) { + throw Error( + `The rank (${this.rank}) of the program must match the ` + + `length of start (${start.length})`); + } + return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => { + if (this.startLoc == null) { + this.startLoc = gpgpu.getUniformLocation(webGLProgram, 'start'); + } + if (this.rank === 1) { + gpgpu.gl.uniform1i(this.startLoc, start[0]); + } else if (this.rank === 2) { + gpgpu.gl.uniform2i(this.startLoc, start[0], start[1]); + } else if (this.rank === 3) { + gpgpu.gl.uniform3i(this.startLoc, start[0], start[1], start[2]); + } else if (this.rank === 4) { + gpgpu.gl.uniform4i( + this.startLoc, start[0], start[1], start[2], start[3]); + } else { + throw Error(`Slicing for rank ${this.rank} is not yet supported`); + } + }; + } +} + +function getCoords(rank: number): string { + if (rank === 1) { + return 'sourceLoc'; + } else if (rank === 2) { + return 'sourceLoc.x, sourceLoc.y'; + } else if (rank === 3) { + return 'sourceLoc.x, sourceLoc.y, sourceLoc.z'; + } else if (rank === 4) { + return 'sourceLoc.x, sourceLoc.y, sourceLoc.z, sourceLoc.w'; + } else { + throw Error(`Slicing for rank ${rank} is not yet supported`); + } +} + +function getDataType(rank: number): string { + if (rank === 1) { + return 'int'; + } else if (rank === 2) { + return 'ivec2'; + } else if (rank === 3) { + return 'ivec3'; + } else if (rank === 4) { + return 'ivec4'; + } else { + throw Error(`Slicing for rank ${rank} is not yet supported`); + } +} diff --git a/src/math/webgl/slice_gpu_test.ts b/src/math/webgl/slice_gpu_test.ts new file mode 100644 index 0000000000..eaafc8ef9f --- /dev/null +++ b/src/math/webgl/slice_gpu_test.ts @@ -0,0 +1,209 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Array1D, Array2D, Array3D, initializeGPU} from '../ndarray'; +import {GPGPUContext} from './gpgpu_context'; +import * as gpgpu_math from './gpgpu_math'; +import {SliceProgram} from './slice_gpu'; +import {TextureManager} from './texture_manager'; +import * as webgl_util from './webgl_util'; + +describe('slice1d_gpu', () => { + let gpgpu: GPGPUContext; + let texManager: TextureManager; + + beforeAll(() => { + gpgpu = new GPGPUContext(); + texManager = new TextureManager(gpgpu); + initializeGPU(gpgpu, texManager); + }); + + afterAll(() => { + texManager.dispose(); + gpgpu.dispose(); + }); + + it('slices 1x1 into 1x1 (effectively a copy)', () => { + const a = Array1D.new([5]); + const result = doSlice1D(a, 0, 1); + expect(result.shape).toEqual([1]); + expect(result.get(0)).toBe(5); + }); + + it('slices 5x1 into shape 2x1 starting at 3', () => { + const a = Array1D.new([1, 2, 3, 4, 5]); + const result = doSlice1D(a, 3, 2); + expect(result.shape).toEqual([2]); + expect(result.getValues()).toEqual(new Float32Array([4, 5])); + }); + + it('slices 5x1 into shape 3x1 starting at 1', () => { + const a = Array1D.new([1, 2, 3, 4, 5]); + const result = doSlice1D(a, 1, 3); + expect(result.shape).toEqual([3]); + expect(result.getValues()).toEqual(new Float32Array([2, 3, 4])); + }); + + it('slices array that is bigger than max tex size', () => { + const maxTexSize = webgl_util.queryMaxTextureSize(gpgpu.gl); + const a = Array1D.randUniform([maxTexSize + 10], -1, 1); + const expected = a.get(a.size - 1); + const result = doSlice1D(a, a.size - 1, 1); + expect(result.shape).toEqual([1]); + expect(result.get(0)).toEqual(expected); + }); + + + function doSlice1D(a: Array1D, start: number, size: number): Array1D { + const program = new SliceProgram([size]); + const result = Array1D.zeros([size]); + + const binary = gpgpu_math.compileProgram(gpgpu, program, [a], result); + const customSetup = program.getCustomSetupFunc([start]); + gpgpu_math.runProgram(binary, [a], result, customSetup); + + a.dispose(); + gpgpu.deleteProgram(binary.webGLProgram); + + return result; + } +}); + +describe('slice2d_gpu', () => { + let gpgpu: GPGPUContext; + let texManager: TextureManager; + + beforeAll(() => { + gpgpu = new GPGPUContext(); + texManager = new TextureManager(gpgpu); + initializeGPU(gpgpu, texManager); + }); + + afterAll(() => { + texManager.dispose(); + gpgpu.dispose(); + }); + + it('slices 1x1 into shape 1x1 (effectively a copy)', () => { + const a = Array2D.new([1, 1], [[5]]); + const result = doSlice2D(a, [0, 0], [1, 1]); + expect(result.shape).toEqual([1, 1]); + expect(result.get(0, 0)).toBe(5); + }); + + it('slices 3x3 array into 2x2 starting at [1, 1]', () => { + const a = Array2D.new([3, 3], [1, 2, 3, 4, 5, 6, 7, 8, 9]); + const result = doSlice2D(a, [1, 1], [2, 2]); + expect(result.shape).toEqual([2, 2]); + expect(result.getValues()).toEqual(new Float32Array([5, 6, 8, 9])); + }); + + it('slices 3x3 into 2x1 starting at [1,1]', () => { + const a = Array2D.new([3, 3], [1, 2, 3, 4, 5, 6, 7, 8, 9]); + const result = doSlice2D(a, [1, 1], [2, 1]); + expect(result.shape).toEqual([2, 1]); + expect(result.getValues()).toEqual(new Float32Array([5, 8])); + }); + + it('slices array that is bigger than max tex size', () => { + const maxTexSize = webgl_util.queryMaxTextureSize(gpgpu.gl); + const a = Array2D.randUniform([maxTexSize + 10, 1], -1, 1); + const expected = a.get(a.size - 1, 0); + const result = doSlice2D(a, [a.size - 1, 0], [1, 1]); + expect(result.shape).toEqual([1, 1]); + expect(result.get(0, 0)).toEqual(expected); + }); + + + function doSlice2D( + a: Array2D, start: [number, number], size: [number, number]): Array2D { + const program = new SliceProgram(size); + const result = Array2D.zeros(size); + + const binary = gpgpu_math.compileProgram(gpgpu, program, [a], result); + const customSetup = program.getCustomSetupFunc(start); + gpgpu_math.runProgram(binary, [a], result, customSetup); + + a.dispose(); + gpgpu.deleteProgram(binary.webGLProgram); + + return result; + } +}); + +describe('slice3d_gpu', () => { + let gpgpu: GPGPUContext; + let texManager: TextureManager; + + beforeAll(() => { + gpgpu = new GPGPUContext(); + texManager = new TextureManager(gpgpu); + initializeGPU(gpgpu, texManager); + }); + + afterAll(() => { + texManager.dispose(); + gpgpu.dispose(); + }); + + it('slices 1x1x1 into shape 1x1x1 (effectively a copy)', () => { + const a = Array3D.new([1, 1, 1], [[[5]]]); + const result = doSlice3D(a, [0, 0, 0], [1, 1, 1]); + expect(result.shape).toEqual([1, 1, 1]); + expect(result.get(0, 0, 0)).toBe(5); + }); + + it('slices 2x2x2 array into 1x2x2 starting at [1, 0, 0]', () => { + const a = Array3D.new([2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8]); + const result = doSlice3D(a, [1, 0, 0], [1, 2, 2]); + expect(result.shape).toEqual([1, 2, 2]); + expect(result.getValues()).toEqual(new Float32Array([5, 6, 7, 8])); + }); + + it('slices 2x2x2 array into 2x1x1 starting at [0, 1, 1]', () => { + const a = Array3D.new([2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8]); + const result = doSlice3D(a, [0, 1, 1], [2, 1, 1]); + expect(result.shape).toEqual([2, 1, 1]); + expect(result.getValues()).toEqual(new Float32Array([4, 8])); + }); + + it('slices array that is bigger than max tex size', () => { + const maxTexSize = webgl_util.queryMaxTextureSize(gpgpu.gl); + const a = Array3D.randUniform([maxTexSize + 10, 1, 1], -1, 1); + const expected = a.get(a.size - 1, 0, 0); + const result = doSlice3D(a, [a.size - 1, 0, 0], [1, 1, 1]); + expect(result.shape).toEqual([1, 1, 1]); + expect(result.get(0, 0, 0)).toEqual(expected); + }); + + + function doSlice3D( + a: Array3D, start: [number, number, number], + size: [number, number, number]): Array3D { + const program = new SliceProgram(size); + const result = Array3D.zeros(size); + + const binary = gpgpu_math.compileProgram(gpgpu, program, [a], result); + const customSetup = program.getCustomSetupFunc(start); + gpgpu_math.runProgram(binary, [a], result, customSetup); + + a.dispose(); + gpgpu.deleteProgram(binary.webGLProgram); + + return result; + } +}); From 284dc94cd7aed24e428e015b26646e1c98455ea0 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sun, 24 Sep 2017 22:01:17 -0400 Subject: [PATCH 4/6] update .gitignore --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index 2b33fc1994..de5be23998 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,3 @@ npm-debug.log .DS_Store dist/ .idea/ -demos/performance_rnn From 83e76cdd9f3b45f6830d487f52531e34e5499771 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sun, 24 Sep 2017 22:45:19 -0400 Subject: [PATCH 5/6] add slice implementations for math_cpu and tests --- src/math/math_cpu.ts | 50 ++++++++++++++++++++++------ src/math/math_cpu_test.ts | 56 ++++++++++++++++++++++++++++++++ src/math/math_gpu_test.ts | 68 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 165 insertions(+), 9 deletions(-) diff --git a/src/math/math_cpu.ts b/src/math/math_cpu.ts index 18c51fd332..1558413b72 100644 --- a/src/math/math_cpu.ts +++ b/src/math/math_cpu.ts @@ -36,27 +36,59 @@ export class NDArrayMathCPU extends NDArrayMath { protected slice1DInternal(input: Array1D, begin: number, size: number): Array1D { - throw new Error('Method not implemented.'); + const newVals = input.getValues().slice(begin, begin + size); + return Array1D.new(newVals); } - protected slice2DInternal( - input: Array2D, beginRowCol: [number, number], - sizeRowCol: [number, number]): Array2D { - const result = Array2D.zeros(sizeRowCol); - this.copy2DInternal( - input, beginRowCol, sizeRowCol, result, [0, 0], sizeRowCol); + protected slice2DInternal(input: Array2D, begin: [number, number], size: [ + number, number + ]): Array2D { + const result = Array2D.zeros(size); + const [startI, startJ] = begin; + + for (let i = 0; i < size[0]; ++i) { + for (let j = 0; j < size[1]; ++j) { + const val = input.get(i + startI, j + startJ); + result.set(val, i, j); + } + } return result; } protected slice3DInternal( input: Array3D, begin: [number, number, number], size: [number, number, number]): Array3D { - throw new Error('Method not implemented.'); + const result = Array3D.zeros(size); + const [startI, startJ, startK] = begin; + + for (let i = 0; i < size[0]; ++i) { + for (let j = 0; j < size[1]; ++j) { + for (let k = 0; k < size[2]; ++k) { + const val = input.get(i + startI, j + startJ, k + startK); + result.set(val, i, j, k); + } + } + } + return result; } protected slice4DInternal( input: Array4D, begin: [number, number, number, number], size: [number, number, number, number]): Array4D { - throw new Error('Method not implemented.'); + const result = Array4D.zeros(size); + const [startI, startJ, startK, startL] = begin; + + for (let i = 0; i < size[0]; ++i) { + for (let j = 0; j < size[1]; ++j) { + for (let k = 0; k < size[2]; ++k) { + for (let l = 0; l < size[3]; ++l) { + const val = + input.get(i + startI, j + startJ, k + startK, l + startL); + result.set(val, i, j, k, l); + } + } + } + } + return result; } protected copy2DInternal( diff --git a/src/math/math_cpu_test.ts b/src/math/math_cpu_test.ts index 76660170cc..5accfd4495 100644 --- a/src/math/math_cpu_test.ts +++ b/src/math/math_cpu_test.ts @@ -32,6 +32,34 @@ describe('NDArrayMathCPU clone', () => { }); }); +describe('NDArrayMathCPU slice1D', () => { + let math: NDArrayMathCPU; + beforeEach(() => { + math = new NDArrayMathCPU(); + }); + + it('slices 1x1 into 1x1 (effectively a copy)', () => { + const a = Array1D.new([5]); + const result = math.slice1D(a, 0, 1); + expect(result.shape).toEqual([1]); + expect(result.get(0)).toBe(5); + }); + + it('slices 5x1 into shape 2x1 starting at 3', () => { + const a = Array1D.new([1, 2, 3, 4, 5]); + const result = math.slice1D(a, 3, 2); + expect(result.shape).toEqual([2]); + expect(result.getValues()).toEqual(new Float32Array([4, 5])); + }); + + it('slices 5x1 into shape 3x1 starting at 1', () => { + const a = Array1D.new([1, 2, 3, 4, 5]); + const result = math.slice1D(a, 1, 3); + expect(result.shape).toEqual([3]); + expect(result.getValues()).toEqual(new Float32Array([2, 3, 4])); + }); +}); + describe('NDArrayMathCPU slice2D', () => { let math: NDArrayMathCPU; beforeEach(() => { @@ -72,6 +100,34 @@ describe('NDArrayMathCPU slice2D', () => { }); }); +describe('NDArrayMathCPU slice3D', () => { + let math: NDArrayMathCPU; + beforeEach(() => { + math = new NDArrayMathCPU(); + }); + + it('slices 1x1x1 into shape 1x1x1 (effectively a copy)', () => { + const a = Array3D.new([1, 1, 1], [[[5]]]); + const result = math.slice3D(a, [0, 0, 0], [1, 1, 1]); + expect(result.shape).toEqual([1, 1, 1]); + expect(result.get(0, 0, 0)).toBe(5); + }); + + it('slices 2x2x2 array into 1x2x2 starting at [1, 0, 0]', () => { + const a = Array3D.new([2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8]); + const result = math.slice3D(a, [1, 0, 0], [1, 2, 2]); + expect(result.shape).toEqual([1, 2, 2]); + expect(result.getValues()).toEqual(new Float32Array([5, 6, 7, 8])); + }); + + it('slices 2x2x2 array into 2x1x1 starting at [0, 1, 1]', () => { + const a = Array3D.new([2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8]); + const result = math.slice3D(a, [0, 1, 1], [2, 1, 1]); + expect(result.shape).toEqual([2, 1, 1]); + expect(result.getValues()).toEqual(new Float32Array([4, 8])); + }); +}); + describe('NDArrayMathCPU copy2D', () => { let math: NDArrayMathCPU; beforeEach(() => { diff --git a/src/math/math_gpu_test.ts b/src/math/math_gpu_test.ts index 0b7383280b..effa08b76d 100644 --- a/src/math/math_gpu_test.ts +++ b/src/math/math_gpu_test.ts @@ -178,6 +178,40 @@ describe('NDArrayMathGPU clone', () => { }); }); +describe('NDArrayMathCPU slice1D', () => { + let math: NDArrayMathGPU; + beforeEach(() => { + math = new NDArrayMathGPU(); + math.startScope(); + }); + + afterEach(() => { + math.endScope(null); + math.dispose(); + }); + + it('slices 1x1 into 1x1 (effectively a copy)', () => { + const a = Array1D.new([5]); + const result = math.slice1D(a, 0, 1); + expect(result.shape).toEqual([1]); + expect(result.get(0)).toBe(5); + }); + + it('slices 5x1 into shape 2x1 starting at 3', () => { + const a = Array1D.new([1, 2, 3, 4, 5]); + const result = math.slice1D(a, 3, 2); + expect(result.shape).toEqual([2]); + expect(result.getValues()).toEqual(new Float32Array([4, 5])); + }); + + it('slices 5x1 into shape 3x1 starting at 1', () => { + const a = Array1D.new([1, 2, 3, 4, 5]); + const result = math.slice1D(a, 1, 3); + expect(result.shape).toEqual([3]); + expect(result.getValues()).toEqual(new Float32Array([2, 3, 4])); + }); +}); + describe('NDArrayMathGPU slice2D', () => { let math: NDArrayMathGPU; beforeEach(() => { @@ -229,6 +263,40 @@ describe('NDArrayMathGPU slice2D', () => { }); }); +describe('NDArrayMathCPU slice3D', () => { + let math: NDArrayMathGPU; + beforeEach(() => { + math = new NDArrayMathGPU(); + math.startScope(); + }); + + afterEach(() => { + math.endScope(null); + math.dispose(); + }); + + it('slices 1x1x1 into shape 1x1x1 (effectively a copy)', () => { + const a = Array3D.new([1, 1, 1], [[[5]]]); + const result = math.slice3D(a, [0, 0, 0], [1, 1, 1]); + expect(result.shape).toEqual([1, 1, 1]); + expect(result.get(0, 0, 0)).toBe(5); + }); + + it('slices 2x2x2 array into 1x2x2 starting at [1, 0, 0]', () => { + const a = Array3D.new([2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8]); + const result = math.slice3D(a, [1, 0, 0], [1, 2, 2]); + expect(result.shape).toEqual([1, 2, 2]); + expect(result.getValues()).toEqual(new Float32Array([5, 6, 7, 8])); + }); + + it('slices 2x2x2 array into 2x1x1 starting at [0, 1, 1]', () => { + const a = Array3D.new([2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8]); + const result = math.slice3D(a, [0, 1, 1], [2, 1, 1]); + expect(result.shape).toEqual([2, 1, 1]); + expect(result.getValues()).toEqual(new Float32Array([4, 8])); + }); +}); + describe('NDArrayMathGPU copy2D', () => { let math: NDArrayMathGPU; beforeEach(() => { From 0718cb22d9c0b876a56329d37783d4f8bb0b9ea8 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Wed, 27 Sep 2017 17:06:39 -0400 Subject: [PATCH 6/6] address comments --- src/math/concat_util.ts | 14 +++---- src/math/math.ts | 10 ++--- src/math/math_cpu_test.ts | 40 +++++++++++++++--- src/math/math_gpu_test.ts | 41 +++++++++++++++++- src/math/slice_util.ts | 6 +-- src/math/webgl/concat_gpu_test.ts | 69 ++++++++++++++++++++++++++++++- src/math/webgl/gpgpu_context.ts | 6 +++ src/math/webgl/slice_gpu.ts | 7 +++- src/math/webgl/slice_gpu_test.ts | 69 ++++++++++++++++++++++++++++++- 9 files changed, 237 insertions(+), 25 deletions(-) diff --git a/src/math/concat_util.ts b/src/math/concat_util.ts index e31d32ef6f..05568b9e50 100644 --- a/src/math/concat_util.ts +++ b/src/math/concat_util.ts @@ -18,22 +18,22 @@ import * as util from '../util'; export function assertParams(aShape: number[], bShape: number[], axis: number) { - const rank = aShape.length; + const aRank = aShape.length; const bRank = bShape.length; util.assert( aShape.length === bShape.length, - `Error in concat${rank}D: rank of x1 (${rank}) and x2 (${bRank}) ` + + `Error in concat${aRank}D: rank of x1 (${aRank}) and x2 (${bRank}) ` + `must be the same.`); util.assert( - axis >= 0 && axis < rank, - `Error in concat${rank}D: axis must be ` + - `between 0 and ${rank - 1}.`); + axis >= 0 && axis < aRank, + `Error in concat${aRank}D: axis must be ` + + `between 0 and ${aRank - 1}.`); - for (let i = 0; i < rank; i++) { + for (let i = 0; i < aRank; i++) { util.assert( (i === axis) || (aShape[i] === bShape[i]), - `Error in concat${rank}D: Shape (${aShape}) does not match ` + + `Error in concat${aRank}D: Shape (${aShape}) does not match ` + `(${bShape}) along the non-concatenated axis ${i}.`); } } diff --git a/src/math/math.ts b/src/math/math.ts index 5bb962c4f0..904b63f5d1 100644 --- a/src/math/math.ts +++ b/src/math/math.ts @@ -23,7 +23,7 @@ import * as copy2d_util from './copy2d_util'; import {Array1D, Array2D, Array3D, Array4D, NDArray, Scalar} from './ndarray'; import * as slice_util from './slice_util'; -export type ScopeResult = NDArray[] | NDArray | void; +export type ScopeResult = NDArray[]|NDArray|void; export interface LSTMCell { (data: Array2D, c: Array2D, h: Array2D): [Array2D, Array2D]; @@ -364,7 +364,7 @@ export abstract class NDArrayMath { * @param size The size of the slice. */ slice1D(input: Array1D, begin: number, size: number): Array1D { - slice_util.assertParams(input, [begin], [size]); + slice_util.assertParamsValid(input, [begin], [size]); return this.executeOp( 'slice1D', () => this.slice1DInternal(input, begin, size)); } @@ -381,7 +381,7 @@ export abstract class NDArrayMath { */ slice2D(input: Array2D, begin: [number, number], size: [number, number]): Array2D { - slice_util.assertParams(input, begin, size); + slice_util.assertParamsValid(input, begin, size); return this.executeOp( 'slice2D', () => this.slice2DInternal(input, begin, size)); } @@ -399,7 +399,7 @@ export abstract class NDArrayMath { slice3D(input: Array3D, begin: [number, number, number], size: [ number, number, number ]): Array3D { - slice_util.assertParams(input, begin, size); + slice_util.assertParamsValid(input, begin, size); return this.executeOp( 'slice3D', () => this.slice3DInternal(input, begin, size)); } @@ -419,7 +419,7 @@ export abstract class NDArrayMath { slice4D(input: Array4D, begin: [number, number, number, number], size: [ number, number, number, number ]): Array4D { - slice_util.assertParams(input, begin, size); + slice_util.assertParamsValid(input, begin, size); return this.executeOp( 'slice4D', () => this.slice4DInternal(input, begin, size)); } diff --git a/src/math/math_cpu_test.ts b/src/math/math_cpu_test.ts index 5accfd4495..898c33ceef 100644 --- a/src/math/math_cpu_test.ts +++ b/src/math/math_cpu_test.ts @@ -20,7 +20,7 @@ import * as util from '../util'; import {MatrixOrientation} from './math'; import {NDArrayMathCPU} from './math_cpu'; -import {Array1D, Array2D, Array3D, Scalar} from './ndarray'; +import {Array1D, Array2D, Array3D, Array4D, Scalar} from './ndarray'; describe('NDArrayMathCPU clone', () => { it('returns a ndarray with the same shape and data', () => { @@ -128,6 +128,38 @@ describe('NDArrayMathCPU slice3D', () => { }); }); +describe('NDArrayMathCPU slice4D', () => { + let math: NDArrayMathCPU; + beforeEach(() => { + math = new NDArrayMathCPU(); + }); + + it('slices 1x1x1x1 into shape 1x1x1x1 (effectively a copy)', () => { + const a = Array4D.new([1, 1, 1, 1], [[[[5]]]]); + const result = math.slice4D(a, [0, 0, 0, 0], [1, 1, 1, 1]); + expect(result.shape).toEqual([1, 1, 1, 1]); + expect(result.get(0, 0, 0, 0)).toBe(5); + }); + + it('slices 2x2x2x2 array into 1x2x2x2 starting at [1, 0, 0, 0]', () => { + const a = Array4D.new( + [2, 2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44, 55, 66, 77, 88]); + const result = math.slice4D(a, [1, 0, 0, 0], [1, 2, 2, 2]); + expect(result.shape).toEqual([1, 2, 2, 2]); + expect(result.getValues()).toEqual(new Float32Array([ + 11, 22, 33, 44, 55, 66, 77, 88 + ])); + }); + + it('slices 2x2x2x2 array into 2x1x1x1 starting at [0, 1, 1, 1]', () => { + const a = Array4D.new( + [2, 2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44, 55, 66, 77, 88]); + const result = math.slice4D(a, [0, 1, 1, 1], [2, 1, 1, 1]); + expect(result.shape).toEqual([2, 1, 1, 1]); + expect(result.getValues()).toEqual(new Float32Array([8, 88])); + }); +}); + describe('NDArrayMathCPU copy2D', () => { let math: NDArrayMathCPU; beforeEach(() => { @@ -770,10 +802,8 @@ describe('NDArrayMathCPU scaledNDArrayAdd', () => { const c1: any = Array1D.randNormal([10]); const c2 = Scalar.new(2); - expect(() => math.scaledArrayAdd(c1 as Scalar, a, c2, b)) - .toThrowError(); - expect(() => math.scaledArrayAdd(c2, a, c1 as Scalar, b)) - .toThrowError(); + expect(() => math.scaledArrayAdd(c1 as Scalar, a, c2, b)).toThrowError(); + expect(() => math.scaledArrayAdd(c2, a, c1 as Scalar, b)).toThrowError(); }); it('throws when NDArrays are different shape', () => { diff --git a/src/math/math_gpu_test.ts b/src/math/math_gpu_test.ts index effa08b76d..2a3dd34180 100644 --- a/src/math/math_gpu_test.ts +++ b/src/math/math_gpu_test.ts @@ -297,6 +297,44 @@ describe('NDArrayMathCPU slice3D', () => { }); }); +describe('NDArrayMathCPU slice4D', () => { + let math: NDArrayMathGPU; + beforeEach(() => { + math = new NDArrayMathGPU(); + math.startScope(); + }); + + afterEach(() => { + math.endScope(null); + math.dispose(); + }); + + it('slices 1x1x1x1 into shape 1x1x1x1 (effectively a copy)', () => { + const a = Array4D.new([1, 1, 1, 1], [[[[5]]]]); + const result = math.slice4D(a, [0, 0, 0, 0], [1, 1, 1, 1]); + expect(result.shape).toEqual([1, 1, 1, 1]); + expect(result.get(0, 0, 0, 0)).toBe(5); + }); + + it('slices 2x2x2x2 array into 1x2x2x2 starting at [1, 0, 0, 0]', () => { + const a = Array4D.new( + [2, 2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44, 55, 66, 77, 88]); + const result = math.slice4D(a, [1, 0, 0, 0], [1, 2, 2, 2]); + expect(result.shape).toEqual([1, 2, 2, 2]); + expect(result.getValues()).toEqual(new Float32Array([ + 11, 22, 33, 44, 55, 66, 77, 88 + ])); + }); + + it('slices 2x2x2x2 array into 2x1x1x1 starting at [0, 1, 1, 1]', () => { + const a = Array4D.new( + [2, 2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44, 55, 66, 77, 88]); + const result = math.slice4D(a, [0, 1, 1, 1], [2, 1, 1, 1]); + expect(result.shape).toEqual([2, 1, 1, 1]); + expect(result.getValues()).toEqual(new Float32Array([8, 88])); + }); +}); + describe('NDArrayMathGPU copy2D', () => { let math: NDArrayMathGPU; beforeEach(() => { @@ -1967,8 +2005,7 @@ describe('NDArrayMathGPU conv2d', () => { const stride = 1; const x = Array3D.new(inputShape, [1, 2, 3, 4]); - const w = Array4D.randNormal( - [fSize, fSize, wrongInputDepth, outputDepth]); + const w = Array4D.randNormal([fSize, fSize, wrongInputDepth, outputDepth]); const bias = Array1D.new([-1]); expect(() => math.conv2d(x, w, bias, stride, pad)).toThrowError(); diff --git a/src/math/slice_util.ts b/src/math/slice_util.ts index da6a2e25d9..3c2fa22aa7 100644 --- a/src/math/slice_util.ts +++ b/src/math/slice_util.ts @@ -19,7 +19,7 @@ import * as util from '../util'; import {NDArray} from './ndarray'; -export function assertParams( +export function assertParamsValid( input: NDArray, begin: number[], size: number[]): void { util.assert( input.rank === begin.length, @@ -34,7 +34,7 @@ export function assertParams( util.assert( begin[i] + size[i] <= input.shape[i], `Error in slice${input.rank}D: begin[${i}] + size[${i}] ` + - `(${begin[i] + - size[i]}) would overflow input.shape[${i}] (${input.shape[i]})`); + `(${begin[i] + size[i]}) would overflow input.shape[${i}] (${ + input.shape[i]})`); } } diff --git a/src/math/webgl/concat_gpu_test.ts b/src/math/webgl/concat_gpu_test.ts index 70edbc5c0e..d1c647bf93 100644 --- a/src/math/webgl/concat_gpu_test.ts +++ b/src/math/webgl/concat_gpu_test.ts @@ -17,7 +17,8 @@ import * as test_util from '../../test_util'; import {NDArrayMathCPU} from '../math_cpu'; -import {Array1D, Array2D, Array3D, initializeGPU, NDArray} from '../ndarray'; +// tslint:disable-next-line:max-line-length +import {Array1D, Array2D, Array3D, Array4D, initializeGPU, NDArray} from '../ndarray'; import {ConcatProgram} from './concat_gpu'; import {GPGPUContext} from './gpgpu_context'; import * as gpgpu_math from './gpgpu_math'; @@ -291,3 +292,69 @@ function uploadConcat3dDownload( return result; } + +describe('concat4d_gpu', () => { + it('concat axis=0', () => { + const x1 = Array4D.new([1, 2, 3, 1], [1, 11, 111, 2, 22, 222]); + const x2 = Array4D.new( + [2, 2, 3, 1], [5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888]); + + const result = doConcat4D(x1, x2, 0); + test_util.expectArraysClose( + result, new Float32Array([ + 1, 11, 111, 2, 22, 222, 5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888 + ]), + 1e-6); + + }); + + it('concat axis=1', () => { + const x1 = Array4D.new([2, 1, 3, 1], [1, 11, 111, 3, 33, 333]); + const x2 = Array4D.new( + [2, 2, 3, 1], [5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888]); + + const result = doConcat4D(x1, x2, 1); + test_util.expectArraysClose( + result, new Float32Array([ + 1, 11, 111, 5, 55, 555, 6, 66, 666, 3, 33, 333, 7, 77, 777, 8, 88, 888 + ]), + 1e-6); + }); + + it('concat axis=2', () => { + const x1 = Array4D.new([2, 2, 2, 1], [1, 11, 2, 22, 3, 33, 4, 44]); + const x2 = Array4D.new( + [2, 2, 3, 1], [5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888]); + + const result = doConcat4D(x1, x2, 2); + test_util.expectArraysClose( + result, new Float32Array([ + 1, 11, 5, 55, 555, 2, 22, 6, 66, 666, + 3, 33, 7, 77, 777, 4, 44, 8, 88, 888 + ]), + 1e-6); + }); + + function doConcat4D(a: Array4D, b: Array4D, axis: number): Float32Array { + const gpgpu = new GPGPUContext(); + gpgpu.enableAutomaticDebugValidation(true); + const textureManager = new TextureManager(gpgpu); + initializeGPU(gpgpu, textureManager); + + const program = new ConcatProgram(a.shape, b.shape, axis); + const rArr = + Array4D.zeros(program.outputShape as [number, number, number, number]); + const binary = gpgpu_math.compileProgram(gpgpu, program, [a, b], rArr); + gpgpu_math.runProgram(binary, [a, b], rArr); + const result = rArr.getValues(); + + a.dispose(); + b.dispose(); + rArr.dispose(); + textureManager.dispose(); + gpgpu.deleteProgram(binary.webGLProgram); + gpgpu.dispose(); + + return result; + } +}); diff --git a/src/math/webgl/gpgpu_context.ts b/src/math/webgl/gpgpu_context.ts index 66fa1a8601..81541db5aa 100644 --- a/src/math/webgl/gpgpu_context.ts +++ b/src/math/webgl/gpgpu_context.ts @@ -198,6 +198,12 @@ export class GPGPUContext { this.gl, program, uniformName); } + public getUniformLocationNoThrow(program: WebGLProgram, uniformName: string): + WebGLUniformLocation { + this.throwIfDisposed(); + return this.gl.getUniformLocation(program, uniformName); + } + public setInputMatrixTexture( inputMatrixTexture: WebGLTexture, uniformLocation: WebGLUniformLocation, textureUnit: number) { diff --git a/src/math/webgl/slice_gpu.ts b/src/math/webgl/slice_gpu.ts index 93832cac81..a48b57df42 100644 --- a/src/math/webgl/slice_gpu.ts +++ b/src/math/webgl/slice_gpu.ts @@ -54,7 +54,12 @@ export class SliceProgram implements GPGPUProgram { } return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => { if (this.startLoc == null) { - this.startLoc = gpgpu.getUniformLocation(webGLProgram, 'start'); + this.startLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'start'); + if (this.startLoc == null) { + // This means the compiler has optimized and realized it doesn't need + // the uniform. + return; + } } if (this.rank === 1) { gpgpu.gl.uniform1i(this.startLoc, start[0]); diff --git a/src/math/webgl/slice_gpu_test.ts b/src/math/webgl/slice_gpu_test.ts index eaafc8ef9f..f9dc2be144 100644 --- a/src/math/webgl/slice_gpu_test.ts +++ b/src/math/webgl/slice_gpu_test.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {Array1D, Array2D, Array3D, initializeGPU} from '../ndarray'; +import {Array1D, Array2D, Array3D, Array4D, initializeGPU} from '../ndarray'; import {GPGPUContext} from './gpgpu_context'; import * as gpgpu_math from './gpgpu_math'; import {SliceProgram} from './slice_gpu'; @@ -207,3 +207,70 @@ describe('slice3d_gpu', () => { return result; } }); + +describe('slice4d_gpu', () => { + let gpgpu: GPGPUContext; + let texManager: TextureManager; + + beforeAll(() => { + gpgpu = new GPGPUContext(); + texManager = new TextureManager(gpgpu); + initializeGPU(gpgpu, texManager); + }); + + afterAll(() => { + texManager.dispose(); + gpgpu.dispose(); + }); + + it('slices 1x1x1x1 into shape 1x1x1x1 (effectively a copy)', () => { + const a = Array4D.new([1, 1, 1, 1], [[[[5]]]]); + const result = doSlice4D(a, [0, 0, 0, 0], [1, 1, 1, 1]); + expect(result.shape).toEqual([1, 1, 1, 1]); + expect(result.get(0, 0, 0, 0)).toBe(5); + }); + + it('slices 2x2x2x2 array into 1x2x2x2 starting at [1, 0, 0, 0]', () => { + const a = Array4D.new( + [2, 2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44, 55, 66, 77, 88]); + const result = doSlice4D(a, [1, 0, 0, 0], [1, 2, 2, 2]); + expect(result.shape).toEqual([1, 2, 2, 2]); + expect(result.getValues()).toEqual(new Float32Array([ + 11, 22, 33, 44, 55, 66, 77, 88 + ])); + }); + + it('slices 2x2x2x2 array into 2x1x1x1 starting at [0, 1, 1, 1]', () => { + const a = Array4D.new( + [2, 2, 2, 2], [1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44, 55, 66, 77, 88]); + const result = doSlice4D(a, [0, 1, 1, 1], [2, 1, 1, 1]); + expect(result.shape).toEqual([2, 1, 1, 1]); + expect(result.getValues()).toEqual(new Float32Array([8, 88])); + }); + + it('slices array that is bigger than max tex size', () => { + const maxTexSize = webgl_util.queryMaxTextureSize(gpgpu.gl); + const a = Array4D.randUniform([maxTexSize + 10, 1, 1, 1], -1, 1); + const expected = a.get(a.size - 1, 0, 0, 0); + const result = doSlice4D(a, [a.size - 1, 0, 0, 0], [1, 1, 1, 1]); + expect(result.shape).toEqual([1, 1, 1, 1]); + expect(result.get(0, 0, 0, 0)).toEqual(expected); + }); + + + function doSlice4D( + a: Array4D, start: [number, number, number, number], + size: [number, number, number, number]): Array4D { + const program = new SliceProgram(size); + const result = Array4D.zeros(size); + + const binary = gpgpu_math.compileProgram(gpgpu, program, [a], result); + const customSetup = program.getCustomSetupFunc(start); + gpgpu_math.runProgram(binary, [a], result, customSetup); + + a.dispose(); + gpgpu.deleteProgram(binary.webGLProgram); + + return result; + } +});