From b97c9df57fdd55daacf3917530655c5d6dd263d4 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Wed, 7 Feb 2018 23:24:16 -0500 Subject: [PATCH 01/14] save --- src/graph/ops/matmul.ts | 2 +- src/index.ts | 2 +- src/math/backends/backend_cpu.ts | 4 +- src/math/backends/backend_engine.ts | 32 +++++++++++++++- src/math/backends/kernel_registry.ts | 9 +---- src/math/backends/types/matmul.ts | 36 ------------------ src/math/matmul.ts | 56 ++++++++++++++-------------- src/math/matmul_test.ts | 8 ++-- 8 files changed, 65 insertions(+), 84 deletions(-) delete mode 100644 src/math/backends/types/matmul.ts diff --git a/src/graph/ops/matmul.ts b/src/graph/ops/matmul.ts index 7ca2bf086c..8b334689f1 100644 --- a/src/graph/ops/matmul.ts +++ b/src/graph/ops/matmul.ts @@ -16,8 +16,8 @@ */ import {keep, tidy} from '../../globals'; -import {MatrixOrientation} from '../../math/backends/types/matmul'; import {NDArrayMath} from '../../math/math'; +import {MatrixOrientation} from '../../math/matmul'; import {Tensor1D, Tensor2D} from '../../math/tensor'; import {SymbolicTensor} from '../graph'; import * as graph_util from '../graph_util'; diff --git a/src/index.ts b/src/index.ts index 9d16738a74..d018a3271c 100644 --- a/src/index.ts +++ b/src/index.ts @@ -44,10 +44,10 @@ export {CostReduction, FeedEntry, Session} from './graph/session'; export {ConstantInitializer, Initializer, OnesInitializer, RandomNormalInitializer, RandomTruncatedNormalInitializer, RandomUniformInitializer, TensorInitializer, VarianceScalingInitializer, ZerosInitializer} from './initializers'; export {MathBackendCPU, NDArrayMathCPU} from './math/backends/backend_cpu'; export {MathBackendWebGL, NDArrayMathGPU} from './math/backends/backend_webgl'; -export {MatrixOrientation} from './math/backends/types/matmul'; export {GPGPUContext} from './math/backends/webgl/gpgpu_context'; export {LSTMCell} from './math/lstm'; export {NDArrayMath} from './math/math'; +export {MatrixOrientation} from './math/matmul'; export {MomentumOptimizer} from './math/optimizers/momentum_optimizer'; export {Optimizer} from './math/optimizers/optimizer'; export {SGDOptimizer} from './math/optimizers/sgd_optimizer'; diff --git a/src/math/backends/backend_cpu.ts b/src/math/backends/backend_cpu.ts index a5f0bdfbd0..b9b8ad90dc 100644 --- a/src/math/backends/backend_cpu.ts +++ b/src/math/backends/backend_cpu.ts @@ -16,23 +16,21 @@ */ import * as seedrandom from 'seedrandom'; - import {ENV} from '../../environment'; import * as util from '../../util'; import * as broadcast_util from '../broadcast_util'; import * as concat_util from '../concat_util'; import {Conv2DInfo} from '../conv_util'; import {NDArrayMath} from '../math'; +import {MatrixOrientation} from '../matmul'; import * as ops from '../ops'; import {tensor2d, tensor3d, tensor4d} from '../ops'; import * as selu_util from '../selu_util'; import {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'; import * as types from '../types'; import {DataType, DataTypeMap, Rank, TypedArray} from '../types'; - import * as axis_util from './../axis_util'; import {MathBackend} from './backend'; -import {MatrixOrientation} from './types/matmul'; export class MathBackendCPU implements MathBackend { private data: {[dataId: number]: DataTypeMap[DataType]} = {}; diff --git a/src/math/backends/backend_engine.ts b/src/math/backends/backend_engine.ts index 9db66cc2e6..cc1f0e5eb7 100644 --- a/src/math/backends/backend_engine.ts +++ b/src/math/backends/backend_engine.ts @@ -25,7 +25,7 @@ import {Rank} from '../types'; import {MathBackend, TensorStorage} from './backend'; import * as kernel_registry from './kernel_registry'; -import {KernelConfigRegistry} from './kernel_registry'; +import {Kernel, KernelConfigRegistry} from './kernel_registry'; import {Profiler} from './profiler'; // tslint:disable-next-line:max-line-length import {KernelNode, Tape, TapeNode, TapeNodeInputGradientTensors} from './tape_types'; @@ -73,6 +73,36 @@ export class BackendEngine implements TensorManager, TensorStorage { this.profiler = new Profiler(backend); } + runKernel( + kernelFn: (backend: MathBackend) => T, inputs: I, + grad?: (dy: T, y: T) => {[P in keyof I]: () => I[P]}): T { + let result: T; + // TODO(smilkov): Figure out kernel name. + const kernelName = '' as Kernel; + if (!ENV.get('DEBUG')) { + result = kernelFn(this.backend); + } else { + result = + this.profiler.profileKernel(kernelName, () => kernelFn(this.backend)); + } + + const recordKernel = + this.activeTape != null && this.customGradientDepth === 0; + if (recordKernel) { + const evaluatedNode: KernelNode = { + id: this.nextTapeNodeId++, + type: 'kernel', + name: kernelName, + kernel: kernelName, + inputAndArgs: {inputs}, + output: result, + gradient: grad + }; + this.activeTape.push(evaluatedNode); + } + return result; + } + executeKernel, C extends KernelConfigRegistry[K]['inputAndArgs']>( kernelName: K, config: C, grad?: KernelConfigRegistry[K]['gradient']): diff --git a/src/math/backends/kernel_registry.ts b/src/math/backends/kernel_registry.ts index 67f3047394..cb9121ef01 100644 --- a/src/math/backends/kernel_registry.ts +++ b/src/math/backends/kernel_registry.ts @@ -32,7 +32,6 @@ import {Conv2DDerBiasNode, Conv2DDerFilterNode, Conv2DDerInputNode, Conv2DNode, import {GatherNode} from './types/gather'; import {EqualNode, LogicalNode, WhereNode} from './types/logical'; import {LRN4DNode} from './types/lrn'; -import {MatMulNode} from './types/matmul'; import {MaximumNode, MaxNode, MinimumNode, MinNode} from './types/minmax'; import {MultinomialNode} from './types/multinomial'; import {OneHotNode} from './types/onehot'; @@ -56,12 +55,7 @@ executeKernel, O extends KernelConfigRegistry[K]['output']>( backend: MathBackend, kernelName: K, inputAndArgs: KernelConfigRegistry[K]['inputAndArgs']): O { - if (kernelName === 'MatMul') { - const config = inputAndArgs as MatMulNode['inputAndArgs']; - return backend.matMul( - config.inputs.a, config.inputs.b, config.args.aOrientation, - config.args.bOrientation) as O; - } else if (kernelName === 'Slice1D') { + if (kernelName === 'Slice1D') { const config = inputAndArgs as Slice1DNode['inputAndArgs']; return backend.slice1D( config.inputs.x, config.args.begin, config.args.size) as O; @@ -366,7 +360,6 @@ executeKernel, O extends } export interface KernelConfigRegistry { - MatMul: MatMulNode; Slice1D: Slice1DNode; Slice2D: Slice2DNode; Slice3D: Slice3DNode; diff --git a/src/math/backends/types/matmul.ts b/src/math/backends/types/matmul.ts deleted file mode 100644 index 1a063dac89..0000000000 --- a/src/math/backends/types/matmul.ts +++ /dev/null @@ -1,36 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {Tensor2D} from '../../tensor'; -import {KernelNode} from '../tape_types'; - -export interface MatMulNode extends KernelNode { - inputAndArgs: { - inputs: {a: Tensor2D; b: Tensor2D;}; - args: {aOrientation: MatrixOrientation; bOrientation: MatrixOrientation}; - }; - output: Tensor2D; - gradient: (dy: Tensor2D, y: Tensor2D) => { - a: () => Tensor2D; - b: () => Tensor2D; - }; -} - -export enum MatrixOrientation { - REGULAR, - TRANSPOSED -} diff --git a/src/math/matmul.ts b/src/math/matmul.ts index 22d65f6bee..0a93851758 100644 --- a/src/math/matmul.ts +++ b/src/math/matmul.ts @@ -17,11 +17,17 @@ import {ENV} from '../environment'; import * as util from '../util'; - -import {MatrixOrientation} from './backends/types/matmul'; import {doc, operation} from './decorators'; import {Scalar, Tensor1D, Tensor2D} from './tensor'; +export enum MatrixOrientation { + REGULAR, + TRANSPOSED +} + +const REGULAR = MatrixOrientation.REGULAR; +const TRANSPOSED = MatrixOrientation.TRANSPOSED; + export class Ops { /** * Computes the dot product of two matrices, A * B. These must be matrices, @@ -29,20 +35,18 @@ export class Ops { * in other cases. * @param a First matrix in dot product operation. * @param b Second matrix in dot product operation. - * @param aOrientation The MatrixOrientation of A. If using TRANSPOSED, will + * @param transposeA The MatrixOrientation of A. If using TRANSPOSED, will * compute A^T * B. - * @param bOrientation The MatrixOrientation of B. If using TRANSPOSED, will + * @param transposeB The MatrixOrientation of B. If using TRANSPOSED, will * compute A * B^T. */ @doc({heading: 'Operations', subheading: 'Matrices'}) @operation static matMul( - a: Tensor2D, b: Tensor2D, aOrientation = MatrixOrientation.REGULAR, - bOrientation = MatrixOrientation.REGULAR): Tensor2D { - const innerShapeA = - (aOrientation === MatrixOrientation.REGULAR) ? a.shape[1] : a.shape[0]; - const innerShapeB = - (bOrientation === MatrixOrientation.REGULAR) ? b.shape[0] : b.shape[1]; + a: Tensor2D, b: Tensor2D, transposeA = REGULAR, transposeB = REGULAR): + Tensor2D { + const innerShapeA = (transposeA === REGULAR) ? a.shape[1] : a.shape[0]; + const innerShapeB = (transposeB === REGULAR) ? b.shape[0] : b.shape[1]; util.assert( a.rank === 2 && b.rank === 2, @@ -53,26 +57,20 @@ export class Ops { innerShapeA === innerShapeB, `Error in matMul: inner shapes (${innerShapeA}) and (` + `${innerShapeB}) of Tensors with shapes ${a.shape} and ` + - `${b.shape} and orientations ${MatrixOrientation[aOrientation]}` + - ` and ${MatrixOrientation[bOrientation]} must match.`); + `${b.shape} and orientations ${MatrixOrientation[transposeA]}` + + ` and ${MatrixOrientation[transposeB]} must match.`); - return ENV.engine.executeKernel( - 'MatMul', {inputs: {a, b}, args: {aOrientation, bOrientation}}, - (dy: Tensor2D, y: Tensor2D) => { - if (aOrientation === MatrixOrientation.TRANSPOSED || - bOrientation === MatrixOrientation.TRANSPOSED) { - throw new Error( - `Backprop for transposed MatMul not yet implemented.`); - } - return { - a: () => dy.matMul( - b.toFloat(), MatrixOrientation.REGULAR, - MatrixOrientation.TRANSPOSED) as Tensor2D, - b: () => a.toFloat().matMul( - dy, MatrixOrientation.TRANSPOSED, - MatrixOrientation.REGULAR) as Tensor2D - }; - }) as Tensor2D; + const grad = (dy: Tensor2D, y: Tensor2D) => { + if (transposeA === TRANSPOSED || transposeB === TRANSPOSED) { + throw new Error(`Backprop for transposed MatMul not yet implemented.`); + } + return { + a: () => dy.matMul(b.toFloat(), REGULAR, TRANSPOSED), + b: () => a.toFloat().matMul(dy, TRANSPOSED, REGULAR) + }; + }; + return ENV.engine.runKernel( + backend => backend.matMul(a, b, transposeA, transposeB), {a, b}, grad); } /** diff --git a/src/math/matmul_test.ts b/src/math/matmul_test.ts index 2c36dcfc4b..82d1402dc2 100644 --- a/src/math/matmul_test.ts +++ b/src/math/matmul_test.ts @@ -18,7 +18,7 @@ import * as dl from '../index'; import * as test_util from '../test_util'; import {MathTests} from '../test_util'; -import {MatrixOrientation} from './backends/types/matmul'; +import {MatrixOrientation} from './matmul'; import {Rank} from './types'; const commonTests: MathTests = it => { @@ -70,8 +70,7 @@ const commonTests: MathTests = it => { const b = dl.zeros([3, 2]); const f = () => { - dl.matMul( - a, b, MatrixOrientation.REGULAR, MatrixOrientation.TRANSPOSED); + dl.matMul(a, b, MatrixOrientation.REGULAR, MatrixOrientation.TRANSPOSED); }; expect(f).toThrowError(); }); @@ -81,8 +80,7 @@ const commonTests: MathTests = it => { const b = dl.zeros([3, 2]); const f = () => { - dl.matMul( - a, b, MatrixOrientation.TRANSPOSED, MatrixOrientation.REGULAR); + dl.matMul(a, b, MatrixOrientation.TRANSPOSED, MatrixOrientation.REGULAR); }; expect(f).toThrowError(); }); From fbc1e32e4aa3f06c9f66a8b3576d9ad7108ba160 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sat, 10 Feb 2018 15:52:05 -0500 Subject: [PATCH 02/14] save --- src/math/backends/types/cast.ts | 28 ------------ src/math/backends/types/pow.ts | 30 ------------ src/math/backends/types/prelu.ts | 31 ------------- src/math/backends/types/reshape.ts | 27 ----------- src/math/backends/types/unary.ts | 73 ------------------------------ 5 files changed, 189 deletions(-) delete mode 100644 src/math/backends/types/cast.ts delete mode 100644 src/math/backends/types/pow.ts delete mode 100644 src/math/backends/types/prelu.ts delete mode 100644 src/math/backends/types/reshape.ts delete mode 100644 src/math/backends/types/unary.ts diff --git a/src/math/backends/types/cast.ts b/src/math/backends/types/cast.ts deleted file mode 100644 index aff3d385b9..0000000000 --- a/src/math/backends/types/cast.ts +++ /dev/null @@ -1,28 +0,0 @@ -/** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {Tensor} from '../../tensor'; -import {DataType} from '../../types'; -import {KernelNode} from '../tape_types'; - -export interface CastNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor}; args: {newDType: DataType};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor - }; -} diff --git a/src/math/backends/types/pow.ts b/src/math/backends/types/pow.ts deleted file mode 100644 index 1da70328b1..0000000000 --- a/src/math/backends/types/pow.ts +++ /dev/null @@ -1,30 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {Tensor} from '../../tensor'; -import {Rank} from '../../types'; -import {KernelNode} from '../tape_types'; - -export interface PowNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {base: T; exp: Tensor;};}; - output: T; - gradient: (dy: Tensor, y: T) => { - base: () => Tensor; - exp: () => Tensor; - }; -} diff --git a/src/math/backends/types/prelu.ts b/src/math/backends/types/prelu.ts deleted file mode 100644 index e4511ea456..0000000000 --- a/src/math/backends/types/prelu.ts +++ /dev/null @@ -1,31 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {Tensor} from '../../tensor'; -import {Rank} from '../../types'; -import {KernelNode} from '../tape_types'; - -// PReLU -export interface PReLUNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T; alpha: T;};}; - output: T; - gradient: (dy: Tensor, y: T) => { - x: () => Tensor; - alpha: () => Tensor; - }; -} diff --git a/src/math/backends/types/reshape.ts b/src/math/backends/types/reshape.ts deleted file mode 100644 index 8469592d8b..0000000000 --- a/src/math/backends/types/reshape.ts +++ /dev/null @@ -1,27 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {Tensor} from '../../tensor'; -import {KernelNode} from '../tape_types'; - -export interface ReshapeNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor}; args: {newShape: number[]};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor - }; -} diff --git a/src/math/backends/types/unary.ts b/src/math/backends/types/unary.ts deleted file mode 100644 index 45203884f3..0000000000 --- a/src/math/backends/types/unary.ts +++ /dev/null @@ -1,73 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {Tensor} from '../../tensor'; -import {Rank} from '../../types'; -import {KernelNode} from '../tape_types'; - -export interface UnaryNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} - -export interface LeakyReluNode< - R extends Rank, T extends Tensor = Tensor> extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {alpha: number;};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} -export interface StepNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {alpha: number;};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} - -export interface ClipNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {min: number; max: number;};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} - -export interface TransposeNode< - R extends Rank, T extends Tensor = Tensor> extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {perm: number[];};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} - -export interface TileNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {reps: number[];};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} From 64bda74cb36dc5a7439ebefcfd00bc9d20664928 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sat, 10 Feb 2018 15:52:09 -0500 Subject: [PATCH 03/14] save --- src/math/array_ops.ts | 63 +++--- src/math/backends/backend.ts | 17 +- src/math/backends/backend_engine.ts | 2 +- src/math/backends/backend_webgl.ts | 3 +- src/math/backends/kernel_registry.ts | 157 -------------- src/math/backends/profiler.ts | 6 +- src/math/backends/webgl/mulmat_gpu.ts | 3 +- src/math/backends/webgl/mulmat_packed_gpu.ts | 3 +- .../backends/webgl/mulmat_packed_gpu_test.ts | 2 +- src/math/binary_ops.ts | 20 +- src/math/logical_ops.ts | 4 +- src/math/math.ts | 2 +- src/math/math_test.ts | 2 +- src/math/tensor.ts | 50 ++--- src/math/transpose.ts | 12 +- src/math/unary_ops.ts | 198 +++++++++--------- 16 files changed, 193 insertions(+), 351 deletions(-) diff --git a/src/math/array_ops.ts b/src/math/array_ops.ts index 8e42c490f8..908d9410fb 100644 --- a/src/math/array_ops.ts +++ b/src/math/array_ops.ts @@ -359,29 +359,24 @@ export class Ops { probabilities: Tensor1D|Tensor2D, numSamples: number, seed?: number): Tensor1D|Tensor2D { const numOutcomes = probabilities.size; + const origRank = probabilities.rank; if (numOutcomes < 2) { throw new Error( `Error in multinomial: you need at least 2 outcomes, but got ` + `${numOutcomes}.`); } - if (probabilities.rank > 2) { + if (origRank > 2) { throw new Error( - `Rank of probabilities must be 1 or 2, but is ${probabilities.rank}`); + `Rank of probabilities must be 1 or 2, but is ${origRank}`); } seed = seed || Math.random(); - const origRank = probabilities.rank; - if (probabilities.rank === 1) { - probabilities = probabilities.as2D(1, -1); - } - const res = ENV.engine.executeKernel('Multinomial', { - inputs: {probs: (probabilities as Tensor2D)}, - args: {numSamples, seed} - }); - if (origRank === 1) { - return res.as1D(); - } - return res; + const prob2D = + origRank === 1 ? probabilities.as2D(1, -1) : probabilities as Tensor2D; + const res = ENV.engine.runKernel( + backend => backend.multinomial(prob2D, numSamples, seed)); + + return origRank === 1 ? res.as1D() : res; } /** @@ -403,8 +398,8 @@ export class Ops { if (depth < 2) { throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`); } - return ENV.engine.executeKernel( - 'OneHot', {inputs: {indices}, args: {depth, onValue, offValue}}); + return ENV.engine.runKernel( + backend => backend.oneHot(indices, depth, onValue, offValue)); } /** @@ -458,9 +453,8 @@ export class Ops { const grad = (dy: Tensor, y: Tensor) => { return {x: () => dy.reshape(x.shape)}; }; - return ENV.engine.executeKernel( - 'Reshape', {inputs: {x}, args: {newShape: shape}}, grad) as - Tensor; + return ENV.engine.runKernel( + backend => Tensor.make(shape, {dataId: x.dataId}, x.dtype), {x}, grad); } /** @@ -472,10 +466,22 @@ export class Ops { @operation static cast(x: T, dtype: DataType): T { const grad = (dy: T, y: T) => { - return {x: () => dy.reshape(dy.shape)}; + return {x: () => dy.clone()}; }; - return ENV.engine.executeKernel( - 'Cast', {inputs: {x}, args: {newDType: dtype}}, grad) as T; + return ENV.engine.runKernel(backend => { + if (!util.hasEncodingLoss(x.dtype, dtype)) { + // We don't change the underlying data, since we cast to higher + // precision. + return Tensor.make(x.shape, {dataId: x.dataId}, dtype) as T; + } + if (dtype === 'int32') { + return backend.int(x); + } else if (dtype === 'bool') { + return backend.notEqual(x, Ops.scalar(0, x.dtype)); + } else { + throw new Error(`Error in Cast: unknown dtype argument (${dtype})`); + } + }, {x}, grad) as T; } /** @@ -497,7 +503,7 @@ export class Ops { x.rank === reps.length, `Error in transpose: rank of input ${x.rank} ` + `must match length of reps ${reps}.`); - return ENV.engine.executeKernel('Tile', {inputs: {x}, args: {reps}}) as T; + return ENV.engine.runKernel(backend => backend.tile(x, reps)); } /** @@ -510,8 +516,7 @@ export class Ops { @doc({heading: 'Tensors', subheading: 'Slicing and Joining'}) @operation static gather(x: T, indices: Tensor1D, axis = 0): T { - return ENV.engine.executeKernel( - 'Gather', {inputs: {x, indices}, args: {axis}}) as T; + return ENV.engine.runKernel(backend => backend.gather(x, indices, axis)); } /** @@ -534,8 +539,8 @@ export class Ops { util.assert( paddings.length === 2, 'Invalid number of paddings. Must be length of 2.'); - return ENV.engine.executeKernel( - 'Pad1D', {inputs: {x}, args: {paddings, constantValue}}); + return ENV.engine.runKernel( + backend => backend.pad1D(x, paddings, constantValue)); } /** @@ -561,8 +566,8 @@ export class Ops { paddings.length === 2 && paddings[0].length === 2 && paddings[1].length === 2, 'Invalid number of paddings. Must be length of 2 each.'); - return ENV.engine.executeKernel( - 'Pad2D', {inputs: {x}, args: {paddings, constantValue}}); + return ENV.engine.runKernel( + backend => backend.pad2D(x, paddings, constantValue)); } /** diff --git a/src/math/backends/backend.ts b/src/math/backends/backend.ts index 4c2607b091..cd02d22bdd 100644 --- a/src/math/backends/backend.ts +++ b/src/math/backends/backend.ts @@ -17,12 +17,11 @@ */ import {Conv2DInfo} from '../conv_util'; +import {MatrixOrientation} from '../matmul'; // tslint:disable-next-line:max-line-length import {DataId, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'; import {DataType, Rank, TypedArray} from '../types'; -import {MatrixOrientation} from './types/matmul'; - export interface TensorStorage { read(dataId: DataId): Promise; readSync(dataId: DataId): TypedArray; @@ -94,7 +93,7 @@ export interface MathBackend extends TensorStorage, BackendTimer { greater(a: Tensor, b: Tensor): Tensor; greaterEqual(a: Tensor, b: Tensor): Tensor; - logicalNot(a: Tensor): Tensor; + logicalNot(a: T): T; logicalAnd(a: Tensor, b: Tensor): Tensor; logicalOr(a: Tensor, b: Tensor): Tensor; logicalXor(a: Tensor, b: Tensor): Tensor; @@ -185,16 +184,16 @@ export interface MathBackend extends TensorStorage, BackendTimer { batchNormalization2D( x: Tensor2D, mean: Tensor2D|Tensor1D, variance: Tensor2D|Tensor1D, - varianceEpsilon: number, scale?: Tensor2D|Tensor1D, - offset?: Tensor2D|Tensor1D): Tensor2D; + varianceEpsilon: number, scale: Tensor2D|Tensor1D, + offset: Tensor2D|Tensor1D): Tensor2D; batchNormalization3D( x: Tensor3D, mean: Tensor3D|Tensor1D, variance: Tensor3D|Tensor1D, - varianceEpsilon: number, scale?: Tensor3D|Tensor1D, - offset?: Tensor3D|Tensor1D): Tensor3D; + varianceEpsilon: number, scale: Tensor3D|Tensor1D, + offset: Tensor3D|Tensor1D): Tensor3D; batchNormalization4D( x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D, - varianceEpsilon: number, scale?: Tensor4D|Tensor1D, - offset?: Tensor4D|Tensor1D): Tensor4D; + varianceEpsilon: number, scale: Tensor4D|Tensor1D, + offset: Tensor4D|Tensor1D): Tensor4D; localResponseNormalization4D( x: Tensor4D, radius: number, bias: number, alpha: number, beta: number, diff --git a/src/math/backends/backend_engine.ts b/src/math/backends/backend_engine.ts index 7a0d5c0a4b..e36c86cc7e 100644 --- a/src/math/backends/backend_engine.ts +++ b/src/math/backends/backend_engine.ts @@ -82,7 +82,7 @@ export class BackendEngine implements TensorManager { } runKernel( - kernelFn: (backend: MathBackend) => T, inputs: I, + kernelFn: (backend: MathBackend) => T, inputs?: I, grad?: (dy: T, y: T) => {[P in keyof I]: () => I[P]}): T { let result: T; // TODO(smilkov): Figure out kernel name. diff --git a/src/math/backends/backend_webgl.ts b/src/math/backends/backend_webgl.ts index 8d4f1804ee..ca6736ee06 100644 --- a/src/math/backends/backend_webgl.ts +++ b/src/math/backends/backend_webgl.ts @@ -20,14 +20,15 @@ import * as util from '../../util'; import * as axis_util from '../axis_util'; import {Conv2DInfo} from '../conv_util'; import {NDArrayMath} from '../math'; +import {MatrixOrientation} from '../matmul'; import * as reduce_util from '../reduce_util'; // tslint:disable-next-line:max-line-length import {DataId, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'; import * as types from '../types'; // tslint:disable-next-line:max-line-length import {DataType, DataTypeMap, Rank, RecursiveArray, TypedArray} from '../types'; + import {MathBackend} from './backend'; -import {MatrixOrientation} from './types/matmul'; import {ArgMinMaxProgram} from './webgl/argminmax_gpu'; import {AvgPool2DBackpropProgram} from './webgl/avg_pool_backprop_gpu'; import {BatchNormProgram} from './webgl/batchnorm_gpu'; diff --git a/src/math/backends/kernel_registry.ts b/src/math/backends/kernel_registry.ts index cb9121ef01..9391b36a8e 100644 --- a/src/math/backends/kernel_registry.ts +++ b/src/math/backends/kernel_registry.ts @@ -15,16 +15,12 @@ * ============================================================================= */ -import * as util from '../../util'; -import * as ops from '../ops'; -import {Tensor} from '../tensor'; import {Rank} from '../types'; import {MathBackend} from './backend'; import {ArgMaxNode, ArgMinNode} from './types/argminmax'; // tslint:disable-next-line:max-line-length import {BatchNorm2DNode, BatchNorm3DNode, BatchNorm4DNode} from './types/batchnorm'; import {BinaryNode} from './types/binary'; -import {CastNode} from './types/cast'; // tslint:disable-next-line:max-line-length import {ConcatNode} from './types/concat'; // tslint:disable-next-line:max-line-length @@ -38,17 +34,12 @@ import {OneHotNode} from './types/onehot'; import {Pad1DNode, Pad2DNode} from './types/pad'; // tslint:disable-next-line:max-line-length import {PoolBackpropNode, PoolNode} from './types/pool'; -import {PowNode} from './types/pow'; -import {PReLUNode} from './types/prelu'; -import {ReshapeNode} from './types/reshape'; import {ResizeBilinearNode} from './types/resize_bilinear'; import {Reverse4DNode} from './types/reverse'; // tslint:disable-next-line:max-line-length import {Slice1DNode, Slice2DNode, Slice3DNode, Slice4DNode} from './types/slice'; import {SumNode} from './types/sum'; import {TopKIndicesNode, TopKValuesNode} from './types/topk'; -// tslint:disable-next-line:max-line-length -import {ClipNode, LeakyReluNode, StepNode, TileNode, TransposeNode, UnaryNode} from './types/unary'; export function executeKernel, O extends @@ -77,9 +68,6 @@ executeKernel, O extends } else if (kernelName === 'Concat') { const config = inputAndArgs as ConcatNode['inputAndArgs']; return backend.concat(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Neg') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.neg(config.inputs.x) as O; } else if (kernelName === 'Add') { const config = inputAndArgs as BinaryNode['inputAndArgs']; return backend.add(config.inputs.a, config.inputs.b) as O; @@ -119,9 +107,6 @@ executeKernel, O extends } else if (kernelName === 'GreaterEqual') { const config = inputAndArgs as EqualNode['inputAndArgs']; return backend.greaterEqual(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'LogicalNot') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.logicalNot(config.inputs.x) as O; } else if (kernelName === 'LogicalAnd') { const config = inputAndArgs as LogicalNode['inputAndArgs']; return backend.logicalAnd(config.inputs.a, config.inputs.b) as O; @@ -154,112 +139,6 @@ executeKernel, O extends } else if (kernelName === 'Maximum') { const config = inputAndArgs as MaximumNode['inputAndArgs']; return backend.maximum(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Ceil') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.ceil(config.inputs.x) as O; - } else if (kernelName === 'Floor') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.floor(config.inputs.x) as O; - } else if (kernelName === 'Pow') { - const config = inputAndArgs as PowNode['inputAndArgs']; - return backend.pow(config.inputs.base, config.inputs.exp) as O; - } else if (kernelName === 'Exp') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.exp(config.inputs.x) as O; - } else if (kernelName === 'Log') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.log(config.inputs.x) as O; - } else if (kernelName === 'Sqrt') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.sqrt(config.inputs.x) as O; - } else if (kernelName === 'Square') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.square(config.inputs.x) as O; - } else if (kernelName === 'Relu') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.relu(config.inputs.x) as O; - } else if (kernelName === 'Reshape') { - const config = inputAndArgs as ReshapeNode['inputAndArgs']; - const x = config.inputs.x; - const newShape = config.args.newShape; - return Tensor.make(newShape, {dataId: x.dataId}, x.dtype) as O; - } else if (kernelName === 'Cast') { - const config = inputAndArgs as CastNode['inputAndArgs']; - const x = config.inputs.x; - const newDType = config.args.newDType; - - if (!util.hasEncodingLoss(x.dtype, newDType)) { - // We don't change the underlying data, since we cast to higher - // precision. - return Tensor.make(x.shape, {dataId: x.dataId}, newDType) as O; - } - if (newDType === 'int32') { - return backend.int(x) as O; - } else if (newDType === 'bool') { - return backend.notEqual(x, ops.scalar(0, x.dtype)) as O; - } else { - throw new Error(`Error in Cast: unknown dtype argument (${newDType})`); - } - } else if (kernelName === 'LeakyRelu') { - const config = inputAndArgs as LeakyReluNode['inputAndArgs']; - return backend.leakyRelu(config.inputs.x, config.args.alpha) as O; - } else if (kernelName === 'PReLU') { - const config = inputAndArgs as PReLUNode['inputAndArgs']; - return backend.prelu(config.inputs.x, config.inputs.alpha) as O; - } else if (kernelName === 'PReLUDer') { - const config = inputAndArgs as PReLUNode['inputAndArgs']; - return backend.preluDer(config.inputs.x, config.inputs.alpha) as O; - } else if (kernelName === 'Elu') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.elu(config.inputs.x) as O; - } else if (kernelName === 'EluDer') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.eluDer(config.inputs.x) as O; - } else if (kernelName === 'Selu') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.selu(config.inputs.x) as O; - } else if (kernelName === 'Abs') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.abs(config.inputs.x) as O; - } else if (kernelName === 'Sigmoid') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.sigmoid(config.inputs.x) as O; - } else if (kernelName === 'Step') { - const config = inputAndArgs as StepNode['inputAndArgs']; - return backend.step(config.inputs.x, config.args.alpha) as O; - } else if (kernelName === 'Sin') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.sin(config.inputs.x) as O; - } else if (kernelName === 'Cos') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.cos(config.inputs.x) as O; - } else if (kernelName === 'Tan') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.tan(config.inputs.x) as O; - } else if (kernelName === 'Asin') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.asin(config.inputs.x) as O; - } else if (kernelName === 'Acos') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.acos(config.inputs.x) as O; - } else if (kernelName === 'Atan') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.atan(config.inputs.x) as O; - } else if (kernelName === 'Sinh') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.sinh(config.inputs.x) as O; - } else if (kernelName === 'Cosh') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.cosh(config.inputs.x) as O; - } else if (kernelName === 'Tanh') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.tanh(config.inputs.x) as O; - } else if (kernelName === 'Clip') { - const config = inputAndArgs as ClipNode['inputAndArgs']; - return backend.clip(config.inputs.x, config.args.min, config.args.max) as O; - } else if (kernelName === 'Tile') { - const config = inputAndArgs as TileNode['inputAndArgs']; - return backend.tile(config.inputs.x, config.args.reps) as O; } else if (kernelName === 'Gather') { const config = inputAndArgs as GatherNode['inputAndArgs']; return backend.gather( @@ -274,9 +153,6 @@ executeKernel, O extends return backend.pad2D( config.inputs.x, config.args.paddings, config.args.constantValue) as O; - } else if (kernelName === 'Transpose') { - const config = inputAndArgs as TransposeNode['inputAndArgs']; - return backend.transpose(config.inputs.x, config.args.perm) as O; } else if (kernelName === 'Conv2D') { const config = inputAndArgs as Conv2DNode['inputAndArgs']; return backend.conv2d( @@ -366,7 +242,6 @@ export interface KernelConfigRegistry { Slice4D: Slice4DNode; Reverse4D: Reverse4DNode; Concat: ConcatNode; - Neg: UnaryNode; Add: BinaryNode; Sub: BinaryNode; Mul: BinaryNode; @@ -380,7 +255,6 @@ export interface KernelConfigRegistry { LessEqual: EqualNode; Greater: EqualNode; GreaterEqual: EqualNode; - LogicalNot: UnaryNode; LogicalAnd: LogicalNode; LogicalOr: LogicalNode; LogicalXor: LogicalNode; @@ -391,39 +265,8 @@ export interface KernelConfigRegistry { Minimum: MinimumNode; Max: MaxNode; Maximum: MaximumNode; - Ceil: UnaryNode; - Floor: UnaryNode; - Pow: PowNode; - Exp: UnaryNode; - Log: UnaryNode; - Sqrt: UnaryNode; - Square: UnaryNode; - Relu: UnaryNode; - LeakyRelu: LeakyReluNode; - PReLU: PReLUNode; - PReLUDer: PReLUNode; - Reshape: ReshapeNode; - Cast: CastNode; - Elu: UnaryNode; - EluDer: UnaryNode; - Selu: UnaryNode; - Abs: UnaryNode; - Sigmoid: UnaryNode; - Step: StepNode; - Sin: UnaryNode; - Cos: UnaryNode; - Tan: UnaryNode; - Asin: UnaryNode; - Acos: UnaryNode; - Atan: UnaryNode; - Sinh: UnaryNode; - Cosh: UnaryNode; - Tanh: UnaryNode; - Clip: ClipNode; - Transpose: TransposeNode; Pad1D: Pad1DNode; Pad2D: Pad2DNode; - Tile: TileNode; Gather: GatherNode; Conv2D: Conv2DNode; Conv2DDerInput: Conv2DDerInputNode; diff --git a/src/math/backends/profiler.ts b/src/math/backends/profiler.ts index fb25600015..32d6284920 100644 --- a/src/math/backends/profiler.ts +++ b/src/math/backends/profiler.ts @@ -18,9 +18,7 @@ import * as util from '../../util'; import {Tensor} from '../tensor'; import {TypedArray} from '../types'; - import {BackendTimer} from './backend'; -import {Kernel} from './kernel_registry'; export class Profiler { constructor(private backendTimer: BackendTimer, private logger?: Logger) { @@ -29,7 +27,7 @@ export class Profiler { } } - profileKernel(kernelName: Kernel, f: () => T): T { + profileKernel(kernelName: string, f: () => T): T { let result: T; const holdResultWrapperFn = () => { result = f(); @@ -49,7 +47,7 @@ export class Profiler { export class Logger { logKernelProfile( - kernelName: Kernel, result: Tensor, vals: TypedArray, timeMs: number) { + kernelName: string, result: Tensor, vals: TypedArray, timeMs: number) { const time = util.rightPad(`${timeMs}ms`, 9); const paddedName = util.rightPad(kernelName, 25); const rank = result.rank; diff --git a/src/math/backends/webgl/mulmat_gpu.ts b/src/math/backends/webgl/mulmat_gpu.ts index fd4aa607a5..a6745e1ab5 100644 --- a/src/math/backends/webgl/mulmat_gpu.ts +++ b/src/math/backends/webgl/mulmat_gpu.ts @@ -15,8 +15,7 @@ * ============================================================================= */ -import {MatrixOrientation} from '../types/matmul'; - +import {MatrixOrientation} from '../../matmul'; import {GPGPUProgram} from './gpgpu_math'; export class MatMulProgram implements GPGPUProgram { diff --git a/src/math/backends/webgl/mulmat_packed_gpu.ts b/src/math/backends/webgl/mulmat_packed_gpu.ts index df0998cab7..293ace12a7 100644 --- a/src/math/backends/webgl/mulmat_packed_gpu.ts +++ b/src/math/backends/webgl/mulmat_packed_gpu.ts @@ -15,8 +15,7 @@ * ============================================================================= */ -import {MatrixOrientation} from '../types/matmul'; - +import {MatrixOrientation} from '../../matmul'; import {GPGPUContext} from './gpgpu_context'; import * as webgl_util from './webgl_util'; diff --git a/src/math/backends/webgl/mulmat_packed_gpu_test.ts b/src/math/backends/webgl/mulmat_packed_gpu_test.ts index eedf3d6f3a..e30db824ff 100644 --- a/src/math/backends/webgl/mulmat_packed_gpu_test.ts +++ b/src/math/backends/webgl/mulmat_packed_gpu_test.ts @@ -17,7 +17,7 @@ // tslint:disable-next-line:max-line-length import {expectArraysClose, expectNumbersClose} from '../../../test_util'; -import {MatrixOrientation} from '../types/matmul'; +import {MatrixOrientation} from '../../matmul'; import {GPGPUContext} from './gpgpu_context'; import * as mulmat_packed_gpu from './mulmat_packed_gpu'; diff --git a/src/math/binary_ops.ts b/src/math/binary_ops.ts index b5020ce594..278cb0815a 100644 --- a/src/math/binary_ops.ts +++ b/src/math/binary_ops.ts @@ -136,30 +136,26 @@ export class Ops { */ @doc({heading: 'Operations', subheading: 'Arithmetic'}) @operation - static pow(base: Tensor, exp: Tensor): T { + static pow(base: T, exp: Tensor): T { util.assert( exp.dtype === 'int32', 'only supports int32 data type for the exponent parameter.'); broadcast_util.assertAndGetBroadcastShape(base.shape, exp.shape); - const gradient = (dy: Tensor, y: Tensor) => { + const grad = (dy: T, y: T) => { if (!util.arraysEqual(base.shape, exp.shape)) { throw new Error( `Gradient of pow not yet supported for broadcasted shapes.`); } const derBase = () => { - const dx = - exp.toFloat().mul(base.pow(exp.sub(scalar(1, 'int32'))).toFloat()); - return dy.mul(dx); + const dx = exp.toFloat().mul( + base.pow(exp.sub(scalar(1, 'int32'))).toFloat()) as T; + return dy.mulStrict(dx); }; - const derExp = () => { - throw new Error(`Backprop through exponent not implemented yet.`); - }; - return {base: derBase, exp: derExp}; + return {base: derBase}; }; - - return ENV.engine.executeKernel('Pow', {inputs: {base, exp}}, gradient) as - T; + return ENV.engine.runKernel( + backend => backend.pow(base, exp), {base}, grad); } /** diff --git a/src/math/logical_ops.ts b/src/math/logical_ops.ts index e13870fe14..f777b35467 100644 --- a/src/math/logical_ops.ts +++ b/src/math/logical_ops.ts @@ -32,9 +32,9 @@ export class Ops { */ @doc({heading: 'Operations', subheading: 'Logical'}) @operation - static logicalNot(x: Tensor): T { + static logicalNot(x: T): T { util.assert(x.dtype === 'bool', 'Error Array must be of type bool.'); - return ENV.engine.executeKernel('LogicalNot', {inputs: {x}}) as T; + return ENV.engine.runKernel(backend => backend.logicalNot(x)); } /** diff --git a/src/math/math.ts b/src/math/math.ts index 0015feaefa..b93953c146 100644 --- a/src/math/math.ts +++ b/src/math/math.ts @@ -281,7 +281,7 @@ export class NDArrayMath { /** @deprecated Use dl.transpose() instead. */ switchDim(x: Tensor, perm?: number[]): Tensor { - return ops.transpose(x, perm); + return ops.transpose(x, perm); } /** @deprecated Use dl.add(c, A) instead. */ diff --git a/src/math/math_test.ts b/src/math/math_test.ts index fc84a3fa74..dce64ad299 100644 --- a/src/math/math_test.ts +++ b/src/math/math_test.ts @@ -20,7 +20,7 @@ import * as dl from '../index'; import {ALL_ENVS, describeWithFlags, expectArraysClose, expectArraysEqual, expectNumbersClose} from '../test_util'; import {Gradients} from './backends/gradients'; -import {MatrixOrientation} from './backends/types/matmul'; +import {MatrixOrientation} from './matmul'; import {Scalar, Tensor} from './tensor'; const gradientsScope = Gradients.gradientsScope; diff --git a/src/math/tensor.ts b/src/math/tensor.ts index b498926381..d63a3a3a79 100644 --- a/src/math/tensor.ts +++ b/src/math/tensor.ts @@ -17,7 +17,7 @@ import {ENV} from '../environment'; import * as util from '../util'; -import {MatrixOrientation} from './backends/types/matmul'; +import {MatrixOrientation} from './matmul'; import * as ops from './ops'; import {RandNormalDataTypes} from './rand'; // tslint:disable-next-line:max-line-length @@ -484,7 +484,7 @@ export class Tensor { this.throwIfDisposed(); return ops.subStrict(this, x); } - pow(exp: Tensor): T { + pow(this: T, exp: Tensor): T { this.throwIfDisposed(); return ops.pow(this, exp); } @@ -524,7 +524,7 @@ export class Tensor { this.throwIfDisposed(); return ops.maximumStrict(this, x); } - transpose(perm?: number[]): Tensor { + transpose(this: T, perm?: number[]): T { this.throwIfDisposed(); return ops.transpose(this, perm); } @@ -599,35 +599,35 @@ export class Tensor { } // Unary ops. - neg(): Tensor { + neg(this: T): T { this.throwIfDisposed(); return ops.neg(this); } - ceil(): Tensor { + ceil(this: T): T { this.throwIfDisposed(); return ops.ceil(this); } - floor(): Tensor { + floor(this: T): T { this.throwIfDisposed(); return ops.floor(this); } - exp(): Tensor { + exp(this: T): T { this.throwIfDisposed(); return ops.exp(this); } - log(): Tensor { + log(this: T): T { this.throwIfDisposed(); return ops.log(this); } - sqrt(): Tensor { + sqrt(this: T): T { this.throwIfDisposed(); return ops.sqrt(this); } - square(): Tensor { + square(this: T): T { this.throwIfDisposed(); return ops.square(this); } - abs(): Tensor { + abs(this: T): T { this.throwIfDisposed(); return ops.abs(this); } @@ -635,15 +635,15 @@ export class Tensor { this.throwIfDisposed(); return ops.clip(this, min, max); } - relu(): Tensor { + relu(this: T): T { this.throwIfDisposed(); return ops.relu(this); } - elu(): Tensor { + elu(this: T): T { this.throwIfDisposed(); return ops.elu(this); } - selu(): Tensor { + selu(this: T): T { this.throwIfDisposed(); return ops.selu(this); } @@ -655,47 +655,47 @@ export class Tensor { this.throwIfDisposed(); return ops.prelu(this, alpha); } - sigmoid(): Tensor { + sigmoid(this: T): T { this.throwIfDisposed(); return ops.sigmoid(this); } - sin(): Tensor { + sin(this: T): T { this.throwIfDisposed(); return ops.sin(this); } - cos(): Tensor { + cos(this: T): T { this.throwIfDisposed(); return ops.cos(this); } - tan(): Tensor { + tan(this: T): T { this.throwIfDisposed(); return ops.tan(this); } - asin(): Tensor { + asin(this: T): T { this.throwIfDisposed(); return ops.asin(this); } - acos(): Tensor { + acos(this: T): T { this.throwIfDisposed(); return ops.acos(this); } - atan(): Tensor { + atan(this: T): T { this.throwIfDisposed(); return ops.atan(this); } - sinh(): Tensor { + sinh(this: T): T { this.throwIfDisposed(); return ops.sinh(this); } - cosh(): Tensor { + cosh(this: T): T { this.throwIfDisposed(); return ops.cosh(this); } - tanh(): Tensor { + tanh(this: T): T { this.throwIfDisposed(); return ops.tanh(this); } - step(alpha = 0.0): Tensor { + step(this: T, alpha = 0.0): T { this.throwIfDisposed(); return ops.step(this, alpha); } diff --git a/src/math/transpose.ts b/src/math/transpose.ts index 0de60da228..b1a03e87b9 100644 --- a/src/math/transpose.ts +++ b/src/math/transpose.ts @@ -21,7 +21,6 @@ import * as util from '../util'; import * as axis_util from './axis_util'; import {doc, operation} from './decorators'; import {Tensor} from './tensor'; -import {Rank} from './types'; export class Ops { /** @@ -37,20 +36,19 @@ export class Ops { */ @doc({heading: 'Operations', subheading: 'Matrices'}) @operation - static transpose(x: Tensor, perm?: number[]): Tensor { + static transpose(x: T, perm?: number[]): T { if (perm == null) { perm = x.shape.map((s, i) => i).reverse(); } - const der = (dy: Tensor) => { + const der = (dy: T) => { const undoPerm = axis_util.getUndoAxesPermutation(perm); - const derX = () => dy.transpose(undoPerm); - return {x: derX}; + return {x: () => dy.transpose(undoPerm)}; }; util.assert( x.rank === perm.length, `Error in transpose: rank of input ${x.rank} ` + `must match length of perm ${perm}.`); - return ENV.engine.executeKernel( - 'Transpose', {inputs: {x}, args: {perm}}, der) as Tensor; + return ENV.engine.runKernel( + backend => backend.transpose(x, perm), {x}, der); } } diff --git a/src/math/unary_ops.ts b/src/math/unary_ops.ts index 762c71c945..9ee34042d0 100644 --- a/src/math/unary_ops.ts +++ b/src/math/unary_ops.ts @@ -16,11 +16,11 @@ */ import {ENV} from '../environment'; -import {zerosLike} from './ops'; import * as util from '../util'; import {doc, operation} from './decorators'; import * as ops from './ops'; +import {zerosLike} from './ops'; import * as selu_util from './selu_util'; import {Tensor} from './tensor'; @@ -32,9 +32,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static neg(x: T): T { - return ENV.engine.executeKernel('Neg', {inputs: {x}}, (dy: T, y: T) => { + const grad = (dy: T, y: T) => { return {x: () => dy.neg()}; - }) as T; + }; + return ENV.engine.runKernel(backend => backend.neg(x), {x}, grad); } /** @@ -46,10 +47,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static ceil(x: T): T { - const gradient = (dy: T, y: T) => { - return {x: () => ops.zeros(y.shape)}; + const grad = (dy: T, y: T) => { + return {x: () => ops.zerosLike(y)}; }; - return ENV.engine.executeKernel('Ceil', {inputs: {x}}, gradient) as T; + return ENV.engine.runKernel(backend => backend.ceil(x), {x}, grad); } /** @@ -60,10 +61,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static floor(x: T): T { - const gradient = (dy: T, y: T) => { - return {x: () => ops.zeros(y.shape)}; + const grad = (dy: T, y: T) => { + return {x: () => ops.zerosLike(y)}; }; - return ENV.engine.executeKernel('Floor', {inputs: {x}}, gradient) as T; + return ENV.engine.runKernel(backend => backend.floor(x), {x}, grad); } /** @@ -73,9 +74,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static exp(x: T): T { - return ENV.engine.executeKernel('Exp', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.mul(y)}; - }) as T; + const grad = (dy: T, y: T) => { + return {x: () => dy.mulStrict(y)}; + }; + return ENV.engine.runKernel(backend => backend.exp(x), {x}, grad); } /** @@ -85,9 +87,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static log(x: T): T { - return ENV.engine.executeKernel('Log', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.div(x.toFloat())}; - }) as T; + const grad = (dy: T, y: T) => { + return {x: () => dy.divStrict(x.toFloat())}; + }; + return ENV.engine.runKernel(backend => backend.log(x), {x}, grad); } /** @@ -97,9 +100,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static sqrt(x: T): T { - return ENV.engine.executeKernel('Sqrt', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.div(x.toFloat().sqrt().mul(ops.scalar(2)))}; - }) as T; + const grad = (dy: T, y: T) => { + return {x: () => dy.divStrict(x.toFloat().sqrt().mul(ops.scalar(2)))}; + }; + return ENV.engine.runKernel(backend => backend.sqrt(x), {x}, grad); } /** @@ -110,9 +114,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static square(x: T): T { - return ENV.engine.executeKernel('Square', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.mul(x.toFloat().mul(ops.scalar(2)))}; - }) as T; + const grad = (dy: T, y: T) => { + return {x: () => dy.mulStrict(x.toFloat().mul(ops.scalar(2)))}; + }; + return ENV.engine.runKernel(backend => backend.square(x), {x}, grad); } /** @@ -122,9 +127,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static abs(x: T): T { - return ENV.engine.executeKernel('Abs', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.mul(x.toFloat().step(-1))}; - }) as T; + const grad = (dy: T, y: T) => { + return {x: () => dy.mulStrict(x.toFloat().step(-1))}; + }; + return ENV.engine.runKernel(backend => backend.abs(x), {x}, grad); } /** @@ -140,15 +146,18 @@ export class Ops { (min <= max), `Error in clip: min (${min}) must be` + `less than or equal to max (${max}).`); - return ENV.engine.executeKernel( - 'Clip', {inputs: {x}, args: {min, max}}, (dy: T, y: T) => { + const grad = (dy: T, y: T) => { return { - // TODO(cais): Fix gradients for the case where x = min or x = max. - x: () => dy.where( - x.greater(ops.scalar(min)).logicalAnd(x.less(ops.scalar(max))), - zerosLike(dy)), + // TODO(cais): Fix gradients for the case where x = min or x + // = max. + x: () => + dy.where( + x.greater(ops.scalar(min)).logicalAnd(x.less(ops.scalar(max))), + zerosLike(dy)) as T, }; - }) as T; + }; + return ENV.engine.runKernel( + backend => backend.clip(x, min, max), {x}, grad); } /** @@ -158,10 +167,11 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static relu(x: T): T { - return ENV.engine.executeKernel('Relu', {inputs: {x}}, (dy: T, y: T) => { - const stepRes = x.step() as Tensor; - return {x: () => dy.mul(stepRes.toFloat())}; - }) as T; + const grad = (dy: T, y: T) => { + const stepRes = x.step(); + return {x: () => dy.mulStrict(stepRes.toFloat())}; + }; + return ENV.engine.runKernel(backend => backend.relu(x), {x}, grad); } /** @@ -171,17 +181,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static elu(x: T): T { - const der = (dy: Tensor) => { - return { - x: () => dy.mul(eluDer(x)), - alpha: () => { - throw new Error( - 'Derivative of prelu with respect to alpha is ' + - 'not implemented yet'); - } - }; + const grad = (dy: T) => { + return {x: () => dy.mulStrict(eluDer(x))}; }; - return ENV.engine.executeKernel('Elu', {inputs: {x}}, der) as T; + return ENV.engine.runKernel(backend => backend.elu(x), {x}, grad); } /** @@ -191,7 +194,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static selu(x: T): T { - const gradient = (dy: T, y: T) => { + const grad = (dy: T, y: T) => { return { x: () => { // Currently, Scalars are not supported by ops.where @@ -204,13 +207,11 @@ export class Ops { const greaterThanZeroDer = dy.mul(scale); const lessEqualZeroDer = dy.mul(scaleAlpha).mul(x.toFloat().exp()); - const res = ops.where(mask, greaterThanZeroDer, lessEqualZeroDer); - - return res; + return ops.where(mask, greaterThanZeroDer, lessEqualZeroDer) as T; } }; }; - return ENV.engine.executeKernel('Selu', {inputs: {x}}, gradient) as T; + return ENV.engine.runKernel(backend => backend.selu(x), {x}, grad); } /** @@ -222,8 +223,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static leakyRelu(x: T, alpha = 0.2): T { - return ENV.engine.executeKernel( - 'LeakyRelu', {inputs: {x}, args: {alpha}}) as T; + return ENV.engine.runKernel(backend => backend.leakyRelu(x, alpha)); } /** @@ -235,17 +235,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static prelu(x: T, alpha: T): T { - const der = (dy: Tensor) => { - return { - x: () => dy.mul(preluDer(x, alpha)), - alpha: () => { - throw new Error( - 'Derivative of prelu with respect to alpha is ' + - 'not implemented yet'); - } - }; + const grad = (dy: T) => { + return {x: () => dy.mulStrict(preluDer(x, alpha))}; }; - return ENV.engine.executeKernel('PReLU', {inputs: {x, alpha}}, der) as T; + return ENV.engine.runKernel(backend => backend.prelu(x, alpha), {x}, grad); } /** @@ -255,9 +248,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static sigmoid(x: T): T { - return ENV.engine.executeKernel('Sigmoid', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.mul(y.mul(ops.scalar(1).sub(y)))}; - }) as T; + const grad = (dy: T, y: T) => { + return {x: () => dy.mulStrict(y.mul(ops.scalar(1).sub(y)))}; + }; + return ENV.engine.runKernel(backend => backend.sigmoid(x), {x}, grad); } /** @@ -269,9 +263,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static sin(x: T): T { - return ENV.engine.executeKernel('Sin', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => x.toFloat().cos().mul(dy)}; - }) as T; + const grad = (dy: T, y: T) => { + return {x: () => x.toFloat().cos().mulStrict(dy)}; + }; + return ENV.engine.runKernel(backend => backend.sin(x), {x}, grad); } /** @@ -281,9 +276,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static cos(x: T): T { - return ENV.engine.executeKernel('Cos', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => x.toFloat().sin().neg().mul(dy)}; - }) as T; + const grad = (dy: T, y: T) => { + return {x: () => x.toFloat().sin().neg().mulStrict(dy)}; + }; + return ENV.engine.runKernel(backend => backend.cos(x), {x}, grad); } /** @@ -293,9 +289,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static tan(x: T): T { - return ENV.engine.executeKernel('Tan', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.div(x.cos().square())}; - }) as T; + const grad = (dy: T, y: T) => { + return {x: () => dy.divStrict(x.cos().square())}; + }; + return ENV.engine.runKernel(backend => backend.tan(x), {x}, grad); } /** @@ -305,11 +302,12 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static asin(x: T): T { - return ENV.engine.executeKernel('Asin', {inputs: {x}}, (dy: T, y: T) => { + const grad = (dy: T, y: T) => { return { - x: () => dy.div(Ops.sqrt(ops.scalar(1).sub(x.toFloat().square()))) + x: () => dy.divStrict(Ops.sqrt(ops.scalar(1).sub(x.toFloat().square()))) }; - }) as T; + }; + return ENV.engine.runKernel(backend => backend.asin(x), {x}, grad); } /** @@ -319,11 +317,13 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static acos(x: T): T { - return ENV.engine.executeKernel('Acos', {inputs: {x}}, (dy: T, y: T) => { + const grad = (dy: T, y: T) => { return { - x: () => dy.div(Ops.sqrt(ops.scalar(1).sub(x.toFloat().square()))).neg() + x: () => dy.divStrict(Ops.sqrt(ops.scalar(1).sub(x.toFloat().square()))) + .neg() }; - }) as T; + }; + return ENV.engine.runKernel(backend => backend.acos(x), {x}, grad); } /** @@ -333,9 +333,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static atan(x: T): T { - return ENV.engine.executeKernel('Atan', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.div(ops.scalar(1).add(x.toFloat().square()))}; - }) as T; + const grad = (dy: T, y: T) => { + return {x: () => dy.divStrict(ops.scalar(1).add(x.toFloat().square()))}; + }; + return ENV.engine.runKernel(backend => backend.atan(x), {x}, grad); } /** @@ -345,9 +346,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static sinh(x: T): T { - return ENV.engine.executeKernel('Sinh', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => x.toFloat().cosh().mul(dy)}; - }) as T; + const grad = (dy: T, y: T) => { + return {x: () => x.toFloat().cosh().mulStrict(dy)}; + }; + return ENV.engine.runKernel(backend => backend.sinh(x), {x}, grad); } /** @@ -357,9 +359,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static cosh(x: T): T { - return ENV.engine.executeKernel('Cosh', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => x.toFloat().sinh().mul(dy)}; - }) as T; + const grad = (dy: T, y: T) => { + return {x: () => x.toFloat().sinh().mulStrict(dy)}; + }; + return ENV.engine.runKernel(backend => backend.cosh(x), {x}, grad); } /** @@ -369,9 +372,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static tanh(x: T): T { - return ENV.engine.executeKernel('Tanh', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => ops.scalar(1).sub(y.square()).mul(dy)}; - }) as T; + const grad = (dy: T, y: T) => { + return {x: () => ops.scalar(1).sub(y.square()).mulStrict(dy) as T}; + }; + return ENV.engine.runKernel(backend => backend.tanh(x), {x}, grad); } /** @@ -384,14 +388,14 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static step(x: T, alpha = 0.0): T { - return ENV.engine.executeKernel('Step', {inputs: {x}, args: {alpha}}) as T; + return ENV.engine.runKernel(backend => backend.step(x, alpha)); } } -function preluDer(x: Tensor, alpha: Tensor): Tensor { - return ENV.engine.executeKernel('PReLUDer', {inputs: {x, alpha}}) as Tensor; +function preluDer(x: T, alpha: T): T { + return ENV.engine.runKernel(backend => backend.preluDer(x, alpha)); } -function eluDer(x: Tensor): Tensor { - return ENV.engine.executeKernel('EluDer', {inputs: {x}}) as Tensor; +function eluDer(x: T): T { + return ENV.engine.runKernel(backend => backend.eluDer(x)); } From 949432783769e6c04531a38c55155a382ae274cf Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sat, 10 Feb 2018 18:31:13 -0500 Subject: [PATCH 04/14] save --- src/math/array_ops.ts | 17 ++++--- src/math/backends/backend.ts | 4 +- src/math/backends/backend_cpu.ts | 4 +- src/math/backends/backend_engine.ts | 21 ++++++--- src/math/backends/backend_webgl.ts | 6 +-- src/math/binary_ops.ts | 2 +- src/math/matmul.ts | 2 +- src/math/unary_ops.ts | 73 ++++++++++++++++++----------- 8 files changed, 78 insertions(+), 51 deletions(-) diff --git a/src/math/array_ops.ts b/src/math/array_ops.ts index 908d9410fb..ab5bf01957 100644 --- a/src/math/array_ops.ts +++ b/src/math/array_ops.ts @@ -17,6 +17,8 @@ import {ENV} from '../environment'; import * as util from '../util'; + +import {ForwardFunc} from './backends/backend_engine'; import {doc, operation} from './decorators'; import {MPRandGauss, RandNormalDataTypes} from './rand'; // tslint:disable-next-line:max-line-length @@ -450,7 +452,7 @@ export class Ops { x.size === util.sizeFromShape(shape), 'new shape and old shape must have the same number of elements.'); - const grad = (dy: Tensor, y: Tensor) => { + const grad = (dy: Tensor) => { return {x: () => dy.reshape(x.shape)}; }; return ENV.engine.runKernel( @@ -465,10 +467,7 @@ export class Ops { @doc({heading: 'Tensors', subheading: 'Transformations'}) @operation static cast(x: T, dtype: DataType): T { - const grad = (dy: T, y: T) => { - return {x: () => dy.clone()}; - }; - return ENV.engine.runKernel(backend => { + const forw: ForwardFunc = backend => { if (!util.hasEncodingLoss(x.dtype, dtype)) { // We don't change the underlying data, since we cast to higher // precision. @@ -477,11 +476,15 @@ export class Ops { if (dtype === 'int32') { return backend.int(x); } else if (dtype === 'bool') { - return backend.notEqual(x, Ops.scalar(0, x.dtype)); + return backend.notEqual(x, Ops.scalar(0, x.dtype)) as T; } else { throw new Error(`Error in Cast: unknown dtype argument (${dtype})`); } - }, {x}, grad) as T; + }; + const grad = (dy: T) => { + return {x: () => dy.clone()}; + }; + return ENV.engine.runKernel(forw, {x}, grad) as T; } /** diff --git a/src/math/backends/backend.ts b/src/math/backends/backend.ts index cd02d22bdd..978a97810f 100644 --- a/src/math/backends/backend.ts +++ b/src/math/backends/backend.ts @@ -20,7 +20,7 @@ import {Conv2DInfo} from '../conv_util'; import {MatrixOrientation} from '../matmul'; // tslint:disable-next-line:max-line-length import {DataId, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'; -import {DataType, Rank, TypedArray} from '../types'; +import {DataType, TypedArray} from '../types'; export interface TensorStorage { read(dataId: DataId): Promise; @@ -126,7 +126,7 @@ export interface MathBackend extends TensorStorage, BackendTimer { leakyRelu(x: T, alpha: number): T; prelu(x: T, alpha: T): T; preluDer(x: T, alpha: T): T; - int(x: Tensor): Tensor; + int(x: T): T; clip(x: T, min: number, max: number): T; diff --git a/src/math/backends/backend_cpu.ts b/src/math/backends/backend_cpu.ts index 6a65153c45..e4e6a02ae2 100644 --- a/src/math/backends/backend_cpu.ts +++ b/src/math/backends/backend_cpu.ts @@ -798,13 +798,13 @@ export class MathBackendCPU implements MathBackend { return Tensor.make(x.shape, {values: resultValues}) as T; } - int(x: Tensor): Tensor { + int(x: T): T { const resultValues = new Int32Array(x.size); const values = x.dataSync(); for (let i = 0; i < values.length; ++i) { resultValues[i] = values[i]; } - return Tensor.make(x.shape, {values: resultValues}, 'int32') as Tensor; + return Tensor.make(x.shape, {values: resultValues}, 'int32'); } sigmoid(x: T): T { diff --git a/src/math/backends/backend_engine.ts b/src/math/backends/backend_engine.ts index e36c86cc7e..5e339d7537 100644 --- a/src/math/backends/backend_engine.ts +++ b/src/math/backends/backend_engine.ts @@ -20,9 +20,9 @@ import {tidy} from '../../globals'; import * as util from '../../util'; import * as ops from '../ops'; import {DataId, Tensor, Tensor3D, Variable} from '../tensor'; +// tslint:disable-next-line:max-line-length import {NamedTensorMap, NamedVariableMap, TypedArray} from '../types'; import {Rank} from '../types'; - import {MathBackend} from './backend'; import * as kernel_registry from './kernel_registry'; import {Kernel, KernelConfigRegistry} from './kernel_registry'; @@ -37,6 +37,9 @@ interface ScopeState { track: Tensor[]; } +export type ForwardFunc = + (backend: MathBackend, save?: (map: NamedTensorMap) => void) => T; + export type CustomGradientFunc = () => { value: T, gradients: (dy: T, y: T) => TapeNodeInputGradientTensors }; @@ -82,16 +85,20 @@ export class BackendEngine implements TensorManager { } runKernel( - kernelFn: (backend: MathBackend) => T, inputs?: I, - grad?: (dy: T, y: T) => {[P in keyof I]: () => I[P]}): T { + forwardFunc: ForwardFunc, + inputs?: I, + backwardsFunc?: + (dy: T, saved: NamedTensorMap) => {[P in keyof I]: () => I[P]}, + ): T { let result: T; // TODO(smilkov): Figure out kernel name. const kernelName = '' as Kernel; + let saved: NamedTensorMap = null; if (!ENV.get('DEBUG')) { - result = kernelFn(this.backend); + result = forwardFunc(this.backend, x => saved = x); } else { - result = - this.profiler.profileKernel(kernelName, () => kernelFn(this.backend)); + result = this.profiler.profileKernel( + kernelName, () => forwardFunc(this.backend, x => saved = x)); } const recordKernel = @@ -104,7 +111,7 @@ export class BackendEngine implements TensorManager { kernel: kernelName, inputAndArgs: {inputs}, output: result, - gradient: grad + gradient: (dy: T) => backwardsFunc(dy, saved) }; this.activeTape.push(evaluatedNode); } diff --git a/src/math/backends/backend_webgl.ts b/src/math/backends/backend_webgl.ts index ca6736ee06..a7afc7c73c 100644 --- a/src/math/backends/backend_webgl.ts +++ b/src/math/backends/backend_webgl.ts @@ -26,7 +26,7 @@ import * as reduce_util from '../reduce_util'; import {DataId, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'; import * as types from '../types'; // tslint:disable-next-line:max-line-length -import {DataType, DataTypeMap, Rank, RecursiveArray, TypedArray} from '../types'; +import {DataType, DataTypeMap, RecursiveArray, TypedArray} from '../types'; import {MathBackend} from './backend'; import {ArgMinMaxProgram} from './webgl/argminmax_gpu'; @@ -729,10 +729,10 @@ export class MathBackendWebGL implements MathBackend { return this.compileAndRun(program, [a, b]) as T; } - int(x: Tensor): Tensor { + int(x: T): T { const program = new UnaryOpProgram(x.shape, unary_op.TO_INT); const output = this.makeOutputArray(program.outputShape, 'int32'); - return this.compileAndRun(program, [x], output) as Tensor; + return this.compileAndRun(program, [x], output) as T; } clip(x: T, min: number, max: number): T { diff --git a/src/math/binary_ops.ts b/src/math/binary_ops.ts index 278cb0815a..562db8fc10 100644 --- a/src/math/binary_ops.ts +++ b/src/math/binary_ops.ts @@ -142,7 +142,7 @@ export class Ops { 'only supports int32 data type for the exponent parameter.'); broadcast_util.assertAndGetBroadcastShape(base.shape, exp.shape); - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { if (!util.arraysEqual(base.shape, exp.shape)) { throw new Error( `Gradient of pow not yet supported for broadcasted shapes.`); diff --git a/src/math/matmul.ts b/src/math/matmul.ts index 0a93851758..5c6e976ace 100644 --- a/src/math/matmul.ts +++ b/src/math/matmul.ts @@ -60,7 +60,7 @@ export class Ops { `${b.shape} and orientations ${MatrixOrientation[transposeA]}` + ` and ${MatrixOrientation[transposeB]} must match.`); - const grad = (dy: Tensor2D, y: Tensor2D) => { + const grad = (dy: Tensor2D) => { if (transposeA === TRANSPOSED || transposeB === TRANSPOSED) { throw new Error(`Backprop for transposed MatMul not yet implemented.`); } diff --git a/src/math/unary_ops.ts b/src/math/unary_ops.ts index 9ee34042d0..97bdb9c2ef 100644 --- a/src/math/unary_ops.ts +++ b/src/math/unary_ops.ts @@ -17,7 +17,7 @@ import {ENV} from '../environment'; import * as util from '../util'; - +import {ForwardFunc} from './backends/backend_engine'; import {doc, operation} from './decorators'; import * as ops from './ops'; import {zerosLike} from './ops'; @@ -32,7 +32,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static neg(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return {x: () => dy.neg()}; }; return ENV.engine.runKernel(backend => backend.neg(x), {x}, grad); @@ -47,8 +47,8 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static ceil(x: T): T { - const grad = (dy: T, y: T) => { - return {x: () => ops.zerosLike(y)}; + const grad = (dy: T) => { + return {x: () => ops.zerosLike(dy)}; }; return ENV.engine.runKernel(backend => backend.ceil(x), {x}, grad); } @@ -61,8 +61,8 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static floor(x: T): T { - const grad = (dy: T, y: T) => { - return {x: () => ops.zerosLike(y)}; + const grad = (dy: T) => { + return {x: () => ops.zerosLike(dy)}; }; return ENV.engine.runKernel(backend => backend.floor(x), {x}, grad); } @@ -74,10 +74,15 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static exp(x: T): T { - const grad = (dy: T, y: T) => { - return {x: () => dy.mulStrict(y)}; + const forw: ForwardFunc = (backend, save) => { + const y = backend.exp(x); + save({y}); + return y; + }; + const bck = (dy: T, saved: {y: T}) => { + return {x: () => dy.mulStrict(saved.y)}; }; - return ENV.engine.runKernel(backend => backend.exp(x), {x}, grad); + return ENV.engine.runKernel(forw, {x}, bck); } /** @@ -87,7 +92,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static log(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return {x: () => dy.divStrict(x.toFloat())}; }; return ENV.engine.runKernel(backend => backend.log(x), {x}, grad); @@ -100,7 +105,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static sqrt(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return {x: () => dy.divStrict(x.toFloat().sqrt().mul(ops.scalar(2)))}; }; return ENV.engine.runKernel(backend => backend.sqrt(x), {x}, grad); @@ -114,7 +119,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static square(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return {x: () => dy.mulStrict(x.toFloat().mul(ops.scalar(2)))}; }; return ENV.engine.runKernel(backend => backend.square(x), {x}, grad); @@ -127,7 +132,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static abs(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return {x: () => dy.mulStrict(x.toFloat().step(-1))}; }; return ENV.engine.runKernel(backend => backend.abs(x), {x}, grad); @@ -146,7 +151,7 @@ export class Ops { (min <= max), `Error in clip: min (${min}) must be` + `less than or equal to max (${max}).`); - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return { // TODO(cais): Fix gradients for the case where x = min or x // = max. @@ -167,7 +172,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static relu(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { const stepRes = x.step(); return {x: () => dy.mulStrict(stepRes.toFloat())}; }; @@ -194,7 +199,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static selu(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return { x: () => { // Currently, Scalars are not supported by ops.where @@ -248,10 +253,16 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static sigmoid(x: T): T { - const grad = (dy: T, y: T) => { + const forw: ForwardFunc = (backend, save) => { + const y = backend.sigmoid(x); + save({y}); + return y; + }; + const grad = (dy: T, saved: {y: T}) => { + const {y} = saved; return {x: () => dy.mulStrict(y.mul(ops.scalar(1).sub(y)))}; }; - return ENV.engine.runKernel(backend => backend.sigmoid(x), {x}, grad); + return ENV.engine.runKernel(forw, {x}, grad); } /** @@ -263,7 +274,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static sin(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return {x: () => x.toFloat().cos().mulStrict(dy)}; }; return ENV.engine.runKernel(backend => backend.sin(x), {x}, grad); @@ -276,7 +287,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static cos(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return {x: () => x.toFloat().sin().neg().mulStrict(dy)}; }; return ENV.engine.runKernel(backend => backend.cos(x), {x}, grad); @@ -289,7 +300,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static tan(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return {x: () => dy.divStrict(x.cos().square())}; }; return ENV.engine.runKernel(backend => backend.tan(x), {x}, grad); @@ -302,7 +313,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static asin(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return { x: () => dy.divStrict(Ops.sqrt(ops.scalar(1).sub(x.toFloat().square()))) }; @@ -317,7 +328,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static acos(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return { x: () => dy.divStrict(Ops.sqrt(ops.scalar(1).sub(x.toFloat().square()))) .neg() @@ -333,7 +344,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static atan(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return {x: () => dy.divStrict(ops.scalar(1).add(x.toFloat().square()))}; }; return ENV.engine.runKernel(backend => backend.atan(x), {x}, grad); @@ -346,7 +357,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static sinh(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return {x: () => x.toFloat().cosh().mulStrict(dy)}; }; return ENV.engine.runKernel(backend => backend.sinh(x), {x}, grad); @@ -359,7 +370,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static cosh(x: T): T { - const grad = (dy: T, y: T) => { + const grad = (dy: T) => { return {x: () => x.toFloat().sinh().mulStrict(dy)}; }; return ENV.engine.runKernel(backend => backend.cosh(x), {x}, grad); @@ -372,10 +383,16 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static tanh(x: T): T { - const grad = (dy: T, y: T) => { + const forw: ForwardFunc = (backend, save) => { + const y = backend.tanh(x); + save({y}); + return y; + }; + const grad = (dy: T, saved: {y: T}) => { + const {y} = saved; return {x: () => ops.scalar(1).sub(y.square()).mulStrict(dy) as T}; }; - return ENV.engine.runKernel(backend => backend.tanh(x), {x}, grad); + return ENV.engine.runKernel(forw, {x}, grad); } /** From e6bfc70cbb559d1f3b2a354c6d257a467d0e6f7c Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sat, 17 Feb 2018 19:36:15 -0500 Subject: [PATCH 05/14] save --- src/kernels/webgl/mulmat_packed_gpu.ts | 2 +- src/ops/array_ops.ts | 2 ++ src/ops/unary_ops.ts | 4 +++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/kernels/webgl/mulmat_packed_gpu.ts b/src/kernels/webgl/mulmat_packed_gpu.ts index 293ace12a7..efb35cf711 100644 --- a/src/kernels/webgl/mulmat_packed_gpu.ts +++ b/src/kernels/webgl/mulmat_packed_gpu.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {MatrixOrientation} from '../../matmul'; +import {MatrixOrientation} from '../types/matmul'; import {GPGPUContext} from './gpgpu_context'; import * as webgl_util from './webgl_util'; diff --git a/src/ops/array_ops.ts b/src/ops/array_ops.ts index 3aef611459..c9105c705d 100644 --- a/src/ops/array_ops.ts +++ b/src/ops/array_ops.ts @@ -16,12 +16,14 @@ */ import {doc} from '../doc'; +import {ForwardFunc} from '../engine'; import {ENV} from '../environment'; // tslint:disable-next-line:max-line-length import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer} from '../tensor'; // tslint:disable-next-line:max-line-length import {ArrayData, DataType, DataTypeMap, Rank, ShapeMap, TensorLike, TensorLike1D, TensorLike2D, TensorLike3D, TensorLike4D, TypedArray} from '../types'; import * as util from '../util'; + import {Concat} from './concat'; import {operation} from './operation'; import {MPRandGauss} from './rand'; diff --git a/src/ops/unary_ops.ts b/src/ops/unary_ops.ts index 35b63a2db9..981d97d140 100644 --- a/src/ops/unary_ops.ts +++ b/src/ops/unary_ops.ts @@ -16,9 +16,11 @@ */ import {doc} from '../doc'; +import {ForwardFunc} from '../engine'; import {ENV} from '../environment'; import {Tensor} from '../tensor'; import * as util from '../util'; + import {operation} from './operation'; import * as ops from './ops'; import {zerosLike} from './ops'; @@ -60,7 +62,7 @@ export class Ops { static ceil(x: T): T { // TODO(manrajgrover): Return null for gradients when backprop supports it. const grad = (dy: T) => { - return {x: () => ops.zeros(dy.shape)}; + return {x: () => ops.zerosLike(dy)}; }; return ENV.engine.runKernel(backend => backend.ceil(x), {x}, grad); } From 3dc3b7b3d29c909dd7ea5809d81006dd25644e79 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sat, 17 Feb 2018 21:13:08 -0500 Subject: [PATCH 06/14] save --- src/kernels/types/argminmax.ts | 35 ---------------- src/kernels/types/batchnorm.ts | 39 ----------------- src/kernels/types/binary.ts | 28 ------------- src/kernels/types/conv.ts | 63 ---------------------------- src/kernels/types/logical.ts | 54 ------------------------ src/kernels/types/lrn.ts | 37 ---------------- src/kernels/types/minmax.ts | 55 ------------------------ src/kernels/types/multinomial.ts | 29 ------------- src/kernels/types/onehot.ts | 30 ------------- src/kernels/types/pad.ts | 43 ------------------- src/kernels/types/pool.ts | 40 ------------------ src/kernels/types/resize_bilinear.ts | 30 ------------- src/kernels/types/slice.ts | 62 --------------------------- 13 files changed, 545 deletions(-) delete mode 100644 src/kernels/types/argminmax.ts delete mode 100644 src/kernels/types/batchnorm.ts delete mode 100644 src/kernels/types/binary.ts delete mode 100644 src/kernels/types/conv.ts delete mode 100644 src/kernels/types/logical.ts delete mode 100644 src/kernels/types/lrn.ts delete mode 100644 src/kernels/types/minmax.ts delete mode 100644 src/kernels/types/multinomial.ts delete mode 100644 src/kernels/types/onehot.ts delete mode 100644 src/kernels/types/pad.ts delete mode 100644 src/kernels/types/pool.ts delete mode 100644 src/kernels/types/resize_bilinear.ts delete mode 100644 src/kernels/types/slice.ts diff --git a/src/kernels/types/argminmax.ts b/src/kernels/types/argminmax.ts deleted file mode 100644 index 466858c183..0000000000 --- a/src/kernels/types/argminmax.ts +++ /dev/null @@ -1,35 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; - -export interface ArgMaxNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor;}; args: {axes: number[];};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor; - }; -} - -export interface ArgMinNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor;}; args: {axes: number[];};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor; - }; -} diff --git a/src/kernels/types/batchnorm.ts b/src/kernels/types/batchnorm.ts deleted file mode 100644 index ed9a5f149c..0000000000 --- a/src/kernels/types/batchnorm.ts +++ /dev/null @@ -1,39 +0,0 @@ - -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor1D, Tensor4D} from '../../tensor'; - -export interface BatchNorm4DNode extends KernelNode { - inputAndArgs: { - inputs: { - x: Tensor4D; mean: Tensor4D | Tensor1D; variance: Tensor4D | Tensor1D; - scale?: Tensor4D | Tensor1D; - offset?: Tensor4D | Tensor1D; - }; - args: {varianceEpsilon: number}; - }; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - mean: () => Tensor4D | Tensor1D; - variance: () => Tensor4D | Tensor1D; - scale?: () => Tensor4D | Tensor1D; - offset?: () => Tensor4D | Tensor1D; - }; -} diff --git a/src/kernels/types/binary.ts b/src/kernels/types/binary.ts deleted file mode 100644 index c0dbb16712..0000000000 --- a/src/kernels/types/binary.ts +++ /dev/null @@ -1,28 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; - -export interface BinaryNode extends KernelNode { - inputAndArgs: {inputs: {a: Tensor; b: Tensor;};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - a: () => Tensor; - b: () => Tensor; - }; -} diff --git a/src/kernels/types/conv.ts b/src/kernels/types/conv.ts deleted file mode 100644 index 31b3daeb00..0000000000 --- a/src/kernels/types/conv.ts +++ /dev/null @@ -1,63 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {Conv2DInfo} from '../../ops/conv_util'; -import {KernelNode} from '../../tape_types'; -import {Tensor4D} from '../../tensor'; - -export interface Conv2DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor4D; filter: Tensor4D;}; args: {convInfo: Conv2DInfo;}; - }; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - filter: () => Tensor4D; - }; -} - -export interface Conv2DDerInputNode extends KernelNode { - inputAndArgs: { - inputs: {dy: Tensor4D; filter: Tensor4D;}; args: {convInfo: Conv2DInfo;}; - }; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - dy: () => Tensor4D; - filter: () => Tensor4D; - }; -} - -export interface Conv2DDerFilterNode extends KernelNode { - inputAndArgs: - {inputs: {x: Tensor4D; dy: Tensor4D;}; args: {convInfo: Conv2DInfo;};}; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - dy: () => Tensor4D; - }; -} - -export interface DepthwiseConv2DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor4D; filter: Tensor4D;}; args: {convInfo: Conv2DInfo;}; - }; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - filter: () => Tensor4D; - }; -} diff --git a/src/kernels/types/logical.ts b/src/kernels/types/logical.ts deleted file mode 100644 index f88d9d3cb6..0000000000 --- a/src/kernels/types/logical.ts +++ /dev/null @@ -1,54 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; -import {DataType} from '../../types'; - -// Equal/NotEqual/Less/LessEqual/Greater/GreaterEqual -export interface EqualNode extends KernelNode { - inputAndArgs: {inputs: {a: Tensor; b: Tensor;};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - a: () => Tensor; - b: () => Tensor; - }; -} - -// LogicalAnd/LogicalOr/LogicalXor -export interface LogicalNode extends KernelNode { - inputAndArgs: {inputs: {a: Tensor; b: Tensor;};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - a: () => Tensor; - b: () => Tensor; - }; -} - -// Where -export interface WhereNode extends KernelNode { - inputAndArgs: { - inputs: {condition: Tensor; a: Tensor; b: Tensor;}; - args: {dtype: DataType}; - }; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - condition: () => Tensor; - a: () => Tensor; - b: () => Tensor; - }; -} diff --git a/src/kernels/types/lrn.ts b/src/kernels/types/lrn.ts deleted file mode 100644 index c6cad957b2..0000000000 --- a/src/kernels/types/lrn.ts +++ /dev/null @@ -1,37 +0,0 @@ - -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor4D} from '../../tensor'; - -// 4D -export interface LRN4DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor4D;}; args: { - radius: number, - bias: number, - alpha: number, - beta: number, - normRegion: 'acrossChannels'|'withinChannel' - }; - }; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - }; -} diff --git a/src/kernels/types/minmax.ts b/src/kernels/types/minmax.ts deleted file mode 100644 index a14d38f732..0000000000 --- a/src/kernels/types/minmax.ts +++ /dev/null @@ -1,55 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; - -// Reduction min. -export interface MinNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor;}; args: {axes: number[]}}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor; - }; -} - -// Element-wise min. -export interface MinimumNode extends KernelNode { - inputAndArgs: {inputs: {a: Tensor, b: Tensor};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - a: () => Tensor, b: () => Tensor - }; -} - -// Reduction Max -export interface MaxNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor;}; args: {axes: number[]}}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor; - }; -} - -// Element-wise max. -export interface MaximumNode extends KernelNode { - inputAndArgs: {inputs: {a: Tensor, b: Tensor};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - a: () => Tensor, b: () => Tensor - }; -} diff --git a/src/kernels/types/multinomial.ts b/src/kernels/types/multinomial.ts deleted file mode 100644 index 640ae1978f..0000000000 --- a/src/kernels/types/multinomial.ts +++ /dev/null @@ -1,29 +0,0 @@ - -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor2D} from '../../tensor'; - -export interface MultinomialNode extends KernelNode { - inputAndArgs: - {inputs: {probs: Tensor2D;}; args: {numSamples: number; seed: number};}; - output: Tensor2D; - gradient: (dy: Tensor2D, y: Tensor2D) => { - probs: () => Tensor2D; - }; -} diff --git a/src/kernels/types/onehot.ts b/src/kernels/types/onehot.ts deleted file mode 100644 index df9b4d9734..0000000000 --- a/src/kernels/types/onehot.ts +++ /dev/null @@ -1,30 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor1D, Tensor2D} from '../../tensor'; - -export interface OneHotNode extends KernelNode { - inputAndArgs: { - inputs: {indices: Tensor1D;}; - args: {depth: number; onValue: number; offValue: number}; - }; - output: Tensor2D; - gradient: (dy: Tensor2D, y: Tensor2D) => { - indices: () => Tensor1D; - }; -} diff --git a/src/kernels/types/pad.ts b/src/kernels/types/pad.ts deleted file mode 100644 index c4347131f4..0000000000 --- a/src/kernels/types/pad.ts +++ /dev/null @@ -1,43 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor1D, Tensor2D} from '../../tensor'; - -export interface Pad1DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor1D;}; - args: {paddings: [number, number], constantValue: number}; - }; - output: Tensor1D; - gradient: (dy: Tensor1D, y: Tensor1D) => { - x: () => Tensor1D; - }; -} - -export interface Pad2DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor2D;}; args: { - paddings: [[number, number], [number, number]], - constantValue: number - }; - }; - output: Tensor2D; - gradient: (dy: Tensor2D, y: Tensor2D) => { - x: () => Tensor2D; - }; -} diff --git a/src/kernels/types/pool.ts b/src/kernels/types/pool.ts deleted file mode 100644 index fce01174e6..0000000000 --- a/src/kernels/types/pool.ts +++ /dev/null @@ -1,40 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {Conv2DInfo} from '../../ops/conv_util'; -import {KernelNode} from '../../tape_types'; -import {Tensor4D} from '../../tensor'; - -// Pool -export interface PoolNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor4D;}; args: {convInfo: Conv2DInfo;};}; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - }; -} - -// PoolBackprop -export interface PoolBackpropNode extends KernelNode { - inputAndArgs: - {inputs: {dy: Tensor4D; x: Tensor4D;}; args: {convInfo: Conv2DInfo;};}; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - dy: () => Tensor4D; - x: () => Tensor4D; - }; -} diff --git a/src/kernels/types/resize_bilinear.ts b/src/kernels/types/resize_bilinear.ts deleted file mode 100644 index 43f94f510e..0000000000 --- a/src/kernels/types/resize_bilinear.ts +++ /dev/null @@ -1,30 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor4D} from '../../tensor'; - -export interface ResizeBilinearNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor4D;}; - args: {newHeight: number; newWidth: number; alignCorners: boolean}; - }; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - }; -} diff --git a/src/kernels/types/slice.ts b/src/kernels/types/slice.ts deleted file mode 100644 index 87ef49dd30..0000000000 --- a/src/kernels/types/slice.ts +++ /dev/null @@ -1,62 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../../tensor'; - -export interface Slice1DNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor1D;}; args: {begin: number; size: number;};}; - output: Tensor1D; - gradient: (dy: Tensor1D, y: Tensor1D) => { - x: () => Tensor1D; - }; -} - -export interface Slice2DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor2D;}; - args: {begin: [number, number]; size: [number, number];}; - }; - output: Tensor2D; - gradient: (dy: Tensor2D, y: Tensor2D) => { - x: () => Tensor2D; - }; -} - -export interface Slice3DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor3D;}; - args: {begin: [number, number, number]; size: [number, number, number];}; - }; - output: Tensor3D; - gradient: (dy: Tensor3D, y: Tensor3D) => { - x: () => Tensor3D; - }; -} - -export interface Slice4DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor4D;}; args: { - begin: [number, number, number, number]; - size: [number, number, number, number]; - }; - }; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - }; -} From 8157fb59dbcc6511c97563938ab655c874789f32 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sat, 17 Feb 2018 21:13:12 -0500 Subject: [PATCH 07/14] save --- src/engine.ts | 19 +-- src/kernels/kernel_registry.ts | 211 +-------------------------------- src/ops/array_ops.ts | 1 - src/ops/batchnorm.ts | 18 +-- src/ops/binary_ops.ts | 29 +++-- src/ops/compare.ts | 12 +- src/ops/conv.ts | 19 +-- src/ops/image_ops.ts | 6 +- src/ops/logical_ops.ts | 13 +- src/ops/lrn.ts | 6 +- src/ops/pool.ts | 24 ++-- src/ops/reduction_ops.ts | 10 +- src/ops/slice.ts | 12 +- src/ops/unary_ops.ts | 39 ++---- src/profiler_test.ts | 3 +- src/tape_types.ts | 6 +- 16 files changed, 98 insertions(+), 330 deletions(-) diff --git a/src/engine.ts b/src/engine.ts index 9d1a43bdda..d6a61d1a91 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -19,7 +19,7 @@ import {ENV} from './environment'; import {tidy} from './globals'; import {BackendTimingInfo, KernelBackend} from './kernels/backend'; import * as kernel_registry from './kernels/kernel_registry'; -import {Kernel, KernelConfigRegistry} from './kernels/kernel_registry'; +import {KernelConfigRegistry} from './kernels/kernel_registry'; import * as ops from './ops/ops'; import {Profiler} from './profiler'; // tslint:disable-next-line:max-line-length @@ -37,7 +37,7 @@ interface ScopeState { } export type ForwardFunc = - (backend: KernelBackend, save?: (map: NamedTensorMap) => void) => T; + (backend: KernelBackend, save?: (tensor: S) => S) => T; /** * @docalias (a: Tensor, b: Tensor,...) => { @@ -94,18 +94,21 @@ export class Engine implements TensorManager { runKernel( forwardFunc: ForwardFunc, inputs?: I, - backwardsFunc?: - (dy: T, saved: NamedTensorMap) => {[P in keyof I]: () => I[P]}, + backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]}, ): T { let result: T; // TODO(smilkov): Figure out kernel name. - const kernelName = '' as Kernel; - let saved: NamedTensorMap = null; + const kernelName = ''; + const saved: Tensor[] = []; + const saveFunc = (x: T): T => { + saved.push(x); + return x; + }; if (!ENV.get('DEBUG')) { - result = forwardFunc(this.backend, x => saved = x); + result = forwardFunc(this.backend, saveFunc); } else { result = this.profiler.profileKernel( - kernelName, () => forwardFunc(this.backend, x => saved = x)); + kernelName, () => forwardFunc(this.backend, saveFunc)); } const recordKernel = diff --git a/src/kernels/kernel_registry.ts b/src/kernels/kernel_registry.ts index f1d6e3cba3..ac51989304 100644 --- a/src/kernels/kernel_registry.ts +++ b/src/kernels/kernel_registry.ts @@ -17,26 +17,9 @@ import {Rank} from '../types'; import {KernelBackend} from './backend'; -import {ArgMaxNode, ArgMinNode} from './types/argminmax'; -import {BatchNorm4DNode} from './types/batchnorm'; -import {BinaryNode} from './types/binary'; -// tslint:disable-next-line:max-line-length import {ConcatNode} from './types/concat'; -// tslint:disable-next-line:max-line-length -import {Conv2DDerFilterNode, Conv2DDerInputNode, Conv2DNode, DepthwiseConv2DNode} from './types/conv'; import {GatherNode} from './types/gather'; -import {EqualNode, LogicalNode, WhereNode} from './types/logical'; -import {LRN4DNode} from './types/lrn'; -import {MaximumNode, MaxNode, MinimumNode, MinNode} from './types/minmax'; -import {MultinomialNode} from './types/multinomial'; -import {OneHotNode} from './types/onehot'; -import {Pad1DNode, Pad2DNode} from './types/pad'; -// tslint:disable-next-line:max-line-length -import {PoolBackpropNode, PoolNode} from './types/pool'; -import {ResizeBilinearNode} from './types/resize_bilinear'; import {Reverse4DNode} from './types/reverse'; -// tslint:disable-next-line:max-line-length -import {Slice1DNode, Slice2DNode, Slice3DNode, Slice4DNode} from './types/slice'; import {SumNode} from './types/sum'; import {TopKIndicesNode, TopKValuesNode} from './types/topk'; @@ -45,226 +28,34 @@ executeKernel, O extends KernelConfigRegistry[K]['output']>( backend: KernelBackend, kernelName: K, inputAndArgs: KernelConfigRegistry[K]['inputAndArgs']): O { - if (kernelName === 'Slice1D') { - const config = inputAndArgs as Slice1DNode['inputAndArgs']; - return backend.slice1D( - config.inputs.x, config.args.begin, config.args.size) as O; - } else if (kernelName === 'Slice2D') { - const config = inputAndArgs as Slice2DNode['inputAndArgs']; - return backend.slice2D( - config.inputs.x, config.args.begin, config.args.size) as O; - } else if (kernelName === 'Slice3D') { - const config = inputAndArgs as Slice3DNode['inputAndArgs']; - return backend.slice3D( - config.inputs.x, config.args.begin, config.args.size) as O; - } else if (kernelName === 'Slice4D') { - const config = inputAndArgs as Slice4DNode['inputAndArgs']; - return backend.slice4D( - config.inputs.x, config.args.begin, config.args.size) as O; - } else if (kernelName === 'Reverse4D') { + if (kernelName === 'Reverse4D') { const config = inputAndArgs as Reverse4DNode['inputAndArgs']; return backend.reverse4D(config.inputs.x, config.args.axis) as O; } else if (kernelName === 'Concat') { const config = inputAndArgs as ConcatNode['inputAndArgs']; return backend.concat(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Add') { - const config = inputAndArgs as BinaryNode['inputAndArgs']; - return backend.add(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Sub') { - const config = inputAndArgs as BinaryNode['inputAndArgs']; - return backend.subtract(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Mul') { - const config = inputAndArgs as BinaryNode['inputAndArgs']; - return backend.multiply(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Div') { - const config = inputAndArgs as BinaryNode['inputAndArgs']; - return backend.divide(config.inputs.a, config.inputs.b) as O; } else if (kernelName === 'Sum') { const config = inputAndArgs as SumNode['inputAndArgs']; return backend.sum(config.inputs.x, config.args.axes) as O; - } else if (kernelName === 'ArgMax') { - const config = inputAndArgs as ArgMaxNode['inputAndArgs']; - return backend.argMax(config.inputs.x, config.args.axes) as O; - } else if (kernelName === 'ArgMin') { - const config = inputAndArgs as ArgMinNode['inputAndArgs']; - return backend.argMin(config.inputs.x, config.args.axes) as O; - } else if (kernelName === 'Equal') { - const config = inputAndArgs as EqualNode['inputAndArgs']; - return backend.equal(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'NotEqual') { - const config = inputAndArgs as EqualNode['inputAndArgs']; - return backend.notEqual(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Less') { - const config = inputAndArgs as EqualNode['inputAndArgs']; - return backend.less(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'LessEqual') { - const config = inputAndArgs as EqualNode['inputAndArgs']; - return backend.lessEqual(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Greater') { - const config = inputAndArgs as EqualNode['inputAndArgs']; - return backend.greater(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'GreaterEqual') { - const config = inputAndArgs as EqualNode['inputAndArgs']; - return backend.greaterEqual(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'LogicalAnd') { - const config = inputAndArgs as LogicalNode['inputAndArgs']; - return backend.logicalAnd(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'LogicalOr') { - const config = inputAndArgs as LogicalNode['inputAndArgs']; - return backend.logicalOr(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'LogicalXor') { - const config = inputAndArgs as LogicalNode['inputAndArgs']; - return backend.logicalXor(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Where') { - const config = inputAndArgs as WhereNode['inputAndArgs']; - return backend.where( - config.inputs.condition, config.inputs.a, config.inputs.b, - config.args.dtype) as O; } else if (kernelName === 'TopKValues') { const config = inputAndArgs as TopKValuesNode['inputAndArgs']; return backend.topKValues(config.inputs.x, config.args.k) as O; } else if (kernelName === 'TopKIndices') { const config = inputAndArgs as TopKIndicesNode['inputAndArgs']; return backend.topKIndices(config.inputs.x, config.args.k) as O; - } else if (kernelName === 'Min') { - const config = inputAndArgs as MinNode['inputAndArgs']; - return backend.min(config.inputs.x, config.args.axes) as O; - } else if (kernelName === 'Minimum') { - const config = inputAndArgs as MinimumNode['inputAndArgs']; - return backend.minimum(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Max') { - const config = inputAndArgs as MaxNode['inputAndArgs']; - return backend.max(config.inputs.x, config.args.axes) as O; - } else if (kernelName === 'Maximum') { - const config = inputAndArgs as MaximumNode['inputAndArgs']; - return backend.maximum(config.inputs.a, config.inputs.b) as O; } else if (kernelName === 'Gather') { const config = inputAndArgs as GatherNode['inputAndArgs']; return backend.gather( config.inputs.x, config.inputs.indices, config.args.axis) as O; - } else if (kernelName === 'Pad1D') { - const config = inputAndArgs as Pad1DNode['inputAndArgs']; - return backend.pad1D( - config.inputs.x, config.args.paddings, - config.args.constantValue) as O; - } else if (kernelName === 'Pad2D') { - const config = inputAndArgs as Pad2DNode['inputAndArgs']; - return backend.pad2D( - config.inputs.x, config.args.paddings, - config.args.constantValue) as O; - } else if (kernelName === 'Conv2D') { - const config = inputAndArgs as Conv2DNode['inputAndArgs']; - return backend.conv2d( - config.inputs.x, config.inputs.filter, config.args.convInfo) as - O; - } else if (kernelName === 'Conv2DDerInput') { - const config = inputAndArgs as Conv2DDerInputNode['inputAndArgs']; - return backend.conv2dDerInput( - config.inputs.dy, config.inputs.filter, config.args.convInfo) as - O; - } else if (kernelName === 'Conv2DDerFilter') { - const config = inputAndArgs as Conv2DDerFilterNode['inputAndArgs']; - return backend.conv2dDerFilter( - config.inputs.x, config.inputs.dy, config.args.convInfo) as O; - } else if (kernelName === 'DepthwiseConv2D') { - const config = inputAndArgs as DepthwiseConv2DNode['inputAndArgs']; - return backend.depthwiseConv2D( - config.inputs.x, config.inputs.filter, config.args.convInfo) as - O; - } else if (kernelName === 'MaxPool') { - const config = inputAndArgs as PoolNode['inputAndArgs']; - return backend.maxPool(config.inputs.x, config.args.convInfo) as O; - } else if (kernelName === 'MaxPoolBackprop') { - const config = inputAndArgs as PoolBackpropNode['inputAndArgs']; - return backend.maxPoolBackprop( - config.inputs.dy, config.inputs.x, config.args.convInfo) as O; - } else if (kernelName === 'AvgPool') { - const config = inputAndArgs as PoolNode['inputAndArgs']; - return backend.avgPool(config.inputs.x, config.args.convInfo) as O; - } else if (kernelName === 'AvgPoolBackprop') { - const config = inputAndArgs as PoolBackpropNode['inputAndArgs']; - return backend.avgPoolBackprop( - config.inputs.dy, config.inputs.x, config.args.convInfo) as O; - } else if (kernelName === 'MinPool') { - const config = inputAndArgs as PoolNode['inputAndArgs']; - return backend.minPool(config.inputs.x, config.args.convInfo) as O; - } else if (kernelName === 'ResizeBilinear') { - const config = inputAndArgs as ResizeBilinearNode['inputAndArgs']; - return backend.resizeBilinear( - config.inputs.x, config.args.newHeight, config.args.newWidth, - config.args.alignCorners) as O; - } else if (kernelName === 'BatchNorm4D') { - const config = inputAndArgs as BatchNorm4DNode['inputAndArgs']; - return backend.batchNormalization4D( - config.inputs.x, config.inputs.mean, config.inputs.variance, - config.args.varianceEpsilon, config.inputs.scale, - config.inputs.offset) as O; - } else if (kernelName === 'LRN4D') { - const config = inputAndArgs as LRN4DNode['inputAndArgs']; - return backend.localResponseNormalization4D( - config.inputs.x, config.args.radius, config.args.bias, - config.args.alpha, config.args.beta, config.args.normRegion) as - O; - } else if (kernelName === 'Multinomial') { - const config = inputAndArgs as MultinomialNode['inputAndArgs']; - return backend.multinomial( - config.inputs.probs, config.args.numSamples, config.args.seed) as - O; - } else if (kernelName === 'OneHot') { - const config = inputAndArgs as OneHotNode['inputAndArgs']; - return backend.oneHot( - config.inputs.indices, config.args.depth, config.args.onValue, - config.args.offValue) as O; } throw new Error(`No backend method found for kernel ${kernelName}`); } export interface KernelConfigRegistry { - Slice1D: Slice1DNode; - Slice2D: Slice2DNode; - Slice3D: Slice3DNode; - Slice4D: Slice4DNode; Reverse4D: Reverse4DNode; Concat: ConcatNode; - Add: BinaryNode; - Sub: BinaryNode; - Mul: BinaryNode; - Div: BinaryNode; Sum: SumNode; - ArgMax: ArgMaxNode; - ArgMin: ArgMinNode; - Equal: EqualNode; - NotEqual: EqualNode; - Less: EqualNode; - LessEqual: EqualNode; - Greater: EqualNode; - GreaterEqual: EqualNode; - LogicalAnd: LogicalNode; - LogicalOr: LogicalNode; - LogicalXor: LogicalNode; - Where: WhereNode; TopKValues: TopKValuesNode; TopKIndices: TopKIndicesNode; - Min: MinNode; - Minimum: MinimumNode; - Max: MaxNode; - Maximum: MaximumNode; - Pad1D: Pad1DNode; - Pad2D: Pad2DNode; Gather: GatherNode; - Conv2D: Conv2DNode; - Conv2DDerInput: Conv2DDerInputNode; - Conv2DDerFilter: Conv2DDerFilterNode; - DepthwiseConv2D: Conv2DNode; - MaxPool: PoolNode; - MaxPoolBackprop: PoolBackpropNode; - AvgPool: PoolNode; - AvgPoolBackprop: PoolBackpropNode; - MinPool: PoolNode; - ResizeBilinear: ResizeBilinearNode; - BatchNorm4D: BatchNorm4DNode; - LRN4D: LRN4DNode; - Multinomial: MultinomialNode; - OneHot: OneHotNode; } -export type Kernel = keyof KernelConfigRegistry; diff --git a/src/ops/array_ops.ts b/src/ops/array_ops.ts index c9105c705d..fd685779d9 100644 --- a/src/ops/array_ops.ts +++ b/src/ops/array_ops.ts @@ -23,7 +23,6 @@ import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer} fr // tslint:disable-next-line:max-line-length import {ArrayData, DataType, DataTypeMap, Rank, ShapeMap, TensorLike, TensorLike1D, TensorLike2D, TensorLike3D, TensorLike4D, TypedArray} from '../types'; import * as util from '../util'; - import {Concat} from './concat'; import {operation} from './operation'; import {MPRandGauss} from './rand'; diff --git a/src/ops/batchnorm.ts b/src/ops/batchnorm.ts index 2847fb605d..3fdd008cae 100644 --- a/src/ops/batchnorm.ts +++ b/src/ops/batchnorm.ts @@ -193,18 +193,12 @@ export class Ops { x4D = x as Tensor4D; } - return ENV.engine - .executeKernel('BatchNorm4D', { - inputs: { - x: x4D, - mean: batchnormReshape4D(mean), - variance: batchnormReshape4D(variance), - scale: batchnormReshape4D(scale), - offset: batchnormReshape4D(offset) - }, - args: {varianceEpsilon} - }) - .reshape(x.shape) as Tensor; + const res = ENV.engine.runKernel( + backend => backend.batchNormalization4D( + x4D, batchnormReshape4D(mean), batchnormReshape4D(variance), + varianceEpsilon, batchnormReshape4D(scale), + batchnormReshape4D(offset))); + return res.reshape(x.shape); } } diff --git a/src/ops/binary_ops.ts b/src/ops/binary_ops.ts index d0975011d0..c580e0fba6 100644 --- a/src/ops/binary_ops.ts +++ b/src/ops/binary_ops.ts @@ -55,7 +55,7 @@ export class Ops { const outShape = broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - const der = (dy: Tensor, y: Tensor) => { + const der = (dy: Tensor) => { const derA = () => { let res = dy; const reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); @@ -74,7 +74,7 @@ export class Ops { }; return {a: derA, b: derB}; }; - return ENV.engine.executeKernel('Add', {inputs: {a, b}}, der) as T; + return ENV.engine.runKernel(backend => backend.add(a, b), {a, b}, der) as T; } /** @@ -121,7 +121,7 @@ export class Ops { const outShape = broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - const der = (dy: Tensor, y: Tensor) => { + const der = (dy: Tensor) => { const derA = () => { let res = dy; const reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); @@ -140,7 +140,8 @@ export class Ops { }; return {a: derA, b: derB}; }; - return ENV.engine.executeKernel('Sub', {inputs: {a, b}}, der) as T; + return ENV.engine.runKernel( + backend => backend.subtract(a, b), {a, b}, der) as T; } /** @@ -253,7 +254,7 @@ export class Ops { const outShape = broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - const der = (dy: Tensor, y: Tensor) => { + const der = (dy: Tensor) => { const derA = () => { const res = dy.mul(b.toFloat()); const reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); @@ -272,7 +273,8 @@ export class Ops { }; return {a: derA, b: derB}; }; - return ENV.engine.executeKernel('Mul', {inputs: {a, b}}, der) as T; + return ENV.engine.runKernel( + backend => backend.multiply(a, b), {a, b}, der) as T; } /** @@ -317,7 +319,7 @@ export class Ops { static div(a: Tensor, b: Tensor): T { const outShape = broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - const der = (dy: Tensor, y: Tensor) => { + const der = (dy: Tensor) => { const derA = () => { const res = dy.div(b.toFloat()); const reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); @@ -337,7 +339,8 @@ export class Ops { }; return {a: derA, b: derB}; }; - return ENV.engine.executeKernel('Div', {inputs: {a, b}}, der) as T; + return ENV.engine.runKernel(backend => backend.divide(a, b), {a, b}, der) as + T; } /** @@ -382,12 +385,13 @@ export class Ops { static minimum(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - const der = (dy: Tensor, y: Tensor) => { + const der = (dy: Tensor) => { const derA = () => dy.mul(a.lessEqual(b).toFloat()); const derB = () => dy.mul(a.greater(b).toFloat()); return {a: derA, b: derB}; }; - return ENV.engine.executeKernel('Minimum', {inputs: {a, b}}, der) as T; + return ENV.engine.runKernel( + backend => backend.minimum(a, b), {a, b}, der) as T; } /** @@ -432,12 +436,13 @@ export class Ops { static maximum(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - const der = (dy: Tensor, y: Tensor) => { + const der = (dy: Tensor) => { const derA = () => dy.mul(a.greaterEqual(b).toFloat()); const derB = () => dy.mul(a.less(b).toFloat()); return {a: derA, b: derB}; }; - return ENV.engine.executeKernel('Maximum', {inputs: {a, b}}, der) as T; + return ENV.engine.runKernel( + backend => backend.maximum(a, b), {a, b}, der) as T; } /** diff --git a/src/ops/compare.ts b/src/ops/compare.ts index 295cd101a9..7c78f70bfe 100644 --- a/src/ops/compare.ts +++ b/src/ops/compare.ts @@ -37,7 +37,7 @@ export class Ops { static notEqual(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('NotEqual', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.notEqual(a, b)) as T; } /** @@ -68,7 +68,7 @@ export class Ops { static less(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('Less', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.less(a, b)) as T; } /** @@ -99,7 +99,7 @@ export class Ops { static equal(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('Equal', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.equal(a, b)) as T; } @operation @@ -122,7 +122,7 @@ export class Ops { static lessEqual(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('LessEqual', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.lessEqual(a, b)) as T; } @operation @@ -145,7 +145,7 @@ export class Ops { static greater(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('Greater', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.greater(a, b)) as T; } @operation @@ -168,7 +168,7 @@ export class Ops { static greaterEqual(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('GreaterEqual', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.greaterEqual(a, b)) as T; } @operation diff --git a/src/ops/conv.ts b/src/ops/conv.ts index 72ce543af9..36698adddb 100644 --- a/src/ops/conv.ts +++ b/src/ops/conv.ts @@ -143,15 +143,16 @@ export class Ops { const convInfo = conv_util.computeConv2DInfo( x4D.shape, filter.shape, strides, pad, dimRoundingMode); - const gradients = (dy: Tensor4D, y: Tensor4D) => { + const grad = (dy: Tensor4D) => { return { x: () => Ops.conv2dDerInput(x4D.shape, dy, filter, strides, pad), filter: () => Ops.conv2dDerFilter(x4D, dy, filter.shape, strides, pad) }; }; - const res = ENV.engine.executeKernel( - 'Conv2D', {inputs: {x: x4D, filter}, args: {convInfo}}, gradients); + const res = ENV.engine.runKernel( + backend => backend.conv2d(x4D, filter, convInfo), {x: x4D, filter}, + grad); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } @@ -229,8 +230,8 @@ export class Ops { const convInfo = conv_util.computeConv2DInfo( xShape4D, filter.shape, strides, pad, dimRoundingMode); - const res = ENV.engine.executeKernel( - 'Conv2DDerInput', {inputs: {dy: dy4D, filter}, args: {convInfo}}); + const res = ENV.engine.runKernel( + backend => backend.conv2dDerInput(dy4D, filter, convInfo)); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } @@ -297,8 +298,8 @@ export class Ops { const convInfo = conv_util.computeConv2DInfo( x4D.shape, filterShape, strides, pad, dimRoundingMode); - return ENV.engine.executeKernel( - 'Conv2DDerFilter', {inputs: {x: x4D, dy: dy4D}, args: {convInfo}}); + return ENV.engine.runKernel( + backend => backend.conv2dDerFilter(x4D, dy4D, convInfo)); } /** @@ -412,8 +413,8 @@ export class Ops { const convInfo = conv_util.computeConv2DInfo( input4D.shape, filter.shape, strides, pad, dimRoundingMode, true /* depthwise */); - const res = ENV.engine.executeKernel( - 'DepthwiseConv2D', {inputs: {x: input4D, filter}, args: {convInfo}}); + const res = ENV.engine.runKernel( + backend => backend.depthwiseConv2D(input4D, filter, convInfo)); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } diff --git a/src/ops/image_ops.ts b/src/ops/image_ops.ts index 625bd23688..70da78b434 100644 --- a/src/ops/image_ops.ts +++ b/src/ops/image_ops.ts @@ -54,9 +54,9 @@ export class Ops { images.as4D(1, images.shape[0], images.shape[1], images.shape[2]); } const [newHeight, newWidth] = size; - const res = ENV.engine.executeKernel( - 'ResizeBilinear', - {inputs: {x: batchImages}, args: {newHeight, newWidth, alignCorners}}); + const res = ENV.engine.runKernel( + backend => backend.resizeBilinear( + batchImages, newHeight, newWidth, alignCorners)); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } diff --git a/src/ops/logical_ops.ts b/src/ops/logical_ops.ts index af45082f03..61522d7df9 100644 --- a/src/ops/logical_ops.ts +++ b/src/ops/logical_ops.ts @@ -19,7 +19,6 @@ import {doc} from '../doc'; import {ENV} from '../environment'; import {Tensor} from '../tensor'; import * as types from '../types'; -import {DataType} from '../types'; import * as util from '../util'; import * as broadcast_util from './broadcast_util'; import {operation} from './operation'; @@ -50,7 +49,7 @@ export class Ops { a.dtype === 'bool' && b.dtype === 'bool', 'Error Array must be of type bool.'); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('LogicalAnd', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.logicalAnd(a, b)) as T; } /** @@ -66,7 +65,7 @@ export class Ops { a.dtype === 'bool' && b.dtype === 'bool', 'Error Array must be of type bool.'); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('LogicalOr', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.logicalOr(a, b)) as T; } /** @@ -82,7 +81,7 @@ export class Ops { a.dtype === 'bool' && b.dtype === 'bool', 'Error Array must be of type bool.'); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('LogicalXor', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.logicalXor(a, b)) as T; } /** @@ -117,9 +116,7 @@ export class Ops { // Default to highest percision of number: const dtype = types.upcastType(a.dtype, b.dtype); - return ENV.engine.executeKernel( - 'Where', - {inputs: {condition, a, b}, args: {dtype: dtype as DataType}}) as - T; + return ENV.engine.runKernel( + backend => backend.where(condition, a, b, dtype)) as T; } } diff --git a/src/ops/lrn.ts b/src/ops/lrn.ts index a8c7330974..461474a327 100644 --- a/src/ops/lrn.ts +++ b/src/ops/lrn.ts @@ -56,9 +56,9 @@ export class LRN { reshapedTo4D = true; x4D = x.as4D(1, x.shape[0], x.shape[1], x.shape[2]); } - const res = ENV.engine.executeKernel( - 'LRN4D', - {inputs: {x: x4D}, args: {radius, bias, alpha, beta, normRegion}}); + const res = ENV.engine.runKernel( + backend => backend.localResponseNormalization4D( + x4D, radius, bias, alpha, beta, normRegion)); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } else { diff --git a/src/ops/pool.ts b/src/ops/pool.ts index 305b320f73..8b70f892c5 100644 --- a/src/ops/pool.ts +++ b/src/ops/pool.ts @@ -66,11 +66,11 @@ export class Ops { const convInfo = conv_util.computePool2DInfo( x4D.shape, filterSize, strides, pad, dimRoundingMode); - const gradients = (dy: Tensor4D, y: Tensor4D) => { + const grad = (dy: Tensor4D) => { return {x: () => Ops.maxPoolBackprop(dy, x4D, filterSize, strides, pad)}; }; - const res = ENV.engine.executeKernel( - 'MaxPool', {inputs: {x: x4D}, args: {convInfo}}, gradients); + const res = ENV.engine.runKernel( + backend => backend.maxPool(x4D, convInfo), {x: x4D}, grad); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } @@ -130,8 +130,8 @@ export class Ops { const convInfo = conv_util.computePool2DInfo( input4D.shape, filterSize, strides, pad, dimRoundingMode); - const res = ENV.engine.executeKernel( - 'MaxPoolBackprop', {inputs: {dy: dy4D, x: input4D}, args: {convInfo}}); + const res = ENV.engine.runKernel( + backend => backend.maxPoolBackprop(dy4D, input4D, convInfo)); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } @@ -180,8 +180,8 @@ export class Ops { } const convInfo = conv_util.computePool2DInfo( input4D.shape, filterSize, strides, pad, dimRoundingMode); - const res = ENV.engine.executeKernel( - 'MinPool', {inputs: {x: input4D}, args: {convInfo}}); + const res = + ENV.engine.runKernel(backend => backend.minPool(input4D, convInfo)); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } @@ -232,11 +232,11 @@ export class Ops { const convInfo = conv_util.computePool2DInfo(x4D.shape, filterSize, strides, pad); - const gradients = (dy: Tensor4D, y: Tensor4D) => { + const grad = (dy: Tensor4D) => { return {x: () => Ops.avgPoolBackprop(dy, x4D, filterSize, strides, pad)}; }; - const res = ENV.engine.executeKernel( - 'AvgPool', {inputs: {x: x4D}, args: {convInfo}}, gradients); + const res = ENV.engine.runKernel( + backend => backend.avgPool(x4D, convInfo), {x: x4D}, grad); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } @@ -285,8 +285,8 @@ export class Ops { const convInfo = conv_util.computePool2DInfo(input4D.shape, filterSize, strides, pad); - const res = ENV.engine.executeKernel( - 'AvgPoolBackprop', {inputs: {dy: dy4D, x: input4D}, args: {convInfo}}); + const res = ENV.engine.runKernel( + backend => backend.avgPoolBackprop(dy4D, input4D, convInfo)); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } diff --git a/src/ops/reduction_ops.ts b/src/ops/reduction_ops.ts index dec6d23e66..f5fd1c1407 100644 --- a/src/ops/reduction_ops.ts +++ b/src/ops/reduction_ops.ts @@ -234,8 +234,7 @@ export class Ops { x = x.transpose(permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, x.rank); } - const res = - ENV.engine.executeKernel('Min', {inputs: {x}, args: {axes}}) as Tensor; + const res = ENV.engine.runKernel(backend => backend.min(x, axes)); if (keepDims) { const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); return res.reshape(newShape) as T; @@ -281,8 +280,7 @@ export class Ops { x = x.transpose(permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, x.rank); } - const res = - ENV.engine.executeKernel('Max', {inputs: {x}, args: {axes}}) as Tensor; + const res = ENV.engine.runKernel(backend => backend.max(x, axes)); if (keepDims) { const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); return res.reshape(newShape) as T; @@ -323,7 +321,7 @@ export class Ops { x = x.transpose(permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, x.rank); } - return ENV.engine.executeKernel('ArgMin', {inputs: {x}, args: {axes}}) as T; + return ENV.engine.runKernel(backend => backend.argMin(x, axes)) as T; } /** @@ -359,7 +357,7 @@ export class Ops { axes = axis_util.getInnerMostAxes(axes.length, x.rank); } - return ENV.engine.executeKernel('ArgMax', {inputs: {x}, args: {axes}}) as T; + return ENV.engine.runKernel(backend => backend.argMax(x, axes)) as T; } /** diff --git a/src/ops/slice.ts b/src/ops/slice.ts index 16be1a4746..e3a7d76eb5 100644 --- a/src/ops/slice.ts +++ b/src/ops/slice.ts @@ -34,8 +34,7 @@ export class Ops { @operation static slice1d(x: Tensor1D, begin: number, size: number): Tensor1D { slice_util.assertParamsValid(x, [begin], [size]); - return ENV.engine.executeKernel( - 'Slice1D', {inputs: {x}, args: {begin, size}}) as Tensor1D; + return ENV.engine.runKernel(backend => backend.slice1D(x, begin, size)); } /** @@ -50,8 +49,7 @@ export class Ops { static slice2d(x: Tensor2D, begin: [number, number], size: [number, number]): Tensor2D { slice_util.assertParamsValid(x, begin, size); - return ENV.engine.executeKernel( - 'Slice2D', {inputs: {x}, args: {begin, size}}) as Tensor2D; + return ENV.engine.runKernel(backend => backend.slice2D(x, begin, size)); } /** @@ -67,8 +65,7 @@ export class Ops { number, number, number ]): Tensor3D { slice_util.assertParamsValid(x, begin, size); - return ENV.engine.executeKernel( - 'Slice3D', {inputs: {x}, args: {begin, size}}) as Tensor3D; + return ENV.engine.runKernel(backend => backend.slice3D(x, begin, size)); } /** @@ -85,8 +82,7 @@ export class Ops { number, number, number, number ]): Tensor4D { slice_util.assertParamsValid(x, begin, size); - return ENV.engine.executeKernel( - 'Slice4D', {inputs: {x}, args: {begin, size}}) as Tensor4D; + return ENV.engine.runKernel(backend => backend.slice4D(x, begin, size)); } /** diff --git a/src/ops/unary_ops.ts b/src/ops/unary_ops.ts index 981d97d140..1f5ad82762 100644 --- a/src/ops/unary_ops.ts +++ b/src/ops/unary_ops.ts @@ -16,11 +16,9 @@ */ import {doc} from '../doc'; -import {ForwardFunc} from '../engine'; import {ENV} from '../environment'; import {Tensor} from '../tensor'; import * as util from '../util'; - import {operation} from './operation'; import * as ops from './ops'; import {zerosLike} from './ops'; @@ -101,15 +99,12 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static exp(x: T): T { - const forw: ForwardFunc = (backend, save) => { - const y = backend.exp(x); - save({y}); - return y; - }; - const bck = (dy: T, saved: {y: T}) => { - return {x: () => dy.mulStrict(saved.y)}; + const bck = (dy: T, saved: Tensor[]) => { + const [y] = saved; + return {x: () => dy.mulStrict(y as T)}; }; - return ENV.engine.runKernel(forw, {x}, bck); + return ENV.engine.runKernel( + (backend, save) => save(backend.exp(x)), {x}, bck); } /** @@ -355,16 +350,12 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static sigmoid(x: T): T { - const forw: ForwardFunc = (backend, save) => { - const y = backend.sigmoid(x); - save({y}); - return y; - }; - const grad = (dy: T, saved: {y: T}) => { - const {y} = saved; + const grad = (dy: T, saved: Tensor[]) => { + const [y] = saved; return {x: () => dy.mulStrict(y.mul(ops.scalar(1).sub(y)))}; }; - return ENV.engine.runKernel(forw, {x}, grad); + return ENV.engine.runKernel( + (backend, save) => save(backend.sigmoid(x)), {x}, grad); } /** @@ -537,16 +528,12 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static tanh(x: T): T { - const forw: ForwardFunc = (backend, save) => { - const y = backend.tanh(x); - save({y}); - return y; - }; - const grad = (dy: T, saved: {y: T}) => { - const {y} = saved; + const grad = (dy: T, saved: Tensor[]) => { + const [y] = saved; return {x: () => ops.scalar(1).sub(y.square()).mulStrict(dy) as T}; }; - return ENV.engine.runKernel(forw, {x}, grad); + return ENV.engine.runKernel( + (backend, save) => save(backend.tanh(x)), {x}, grad); } /** diff --git a/src/profiler_test.ts b/src/profiler_test.ts index 50fa077e69..9092cc0f93 100644 --- a/src/profiler_test.ts +++ b/src/profiler_test.ts @@ -17,7 +17,6 @@ import * as dl from './index'; import {BackendTimer, BackendTimingInfo} from './kernels/backend'; -import {Kernel} from './kernels/kernel_registry'; import {TypedArray} from './kernels/webgl/tex_util'; import {Logger, Profiler} from './profiler'; import {Tensor} from './tensor'; @@ -37,7 +36,7 @@ class TestBackendTimer implements BackendTimer { class TestLogger extends Logger { logKernelProfile( - kernelName: Kernel, result: Tensor, vals: TypedArray, timeMs: number) {} + kernelName: string, result: Tensor, vals: TypedArray, timeMs: number) {} } describe('profiler.Profiler', () => { diff --git a/src/tape_types.ts b/src/tape_types.ts index 062f3b6376..506ce21cff 100644 --- a/src/tape_types.ts +++ b/src/tape_types.ts @@ -15,10 +15,8 @@ * ============================================================================= */ -import {NamedTensorMap} from './types'; import {Tensor} from './tensor'; -import {Rank} from './types'; -import {KernelConfigRegistry} from './kernels/kernel_registry'; +import {NamedTensorMap} from './types'; export type Tape = Array>; export type TapeNodeOutput = Tensor|NamedTensorMap; @@ -42,7 +40,7 @@ export type TapeNodeInputGradientTensors = { // Kernel nodes export interface KernelNode extends TapeNode { - kernel: keyof KernelConfigRegistry; + kernel: string; inputAndArgs: KernelInputConfig; output: Tensor; } From 73f1cf60e294cc21d4efd194609beaca985b1b4f Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sat, 17 Feb 2018 21:24:10 -0500 Subject: [PATCH 08/14] save --- src/engine.ts | 39 ----------- src/index.ts | 2 +- src/kernels/kernel_registry.ts | 61 ----------------- src/kernels/types/cast.ts | 28 -------- src/kernels/types/concat.ts | 28 -------- src/kernels/types/gather.ts | 30 --------- src/kernels/types/matmul.ts | 36 ---------- src/kernels/types/pow.ts | 30 --------- src/kernels/types/prelu.ts | 31 --------- src/kernels/types/reshape.ts | 27 -------- src/kernels/types/reverse.ts | 27 -------- src/kernels/types/sum.ts | 27 -------- src/kernels/types/topk.ts | 39 ----------- src/kernels/types/unary.ts | 73 --------------------- src/kernels/webgl/mulmat_packed_gpu.ts | 2 +- src/kernels/webgl/mulmat_packed_gpu_test.ts | 2 +- src/math.ts | 5 +- src/ops/concat.ts | 4 +- src/ops/matmul.ts | 7 +- src/ops/reduction_ops.ts | 4 +- src/ops/reverse.ts | 4 +- 21 files changed, 17 insertions(+), 489 deletions(-) delete mode 100644 src/kernels/kernel_registry.ts delete mode 100644 src/kernels/types/cast.ts delete mode 100644 src/kernels/types/concat.ts delete mode 100644 src/kernels/types/gather.ts delete mode 100644 src/kernels/types/matmul.ts delete mode 100644 src/kernels/types/pow.ts delete mode 100644 src/kernels/types/prelu.ts delete mode 100644 src/kernels/types/reshape.ts delete mode 100644 src/kernels/types/reverse.ts delete mode 100644 src/kernels/types/sum.ts delete mode 100644 src/kernels/types/topk.ts delete mode 100644 src/kernels/types/unary.ts diff --git a/src/engine.ts b/src/engine.ts index d6a61d1a91..ac3e3b6882 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -18,8 +18,6 @@ import {ENV} from './environment'; import {tidy} from './globals'; import {BackendTimingInfo, KernelBackend} from './kernels/backend'; -import * as kernel_registry from './kernels/kernel_registry'; -import {KernelConfigRegistry} from './kernels/kernel_registry'; import * as ops from './ops/ops'; import {Profiler} from './profiler'; // tslint:disable-next-line:max-line-length @@ -28,7 +26,6 @@ import * as tape_util from './tape_util'; import {ScopeResultImmediate} from './tape_util'; import {DataId, Tensor, Tensor3D, Variable} from './tensor'; import {NamedTensorMap, NamedVariableMap, TypedArray} from './types'; -import {Rank} from './types'; import * as util from './util'; interface ScopeState { @@ -128,42 +125,6 @@ export class Engine implements TensorManager { return result; } - executeKernel, C - extends KernelConfigRegistry[K]['inputAndArgs']>( - kernelName: K, config: C, grad?: KernelConfigRegistry[K]['gradient']): - KernelConfigRegistry[K]['output'] { - let result: KernelConfigRegistry[K]['output']; - if (!ENV.get('DEBUG')) { - // NOTE: This isn't pulled out into a separate function to so that we - // keep a shallow stack trace. - result = kernel_registry.executeKernel(this.backend, kernelName, config); - } else { - result = this.profiler.profileKernel( - kernelName, - () => - kernel_registry.executeKernel(this.backend, kernelName, config)); - } - - const recordKernel = - this.activeTape != null && this.customGradientDepth === 0; - if (recordKernel) { - config = tape_util.stripUndefinedInputsFromInputConfig(config) as C; - - const evaluatedNode: KernelNode = { - id: this.nextTapeNodeId++, - type: 'kernel', - name: `kernel: ${kernelName}`, - kernel: kernelName, - inputAndArgs: config, - output: result, - gradient: grad - }; - this.activeTape.push(evaluatedNode); - } - - return result; - } - // TensorManager implementation. registerTensor(a: Tensor|Variable): void { diff --git a/src/index.ts b/src/index.ts index 4144ea1c2f..e77c4f6e63 100644 --- a/src/index.ts +++ b/src/index.ts @@ -42,11 +42,11 @@ export {CostReduction, FeedEntry, Session} from './graph/session'; export {MathBackendCPU, NDArrayMathCPU} from './kernels/backend_cpu'; // tslint:disable-next-line:max-line-length export {MathBackendWebGL, NDArrayMathGPU, WebGLTimingInfo} from './kernels/backend_webgl'; -export {MatrixOrientation} from './kernels/types/matmul'; export {GPGPUContext} from './kernels/webgl/gpgpu_context'; export {NDArrayMath} from './math'; export {Model} from './model'; export {LSTMCell} from './ops/lstm'; +export {MatrixOrientation} from './ops/matmul'; export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer'; export {AdagradOptimizer} from './optimizers/adagrad_optimizer'; export {AdamOptimizer} from './optimizers/adam_optimizer'; diff --git a/src/kernels/kernel_registry.ts b/src/kernels/kernel_registry.ts deleted file mode 100644 index ac51989304..0000000000 --- a/src/kernels/kernel_registry.ts +++ /dev/null @@ -1,61 +0,0 @@ -/** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {Rank} from '../types'; -import {KernelBackend} from './backend'; -import {ConcatNode} from './types/concat'; -import {GatherNode} from './types/gather'; -import {Reverse4DNode} from './types/reverse'; -import {SumNode} from './types/sum'; -import {TopKIndicesNode, TopKValuesNode} from './types/topk'; - -export function -executeKernel, O extends - KernelConfigRegistry[K]['output']>( - backend: KernelBackend, kernelName: K, - inputAndArgs: KernelConfigRegistry[K]['inputAndArgs']): O { - if (kernelName === 'Reverse4D') { - const config = inputAndArgs as Reverse4DNode['inputAndArgs']; - return backend.reverse4D(config.inputs.x, config.args.axis) as O; - } else if (kernelName === 'Concat') { - const config = inputAndArgs as ConcatNode['inputAndArgs']; - return backend.concat(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Sum') { - const config = inputAndArgs as SumNode['inputAndArgs']; - return backend.sum(config.inputs.x, config.args.axes) as O; - } else if (kernelName === 'TopKValues') { - const config = inputAndArgs as TopKValuesNode['inputAndArgs']; - return backend.topKValues(config.inputs.x, config.args.k) as O; - } else if (kernelName === 'TopKIndices') { - const config = inputAndArgs as TopKIndicesNode['inputAndArgs']; - return backend.topKIndices(config.inputs.x, config.args.k) as O; - } else if (kernelName === 'Gather') { - const config = inputAndArgs as GatherNode['inputAndArgs']; - return backend.gather( - config.inputs.x, config.inputs.indices, config.args.axis) as O; - } - throw new Error(`No backend method found for kernel ${kernelName}`); -} - -export interface KernelConfigRegistry { - Reverse4D: Reverse4DNode; - Concat: ConcatNode; - Sum: SumNode; - TopKValues: TopKValuesNode; - TopKIndices: TopKIndicesNode; - Gather: GatherNode; -} diff --git a/src/kernels/types/cast.ts b/src/kernels/types/cast.ts deleted file mode 100644 index 4aad261456..0000000000 --- a/src/kernels/types/cast.ts +++ /dev/null @@ -1,28 +0,0 @@ -/** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; -import {DataType} from '../../types'; - -export interface CastNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor}; args: {newDType: DataType};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor - }; -} diff --git a/src/kernels/types/concat.ts b/src/kernels/types/concat.ts deleted file mode 100644 index 8b29ed3a62..0000000000 --- a/src/kernels/types/concat.ts +++ /dev/null @@ -1,28 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor2D} from '../../tensor'; - -export interface ConcatNode extends KernelNode { - inputAndArgs: {inputs: {a: Tensor2D; b: Tensor2D;};}; - output: Tensor2D; - gradient: (dy: Tensor2D, y: Tensor2D) => { - a: () => Tensor2D; - b: () => Tensor2D; - }; -} diff --git a/src/kernels/types/gather.ts b/src/kernels/types/gather.ts deleted file mode 100644 index bd5dc82fd3..0000000000 --- a/src/kernels/types/gather.ts +++ /dev/null @@ -1,30 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor, Tensor1D} from '../../tensor'; -import {Rank} from '../../types'; - -export interface GatherNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T; indices: Tensor1D;}; args: {axis: number};}; - output: T; - gradient: (dy: Tensor, y: T) => { - x: () => Tensor; - indices: () => Tensor1D; - }; -} diff --git a/src/kernels/types/matmul.ts b/src/kernels/types/matmul.ts deleted file mode 100644 index 833ae3a69d..0000000000 --- a/src/kernels/types/matmul.ts +++ /dev/null @@ -1,36 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor2D} from '../../tensor'; - -export interface MatMulNode extends KernelNode { - inputAndArgs: { - inputs: {a: Tensor2D; b: Tensor2D;}; - args: {transposeA: boolean; transposeB: boolean}; - }; - output: Tensor2D; - gradient: (dy: Tensor2D, y: Tensor2D) => { - a: () => Tensor2D; - b: () => Tensor2D; - }; -} - -export enum MatrixOrientation { - REGULAR, - TRANSPOSED -} diff --git a/src/kernels/types/pow.ts b/src/kernels/types/pow.ts deleted file mode 100644 index ed70227cf4..0000000000 --- a/src/kernels/types/pow.ts +++ /dev/null @@ -1,30 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; -import {Rank} from '../../types'; - -export interface PowNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {base: T; exp: Tensor;};}; - output: T; - gradient: (dy: Tensor, y: T) => { - base: () => Tensor; - exp: () => Tensor; - }; -} diff --git a/src/kernels/types/prelu.ts b/src/kernels/types/prelu.ts deleted file mode 100644 index f279f7e85c..0000000000 --- a/src/kernels/types/prelu.ts +++ /dev/null @@ -1,31 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; -import {Rank} from '../../types'; - -// PReLU -export interface PReLUNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T; alpha: T;};}; - output: T; - gradient: (dy: Tensor, y: T) => { - x: () => Tensor; - alpha: () => Tensor; - }; -} diff --git a/src/kernels/types/reshape.ts b/src/kernels/types/reshape.ts deleted file mode 100644 index 374c1f5470..0000000000 --- a/src/kernels/types/reshape.ts +++ /dev/null @@ -1,27 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; - -export interface ReshapeNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor}; args: {newShape: number[]};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor - }; -} diff --git a/src/kernels/types/reverse.ts b/src/kernels/types/reverse.ts deleted file mode 100644 index 05ebf6231e..0000000000 --- a/src/kernels/types/reverse.ts +++ /dev/null @@ -1,27 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor4D} from '../../tensor'; - -export interface Reverse4DNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor4D;}; args: {axis: number[];};}; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - }; -} diff --git a/src/kernels/types/sum.ts b/src/kernels/types/sum.ts deleted file mode 100644 index b5b9e557f0..0000000000 --- a/src/kernels/types/sum.ts +++ /dev/null @@ -1,27 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; - -export interface SumNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor;}; args: {axes: number[];};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor; - }; -} diff --git a/src/kernels/types/topk.ts b/src/kernels/types/topk.ts deleted file mode 100644 index 95acc7f81e..0000000000 --- a/src/kernels/types/topk.ts +++ /dev/null @@ -1,39 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor, Tensor1D} from '../../tensor'; -import {Rank} from '../../types'; - -// Values -export interface TopKValuesNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {k: number};}; - output: Tensor1D; - gradient: (dy: Tensor1D, y: Tensor1D) => { - x: () => T; - }; -} - -// Indices -export interface TopKIndicesNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor;}; args: {k: number};}; - output: Tensor1D; - gradient: (dy: Tensor1D, y: Tensor1D) => { - x: () => Tensor; - }; -} diff --git a/src/kernels/types/unary.ts b/src/kernels/types/unary.ts deleted file mode 100644 index 0b5e675c0e..0000000000 --- a/src/kernels/types/unary.ts +++ /dev/null @@ -1,73 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; -import {Rank} from '../../types'; - -export interface UnaryNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} - -export interface LeakyReluNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {alpha: number;};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} -export interface StepNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {alpha: number;};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} - -export interface ClipNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {min: number; max: number;};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} - -export interface TransposeNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {perm: number[];};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} - -export interface TileNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {reps: number[];};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} diff --git a/src/kernels/webgl/mulmat_packed_gpu.ts b/src/kernels/webgl/mulmat_packed_gpu.ts index efb35cf711..d114b12441 100644 --- a/src/kernels/webgl/mulmat_packed_gpu.ts +++ b/src/kernels/webgl/mulmat_packed_gpu.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {MatrixOrientation} from '../types/matmul'; +import {MatrixOrientation} from '../../ops/matmul'; import {GPGPUContext} from './gpgpu_context'; import * as webgl_util from './webgl_util'; diff --git a/src/kernels/webgl/mulmat_packed_gpu_test.ts b/src/kernels/webgl/mulmat_packed_gpu_test.ts index f71ff2231c..a090de1a45 100644 --- a/src/kernels/webgl/mulmat_packed_gpu_test.ts +++ b/src/kernels/webgl/mulmat_packed_gpu_test.ts @@ -16,8 +16,8 @@ */ // tslint:disable-next-line:max-line-length +import {MatrixOrientation} from '../../ops/matmul'; import {expectArraysClose, expectNumbersClose} from '../../test_util'; -import {MatrixOrientation} from '../types/matmul'; import {GPGPUContext} from './gpgpu_context'; import * as mulmat_packed_gpu from './mulmat_packed_gpu'; diff --git a/src/math.ts b/src/math.ts index 1df032069b..017a886740 100644 --- a/src/math.ts +++ b/src/math.ts @@ -227,9 +227,8 @@ export class NDArrayMath { let values: Tensor1D; let indices: Tensor1D; tidy('topK', () => { - values = ENV.engine.executeKernel('TopKValues', {inputs: {x}, args: {k}}); - indices = - ENV.engine.executeKernel('TopKIndices', {inputs: {x}, args: {k}}); + values = ENV.engine.runKernel(backend => backend.topKValues(x, k)); + indices = ENV.engine.runKernel(backend => backend.topKIndices(x, k)); return values; }); const result = {values, indices}; diff --git a/src/ops/concat.ts b/src/ops/concat.ts index 1e5427eb00..b856e1f2be 100644 --- a/src/ops/concat.ts +++ b/src/ops/concat.ts @@ -168,7 +168,7 @@ function concat2Tensors(a: T, b: T, axis: number): T { const der = (dy: Tensor2D) => { return {a: () => dy.slice(aBegin, aSize), b: () => dy.slice(bBegin, bSize)}; }; - const res = - ENV.engine.executeKernel('Concat', {inputs: {a: a2D, b: b2D}}, der); + const res = ENV.engine.runKernel( + backend => backend.concat(a2D, b2D), {a: a2D, b: b2D}, der); return res.reshape(outShape) as T; } diff --git a/src/ops/matmul.ts b/src/ops/matmul.ts index dc6a0b156b..992c4510af 100644 --- a/src/ops/matmul.ts +++ b/src/ops/matmul.ts @@ -17,11 +17,16 @@ import {doc} from '../doc'; import {ENV} from '../environment'; -import {MatrixOrientation} from '../kernels/types/matmul'; import {Scalar, Tensor1D, Tensor2D} from '../tensor'; import * as util from '../util'; import {operation} from './operation'; +/** @deprecated Use bools transposeA and transposeB when calling matmul() */ +export enum MatrixOrientation { + REGULAR, + TRANSPOSED +} + export class Ops { /** * Computes the dot product of two matrices, A * B. These must be matrices. diff --git a/src/ops/reduction_ops.ts b/src/ops/reduction_ops.ts index f5fd1c1407..26cbce2e46 100644 --- a/src/ops/reduction_ops.ts +++ b/src/ops/reduction_ops.ts @@ -115,8 +115,8 @@ export class Ops { reductionAxes = axis_util.getInnerMostAxes(reductionAxes.length, x.rank); } - let value = ENV.engine.executeKernel( - 'Sum', {inputs: {x: permutedX}, args: {axes: reductionAxes}}); + let value = ENV.engine.runKernel( + backend => backend.sum(permutedX, reductionAxes)); if (keepDims) { const newShape = axis_util.expandShapeToKeepDim(value.shape, axes); value = value.reshape(newShape); diff --git a/src/ops/reverse.ts b/src/ops/reverse.ts index 8a8cfdbf75..dd29b0cc52 100644 --- a/src/ops/reverse.ts +++ b/src/ops/reverse.ts @@ -109,8 +109,8 @@ export class Ops { } else { throw new Error(`Reverse for rank ${x.rank} is not yet implemented`); } - const res = ENV.engine.executeKernel( - 'Reverse4D', {inputs: {x: x4d}, args: {axis: axisCleaned}}); + const res = + ENV.engine.runKernel(backend => backend.reverse4D(x4d, axisCleaned)); return res.reshapeAs(x); } } From 3e3ac3aa9e9e6084ad2bc1b3f63744d6d013bb2e Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sat, 17 Feb 2018 21:51:11 -0500 Subject: [PATCH 09/14] save --- src/engine.ts | 20 ++- src/tape_types.ts | 52 ------- src/tape_util.ts | 96 ++++--------- src/tape_util_test.ts | 311 ++++++++++-------------------------------- 4 files changed, 107 insertions(+), 372 deletions(-) delete mode 100644 src/tape_types.ts diff --git a/src/engine.ts b/src/engine.ts index ac3e3b6882..1831cb9591 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -20,9 +20,8 @@ import {tidy} from './globals'; import {BackendTimingInfo, KernelBackend} from './kernels/backend'; import * as ops from './ops/ops'; import {Profiler} from './profiler'; -// tslint:disable-next-line:max-line-length -import {KernelNode, Tape, TapeNode, TapeNodeInputGradientTensors} from './tape_types'; import * as tape_util from './tape_util'; +import {NamedGradientMap, TapeNode} from './tape_util'; import {ScopeResultImmediate} from './tape_util'; import {DataId, Tensor, Tensor3D, Variable} from './tensor'; import {NamedTensorMap, NamedVariableMap, TypedArray} from './types'; @@ -70,7 +69,7 @@ export class Engine implements TensorManager { private numTensors = 0; private numDataBuffers = 0; - private activeTape: Tape; + private activeTape: TapeNode[]; private gradientScopeCount = 0; private customGradientDepth = 0; @@ -111,12 +110,10 @@ export class Engine implements TensorManager { const recordKernel = this.activeTape != null && this.customGradientDepth === 0; if (recordKernel) { - const evaluatedNode: KernelNode = { + const evaluatedNode: TapeNode = { id: this.nextTapeNodeId++, - type: 'kernel', name: kernelName, - kernel: kernelName, - inputAndArgs: {inputs}, + inputs, output: result, gradient: (dy: T) => backwardsFunc(dy, saved) }; @@ -192,18 +189,17 @@ export class Engine implements TensorManager { const gradient = (dy: Tensor) => { const res = gradientsFunc(dy); - const resMap: TapeNodeInputGradientTensors = {}; + const resMap: NamedGradientMap = {}; res.forEach((r, idx) => { resMap[idx] = () => r; }); return resMap; }; - const evaluatedNode: TapeNode = { + const evaluatedNode: TapeNode = { id: this.nextTapeNodeId++, - type: 'customGradient', - name, - inputAndArgs: {inputs: inputsMap}, + name: '', // TODO(smilkov): Figure out kernel name. + inputs: inputsMap, output: result, gradient }; diff --git a/src/tape_types.ts b/src/tape_types.ts deleted file mode 100644 index 506ce21cff..0000000000 --- a/src/tape_types.ts +++ /dev/null @@ -1,52 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {Tensor} from './tensor'; -import {NamedTensorMap} from './types'; - -export type Tape = Array>; -export type TapeNodeOutput = Tensor|NamedTensorMap; -export type TapeNodeType = 'kernel'|'customGradient'; - -export interface TapeNode { - id: number; - type: TapeNodeType; - name: string; - inputAndArgs: TapeNodeInputConfig; - - output: T; - gradient: (dy: Tensor|NamedTensorMap, y: T) => TapeNodeInputGradientTensors; -} - -export interface TapeNodeInputConfig { inputs: NamedTensorMap; } - -export type TapeNodeInputGradientTensors = { - [inputName: string]: () => Tensor; -}; - -// Kernel nodes -export interface KernelNode extends TapeNode { - kernel: string; - inputAndArgs: KernelInputConfig; - output: Tensor; -} - -export interface KernelInputConfig extends TapeNodeInputConfig { - inputs: NamedTensorMap; - // tslint:disable-next-line:no-any - args?: {[argName: string]: any}; -} diff --git a/src/tape_util.ts b/src/tape_util.ts index e35df3bc24..d6c130a83c 100644 --- a/src/tape_util.ts +++ b/src/tape_util.ts @@ -15,12 +15,21 @@ * ============================================================================= */ -import * as util from './util'; import {Tensor} from './tensor'; import {NamedTensorMap, RegularArray} from './types'; +import * as util from './util'; + +export interface TapeNode { + id: number; + name: string; + inputs: NamedTensorMap; + output: Tensor; + gradient: (dy: Tensor|NamedTensorMap) => NamedGradientMap; +} -// tslint:disable-next-line:max-line-length -import {Tape, TapeNode, TapeNodeInputConfig, TapeNodeOutput} from './tape_types'; +export type NamedGradientMap = { + [inputName: string]: () => Tensor; +}; /** * Computes a list of TapeNodes that connect x to y, filtering everything else @@ -30,7 +39,7 @@ import {Tape, TapeNode, TapeNodeInputConfig, TapeNodeOutput} from './tape_types' * @param y The output Tensor. */ export function getFilteredNodesXToY( - tape: Tape, xs: Tensor[], y: Tensor): Tape { + tape: TapeNode[], xs: Tensor[], y: Tensor): TapeNode[] { // Forward pass to compute all the nodes and Tensors that are transitively a // function of x. const tensorsFromX: {[tensorId: number]: boolean} = {}; @@ -41,7 +50,7 @@ export function getFilteredNodesXToY( for (let i = 0; i < tape.length; i++) { const node = tape[i]; - const nodeInputs = node.inputAndArgs.inputs; + const nodeInputs = node.inputs; for (const inputName in nodeInputs) { const input = nodeInputs[inputName]; @@ -49,14 +58,7 @@ export function getFilteredNodesXToY( let anyInputFromX = false; for (let j = 0; j < xs.length; j++) { if (tensorsFromX[input.id]) { - if (node.output instanceof Tensor) { - tensorsFromX[node.output.id] = true; - } else { - const keys = Object.keys(node.output); - for (const key of keys) { - tensorsFromX[node.output[key].id] = true; - } - } + tensorsFromX[node.output.id] = true; anyInputFromX = true; nodesFromX[node.id] = true; break; @@ -76,17 +78,10 @@ export function getFilteredNodesXToY( for (let i = tape.length - 1; i >= 0; i--) { const node = tape[i]; - const nodeInputs = node.inputAndArgs.inputs; + const nodeInputs = node.inputs; const outputs: Tensor[] = []; - if (node.output instanceof Tensor) { - outputs.push(node.output); - } else { - const keys = Object.keys(node.output); - for (const key of keys) { - outputs.push(node.output[key]); - } - } + outputs.push(node.output); // If any of the outputs lead to y, mark all of the inputs as leading to y. for (let j = 0; j < outputs.length; j++) { @@ -101,39 +96,28 @@ export function getFilteredNodesXToY( } // Return the paths that come from x and lead to y. - const filteredTape: Tape = []; + const filteredTape: TapeNode[] = []; for (let i = 0; i < tape.length; i++) { const node = tape[i]; if (nodesFromX[node.id] && nodesToY[node.id]) { // Prune the inputs from the node that aren't a function of x. const prunedInputs: {[inputName: string]: Tensor} = {}; - for (const inputName in node.inputAndArgs.inputs) { - const nodeInput = node.inputAndArgs.inputs[inputName]; + for (const inputName in node.inputs) { + const nodeInput = node.inputs[inputName]; if (tensorsFromX[nodeInput.id]) { prunedInputs[inputName] = nodeInput; } } let prunedOutputs: Tensor|{[outputName: string]: Tensor}; - if (node.output instanceof Tensor) { - // Nothing to prune if the output is just a single Tensor since the - // node would have been pruned. - prunedOutputs = node.output; - } else { - // Prune the outputs from the node that don't lead to y. - prunedOutputs = {}; - for (const outputName in node.output) { - const output = node.output[outputName]; - if (tensorsLeadToY[output.id]) { - prunedOutputs[outputName] = node.output[outputName]; - } - } - } + // Nothing to prune if the output is just a single Tensor since the + // node would have been pruned. + prunedOutputs = node.output; // Copy the node and overwrite inputsAndArgs to the pruned version. - const prunedNode = Object.assign({}, node) as TapeNode; - prunedNode.inputAndArgs = {inputs: prunedInputs}; + const prunedNode = Object.assign({}, node) as TapeNode; + prunedNode.inputs = prunedInputs; prunedNode.output = prunedOutputs; filteredTape.push(prunedNode); @@ -151,21 +135,12 @@ export function getFilteredNodesXToY( */ export function backpropagateGradients( tensorAccumulatedGradientMap: {[tensorId: number]: Tensor}, - filteredTape: Tape) { + filteredTape: TapeNode[]) { // Walk the tape backwards and keep a map of Tensor to its gradient. for (let i = filteredTape.length - 1; i >= 0; i--) { const node = filteredTape[i]; - let dy: Tensor|NamedTensorMap; - if (node.output instanceof Tensor) { - dy = tensorAccumulatedGradientMap[node.output.id]; - } else { - dy = {}; - const keys = Object.keys(node.output); - for (const key of keys) { - dy[key] = tensorAccumulatedGradientMap[node.output[key].id]; - } - } + const dy = tensorAccumulatedGradientMap[node.output.id]; if (node.gradient == null) { throw new Error( @@ -174,8 +149,8 @@ export function backpropagateGradients( } // Backprop dy through this node and accumulate gradients over the inputs. - const inputGradients = node.gradient(dy, node.output); - for (const inputName in node.inputAndArgs.inputs) { + const inputGradients = node.gradient(dy); + for (const inputName in node.inputs) { if (!(inputName in inputGradients)) { throw new Error( `Cannot backprop through input ${inputName}. ` + @@ -184,7 +159,7 @@ export function backpropagateGradients( // Call the gradient function. const dx = inputGradients[inputName](); - const x = node.inputAndArgs.inputs[inputName]; + const x = node.inputs[inputName]; if (!util.arraysEqual(dx.shape, x.shape)) { throw new Error( `Error in gradient for op ${node.name}. The gradient of input ` + @@ -228,14 +203,3 @@ export function extractTensorsFromScopeResult(result: ScopeResultImmediate): } return list; } - -export function stripUndefinedInputsFromInputConfig( - config: TapeNodeInputConfig): TapeNodeInputConfig { - const keys = Object.keys(config.inputs); - keys.forEach(key => { - if (config.inputs[key] == null) { - delete config.inputs[key]; - } - }); - return config; -} diff --git a/src/tape_util_test.ts b/src/tape_util_test.ts index 33194829ab..4f56948f36 100644 --- a/src/tape_util_test.ts +++ b/src/tape_util_test.ts @@ -17,12 +17,12 @@ */ import * as dl from './index'; -import {NamedTensorMap} from './types'; -import {CPU_ENVS, describeWithFlags, expectArraysClose} from './test_util'; -import {Scalar, Tensor} from './tensor'; -// tslint:disable-next-line:max-line-length -import {Tape, TapeNode, TapeNodeInputConfig, TapeNodeOutput} from './tape_types'; import * as tape_util from './tape_util'; +// tslint:disable-next-line:max-line-length +import {TapeNode} from './tape_util'; +import {Scalar, Tensor} from './tensor'; +import {CPU_ENVS, describeWithFlags, expectArraysClose} from './test_util'; +import {NamedTensorMap} from './types'; describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { it('getFilteredNodesXToY no paths from x to y', () => { @@ -32,24 +32,18 @@ describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { const intermediate2 = dl.scalar(0); const y = dl.scalar(2); - const tape: Tape = [ + const tape: TapeNode[] = [ { id: 0, - type: 'kernel', name: 'node0', - inputAndArgs: { - inputs: {x}, - }, + inputs: {x}, output: intermediate1, gradient: null }, { id: 1, - type: 'kernel', name: 'node1', - inputAndArgs: { - inputs: {intermediate2}, - }, + inputs: {intermediate2}, output: y, gradient: null } @@ -65,16 +59,8 @@ describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { const x = dl.scalar(1); const y = dl.scalar(2); - const tape: Tape = [{ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: y, - gradient: null - }]; + const tape: TapeNode[] = + [{id: 0, name: 'node0', inputs: {x}, output: y, gradient: null}]; const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); @@ -87,16 +73,8 @@ describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { const x1 = dl.scalar(1); const y = dl.scalar(2); - const tape: Tape = [{ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x0, x1}, - }, - output: y, - gradient: null - }]; + const tape: TapeNode[] = + [{id: 0, name: 'node0', inputs: {x0, x1}, output: y, gradient: null}]; const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x0, x1], y); @@ -110,31 +88,17 @@ describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { const x1 = dl.scalar(1); const y = dl.scalar(2); - const tape: Tape = [{ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x0, x1}, - }, - output: y, - gradient: null - }]; + const tape: TapeNode[] = [ + {id: 0, name: 'node0', inputs: {x0, x1}, output: y, gradient: null} + ]; const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x0], y); expect(filteredTapeNodes.length).toBe(1); // x1 input should be pruned, we don't ask for the gradient of x1. - expect(filteredTapeNodes[0]).toEqual({ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x0}, - }, - output: y, - gradient: null - }); + expect(filteredTapeNodes[0]) + .toEqual( + {id: 0, name: 'node0', inputs: {x0}, output: y, gradient: null}); }); it('getFilteredNodesXToY two operations x => intermediate => y', () => { @@ -142,24 +106,12 @@ describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { const intermediate = dl.scalar(0); const y = dl.scalar(2); - const tape: Tape = [ - { - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: intermediate, - gradient: null - }, + const tape: TapeNode[] = [ + {id: 0, name: 'node0', inputs: {x}, output: intermediate, gradient: null}, { id: 1, - type: 'kernel', name: 'node1', - inputAndArgs: { - inputs: {intermediate}, - }, + inputs: {intermediate}, output: y, gradient: null } @@ -180,24 +132,18 @@ describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { const intermediate = dl.scalar(4); const y = dl.scalar(2); - const tape: Tape = [ + const tape: TapeNode[] = [ { id: 0, - type: 'kernel', name: 'node0', - inputAndArgs: { - inputs: {x0, x1}, - }, + inputs: {x0, x1}, output: intermediate, gradient: null }, { id: 1, - type: 'kernel', name: 'node1', - inputAndArgs: { - inputs: {x2, intermediate}, - }, + inputs: {x2, intermediate}, output: y, gradient: null } @@ -215,27 +161,9 @@ describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { const orphan = dl.scalar(0); const y = dl.scalar(2); - const tape: Tape = [ - { - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: orphan, - gradient: null - }, - { - id: 1, - type: 'kernel', - name: 'node1', - inputAndArgs: { - inputs: {x}, - }, - output: y, - gradient: null - } + const tape: TapeNode[] = [ + {id: 0, name: 'node0', inputs: {x}, output: orphan, gradient: null}, + {id: 1, name: 'node1', inputs: {x}, output: y, gradient: null} ]; const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); @@ -250,31 +178,17 @@ describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { const orphan = dl.scalar(0); const y = dl.scalar(2); - const tape: Tape = [{ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x, orphan}, - }, - output: y, - gradient: null - }]; + const tape: TapeNode[] = [ + {id: 0, name: 'node0', inputs: {x, orphan}, output: y, gradient: null} + ]; const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); expect(filteredTapeNodes.length).toBe(1); // The orphan should be pruned from the node's input. - expect(filteredTapeNodes[0]).toEqual({ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: y, - gradient: null - }); + expect(filteredTapeNodes[0]) + .toEqual( + {id: 0, name: 'node0', inputs: {x}, output: y, gradient: null}); }); it('getFilteredNodesXToY x => {intermediate, orphan1} and ' + @@ -287,24 +201,18 @@ describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { const orphan3 = dl.scalar(3); const y = dl.scalar(2); - const tape: Array> = [ + const tape: TapeNode[] = [ { id: 0, - type: 'kernel', name: 'node0', - inputAndArgs: { - inputs: {x}, - }, + inputs: {x}, output: {orphan1, intermediate}, gradient: null }, { id: 1, - type: 'kernel', name: 'node1', - inputAndArgs: { - inputs: {intermediate, orphan2}, - }, + inputs: {intermediate, orphan2}, output: {y, orphan3}, gradient: null } @@ -316,22 +224,16 @@ describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { // The orphans should be pruned from inputs and outputs. expect(filteredTapeNodes[0]).toEqual({ id: 0, - type: 'kernel', name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: {intermediate}, + inputs: {x}, + output: intermediate, gradient: null }); expect(filteredTapeNodes[1]).toEqual({ id: 1, - type: 'kernel', name: 'node1', - inputAndArgs: { - inputs: {intermediate}, - }, - output: {y}, + inputs: {intermediate}, + output: y, gradient: null }); }); @@ -351,44 +253,26 @@ describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { const y = dl.scalar(7); const orphan2 = dl.scalar(8); - const tape: Tape = [ + const tape: TapeNode[] = [ { id: 0, - type: 'kernel', name: 'node0', - inputAndArgs: { - inputs: {x0}, - }, + inputs: {x0}, output: intermediate0, gradient: null }, { id: 1, - type: 'kernel', name: 'node1', - inputAndArgs: { - inputs: {x0}, - }, + inputs: {x0}, output: intermediate1, gradient: null }, - { - id: 2, - type: 'kernel', - name: 'node2', - inputAndArgs: { - inputs: {x0}, - }, - output: orphan0, - gradient: null - }, + {id: 2, name: 'node2', inputs: {x0}, output: orphan0, gradient: null}, { id: 3, - type: 'kernel', name: 'node3', - inputAndArgs: { - inputs: {intermediate0, intermediate1, x1, orphan1}, - }, + inputs: {intermediate0, intermediate1, x1, orphan1}, output: {y, orphan2}, gradient: null } @@ -404,12 +288,9 @@ describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { // inputs. expect(filteredTapeNodes[2]).toEqual({ id: 3, - type: 'kernel', name: 'node3', - inputAndArgs: { - inputs: {intermediate0, intermediate1, x1}, - }, - output: {y}, + inputs: {intermediate0, intermediate1, x1}, + output: y, gradient: null }); }); @@ -425,16 +306,8 @@ describeWithFlags('backpropagateGradients', CPU_ENVS, () => { const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; accumulatedGradientsMap[y.id] = dy; - const tape: Tape = [{ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: y, - gradient: null - }]; + const tape: TapeNode[] = + [{id: 0, name: 'node0', inputs: {x}, output: y, gradient: null}]; expect( () => tape_util.backpropagateGradients(accumulatedGradientsMap, tape)) @@ -450,15 +323,12 @@ describeWithFlags('backpropagateGradients', CPU_ENVS, () => { const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; accumulatedGradientsMap[y.id] = dy; - const tape: Tape = [{ + const tape: TapeNode[] = [{ id: 0, - type: 'kernel', name: 'node0', - inputAndArgs: { - inputs: {x}, - }, + inputs: {x}, output: y, - gradient: (dy: Scalar, y: Scalar) => { + gradient: (dy: Scalar) => { return {x: () => dy.add(dl.scalar(1))}; } }]; @@ -478,28 +348,22 @@ describeWithFlags('backpropagateGradients', CPU_ENVS, () => { const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; accumulatedGradientsMap[y.id] = dy; - const tape: Tape = [ + const tape: TapeNode[] = [ { id: 0, - type: 'kernel', name: 'node0', - inputAndArgs: { - inputs: {x}, - }, + inputs: {x}, output: intermediate, - gradient: (dy: Scalar, y: Scalar) => { + gradient: (dy: Scalar) => { return {x: () => dy.add(dl.scalar(1))}; } }, { id: 1, - type: 'kernel', name: 'node1', - inputAndArgs: { - inputs: {intermediate}, - }, + inputs: {intermediate}, output: y, - gradient: (dy: Scalar, y: Scalar) => { + gradient: (dy: Scalar) => { return {intermediate: () => dy.add(dl.scalar(1))}; } } @@ -522,40 +386,31 @@ describeWithFlags('backpropagateGradients', CPU_ENVS, () => { const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; accumulatedGradientsMap[y.id] = dy; - const tape: Tape = [ + const tape: TapeNode[] = [ { id: 0, - type: 'kernel', name: 'node0', - inputAndArgs: { - inputs: {x}, - }, + inputs: {x}, output: intermediate1, - gradient: (dy: Scalar, y: Scalar) => { + gradient: (dy: Scalar) => { return {x: () => dy.add(dl.scalar(1))}; } }, { id: 1, - type: 'kernel', name: 'node1', - inputAndArgs: { - inputs: {x}, - }, + inputs: {x}, output: intermediate2, - gradient: (dy: Scalar, y: Scalar) => { + gradient: (dy: Scalar) => { return {x: () => dy.add(dl.scalar(1))}; } }, { id: 2, - type: 'kernel', name: 'node2', - inputAndArgs: { - inputs: {intermediate1, intermediate2}, - }, + inputs: {intermediate1, intermediate2}, output: y, - gradient: (dy: Scalar, y: Scalar) => { + gradient: (dy: Scalar) => { return { intermediate1: () => dy.add(dl.scalar(1)), intermediate2: () => dy.add(dl.scalar(1)) @@ -582,28 +437,22 @@ describeWithFlags('backpropagateGradients', CPU_ENVS, () => { const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; accumulatedGradientsMap[y.id] = dy; - const tape: Array> = [ + const tape: TapeNode[] = [ { id: 0, - type: 'kernel', name: 'node0', - inputAndArgs: { - inputs: {x}, - }, + inputs: {x}, output: {intermediate1, intermediate2}, - gradient: (dy: NamedTensorMap, y: NamedTensorMap) => { + gradient: (dy: NamedTensorMap) => { return {x: () => dy['intermediate1'].mul(dy['intermediate2'])}; } }, { id: 1, - type: 'kernel', name: 'node1', - inputAndArgs: { - inputs: {intermediate1, intermediate2}, - }, + inputs: {intermediate1, intermediate2}, output: y, - gradient: (dy: Scalar, y: Scalar) => { + gradient: (dy: Scalar) => { return { intermediate1: () => dy.add(dl.scalar(2)), intermediate2: () => dy.add(dl.scalar(3)) @@ -641,26 +490,4 @@ describeWithFlags('extractTensorsFromScopeResult', CPU_ENVS, () => { expect(results).toEqual([x1, x2, x3]); }); - - it('pass through when all inputs are defined', () => { - const x1 = dl.scalar(1); - const x2 = dl.scalar(2); - const config: TapeNodeInputConfig = { - inputs: {x1, x2}, - }; - expect(tape_util.stripUndefinedInputsFromInputConfig(config)).toEqual({ - inputs: {x1, x2} - }); - }); - - it('strips undefined inputs', () => { - const x1 = dl.scalar(1); - const x4 = dl.scalar(2); - const config: TapeNodeInputConfig = { - inputs: {x1, x2: undefined, x3: undefined, x4}, - }; - expect(tape_util.stripUndefinedInputsFromInputConfig(config)).toEqual({ - inputs: {x1, x4} - }); - }); }); From ecee2a23961a0589ad0cd4a3e696bc253ec2e3e1 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sat, 17 Feb 2018 22:01:12 -0500 Subject: [PATCH 10/14] save --- src/tape_util_test.ts | 148 ------------------------------------------ 1 file changed, 148 deletions(-) diff --git a/src/tape_util_test.ts b/src/tape_util_test.ts index 4f56948f36..f8cbcf5616 100644 --- a/src/tape_util_test.ts +++ b/src/tape_util_test.ts @@ -18,11 +18,9 @@ import * as dl from './index'; import * as tape_util from './tape_util'; -// tslint:disable-next-line:max-line-length import {TapeNode} from './tape_util'; import {Scalar, Tensor} from './tensor'; import {CPU_ENVS, describeWithFlags, expectArraysClose} from './test_util'; -import {NamedTensorMap} from './types'; describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { it('getFilteredNodesXToY no paths from x to y', () => { @@ -190,110 +188,6 @@ describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { .toEqual( {id: 0, name: 'node0', inputs: {x}, output: y, gradient: null}); }); - - it('getFilteredNodesXToY x => {intermediate, orphan1} and ' + - '{orphan2, intermediate} => {y, orphan3}', - () => { - const x = dl.scalar(1); - const intermediate = dl.scalar(5); - const orphan1 = dl.scalar(1); - const orphan2 = dl.scalar(2); - const orphan3 = dl.scalar(3); - const y = dl.scalar(2); - - const tape: TapeNode[] = [ - { - id: 0, - name: 'node0', - inputs: {x}, - output: {orphan1, intermediate}, - gradient: null - }, - { - id: 1, - name: 'node1', - inputs: {intermediate, orphan2}, - output: {y, orphan3}, - gradient: null - } - ]; - - const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); - - expect(filteredTapeNodes.length).toBe(2); - // The orphans should be pruned from inputs and outputs. - expect(filteredTapeNodes[0]).toEqual({ - id: 0, - name: 'node0', - inputs: {x}, - output: intermediate, - gradient: null - }); - expect(filteredTapeNodes[1]).toEqual({ - id: 1, - name: 'node1', - inputs: {intermediate}, - output: y, - gradient: null - }); - }); - - it('getFilteredNodesXToY x0 => orphan0, ' + - 'x0 => intermediate0, x0 => intermediate1, ' + - '[intermediate0, intermediate1, x1, orphan1] => {y, orphan2}', - () => { - const x0 = dl.scalar(1); - const orphan0 = dl.scalar(2); - - const intermediate0 = dl.scalar(3); - const intermediate1 = dl.scalar(4); - - const x1 = dl.scalar(5); - const orphan1 = dl.scalar(6); - const y = dl.scalar(7); - const orphan2 = dl.scalar(8); - - const tape: TapeNode[] = [ - { - id: 0, - name: 'node0', - inputs: {x0}, - output: intermediate0, - gradient: null - }, - { - id: 1, - name: 'node1', - inputs: {x0}, - output: intermediate1, - gradient: null - }, - {id: 2, name: 'node2', inputs: {x0}, output: orphan0, gradient: null}, - { - id: 3, - name: 'node3', - inputs: {intermediate0, intermediate1, x1, orphan1}, - output: {y, orphan2}, - gradient: null - } - ]; - - const filteredTapeNodes = - tape_util.getFilteredNodesXToY(tape, [x0, x1], y); - - expect(filteredTapeNodes.length).toBe(3); - expect(filteredTapeNodes[0]).toEqual(tape[0]); - expect(filteredTapeNodes[1]).toEqual(tape[1]); - // The orphans should be removed and the orphan1 should be pruned from - // inputs. - expect(filteredTapeNodes[2]).toEqual({ - id: 3, - name: 'node3', - inputs: {intermediate0, intermediate1, x1}, - output: y, - gradient: null - }); - }); }); describeWithFlags('backpropagateGradients', CPU_ENVS, () => { @@ -424,48 +318,6 @@ describeWithFlags('backpropagateGradients', CPU_ENVS, () => { // dx = dy + 1 + 1 + 1 + 1 + 1 expectArraysClose(accumulatedGradientsMap[x.id], [dy.dataSync()[0] + 5]); }); - - it('basic backprop with a multi-output split node accumulates gradients', - () => { - const x = dl.scalar(0); - const intermediate1 = dl.scalar(1); - const intermediate2 = dl.scalar(2); - const y = dl.scalar(3); - - const dy = dl.scalar(1); - - const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; - accumulatedGradientsMap[y.id] = dy; - - const tape: TapeNode[] = [ - { - id: 0, - name: 'node0', - inputs: {x}, - output: {intermediate1, intermediate2}, - gradient: (dy: NamedTensorMap) => { - return {x: () => dy['intermediate1'].mul(dy['intermediate2'])}; - } - }, - { - id: 1, - name: 'node1', - inputs: {intermediate1, intermediate2}, - output: y, - gradient: (dy: Scalar) => { - return { - intermediate1: () => dy.add(dl.scalar(2)), - intermediate2: () => dy.add(dl.scalar(3)) - }; - } - } - ]; - - tape_util.backpropagateGradients(accumulatedGradientsMap, tape); - - expectArraysClose( - accumulatedGradientsMap[x.id], [(dy.get() + 2) * (dy.get() + 3)]); - }); }); describeWithFlags('extractTensorsFromScopeResult', CPU_ENVS, () => { From ef1a310bc8efcd282c72eae6d50881e9ca3f9fad Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sat, 17 Feb 2018 22:18:13 -0500 Subject: [PATCH 11/14] save --- src/engine.ts | 18 +++++++++++------- src/tracking.ts | 5 +++-- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/engine.ts b/src/engine.ts index 1831cb9591..646b0b236a 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -30,6 +30,7 @@ import * as util from './util'; interface ScopeState { keep: Tensor[]; track: Tensor[]; + name?: string; } export type ForwardFunc = @@ -93,13 +94,13 @@ export class Engine implements TensorManager { backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]}, ): T { let result: T; - // TODO(smilkov): Figure out kernel name. - const kernelName = ''; const saved: Tensor[] = []; const saveFunc = (x: T): T => { saved.push(x); return x; }; + const kernelName = this.activeScope.name; + if (!ENV.get('DEBUG')) { result = forwardFunc(this.backend, saveFunc); } else { @@ -198,7 +199,7 @@ export class Engine implements TensorManager { const evaluatedNode: TapeNode = { id: this.nextTapeNodeId++, - name: '', // TODO(smilkov): Figure out kernel name. + name: this.activeScope.name, inputs: inputsMap, output: result, gradient @@ -220,7 +221,7 @@ export class Engine implements TensorManager { * Start a scope. Use this with endScope() to achieve the same functionality * as scope() without the need for a function closure. */ - startScope(gradientsMode = false) { + startScope(name?: string, gradientsMode = false) { if (gradientsMode && this.gradientScopeCount === 0) { this.activeTape = []; } @@ -228,9 +229,12 @@ export class Engine implements TensorManager { this.gradientScopeCount++; } - const newScopeArrays: ScopeState = {keep: [], track: []}; - this.scopeStack.push(newScopeArrays); - this.activeScope = newScopeArrays; + const scopeInfo: ScopeState = {keep: [], track: []}; + if (name) { + scopeInfo.name = name; + } + this.scopeStack.push(scopeInfo); + this.activeScope = scopeInfo; } /** diff --git a/src/tracking.ts b/src/tracking.ts index fb383f7388..7428d148c3 100644 --- a/src/tracking.ts +++ b/src/tracking.ts @@ -64,13 +64,13 @@ export class Tracking { @doc({heading: 'Performance', subheading: 'Memory'}) static tidy( nameOrFn: string|ScopeFn, fn?: ScopeFn, gradMode = false): T { + let name = null; if (fn == null) { // Called with only 1 argument. if (typeof nameOrFn !== 'function') { throw new Error('Please provide a function to dl.tidy()'); } fn = nameOrFn; - nameOrFn = ''; } else { // Called with 2 arguments. if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) { @@ -83,9 +83,10 @@ export class Tracking { 'When calling with two arguments, the 2nd argument ' + 'to dl.tidy() must be a function'); } + name = nameOrFn as string; // TODO(nsthorat,smilkov): Do operation logging and performance profiling. } - ENV.engine.startScope(gradMode); + ENV.engine.startScope(name, gradMode); const result = fn(); if (result instanceof Promise) { From 680f4f1dc42a5fc2157c5bf9c20a6a2ad69bb540 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sat, 17 Feb 2018 22:22:31 -0500 Subject: [PATCH 12/14] save --- src/engine.ts | 6 +++--- src/gradients.ts | 2 +- src/math.ts | 2 +- src/{tape_util.ts => tape.ts} | 0 src/tape_util_test.ts | 4 ++-- src/tracking.ts | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) rename src/{tape_util.ts => tape.ts} (100%) diff --git a/src/engine.ts b/src/engine.ts index 646b0b236a..dd40bcbc0b 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -20,9 +20,9 @@ import {tidy} from './globals'; import {BackendTimingInfo, KernelBackend} from './kernels/backend'; import * as ops from './ops/ops'; import {Profiler} from './profiler'; -import * as tape_util from './tape_util'; -import {NamedGradientMap, TapeNode} from './tape_util'; -import {ScopeResultImmediate} from './tape_util'; +import * as tape_util from './tape'; +import {NamedGradientMap, TapeNode} from './tape'; +import {ScopeResultImmediate} from './tape'; import {DataId, Tensor, Tensor3D, Variable} from './tensor'; import {NamedTensorMap, NamedVariableMap, TypedArray} from './types'; import * as util from './util'; diff --git a/src/gradients.ts b/src/gradients.ts index 4840867b75..970cb73a77 100644 --- a/src/gradients.ts +++ b/src/gradients.ts @@ -19,7 +19,7 @@ import {doc} from './doc'; import {CustomGradientFunc} from './engine'; import {ENV} from './environment'; import {tidy} from './globals'; -import {ScopeFn, ScopeResult} from './tape_util'; +import {ScopeFn, ScopeResult} from './tape'; import {Scalar, Tensor, Variable} from './tensor'; import {NamedTensorMap} from './types'; import * as util from './util'; diff --git a/src/math.ts b/src/math.ts index 017a886740..ef1d5c9a51 100644 --- a/src/math.ts +++ b/src/math.ts @@ -36,7 +36,7 @@ import * as slice from './ops/slice'; import * as softmax_ops from './ops/softmax'; import * as transpose from './ops/transpose'; import * as unary_ops from './ops/unary_ops'; -import {ScopeResult} from './tape_util'; +import {ScopeResult} from './tape'; import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from './tensor'; import {Tracking} from './tracking'; import {Rank} from './types'; diff --git a/src/tape_util.ts b/src/tape.ts similarity index 100% rename from src/tape_util.ts rename to src/tape.ts diff --git a/src/tape_util_test.ts b/src/tape_util_test.ts index f8cbcf5616..c937eb2391 100644 --- a/src/tape_util_test.ts +++ b/src/tape_util_test.ts @@ -17,8 +17,8 @@ */ import * as dl from './index'; -import * as tape_util from './tape_util'; -import {TapeNode} from './tape_util'; +import * as tape_util from './tape'; +import {TapeNode} from './tape'; import {Scalar, Tensor} from './tensor'; import {CPU_ENVS, describeWithFlags, expectArraysClose} from './test_util'; diff --git a/src/tracking.ts b/src/tracking.ts index 7428d148c3..5ae889adc7 100644 --- a/src/tracking.ts +++ b/src/tracking.ts @@ -19,7 +19,7 @@ import {doc} from './doc'; import {TimingInfo} from './engine'; import {ENV} from './environment'; // tslint:disable-next-line:max-line-length -import {ScopeFn, ScopeResult, ScopeResultImmediate} from './tape_util'; +import {ScopeFn, ScopeResult, ScopeResultImmediate} from './tape'; import {Tensor} from './tensor'; export class Tracking { From f2eea0fc13817570427aa8e07ed4bcdb25cc80a9 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sun, 18 Feb 2018 15:30:38 -0500 Subject: [PATCH 13/14] save --- src/engine.ts | 33 +++++++++++++++---------- src/tape.ts | 18 +++++++------- src/{tape_util_test.ts => tape_test.ts} | 1 - 3 files changed, 29 insertions(+), 23 deletions(-) rename src/{tape_util_test.ts => tape_test.ts} (99%) diff --git a/src/engine.ts b/src/engine.ts index dd40bcbc0b..dc0c02d0fe 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -20,7 +20,8 @@ import {tidy} from './globals'; import {BackendTimingInfo, KernelBackend} from './kernels/backend'; import * as ops from './ops/ops'; import {Profiler} from './profiler'; -import * as tape_util from './tape'; +// tslint:disable-next-line:max-line-length +import {backpropagateGradients, extractTensorsFromScopeResult, getFilteredNodesXToY} from './tape'; import {NamedGradientMap, TapeNode} from './tape'; import {ScopeResultImmediate} from './tape'; import {DataId, Tensor, Tensor3D, Variable} from './tensor'; @@ -33,6 +34,10 @@ interface ScopeState { name?: string; } +/** + * A function that computes an output. The save function is for saving tensors + * computed in the forward pass, that we need in the backwards pass. + */ export type ForwardFunc = (backend: KernelBackend, save?: (tensor: S) => S) => T; @@ -111,14 +116,18 @@ export class Engine implements TensorManager { const recordKernel = this.activeTape != null && this.customGradientDepth === 0; if (recordKernel) { - const evaluatedNode: TapeNode = { + const tapeNode: TapeNode = { id: this.nextTapeNodeId++, name: kernelName, - inputs, - output: result, - gradient: (dy: T) => backwardsFunc(dy, saved) + output: result }; - this.activeTape.push(evaluatedNode); + if (inputs != null) { + tapeNode.inputs = inputs; + } + if (backwardsFunc != null) { + tapeNode.gradient = (dy: T) => backwardsFunc(dy, saved); + } + this.activeTape.push(tapeNode); } return result; } @@ -197,14 +206,14 @@ export class Engine implements TensorManager { return resMap; }; - const evaluatedNode: TapeNode = { + const tapeNode: TapeNode = { id: this.nextTapeNodeId++, name: this.activeScope.name, inputs: inputsMap, output: result, gradient }; - this.activeTape.push(evaluatedNode); + this.activeTape.push(tapeNode); } keep(result: T): T { @@ -250,8 +259,7 @@ export class Engine implements TensorManager { } let tensorsToKeep = this.activeScope.keep; - const tensorsToTrackInParent = - tape_util.extractTensorsFromScopeResult(result); + const tensorsToTrackInParent = extractTensorsFromScopeResult(result); tensorsToKeep = tensorsToKeep.concat(tensorsToTrackInParent); // Dispose the arrays tracked in this scope. @@ -302,8 +310,7 @@ export class Engine implements TensorManager { y instanceof Tensor, 'The result y returned by f() must be a tensor.'); // Filter out the nodes that don't connect x => y. - const filteredTape = - tape_util.getFilteredNodesXToY(this.activeTape, xs, y); + const filteredTape = getFilteredNodesXToY(this.activeTape, xs, y); if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) { throw new Error( 'Cannot compute gradient of y=f(x) with respect to x. Make sure ' + @@ -315,7 +322,7 @@ export class Engine implements TensorManager { accumulatedGradientMap[y.id] = (dy == null) ? ops.onesLike(y) : dy; // Backprop gradients through the filtered nodes. - tape_util.backpropagateGradients(accumulatedGradientMap, filteredTape); + backpropagateGradients(accumulatedGradientMap, filteredTape); const grads = xs.map(x => accumulatedGradientMap[x.id]); return {value: y, grads}; diff --git a/src/tape.ts b/src/tape.ts index d6c130a83c..a83d2a1256 100644 --- a/src/tape.ts +++ b/src/tape.ts @@ -22,9 +22,10 @@ import * as util from './util'; export interface TapeNode { id: number; name: string; - inputs: NamedTensorMap; output: Tensor; - gradient: (dy: Tensor|NamedTensorMap) => NamedGradientMap; + // Optional params, defined only for ops with gradient impl. + inputs?: NamedTensorMap; + gradient?: (dy: Tensor|NamedTensorMap) => NamedGradientMap; } export type NamedGradientMap = { @@ -51,7 +52,11 @@ export function getFilteredNodesXToY( for (let i = 0; i < tape.length; i++) { const node = tape[i]; const nodeInputs = node.inputs; - + if (nodeInputs == null) { + throw new Error( + `${node.name} is missing gradient implementation. ` + + `Failed to back-propagate.`); + } for (const inputName in nodeInputs) { const input = nodeInputs[inputName]; @@ -110,15 +115,10 @@ export function getFilteredNodesXToY( } } - let prunedOutputs: Tensor|{[outputName: string]: Tensor}; - // Nothing to prune if the output is just a single Tensor since the - // node would have been pruned. - prunedOutputs = node.output; - // Copy the node and overwrite inputsAndArgs to the pruned version. const prunedNode = Object.assign({}, node) as TapeNode; prunedNode.inputs = prunedInputs; - prunedNode.output = prunedOutputs; + prunedNode.output = node.output; filteredTape.push(prunedNode); } diff --git a/src/tape_util_test.ts b/src/tape_test.ts similarity index 99% rename from src/tape_util_test.ts rename to src/tape_test.ts index c937eb2391..347fa13203 100644 --- a/src/tape_util_test.ts +++ b/src/tape_test.ts @@ -1,4 +1,3 @@ - /** * @license * Copyright 2017 Google Inc. All Rights Reserved. From 89cab055a9f592ccd27dd8a29e01ee7db064767d Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sun, 18 Feb 2018 20:52:23 -0500 Subject: [PATCH 14/14] save --- src/engine.ts | 6 +++--- src/profiler.ts | 10 +++++----- src/profiler_test.ts | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/engine.ts b/src/engine.ts index dc0c02d0fe..36bb788a01 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -104,13 +104,13 @@ export class Engine implements TensorManager { saved.push(x); return x; }; - const kernelName = this.activeScope.name; + const scopeName = this.activeScope.name; if (!ENV.get('DEBUG')) { result = forwardFunc(this.backend, saveFunc); } else { result = this.profiler.profileKernel( - kernelName, () => forwardFunc(this.backend, saveFunc)); + scopeName, () => forwardFunc(this.backend, saveFunc)); } const recordKernel = @@ -118,7 +118,7 @@ export class Engine implements TensorManager { if (recordKernel) { const tapeNode: TapeNode = { id: this.nextTapeNodeId++, - name: kernelName, + name: scopeName, output: result }; if (inputs != null) { diff --git a/src/profiler.ts b/src/profiler.ts index f1f9e6cfbe..6b78c1c779 100644 --- a/src/profiler.ts +++ b/src/profiler.ts @@ -27,7 +27,7 @@ export class Profiler { } } - profileKernel(kernelName: string, f: () => T): T { + profileKernel(name: string, f: () => T): T { let result: T; const holdResultWrapperFn = () => { result = f(); @@ -35,10 +35,10 @@ export class Profiler { const timer = this.backendTimer.time(holdResultWrapperFn); const vals = result.dataSync(); - util.checkForNaN(vals, result.dtype, kernelName); + util.checkForNaN(vals, result.dtype, name); timer.then(timing => { - this.logger.logKernelProfile(kernelName, result, vals, timing.kernelMs); + this.logger.logKernelProfile(name, result, vals, timing.kernelMs); }); return result as T; @@ -47,9 +47,9 @@ export class Profiler { export class Logger { logKernelProfile( - kernelName: string, result: Tensor, vals: TypedArray, timeMs: number) { + name: string, result: Tensor, vals: TypedArray, timeMs: number) { const time = util.rightPad(`${timeMs}ms`, 9); - const paddedName = util.rightPad(kernelName, 25); + const paddedName = util.rightPad(name, 25); const rank = result.rank; const size = result.size; const shape = util.rightPad(result.shape.toString(), 14); diff --git a/src/profiler_test.ts b/src/profiler_test.ts index 9092cc0f93..9d3d368810 100644 --- a/src/profiler_test.ts +++ b/src/profiler_test.ts @@ -36,7 +36,7 @@ class TestBackendTimer implements BackendTimer { class TestLogger extends Logger { logKernelProfile( - kernelName: string, result: Tensor, vals: TypedArray, timeMs: number) {} + name: string, result: Tensor, vals: TypedArray, timeMs: number) {} } describe('profiler.Profiler', () => {