diff --git a/src/engine.ts b/src/engine.ts index 1d2fa86eb8..36bb788a01 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -18,24 +18,29 @@ import {ENV} from './environment'; import {tidy} from './globals'; import {BackendTimingInfo, KernelBackend} from './kernels/backend'; -import * as kernel_registry from './kernels/kernel_registry'; -import {KernelConfigRegistry} from './kernels/kernel_registry'; import * as ops from './ops/ops'; import {Profiler} from './profiler'; // tslint:disable-next-line:max-line-length -import {KernelNode, Tape, TapeNode, TapeNodeInputGradientTensors} from './tape_types'; -import * as tape_util from './tape_util'; -import {ScopeResultImmediate} from './tape_util'; +import {backpropagateGradients, extractTensorsFromScopeResult, getFilteredNodesXToY} from './tape'; +import {NamedGradientMap, TapeNode} from './tape'; +import {ScopeResultImmediate} from './tape'; import {DataId, Tensor, Tensor3D, Variable} from './tensor'; import {NamedTensorMap, NamedVariableMap, TypedArray} from './types'; -import {Rank} from './types'; import * as util from './util'; interface ScopeState { keep: Tensor[]; track: Tensor[]; + name?: string; } +/** + * A function that computes an output. The save function is for saving tensors + * computed in the forward pass, that we need in the backwards pass. + */ +export type ForwardFunc = + (backend: KernelBackend, save?: (tensor: S) => S) => T; + /** * @docalias (a: Tensor, b: Tensor,...) => { * value: Tensor, @@ -70,7 +75,7 @@ export class Engine implements TensorManager { private numTensors = 0; private numDataBuffers = 0; - private activeTape: Tape; + private activeTape: TapeNode[]; private gradientScopeCount = 0; private customGradientDepth = 0; @@ -88,39 +93,42 @@ export class Engine implements TensorManager { this.profiler = new Profiler(backend); } - executeKernel, C - extends KernelConfigRegistry[K]['inputAndArgs']>( - kernelName: K, config: C, grad?: KernelConfigRegistry[K]['gradient']): - KernelConfigRegistry[K]['output'] { - let result: KernelConfigRegistry[K]['output']; + runKernel( + forwardFunc: ForwardFunc, + inputs?: I, + backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]}, + ): T { + let result: T; + const saved: Tensor[] = []; + const saveFunc = (x: T): T => { + saved.push(x); + return x; + }; + const scopeName = this.activeScope.name; + if (!ENV.get('DEBUG')) { - // NOTE: This isn't pulled out into a separate function to so that we - // keep a shallow stack trace. - result = kernel_registry.executeKernel(this.backend, kernelName, config); + result = forwardFunc(this.backend, saveFunc); } else { result = this.profiler.profileKernel( - kernelName, - () => - kernel_registry.executeKernel(this.backend, kernelName, config)); + scopeName, () => forwardFunc(this.backend, saveFunc)); } const recordKernel = this.activeTape != null && this.customGradientDepth === 0; if (recordKernel) { - config = tape_util.stripUndefinedInputsFromInputConfig(config) as C; - - const evaluatedNode: KernelNode = { + const tapeNode: TapeNode = { id: this.nextTapeNodeId++, - type: 'kernel', - name: `kernel: ${kernelName}`, - kernel: kernelName, - inputAndArgs: config, - output: result, - gradient: grad + name: scopeName, + output: result }; - this.activeTape.push(evaluatedNode); + if (inputs != null) { + tapeNode.inputs = inputs; + } + if (backwardsFunc != null) { + tapeNode.gradient = (dy: T) => backwardsFunc(dy, saved); + } + this.activeTape.push(tapeNode); } - return result; } @@ -191,22 +199,21 @@ export class Engine implements TensorManager { const gradient = (dy: Tensor) => { const res = gradientsFunc(dy); - const resMap: TapeNodeInputGradientTensors = {}; + const resMap: NamedGradientMap = {}; res.forEach((r, idx) => { resMap[idx] = () => r; }); return resMap; }; - const evaluatedNode: TapeNode = { + const tapeNode: TapeNode = { id: this.nextTapeNodeId++, - type: 'customGradient', - name, - inputAndArgs: {inputs: inputsMap}, + name: this.activeScope.name, + inputs: inputsMap, output: result, gradient }; - this.activeTape.push(evaluatedNode); + this.activeTape.push(tapeNode); } keep(result: T): T { @@ -223,7 +230,7 @@ export class Engine implements TensorManager { * Start a scope. Use this with endScope() to achieve the same functionality * as scope() without the need for a function closure. */ - startScope(gradientsMode = false) { + startScope(name?: string, gradientsMode = false) { if (gradientsMode && this.gradientScopeCount === 0) { this.activeTape = []; } @@ -231,9 +238,12 @@ export class Engine implements TensorManager { this.gradientScopeCount++; } - const newScopeArrays: ScopeState = {keep: [], track: []}; - this.scopeStack.push(newScopeArrays); - this.activeScope = newScopeArrays; + const scopeInfo: ScopeState = {keep: [], track: []}; + if (name) { + scopeInfo.name = name; + } + this.scopeStack.push(scopeInfo); + this.activeScope = scopeInfo; } /** @@ -249,8 +259,7 @@ export class Engine implements TensorManager { } let tensorsToKeep = this.activeScope.keep; - const tensorsToTrackInParent = - tape_util.extractTensorsFromScopeResult(result); + const tensorsToTrackInParent = extractTensorsFromScopeResult(result); tensorsToKeep = tensorsToKeep.concat(tensorsToTrackInParent); // Dispose the arrays tracked in this scope. @@ -301,8 +310,7 @@ export class Engine implements TensorManager { y instanceof Tensor, 'The result y returned by f() must be a tensor.'); // Filter out the nodes that don't connect x => y. - const filteredTape = - tape_util.getFilteredNodesXToY(this.activeTape, xs, y); + const filteredTape = getFilteredNodesXToY(this.activeTape, xs, y); if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) { throw new Error( 'Cannot compute gradient of y=f(x) with respect to x. Make sure ' + @@ -314,7 +322,7 @@ export class Engine implements TensorManager { accumulatedGradientMap[y.id] = (dy == null) ? ops.onesLike(y) : dy; // Backprop gradients through the filtered nodes. - tape_util.backpropagateGradients(accumulatedGradientMap, filteredTape); + backpropagateGradients(accumulatedGradientMap, filteredTape); const grads = xs.map(x => accumulatedGradientMap[x.id]); return {value: y, grads}; diff --git a/src/gradients.ts b/src/gradients.ts index 4840867b75..970cb73a77 100644 --- a/src/gradients.ts +++ b/src/gradients.ts @@ -19,7 +19,7 @@ import {doc} from './doc'; import {CustomGradientFunc} from './engine'; import {ENV} from './environment'; import {tidy} from './globals'; -import {ScopeFn, ScopeResult} from './tape_util'; +import {ScopeFn, ScopeResult} from './tape'; import {Scalar, Tensor, Variable} from './tensor'; import {NamedTensorMap} from './types'; import * as util from './util'; diff --git a/src/index.ts b/src/index.ts index 4144ea1c2f..e77c4f6e63 100644 --- a/src/index.ts +++ b/src/index.ts @@ -42,11 +42,11 @@ export {CostReduction, FeedEntry, Session} from './graph/session'; export {MathBackendCPU, NDArrayMathCPU} from './kernels/backend_cpu'; // tslint:disable-next-line:max-line-length export {MathBackendWebGL, NDArrayMathGPU, WebGLTimingInfo} from './kernels/backend_webgl'; -export {MatrixOrientation} from './kernels/types/matmul'; export {GPGPUContext} from './kernels/webgl/gpgpu_context'; export {NDArrayMath} from './math'; export {Model} from './model'; export {LSTMCell} from './ops/lstm'; +export {MatrixOrientation} from './ops/matmul'; export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer'; export {AdagradOptimizer} from './optimizers/adagrad_optimizer'; export {AdamOptimizer} from './optimizers/adam_optimizer'; diff --git a/src/kernels/backend.ts b/src/kernels/backend.ts index 84c6c80953..ea208d2175 100644 --- a/src/kernels/backend.ts +++ b/src/kernels/backend.ts @@ -19,7 +19,7 @@ import {Conv2DInfo} from '../ops/conv_util'; // tslint:disable-next-line:max-line-length import {DataId, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'; -import {DataType, Rank, TypedArray} from '../types'; +import {DataType, TypedArray} from '../types'; // Required information for all backends. export interface BackendTimingInfo { kernelMs: number; } @@ -95,7 +95,7 @@ export interface KernelBackend extends TensorStorage, BackendTimer { greater(a: Tensor, b: Tensor): Tensor; greaterEqual(a: Tensor, b: Tensor): Tensor; - logicalNot(a: Tensor): Tensor; + logicalNot(a: T): T; logicalAnd(a: Tensor, b: Tensor): Tensor; logicalOr(a: Tensor, b: Tensor): Tensor; logicalXor(a: Tensor, b: Tensor): Tensor; @@ -128,7 +128,7 @@ export interface KernelBackend extends TensorStorage, BackendTimer { leakyRelu(x: T, alpha: number): T; prelu(x: T, alpha: T): T; preluDer(x: T, alpha: T): T; - int(x: Tensor): Tensor; + int(x: T): T; clip(x: T, min: number, max: number): T; @@ -183,8 +183,8 @@ export interface KernelBackend extends TensorStorage, BackendTimer { batchNormalization4D( x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D, - varianceEpsilon: number, scale?: Tensor4D|Tensor1D, - offset?: Tensor4D|Tensor1D): Tensor4D; + varianceEpsilon: number, scale: Tensor4D|Tensor1D, + offset: Tensor4D|Tensor1D): Tensor4D; localResponseNormalization4D( x: Tensor4D, radius: number, bias: number, alpha: number, beta: number, diff --git a/src/kernels/backend_cpu.ts b/src/kernels/backend_cpu.ts index 9f85e7b322..ee4f02a011 100644 --- a/src/kernels/backend_cpu.ts +++ b/src/kernels/backend_cpu.ts @@ -790,13 +790,13 @@ export class MathBackendCPU implements KernelBackend { return Tensor.make(x.shape, {values: resultValues}) as T; } - int(x: Tensor): Tensor { + int(x: T): T { const resultValues = new Int32Array(x.size); const values = x.dataSync(); for (let i = 0; i < values.length; ++i) { resultValues[i] = values[i]; } - return Tensor.make(x.shape, {values: resultValues}, 'int32') as Tensor; + return Tensor.make(x.shape, {values: resultValues}, 'int32'); } sigmoid(x: T): T { diff --git a/src/kernels/backend_webgl.ts b/src/kernels/backend_webgl.ts index 4d3c5e413b..892d6e0f77 100644 --- a/src/kernels/backend_webgl.ts +++ b/src/kernels/backend_webgl.ts @@ -24,10 +24,8 @@ import * as reduce_util from '../ops/reduce_util'; // tslint:disable-next-line:max-line-length import {DataId, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'; import * as types from '../types'; -// tslint:disable-next-line:max-line-length -import {DataType, DataTypeMap, Rank, RecursiveArray, TypedArray} from '../types'; +import {DataType, DataTypeMap, RecursiveArray, TypedArray} from '../types'; import * as util from '../util'; - import {KernelBackend} from './backend'; import {ArgMinMaxProgram} from './webgl/argminmax_gpu'; import {AvgPool2DBackpropProgram} from './webgl/avg_pool_backprop_gpu'; @@ -705,10 +703,10 @@ export class MathBackendWebGL implements KernelBackend { return this.compileAndRun(program, [a, b]) as T; } - int(x: Tensor): Tensor { + int(x: T): T { const program = new UnaryOpProgram(x.shape, unary_op.TO_INT); const output = this.makeOutputArray(program.outputShape, 'int32'); - return this.compileAndRun(program, [x], output) as Tensor; + return this.compileAndRun(program, [x], output) as T; } clip(x: T, min: number, max: number): T { diff --git a/src/kernels/kernel_registry.ts b/src/kernels/kernel_registry.ts deleted file mode 100644 index 0383cb067e..0000000000 --- a/src/kernels/kernel_registry.ts +++ /dev/null @@ -1,435 +0,0 @@ -/** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import * as ops from '../ops/ops'; -import {Tensor} from '../tensor'; -import {Rank} from '../types'; -import * as util from '../util'; - -import {KernelBackend} from './backend'; -import {ArgMaxNode, ArgMinNode} from './types/argminmax'; -import {BatchNorm4DNode} from './types/batchnorm'; -import {BinaryNode} from './types/binary'; -import {CastNode} from './types/cast'; -// tslint:disable-next-line:max-line-length -import {ConcatNode} from './types/concat'; -// tslint:disable-next-line:max-line-length -import {Conv2DDerFilterNode, Conv2DDerInputNode, Conv2DNode, DepthwiseConv2DNode} from './types/conv'; -import {GatherNode} from './types/gather'; -import {EqualNode, LogicalNode, WhereNode} from './types/logical'; -import {LRN4DNode} from './types/lrn'; -import {MatMulNode} from './types/matmul'; -import {MaximumNode, MaxNode, MinimumNode, MinNode} from './types/minmax'; -import {MultinomialNode} from './types/multinomial'; -import {OneHotNode} from './types/onehot'; -import {Pad1DNode, Pad2DNode} from './types/pad'; -// tslint:disable-next-line:max-line-length -import {PoolBackpropNode, PoolNode} from './types/pool'; -import {PowNode} from './types/pow'; -import {PReLUNode} from './types/prelu'; -import {ReshapeNode} from './types/reshape'; -import {ResizeBilinearNode} from './types/resize_bilinear'; -import {Reverse4DNode} from './types/reverse'; -// tslint:disable-next-line:max-line-length -import {Slice1DNode, Slice2DNode, Slice3DNode, Slice4DNode} from './types/slice'; -import {SumNode} from './types/sum'; -import {TopKIndicesNode, TopKValuesNode} from './types/topk'; -// tslint:disable-next-line:max-line-length -import {ClipNode, LeakyReluNode, StepNode, TileNode, TransposeNode, UnaryNode} from './types/unary'; - -export function -executeKernel, O extends - KernelConfigRegistry[K]['output']>( - backend: KernelBackend, kernelName: K, - inputAndArgs: KernelConfigRegistry[K]['inputAndArgs']): O { - if (kernelName === 'MatMul') { - const config = inputAndArgs as MatMulNode['inputAndArgs']; - return backend.matMul( - config.inputs.a, config.inputs.b, config.args.transposeA, - config.args.transposeB) as O; - } else if (kernelName === 'Slice1D') { - const config = inputAndArgs as Slice1DNode['inputAndArgs']; - return backend.slice1D( - config.inputs.x, config.args.begin, config.args.size) as O; - } else if (kernelName === 'Slice2D') { - const config = inputAndArgs as Slice2DNode['inputAndArgs']; - return backend.slice2D( - config.inputs.x, config.args.begin, config.args.size) as O; - } else if (kernelName === 'Slice3D') { - const config = inputAndArgs as Slice3DNode['inputAndArgs']; - return backend.slice3D( - config.inputs.x, config.args.begin, config.args.size) as O; - } else if (kernelName === 'Slice4D') { - const config = inputAndArgs as Slice4DNode['inputAndArgs']; - return backend.slice4D( - config.inputs.x, config.args.begin, config.args.size) as O; - } else if (kernelName === 'Reverse4D') { - const config = inputAndArgs as Reverse4DNode['inputAndArgs']; - return backend.reverse4D(config.inputs.x, config.args.axis) as O; - } else if (kernelName === 'Concat') { - const config = inputAndArgs as ConcatNode['inputAndArgs']; - return backend.concat(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Neg') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.neg(config.inputs.x) as O; - } else if (kernelName === 'Add') { - const config = inputAndArgs as BinaryNode['inputAndArgs']; - return backend.add(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Sub') { - const config = inputAndArgs as BinaryNode['inputAndArgs']; - return backend.subtract(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Mul') { - const config = inputAndArgs as BinaryNode['inputAndArgs']; - return backend.multiply(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Div') { - const config = inputAndArgs as BinaryNode['inputAndArgs']; - return backend.divide(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Sum') { - const config = inputAndArgs as SumNode['inputAndArgs']; - return backend.sum(config.inputs.x, config.args.axes) as O; - } else if (kernelName === 'ArgMax') { - const config = inputAndArgs as ArgMaxNode['inputAndArgs']; - return backend.argMax(config.inputs.x, config.args.axes) as O; - } else if (kernelName === 'ArgMin') { - const config = inputAndArgs as ArgMinNode['inputAndArgs']; - return backend.argMin(config.inputs.x, config.args.axes) as O; - } else if (kernelName === 'Equal') { - const config = inputAndArgs as EqualNode['inputAndArgs']; - return backend.equal(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'NotEqual') { - const config = inputAndArgs as EqualNode['inputAndArgs']; - return backend.notEqual(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Less') { - const config = inputAndArgs as EqualNode['inputAndArgs']; - return backend.less(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'LessEqual') { - const config = inputAndArgs as EqualNode['inputAndArgs']; - return backend.lessEqual(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Greater') { - const config = inputAndArgs as EqualNode['inputAndArgs']; - return backend.greater(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'GreaterEqual') { - const config = inputAndArgs as EqualNode['inputAndArgs']; - return backend.greaterEqual(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'LogicalNot') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.logicalNot(config.inputs.x) as O; - } else if (kernelName === 'LogicalAnd') { - const config = inputAndArgs as LogicalNode['inputAndArgs']; - return backend.logicalAnd(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'LogicalOr') { - const config = inputAndArgs as LogicalNode['inputAndArgs']; - return backend.logicalOr(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'LogicalXor') { - const config = inputAndArgs as LogicalNode['inputAndArgs']; - return backend.logicalXor(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Where') { - const config = inputAndArgs as WhereNode['inputAndArgs']; - return backend.where( - config.inputs.condition, config.inputs.a, config.inputs.b, - config.args.dtype) as O; - } else if (kernelName === 'TopKValues') { - const config = inputAndArgs as TopKValuesNode['inputAndArgs']; - return backend.topKValues(config.inputs.x, config.args.k) as O; - } else if (kernelName === 'TopKIndices') { - const config = inputAndArgs as TopKIndicesNode['inputAndArgs']; - return backend.topKIndices(config.inputs.x, config.args.k) as O; - } else if (kernelName === 'Min') { - const config = inputAndArgs as MinNode['inputAndArgs']; - return backend.min(config.inputs.x, config.args.axes) as O; - } else if (kernelName === 'Minimum') { - const config = inputAndArgs as MinimumNode['inputAndArgs']; - return backend.minimum(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Max') { - const config = inputAndArgs as MaxNode['inputAndArgs']; - return backend.max(config.inputs.x, config.args.axes) as O; - } else if (kernelName === 'Maximum') { - const config = inputAndArgs as MaximumNode['inputAndArgs']; - return backend.maximum(config.inputs.a, config.inputs.b) as O; - } else if (kernelName === 'Ceil') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.ceil(config.inputs.x) as O; - } else if (kernelName === 'Floor') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.floor(config.inputs.x) as O; - } else if (kernelName === 'Pow') { - const config = inputAndArgs as PowNode['inputAndArgs']; - return backend.pow(config.inputs.base, config.inputs.exp) as O; - } else if (kernelName === 'Exp') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.exp(config.inputs.x) as O; - } else if (kernelName === 'Log') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.log(config.inputs.x) as O; - } else if (kernelName === 'Sqrt') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.sqrt(config.inputs.x) as O; - } else if (kernelName === 'Square') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.square(config.inputs.x) as O; - } else if (kernelName === 'Relu') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.relu(config.inputs.x) as O; - } else if (kernelName === 'Reshape') { - const config = inputAndArgs as ReshapeNode['inputAndArgs']; - const x = config.inputs.x; - const newShape = config.args.newShape; - return Tensor.make(newShape, {dataId: x.dataId}, x.dtype) as O; - } else if (kernelName === 'Cast') { - const config = inputAndArgs as CastNode['inputAndArgs']; - const x = config.inputs.x; - const newDType = config.args.newDType; - - if (!util.hasEncodingLoss(x.dtype, newDType)) { - // We don't change the underlying data, since we cast to higher - // precision. - return Tensor.make(x.shape, {dataId: x.dataId}, newDType) as O; - } - if (newDType === 'int32') { - return backend.int(x) as O; - } else if (newDType === 'bool') { - return backend.notEqual(x, ops.scalar(0, x.dtype)) as O; - } else { - throw new Error(`Error in Cast: unknown dtype argument (${newDType})`); - } - } else if (kernelName === 'LeakyRelu') { - const config = inputAndArgs as LeakyReluNode['inputAndArgs']; - return backend.leakyRelu(config.inputs.x, config.args.alpha) as O; - } else if (kernelName === 'PReLU') { - const config = inputAndArgs as PReLUNode['inputAndArgs']; - return backend.prelu(config.inputs.x, config.inputs.alpha) as O; - } else if (kernelName === 'PReLUDer') { - const config = inputAndArgs as PReLUNode['inputAndArgs']; - return backend.preluDer(config.inputs.x, config.inputs.alpha) as O; - } else if (kernelName === 'Elu') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.elu(config.inputs.x) as O; - } else if (kernelName === 'EluDer') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.eluDer(config.inputs.x) as O; - } else if (kernelName === 'Selu') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.selu(config.inputs.x) as O; - } else if (kernelName === 'Abs') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.abs(config.inputs.x) as O; - } else if (kernelName === 'Sigmoid') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.sigmoid(config.inputs.x) as O; - } else if (kernelName === 'Step') { - const config = inputAndArgs as StepNode['inputAndArgs']; - return backend.step(config.inputs.x, config.args.alpha) as O; - } else if (kernelName === 'Sin') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.sin(config.inputs.x) as O; - } else if (kernelName === 'Cos') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.cos(config.inputs.x) as O; - } else if (kernelName === 'Tan') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.tan(config.inputs.x) as O; - } else if (kernelName === 'Asin') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.asin(config.inputs.x) as O; - } else if (kernelName === 'Acos') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.acos(config.inputs.x) as O; - } else if (kernelName === 'Atan') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.atan(config.inputs.x) as O; - } else if (kernelName === 'Sinh') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.sinh(config.inputs.x) as O; - } else if (kernelName === 'Cosh') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.cosh(config.inputs.x) as O; - } else if (kernelName === 'Tanh') { - const config = inputAndArgs as UnaryNode['inputAndArgs']; - return backend.tanh(config.inputs.x) as O; - } else if (kernelName === 'Clip') { - const config = inputAndArgs as ClipNode['inputAndArgs']; - return backend.clip(config.inputs.x, config.args.min, config.args.max) as O; - } else if (kernelName === 'Tile') { - const config = inputAndArgs as TileNode['inputAndArgs']; - return backend.tile(config.inputs.x, config.args.reps) as O; - } else if (kernelName === 'Gather') { - const config = inputAndArgs as GatherNode['inputAndArgs']; - return backend.gather( - config.inputs.x, config.inputs.indices, config.args.axis) as O; - } else if (kernelName === 'Pad1D') { - const config = inputAndArgs as Pad1DNode['inputAndArgs']; - return backend.pad1D( - config.inputs.x, config.args.paddings, - config.args.constantValue) as O; - } else if (kernelName === 'Pad2D') { - const config = inputAndArgs as Pad2DNode['inputAndArgs']; - return backend.pad2D( - config.inputs.x, config.args.paddings, - config.args.constantValue) as O; - } else if (kernelName === 'Transpose') { - const config = inputAndArgs as TransposeNode['inputAndArgs']; - return backend.transpose(config.inputs.x, config.args.perm) as O; - } else if (kernelName === 'Conv2D') { - const config = inputAndArgs as Conv2DNode['inputAndArgs']; - return backend.conv2d( - config.inputs.x, config.inputs.filter, config.args.convInfo) as - O; - } else if (kernelName === 'Conv2DDerInput') { - const config = inputAndArgs as Conv2DDerInputNode['inputAndArgs']; - return backend.conv2dDerInput( - config.inputs.dy, config.inputs.filter, config.args.convInfo) as - O; - } else if (kernelName === 'Conv2DDerFilter') { - const config = inputAndArgs as Conv2DDerFilterNode['inputAndArgs']; - return backend.conv2dDerFilter( - config.inputs.x, config.inputs.dy, config.args.convInfo) as O; - } else if (kernelName === 'DepthwiseConv2D') { - const config = inputAndArgs as DepthwiseConv2DNode['inputAndArgs']; - return backend.depthwiseConv2D( - config.inputs.x, config.inputs.filter, config.args.convInfo) as - O; - } else if (kernelName === 'MaxPool') { - const config = inputAndArgs as PoolNode['inputAndArgs']; - return backend.maxPool(config.inputs.x, config.args.convInfo) as O; - } else if (kernelName === 'MaxPoolBackprop') { - const config = inputAndArgs as PoolBackpropNode['inputAndArgs']; - return backend.maxPoolBackprop( - config.inputs.dy, config.inputs.x, config.args.convInfo) as O; - } else if (kernelName === 'AvgPool') { - const config = inputAndArgs as PoolNode['inputAndArgs']; - return backend.avgPool(config.inputs.x, config.args.convInfo) as O; - } else if (kernelName === 'AvgPoolBackprop') { - const config = inputAndArgs as PoolBackpropNode['inputAndArgs']; - return backend.avgPoolBackprop( - config.inputs.dy, config.inputs.x, config.args.convInfo) as O; - } else if (kernelName === 'MinPool') { - const config = inputAndArgs as PoolNode['inputAndArgs']; - return backend.minPool(config.inputs.x, config.args.convInfo) as O; - } else if (kernelName === 'ResizeBilinear') { - const config = inputAndArgs as ResizeBilinearNode['inputAndArgs']; - return backend.resizeBilinear( - config.inputs.x, config.args.newHeight, config.args.newWidth, - config.args.alignCorners) as O; - } else if (kernelName === 'BatchNorm4D') { - const config = inputAndArgs as BatchNorm4DNode['inputAndArgs']; - return backend.batchNormalization4D( - config.inputs.x, config.inputs.mean, config.inputs.variance, - config.args.varianceEpsilon, config.inputs.scale, - config.inputs.offset) as O; - } else if (kernelName === 'LRN4D') { - const config = inputAndArgs as LRN4DNode['inputAndArgs']; - return backend.localResponseNormalization4D( - config.inputs.x, config.args.radius, config.args.bias, - config.args.alpha, config.args.beta, config.args.normRegion) as - O; - } else if (kernelName === 'Multinomial') { - const config = inputAndArgs as MultinomialNode['inputAndArgs']; - return backend.multinomial( - config.inputs.probs, config.args.numSamples, config.args.seed) as - O; - } else if (kernelName === 'OneHot') { - const config = inputAndArgs as OneHotNode['inputAndArgs']; - return backend.oneHot( - config.inputs.indices, config.args.depth, config.args.onValue, - config.args.offValue) as O; - } - throw new Error(`No backend method found for kernel ${kernelName}`); -} - -export interface KernelConfigRegistry { - MatMul: MatMulNode; - Slice1D: Slice1DNode; - Slice2D: Slice2DNode; - Slice3D: Slice3DNode; - Slice4D: Slice4DNode; - Reverse4D: Reverse4DNode; - Concat: ConcatNode; - Neg: UnaryNode; - Add: BinaryNode; - Sub: BinaryNode; - Mul: BinaryNode; - Div: BinaryNode; - Sum: SumNode; - ArgMax: ArgMaxNode; - ArgMin: ArgMinNode; - Equal: EqualNode; - NotEqual: EqualNode; - Less: EqualNode; - LessEqual: EqualNode; - Greater: EqualNode; - GreaterEqual: EqualNode; - LogicalNot: UnaryNode; - LogicalAnd: LogicalNode; - LogicalOr: LogicalNode; - LogicalXor: LogicalNode; - Where: WhereNode; - TopKValues: TopKValuesNode; - TopKIndices: TopKIndicesNode; - Min: MinNode; - Minimum: MinimumNode; - Max: MaxNode; - Maximum: MaximumNode; - Ceil: UnaryNode; - Floor: UnaryNode; - Pow: PowNode; - Exp: UnaryNode; - Log: UnaryNode; - Sqrt: UnaryNode; - Square: UnaryNode; - Relu: UnaryNode; - LeakyRelu: LeakyReluNode; - PReLU: PReLUNode; - PReLUDer: PReLUNode; - Reshape: ReshapeNode; - Cast: CastNode; - Elu: UnaryNode; - EluDer: UnaryNode; - Selu: UnaryNode; - Abs: UnaryNode; - Sigmoid: UnaryNode; - Step: StepNode; - Sin: UnaryNode; - Cos: UnaryNode; - Tan: UnaryNode; - Asin: UnaryNode; - Acos: UnaryNode; - Atan: UnaryNode; - Sinh: UnaryNode; - Cosh: UnaryNode; - Tanh: UnaryNode; - Clip: ClipNode; - Transpose: TransposeNode; - Pad1D: Pad1DNode; - Pad2D: Pad2DNode; - Tile: TileNode; - Gather: GatherNode; - Conv2D: Conv2DNode; - Conv2DDerInput: Conv2DDerInputNode; - Conv2DDerFilter: Conv2DDerFilterNode; - DepthwiseConv2D: Conv2DNode; - MaxPool: PoolNode; - MaxPoolBackprop: PoolBackpropNode; - AvgPool: PoolNode; - AvgPoolBackprop: PoolBackpropNode; - MinPool: PoolNode; - ResizeBilinear: ResizeBilinearNode; - BatchNorm4D: BatchNorm4DNode; - LRN4D: LRN4DNode; - Multinomial: MultinomialNode; - OneHot: OneHotNode; -} -export type Kernel = keyof KernelConfigRegistry; diff --git a/src/kernels/types/argminmax.ts b/src/kernels/types/argminmax.ts deleted file mode 100644 index 466858c183..0000000000 --- a/src/kernels/types/argminmax.ts +++ /dev/null @@ -1,35 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; - -export interface ArgMaxNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor;}; args: {axes: number[];};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor; - }; -} - -export interface ArgMinNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor;}; args: {axes: number[];};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor; - }; -} diff --git a/src/kernels/types/batchnorm.ts b/src/kernels/types/batchnorm.ts deleted file mode 100644 index ed9a5f149c..0000000000 --- a/src/kernels/types/batchnorm.ts +++ /dev/null @@ -1,39 +0,0 @@ - -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor1D, Tensor4D} from '../../tensor'; - -export interface BatchNorm4DNode extends KernelNode { - inputAndArgs: { - inputs: { - x: Tensor4D; mean: Tensor4D | Tensor1D; variance: Tensor4D | Tensor1D; - scale?: Tensor4D | Tensor1D; - offset?: Tensor4D | Tensor1D; - }; - args: {varianceEpsilon: number}; - }; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - mean: () => Tensor4D | Tensor1D; - variance: () => Tensor4D | Tensor1D; - scale?: () => Tensor4D | Tensor1D; - offset?: () => Tensor4D | Tensor1D; - }; -} diff --git a/src/kernels/types/binary.ts b/src/kernels/types/binary.ts deleted file mode 100644 index c0dbb16712..0000000000 --- a/src/kernels/types/binary.ts +++ /dev/null @@ -1,28 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; - -export interface BinaryNode extends KernelNode { - inputAndArgs: {inputs: {a: Tensor; b: Tensor;};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - a: () => Tensor; - b: () => Tensor; - }; -} diff --git a/src/kernels/types/cast.ts b/src/kernels/types/cast.ts deleted file mode 100644 index 4aad261456..0000000000 --- a/src/kernels/types/cast.ts +++ /dev/null @@ -1,28 +0,0 @@ -/** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; -import {DataType} from '../../types'; - -export interface CastNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor}; args: {newDType: DataType};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor - }; -} diff --git a/src/kernels/types/concat.ts b/src/kernels/types/concat.ts deleted file mode 100644 index 8b29ed3a62..0000000000 --- a/src/kernels/types/concat.ts +++ /dev/null @@ -1,28 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor2D} from '../../tensor'; - -export interface ConcatNode extends KernelNode { - inputAndArgs: {inputs: {a: Tensor2D; b: Tensor2D;};}; - output: Tensor2D; - gradient: (dy: Tensor2D, y: Tensor2D) => { - a: () => Tensor2D; - b: () => Tensor2D; - }; -} diff --git a/src/kernels/types/conv.ts b/src/kernels/types/conv.ts deleted file mode 100644 index 31b3daeb00..0000000000 --- a/src/kernels/types/conv.ts +++ /dev/null @@ -1,63 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {Conv2DInfo} from '../../ops/conv_util'; -import {KernelNode} from '../../tape_types'; -import {Tensor4D} from '../../tensor'; - -export interface Conv2DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor4D; filter: Tensor4D;}; args: {convInfo: Conv2DInfo;}; - }; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - filter: () => Tensor4D; - }; -} - -export interface Conv2DDerInputNode extends KernelNode { - inputAndArgs: { - inputs: {dy: Tensor4D; filter: Tensor4D;}; args: {convInfo: Conv2DInfo;}; - }; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - dy: () => Tensor4D; - filter: () => Tensor4D; - }; -} - -export interface Conv2DDerFilterNode extends KernelNode { - inputAndArgs: - {inputs: {x: Tensor4D; dy: Tensor4D;}; args: {convInfo: Conv2DInfo;};}; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - dy: () => Tensor4D; - }; -} - -export interface DepthwiseConv2DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor4D; filter: Tensor4D;}; args: {convInfo: Conv2DInfo;}; - }; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - filter: () => Tensor4D; - }; -} diff --git a/src/kernels/types/gather.ts b/src/kernels/types/gather.ts deleted file mode 100644 index bd5dc82fd3..0000000000 --- a/src/kernels/types/gather.ts +++ /dev/null @@ -1,30 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor, Tensor1D} from '../../tensor'; -import {Rank} from '../../types'; - -export interface GatherNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T; indices: Tensor1D;}; args: {axis: number};}; - output: T; - gradient: (dy: Tensor, y: T) => { - x: () => Tensor; - indices: () => Tensor1D; - }; -} diff --git a/src/kernels/types/logical.ts b/src/kernels/types/logical.ts deleted file mode 100644 index f88d9d3cb6..0000000000 --- a/src/kernels/types/logical.ts +++ /dev/null @@ -1,54 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; -import {DataType} from '../../types'; - -// Equal/NotEqual/Less/LessEqual/Greater/GreaterEqual -export interface EqualNode extends KernelNode { - inputAndArgs: {inputs: {a: Tensor; b: Tensor;};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - a: () => Tensor; - b: () => Tensor; - }; -} - -// LogicalAnd/LogicalOr/LogicalXor -export interface LogicalNode extends KernelNode { - inputAndArgs: {inputs: {a: Tensor; b: Tensor;};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - a: () => Tensor; - b: () => Tensor; - }; -} - -// Where -export interface WhereNode extends KernelNode { - inputAndArgs: { - inputs: {condition: Tensor; a: Tensor; b: Tensor;}; - args: {dtype: DataType}; - }; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - condition: () => Tensor; - a: () => Tensor; - b: () => Tensor; - }; -} diff --git a/src/kernels/types/lrn.ts b/src/kernels/types/lrn.ts deleted file mode 100644 index c6cad957b2..0000000000 --- a/src/kernels/types/lrn.ts +++ /dev/null @@ -1,37 +0,0 @@ - -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor4D} from '../../tensor'; - -// 4D -export interface LRN4DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor4D;}; args: { - radius: number, - bias: number, - alpha: number, - beta: number, - normRegion: 'acrossChannels'|'withinChannel' - }; - }; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - }; -} diff --git a/src/kernels/types/matmul.ts b/src/kernels/types/matmul.ts deleted file mode 100644 index 833ae3a69d..0000000000 --- a/src/kernels/types/matmul.ts +++ /dev/null @@ -1,36 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor2D} from '../../tensor'; - -export interface MatMulNode extends KernelNode { - inputAndArgs: { - inputs: {a: Tensor2D; b: Tensor2D;}; - args: {transposeA: boolean; transposeB: boolean}; - }; - output: Tensor2D; - gradient: (dy: Tensor2D, y: Tensor2D) => { - a: () => Tensor2D; - b: () => Tensor2D; - }; -} - -export enum MatrixOrientation { - REGULAR, - TRANSPOSED -} diff --git a/src/kernels/types/minmax.ts b/src/kernels/types/minmax.ts deleted file mode 100644 index a14d38f732..0000000000 --- a/src/kernels/types/minmax.ts +++ /dev/null @@ -1,55 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; - -// Reduction min. -export interface MinNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor;}; args: {axes: number[]}}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor; - }; -} - -// Element-wise min. -export interface MinimumNode extends KernelNode { - inputAndArgs: {inputs: {a: Tensor, b: Tensor};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - a: () => Tensor, b: () => Tensor - }; -} - -// Reduction Max -export interface MaxNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor;}; args: {axes: number[]}}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor; - }; -} - -// Element-wise max. -export interface MaximumNode extends KernelNode { - inputAndArgs: {inputs: {a: Tensor, b: Tensor};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - a: () => Tensor, b: () => Tensor - }; -} diff --git a/src/kernels/types/multinomial.ts b/src/kernels/types/multinomial.ts deleted file mode 100644 index 640ae1978f..0000000000 --- a/src/kernels/types/multinomial.ts +++ /dev/null @@ -1,29 +0,0 @@ - -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor2D} from '../../tensor'; - -export interface MultinomialNode extends KernelNode { - inputAndArgs: - {inputs: {probs: Tensor2D;}; args: {numSamples: number; seed: number};}; - output: Tensor2D; - gradient: (dy: Tensor2D, y: Tensor2D) => { - probs: () => Tensor2D; - }; -} diff --git a/src/kernels/types/onehot.ts b/src/kernels/types/onehot.ts deleted file mode 100644 index df9b4d9734..0000000000 --- a/src/kernels/types/onehot.ts +++ /dev/null @@ -1,30 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor1D, Tensor2D} from '../../tensor'; - -export interface OneHotNode extends KernelNode { - inputAndArgs: { - inputs: {indices: Tensor1D;}; - args: {depth: number; onValue: number; offValue: number}; - }; - output: Tensor2D; - gradient: (dy: Tensor2D, y: Tensor2D) => { - indices: () => Tensor1D; - }; -} diff --git a/src/kernels/types/pad.ts b/src/kernels/types/pad.ts deleted file mode 100644 index c4347131f4..0000000000 --- a/src/kernels/types/pad.ts +++ /dev/null @@ -1,43 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor1D, Tensor2D} from '../../tensor'; - -export interface Pad1DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor1D;}; - args: {paddings: [number, number], constantValue: number}; - }; - output: Tensor1D; - gradient: (dy: Tensor1D, y: Tensor1D) => { - x: () => Tensor1D; - }; -} - -export interface Pad2DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor2D;}; args: { - paddings: [[number, number], [number, number]], - constantValue: number - }; - }; - output: Tensor2D; - gradient: (dy: Tensor2D, y: Tensor2D) => { - x: () => Tensor2D; - }; -} diff --git a/src/kernels/types/pool.ts b/src/kernels/types/pool.ts deleted file mode 100644 index fce01174e6..0000000000 --- a/src/kernels/types/pool.ts +++ /dev/null @@ -1,40 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {Conv2DInfo} from '../../ops/conv_util'; -import {KernelNode} from '../../tape_types'; -import {Tensor4D} from '../../tensor'; - -// Pool -export interface PoolNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor4D;}; args: {convInfo: Conv2DInfo;};}; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - }; -} - -// PoolBackprop -export interface PoolBackpropNode extends KernelNode { - inputAndArgs: - {inputs: {dy: Tensor4D; x: Tensor4D;}; args: {convInfo: Conv2DInfo;};}; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - dy: () => Tensor4D; - x: () => Tensor4D; - }; -} diff --git a/src/kernels/types/pow.ts b/src/kernels/types/pow.ts deleted file mode 100644 index ed70227cf4..0000000000 --- a/src/kernels/types/pow.ts +++ /dev/null @@ -1,30 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; -import {Rank} from '../../types'; - -export interface PowNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {base: T; exp: Tensor;};}; - output: T; - gradient: (dy: Tensor, y: T) => { - base: () => Tensor; - exp: () => Tensor; - }; -} diff --git a/src/kernels/types/prelu.ts b/src/kernels/types/prelu.ts deleted file mode 100644 index f279f7e85c..0000000000 --- a/src/kernels/types/prelu.ts +++ /dev/null @@ -1,31 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; -import {Rank} from '../../types'; - -// PReLU -export interface PReLUNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T; alpha: T;};}; - output: T; - gradient: (dy: Tensor, y: T) => { - x: () => Tensor; - alpha: () => Tensor; - }; -} diff --git a/src/kernels/types/reshape.ts b/src/kernels/types/reshape.ts deleted file mode 100644 index 374c1f5470..0000000000 --- a/src/kernels/types/reshape.ts +++ /dev/null @@ -1,27 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; - -export interface ReshapeNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor}; args: {newShape: number[]};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor - }; -} diff --git a/src/kernels/types/resize_bilinear.ts b/src/kernels/types/resize_bilinear.ts deleted file mode 100644 index 43f94f510e..0000000000 --- a/src/kernels/types/resize_bilinear.ts +++ /dev/null @@ -1,30 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor4D} from '../../tensor'; - -export interface ResizeBilinearNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor4D;}; - args: {newHeight: number; newWidth: number; alignCorners: boolean}; - }; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - }; -} diff --git a/src/kernels/types/reverse.ts b/src/kernels/types/reverse.ts deleted file mode 100644 index 05ebf6231e..0000000000 --- a/src/kernels/types/reverse.ts +++ /dev/null @@ -1,27 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor4D} from '../../tensor'; - -export interface Reverse4DNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor4D;}; args: {axis: number[];};}; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - }; -} diff --git a/src/kernels/types/slice.ts b/src/kernels/types/slice.ts deleted file mode 100644 index 87ef49dd30..0000000000 --- a/src/kernels/types/slice.ts +++ /dev/null @@ -1,62 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../../tensor'; - -export interface Slice1DNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor1D;}; args: {begin: number; size: number;};}; - output: Tensor1D; - gradient: (dy: Tensor1D, y: Tensor1D) => { - x: () => Tensor1D; - }; -} - -export interface Slice2DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor2D;}; - args: {begin: [number, number]; size: [number, number];}; - }; - output: Tensor2D; - gradient: (dy: Tensor2D, y: Tensor2D) => { - x: () => Tensor2D; - }; -} - -export interface Slice3DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor3D;}; - args: {begin: [number, number, number]; size: [number, number, number];}; - }; - output: Tensor3D; - gradient: (dy: Tensor3D, y: Tensor3D) => { - x: () => Tensor3D; - }; -} - -export interface Slice4DNode extends KernelNode { - inputAndArgs: { - inputs: {x: Tensor4D;}; args: { - begin: [number, number, number, number]; - size: [number, number, number, number]; - }; - }; - output: Tensor4D; - gradient: (dy: Tensor4D, y: Tensor4D) => { - x: () => Tensor4D; - }; -} diff --git a/src/kernels/types/sum.ts b/src/kernels/types/sum.ts deleted file mode 100644 index b5b9e557f0..0000000000 --- a/src/kernels/types/sum.ts +++ /dev/null @@ -1,27 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; - -export interface SumNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor;}; args: {axes: number[];};}; - output: Tensor; - gradient: (dy: Tensor, y: Tensor) => { - x: () => Tensor; - }; -} diff --git a/src/kernels/types/topk.ts b/src/kernels/types/topk.ts deleted file mode 100644 index 95acc7f81e..0000000000 --- a/src/kernels/types/topk.ts +++ /dev/null @@ -1,39 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor, Tensor1D} from '../../tensor'; -import {Rank} from '../../types'; - -// Values -export interface TopKValuesNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {k: number};}; - output: Tensor1D; - gradient: (dy: Tensor1D, y: Tensor1D) => { - x: () => T; - }; -} - -// Indices -export interface TopKIndicesNode extends KernelNode { - inputAndArgs: {inputs: {x: Tensor;}; args: {k: number};}; - output: Tensor1D; - gradient: (dy: Tensor1D, y: Tensor1D) => { - x: () => Tensor; - }; -} diff --git a/src/kernels/types/unary.ts b/src/kernels/types/unary.ts deleted file mode 100644 index 0b5e675c0e..0000000000 --- a/src/kernels/types/unary.ts +++ /dev/null @@ -1,73 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelNode} from '../../tape_types'; -import {Tensor} from '../../tensor'; -import {Rank} from '../../types'; - -export interface UnaryNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} - -export interface LeakyReluNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {alpha: number;};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} -export interface StepNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {alpha: number;};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} - -export interface ClipNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {min: number; max: number;};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} - -export interface TransposeNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {perm: number[];};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} - -export interface TileNode = Tensor> - extends KernelNode { - inputAndArgs: {inputs: {x: T;}; args: {reps: number[];};}; - output: T; - gradient: (dy: T, y: T) => { - x: () => T; - }; -} diff --git a/src/kernels/webgl/mulmat_packed_gpu.ts b/src/kernels/webgl/mulmat_packed_gpu.ts index df0998cab7..d114b12441 100644 --- a/src/kernels/webgl/mulmat_packed_gpu.ts +++ b/src/kernels/webgl/mulmat_packed_gpu.ts @@ -15,8 +15,7 @@ * ============================================================================= */ -import {MatrixOrientation} from '../types/matmul'; - +import {MatrixOrientation} from '../../ops/matmul'; import {GPGPUContext} from './gpgpu_context'; import * as webgl_util from './webgl_util'; diff --git a/src/kernels/webgl/mulmat_packed_gpu_test.ts b/src/kernels/webgl/mulmat_packed_gpu_test.ts index 029df69716..a090de1a45 100644 --- a/src/kernels/webgl/mulmat_packed_gpu_test.ts +++ b/src/kernels/webgl/mulmat_packed_gpu_test.ts @@ -16,9 +16,8 @@ */ // tslint:disable-next-line:max-line-length +import {MatrixOrientation} from '../../ops/matmul'; import {expectArraysClose, expectNumbersClose} from '../../test_util'; -import {MatrixOrientation} from '../types/matmul'; - import {GPGPUContext} from './gpgpu_context'; import * as mulmat_packed_gpu from './mulmat_packed_gpu'; diff --git a/src/math.ts b/src/math.ts index bb22a5c29e..ef1d5c9a51 100644 --- a/src/math.ts +++ b/src/math.ts @@ -36,7 +36,7 @@ import * as slice from './ops/slice'; import * as softmax_ops from './ops/softmax'; import * as transpose from './ops/transpose'; import * as unary_ops from './ops/unary_ops'; -import {ScopeResult} from './tape_util'; +import {ScopeResult} from './tape'; import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from './tensor'; import {Tracking} from './tracking'; import {Rank} from './types'; @@ -227,9 +227,8 @@ export class NDArrayMath { let values: Tensor1D; let indices: Tensor1D; tidy('topK', () => { - values = ENV.engine.executeKernel('TopKValues', {inputs: {x}, args: {k}}); - indices = - ENV.engine.executeKernel('TopKIndices', {inputs: {x}, args: {k}}); + values = ENV.engine.runKernel(backend => backend.topKValues(x, k)); + indices = ENV.engine.runKernel(backend => backend.topKIndices(x, k)); return values; }); const result = {values, indices}; @@ -261,7 +260,7 @@ export class NDArrayMath { /** @deprecated Use dl.transpose() instead. */ switchDim(x: Tensor, perm?: number[]): Tensor { - return ops.transpose(x, perm); + return ops.transpose(x, perm); } /** @deprecated Use dl.add(c, A) instead. */ diff --git a/src/ops/array_ops.ts b/src/ops/array_ops.ts index 926bcb51f9..47443bdac2 100644 --- a/src/ops/array_ops.ts +++ b/src/ops/array_ops.ts @@ -16,6 +16,7 @@ */ import {doc} from '../doc'; +import {ForwardFunc} from '../engine'; import {ENV} from '../environment'; // tslint:disable-next-line:max-line-length import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer} from '../tensor'; @@ -473,29 +474,24 @@ export class Ops { probabilities: Tensor1D|Tensor2D, numSamples: number, seed?: number): Tensor1D|Tensor2D { const numOutcomes = probabilities.size; + const origRank = probabilities.rank; if (numOutcomes < 2) { throw new Error( `Error in multinomial: you need at least 2 outcomes, but got ` + `${numOutcomes}.`); } - if (probabilities.rank > 2) { + if (origRank > 2) { throw new Error( - `Rank of probabilities must be 1 or 2, but is ${probabilities.rank}`); + `Rank of probabilities must be 1 or 2, but is ${origRank}`); } seed = seed || Math.random(); - const origRank = probabilities.rank; - if (probabilities.rank === 1) { - probabilities = probabilities.as2D(1, -1); - } - const res = ENV.engine.executeKernel('Multinomial', { - inputs: {probs: (probabilities as Tensor2D)}, - args: {numSamples, seed} - }); - if (origRank === 1) { - return res.as1D(); - } - return res; + const prob2D = + origRank === 1 ? probabilities.as2D(1, -1) : probabilities as Tensor2D; + const res = ENV.engine.runKernel( + backend => backend.multinomial(prob2D, numSamples, seed)); + + return origRank === 1 ? res.as1D() : res; } /** @@ -521,8 +517,8 @@ export class Ops { if (depth < 2) { throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`); } - return ENV.engine.executeKernel( - 'OneHot', {inputs: {indices}, args: {depth, onValue, offValue}}); + return ENV.engine.runKernel( + backend => backend.oneHot(indices, depth, onValue, offValue)); } /** @@ -589,12 +585,11 @@ export class Ops { x.size === util.sizeFromShape(shape), 'new shape and old shape must have the same number of elements.'); - const grad = (dy: Tensor, y: Tensor) => { + const grad = (dy: Tensor) => { return {x: () => dy.reshape(x.shape)}; }; - return ENV.engine.executeKernel( - 'Reshape', {inputs: {x}, args: {newShape: shape}}, grad) as - Tensor; + return ENV.engine.runKernel( + backend => Tensor.make(shape, {dataId: x.dataId}, x.dtype), {x}, grad); } /** @@ -626,11 +621,24 @@ export class Ops { @doc({heading: 'Tensors', subheading: 'Transformations'}) @operation static cast(x: T, dtype: DataType): T { - const grad = (dy: T, y: T) => { - return {x: () => dy.reshape(dy.shape)}; + const forw: ForwardFunc = backend => { + if (!util.hasEncodingLoss(x.dtype, dtype)) { + // We don't change the underlying data, since we cast to higher + // precision. + return Tensor.make(x.shape, {dataId: x.dataId}, dtype) as T; + } + if (dtype === 'int32') { + return backend.int(x); + } else if (dtype === 'bool') { + return backend.notEqual(x, Ops.scalar(0, x.dtype)) as T; + } else { + throw new Error(`Error in Cast: unknown dtype argument (${dtype})`); + } + }; + const grad = (dy: T) => { + return {x: () => dy.clone()}; }; - return ENV.engine.executeKernel( - 'Cast', {inputs: {x}, args: {newDType: dtype}}, grad) as T; + return ENV.engine.runKernel(forw, {x}, grad) as T; } /** @@ -663,7 +671,7 @@ export class Ops { x.rank === reps.length, `Error in transpose: rank of input ${x.rank} ` + `must match length of reps ${reps}.`); - return ENV.engine.executeKernel('Tile', {inputs: {x}, args: {reps}}) as T; + return ENV.engine.runKernel(backend => backend.tile(x, reps)); } /** @@ -689,8 +697,7 @@ export class Ops { @doc({heading: 'Tensors', subheading: 'Slicing and Joining'}) @operation static gather(x: T, indices: Tensor1D, axis = 0): T { - return ENV.engine.executeKernel( - 'Gather', {inputs: {x, indices}, args: {axis}}) as T; + return ENV.engine.runKernel(backend => backend.gather(x, indices, axis)); } /** @@ -711,8 +718,8 @@ export class Ops { util.assert( paddings.length === 2, 'Invalid number of paddings. Must be length of 2.'); - return ENV.engine.executeKernel( - 'Pad1D', {inputs: {x}, args: {paddings, constantValue}}); + return ENV.engine.runKernel( + backend => backend.pad1D(x, paddings, constantValue)); } /** @@ -734,8 +741,8 @@ export class Ops { paddings.length === 2 && paddings[0].length === 2 && paddings[1].length === 2, 'Invalid number of paddings. Must be length of 2 each.'); - return ENV.engine.executeKernel( - 'Pad2D', {inputs: {x}, args: {paddings, constantValue}}); + return ENV.engine.runKernel( + backend => backend.pad2D(x, paddings, constantValue)); } /** diff --git a/src/ops/batchnorm.ts b/src/ops/batchnorm.ts index 2847fb605d..3fdd008cae 100644 --- a/src/ops/batchnorm.ts +++ b/src/ops/batchnorm.ts @@ -193,18 +193,12 @@ export class Ops { x4D = x as Tensor4D; } - return ENV.engine - .executeKernel('BatchNorm4D', { - inputs: { - x: x4D, - mean: batchnormReshape4D(mean), - variance: batchnormReshape4D(variance), - scale: batchnormReshape4D(scale), - offset: batchnormReshape4D(offset) - }, - args: {varianceEpsilon} - }) - .reshape(x.shape) as Tensor; + const res = ENV.engine.runKernel( + backend => backend.batchNormalization4D( + x4D, batchnormReshape4D(mean), batchnormReshape4D(variance), + varianceEpsilon, batchnormReshape4D(scale), + batchnormReshape4D(offset))); + return res.reshape(x.shape); } } diff --git a/src/ops/binary_ops.ts b/src/ops/binary_ops.ts index 31f5e3eb52..c580e0fba6 100644 --- a/src/ops/binary_ops.ts +++ b/src/ops/binary_ops.ts @@ -55,7 +55,7 @@ export class Ops { const outShape = broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - const der = (dy: Tensor, y: Tensor) => { + const der = (dy: Tensor) => { const derA = () => { let res = dy; const reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); @@ -74,7 +74,7 @@ export class Ops { }; return {a: derA, b: derB}; }; - return ENV.engine.executeKernel('Add', {inputs: {a, b}}, der) as T; + return ENV.engine.runKernel(backend => backend.add(a, b), {a, b}, der) as T; } /** @@ -121,7 +121,7 @@ export class Ops { const outShape = broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - const der = (dy: Tensor, y: Tensor) => { + const der = (dy: Tensor) => { const derA = () => { let res = dy; const reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); @@ -140,7 +140,8 @@ export class Ops { }; return {a: derA, b: derB}; }; - return ENV.engine.executeKernel('Sub', {inputs: {a, b}}, der) as T; + return ENV.engine.runKernel( + backend => backend.subtract(a, b), {a, b}, der) as T; } /** @@ -185,31 +186,27 @@ export class Ops { */ @doc({heading: 'Operations', subheading: 'Arithmetic'}) @operation - static pow(base: Tensor, exp: Tensor): T { + static pow(base: T, exp: Tensor): T { util.assert( exp.dtype === 'int32', 'only supports int32 data type for the exponent parameter.'); broadcast_util.assertAndGetBroadcastShape(base.shape, exp.shape); - const gradient = (dy: Tensor, y: Tensor) => { + const grad = (dy: Tensor) => { if (!util.arraysEqual(base.shape, exp.shape) && !util.isScalarShape(exp.shape)) { throw new Error( `Gradient of pow not yet supported for broadcasted shapes.`); } const derBase = () => { - const dx = - exp.toFloat().mul(base.pow(exp.sub(scalar(1, 'int32'))).toFloat()); - return dy.mul(dx); + const dx = exp.toFloat().mul( + base.pow(exp.sub(scalar(1, 'int32'))).toFloat()) as T; + return dy.mulStrict(dx) as T; }; - const derExp = () => { - throw new Error(`Backprop through exponent not implemented yet.`); - }; - return {base: derBase, exp: derExp}; + return {base: derBase}; }; - - return ENV.engine.executeKernel('Pow', {inputs: {base, exp}}, gradient) as - T; + return ENV.engine.runKernel( + backend => backend.pow(base, exp), {base}, grad) as T; } /** @@ -257,7 +254,7 @@ export class Ops { const outShape = broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - const der = (dy: Tensor, y: Tensor) => { + const der = (dy: Tensor) => { const derA = () => { const res = dy.mul(b.toFloat()); const reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); @@ -276,7 +273,8 @@ export class Ops { }; return {a: derA, b: derB}; }; - return ENV.engine.executeKernel('Mul', {inputs: {a, b}}, der) as T; + return ENV.engine.runKernel( + backend => backend.multiply(a, b), {a, b}, der) as T; } /** @@ -321,7 +319,7 @@ export class Ops { static div(a: Tensor, b: Tensor): T { const outShape = broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - const der = (dy: Tensor, y: Tensor) => { + const der = (dy: Tensor) => { const derA = () => { const res = dy.div(b.toFloat()); const reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); @@ -341,7 +339,8 @@ export class Ops { }; return {a: derA, b: derB}; }; - return ENV.engine.executeKernel('Div', {inputs: {a, b}}, der) as T; + return ENV.engine.runKernel(backend => backend.divide(a, b), {a, b}, der) as + T; } /** @@ -386,12 +385,13 @@ export class Ops { static minimum(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - const der = (dy: Tensor, y: Tensor) => { + const der = (dy: Tensor) => { const derA = () => dy.mul(a.lessEqual(b).toFloat()); const derB = () => dy.mul(a.greater(b).toFloat()); return {a: derA, b: derB}; }; - return ENV.engine.executeKernel('Minimum', {inputs: {a, b}}, der) as T; + return ENV.engine.runKernel( + backend => backend.minimum(a, b), {a, b}, der) as T; } /** @@ -436,12 +436,13 @@ export class Ops { static maximum(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - const der = (dy: Tensor, y: Tensor) => { + const der = (dy: Tensor) => { const derA = () => dy.mul(a.greaterEqual(b).toFloat()); const derB = () => dy.mul(a.less(b).toFloat()); return {a: derA, b: derB}; }; - return ENV.engine.executeKernel('Maximum', {inputs: {a, b}}, der) as T; + return ENV.engine.runKernel( + backend => backend.maximum(a, b), {a, b}, der) as T; } /** diff --git a/src/ops/compare.ts b/src/ops/compare.ts index 295cd101a9..7c78f70bfe 100644 --- a/src/ops/compare.ts +++ b/src/ops/compare.ts @@ -37,7 +37,7 @@ export class Ops { static notEqual(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('NotEqual', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.notEqual(a, b)) as T; } /** @@ -68,7 +68,7 @@ export class Ops { static less(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('Less', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.less(a, b)) as T; } /** @@ -99,7 +99,7 @@ export class Ops { static equal(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('Equal', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.equal(a, b)) as T; } @operation @@ -122,7 +122,7 @@ export class Ops { static lessEqual(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('LessEqual', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.lessEqual(a, b)) as T; } @operation @@ -145,7 +145,7 @@ export class Ops { static greater(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('Greater', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.greater(a, b)) as T; } @operation @@ -168,7 +168,7 @@ export class Ops { static greaterEqual(a: Tensor, b: Tensor): T { util.assertTypesMatch(a, b); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('GreaterEqual', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.greaterEqual(a, b)) as T; } @operation diff --git a/src/ops/concat.ts b/src/ops/concat.ts index 1e5427eb00..b856e1f2be 100644 --- a/src/ops/concat.ts +++ b/src/ops/concat.ts @@ -168,7 +168,7 @@ function concat2Tensors(a: T, b: T, axis: number): T { const der = (dy: Tensor2D) => { return {a: () => dy.slice(aBegin, aSize), b: () => dy.slice(bBegin, bSize)}; }; - const res = - ENV.engine.executeKernel('Concat', {inputs: {a: a2D, b: b2D}}, der); + const res = ENV.engine.runKernel( + backend => backend.concat(a2D, b2D), {a: a2D, b: b2D}, der); return res.reshape(outShape) as T; } diff --git a/src/ops/conv.ts b/src/ops/conv.ts index 72ce543af9..36698adddb 100644 --- a/src/ops/conv.ts +++ b/src/ops/conv.ts @@ -143,15 +143,16 @@ export class Ops { const convInfo = conv_util.computeConv2DInfo( x4D.shape, filter.shape, strides, pad, dimRoundingMode); - const gradients = (dy: Tensor4D, y: Tensor4D) => { + const grad = (dy: Tensor4D) => { return { x: () => Ops.conv2dDerInput(x4D.shape, dy, filter, strides, pad), filter: () => Ops.conv2dDerFilter(x4D, dy, filter.shape, strides, pad) }; }; - const res = ENV.engine.executeKernel( - 'Conv2D', {inputs: {x: x4D, filter}, args: {convInfo}}, gradients); + const res = ENV.engine.runKernel( + backend => backend.conv2d(x4D, filter, convInfo), {x: x4D, filter}, + grad); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } @@ -229,8 +230,8 @@ export class Ops { const convInfo = conv_util.computeConv2DInfo( xShape4D, filter.shape, strides, pad, dimRoundingMode); - const res = ENV.engine.executeKernel( - 'Conv2DDerInput', {inputs: {dy: dy4D, filter}, args: {convInfo}}); + const res = ENV.engine.runKernel( + backend => backend.conv2dDerInput(dy4D, filter, convInfo)); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } @@ -297,8 +298,8 @@ export class Ops { const convInfo = conv_util.computeConv2DInfo( x4D.shape, filterShape, strides, pad, dimRoundingMode); - return ENV.engine.executeKernel( - 'Conv2DDerFilter', {inputs: {x: x4D, dy: dy4D}, args: {convInfo}}); + return ENV.engine.runKernel( + backend => backend.conv2dDerFilter(x4D, dy4D, convInfo)); } /** @@ -412,8 +413,8 @@ export class Ops { const convInfo = conv_util.computeConv2DInfo( input4D.shape, filter.shape, strides, pad, dimRoundingMode, true /* depthwise */); - const res = ENV.engine.executeKernel( - 'DepthwiseConv2D', {inputs: {x: input4D, filter}, args: {convInfo}}); + const res = ENV.engine.runKernel( + backend => backend.depthwiseConv2D(input4D, filter, convInfo)); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } diff --git a/src/ops/image_ops.ts b/src/ops/image_ops.ts index 625bd23688..70da78b434 100644 --- a/src/ops/image_ops.ts +++ b/src/ops/image_ops.ts @@ -54,9 +54,9 @@ export class Ops { images.as4D(1, images.shape[0], images.shape[1], images.shape[2]); } const [newHeight, newWidth] = size; - const res = ENV.engine.executeKernel( - 'ResizeBilinear', - {inputs: {x: batchImages}, args: {newHeight, newWidth, alignCorners}}); + const res = ENV.engine.runKernel( + backend => backend.resizeBilinear( + batchImages, newHeight, newWidth, alignCorners)); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } diff --git a/src/ops/logical_ops.ts b/src/ops/logical_ops.ts index c4a4a95fa1..61522d7df9 100644 --- a/src/ops/logical_ops.ts +++ b/src/ops/logical_ops.ts @@ -19,7 +19,6 @@ import {doc} from '../doc'; import {ENV} from '../environment'; import {Tensor} from '../tensor'; import * as types from '../types'; -import {DataType} from '../types'; import * as util from '../util'; import * as broadcast_util from './broadcast_util'; import {operation} from './operation'; @@ -32,9 +31,9 @@ export class Ops { */ @doc({heading: 'Operations', subheading: 'Logical'}) @operation - static logicalNot(x: Tensor): T { + static logicalNot(x: T): T { util.assert(x.dtype === 'bool', 'Error Array must be of type bool.'); - return ENV.engine.executeKernel('LogicalNot', {inputs: {x}}) as T; + return ENV.engine.runKernel(backend => backend.logicalNot(x)); } /** @@ -50,7 +49,7 @@ export class Ops { a.dtype === 'bool' && b.dtype === 'bool', 'Error Array must be of type bool.'); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('LogicalAnd', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.logicalAnd(a, b)) as T; } /** @@ -66,7 +65,7 @@ export class Ops { a.dtype === 'bool' && b.dtype === 'bool', 'Error Array must be of type bool.'); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('LogicalOr', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.logicalOr(a, b)) as T; } /** @@ -82,7 +81,7 @@ export class Ops { a.dtype === 'bool' && b.dtype === 'bool', 'Error Array must be of type bool.'); broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); - return ENV.engine.executeKernel('LogicalXor', {inputs: {a, b}}) as T; + return ENV.engine.runKernel(backend => backend.logicalXor(a, b)) as T; } /** @@ -117,9 +116,7 @@ export class Ops { // Default to highest percision of number: const dtype = types.upcastType(a.dtype, b.dtype); - return ENV.engine.executeKernel( - 'Where', - {inputs: {condition, a, b}, args: {dtype: dtype as DataType}}) as - T; + return ENV.engine.runKernel( + backend => backend.where(condition, a, b, dtype)) as T; } } diff --git a/src/ops/lrn.ts b/src/ops/lrn.ts index a8c7330974..461474a327 100644 --- a/src/ops/lrn.ts +++ b/src/ops/lrn.ts @@ -56,9 +56,9 @@ export class LRN { reshapedTo4D = true; x4D = x.as4D(1, x.shape[0], x.shape[1], x.shape[2]); } - const res = ENV.engine.executeKernel( - 'LRN4D', - {inputs: {x: x4D}, args: {radius, bias, alpha, beta, normRegion}}); + const res = ENV.engine.runKernel( + backend => backend.localResponseNormalization4D( + x4D, radius, bias, alpha, beta, normRegion)); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } else { diff --git a/src/ops/matmul.ts b/src/ops/matmul.ts index 2ebf869be1..992c4510af 100644 --- a/src/ops/matmul.ts +++ b/src/ops/matmul.ts @@ -17,11 +17,16 @@ import {doc} from '../doc'; import {ENV} from '../environment'; -import {MatrixOrientation} from '../kernels/types/matmul'; import {Scalar, Tensor1D, Tensor2D} from '../tensor'; import * as util from '../util'; import {operation} from './operation'; +/** @deprecated Use bools transposeA and transposeB when calling matmul() */ +export enum MatrixOrientation { + REGULAR, + TRANSPOSED +} + export class Ops { /** * Computes the dot product of two matrices, A * B. These must be matrices. @@ -60,18 +65,17 @@ export class Ops { `${b.shape} and transposeA=${transposeA}` + ` and transposeB=${transposeB} must match.`); - return ENV.engine.executeKernel( - 'MatMul', {inputs: {a, b}, args: {transposeA, transposeB}}, - (dy: Tensor2D, y: Tensor2D) => { - if (transposeA || transposeB) { - throw new Error( - `Backprop for transposed MatMul not yet implemented.`); - } - return { - a: () => dy.matMul(b.toFloat(), false, true) as Tensor2D, - b: () => a.toFloat().matMul(dy, true, false) as Tensor2D - }; - }) as Tensor2D; + const grad = (dy: Tensor2D) => { + if (transposeA || transposeB) { + throw new Error(`Backprop for transposed MatMul not yet implemented.`); + } + return { + a: () => dy.matMul(b.toFloat(), false, true), + b: () => a.toFloat().matMul(dy, true, false) + }; + }; + return ENV.engine.runKernel( + backend => backend.matMul(a, b, transposeA, transposeB), {a, b}, grad); } /** diff --git a/src/ops/matmul_test.ts b/src/ops/matmul_test.ts index 78205efd35..02955b558e 100644 --- a/src/ops/matmul_test.ts +++ b/src/ops/matmul_test.ts @@ -19,7 +19,6 @@ import * as dl from '../index'; // tslint:disable-next-line:max-line-length import {ALL_ENVS, describeWithFlags, expectArraysClose, expectNumbersClose, WEBGL_ENVS} from '../test_util'; import {Rank} from '../types'; - import {Ops as MatmulOps} from './matmul'; describeWithFlags('matmul', ALL_ENVS, () => { diff --git a/src/ops/pool.ts b/src/ops/pool.ts index 305b320f73..8b70f892c5 100644 --- a/src/ops/pool.ts +++ b/src/ops/pool.ts @@ -66,11 +66,11 @@ export class Ops { const convInfo = conv_util.computePool2DInfo( x4D.shape, filterSize, strides, pad, dimRoundingMode); - const gradients = (dy: Tensor4D, y: Tensor4D) => { + const grad = (dy: Tensor4D) => { return {x: () => Ops.maxPoolBackprop(dy, x4D, filterSize, strides, pad)}; }; - const res = ENV.engine.executeKernel( - 'MaxPool', {inputs: {x: x4D}, args: {convInfo}}, gradients); + const res = ENV.engine.runKernel( + backend => backend.maxPool(x4D, convInfo), {x: x4D}, grad); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } @@ -130,8 +130,8 @@ export class Ops { const convInfo = conv_util.computePool2DInfo( input4D.shape, filterSize, strides, pad, dimRoundingMode); - const res = ENV.engine.executeKernel( - 'MaxPoolBackprop', {inputs: {dy: dy4D, x: input4D}, args: {convInfo}}); + const res = ENV.engine.runKernel( + backend => backend.maxPoolBackprop(dy4D, input4D, convInfo)); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } @@ -180,8 +180,8 @@ export class Ops { } const convInfo = conv_util.computePool2DInfo( input4D.shape, filterSize, strides, pad, dimRoundingMode); - const res = ENV.engine.executeKernel( - 'MinPool', {inputs: {x: input4D}, args: {convInfo}}); + const res = + ENV.engine.runKernel(backend => backend.minPool(input4D, convInfo)); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } @@ -232,11 +232,11 @@ export class Ops { const convInfo = conv_util.computePool2DInfo(x4D.shape, filterSize, strides, pad); - const gradients = (dy: Tensor4D, y: Tensor4D) => { + const grad = (dy: Tensor4D) => { return {x: () => Ops.avgPoolBackprop(dy, x4D, filterSize, strides, pad)}; }; - const res = ENV.engine.executeKernel( - 'AvgPool', {inputs: {x: x4D}, args: {convInfo}}, gradients); + const res = ENV.engine.runKernel( + backend => backend.avgPool(x4D, convInfo), {x: x4D}, grad); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } @@ -285,8 +285,8 @@ export class Ops { const convInfo = conv_util.computePool2DInfo(input4D.shape, filterSize, strides, pad); - const res = ENV.engine.executeKernel( - 'AvgPoolBackprop', {inputs: {dy: dy4D, x: input4D}, args: {convInfo}}); + const res = ENV.engine.runKernel( + backend => backend.avgPoolBackprop(dy4D, input4D, convInfo)); if (reshapedTo4D) { return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; } diff --git a/src/ops/reduction_ops.ts b/src/ops/reduction_ops.ts index dec6d23e66..26cbce2e46 100644 --- a/src/ops/reduction_ops.ts +++ b/src/ops/reduction_ops.ts @@ -115,8 +115,8 @@ export class Ops { reductionAxes = axis_util.getInnerMostAxes(reductionAxes.length, x.rank); } - let value = ENV.engine.executeKernel( - 'Sum', {inputs: {x: permutedX}, args: {axes: reductionAxes}}); + let value = ENV.engine.runKernel( + backend => backend.sum(permutedX, reductionAxes)); if (keepDims) { const newShape = axis_util.expandShapeToKeepDim(value.shape, axes); value = value.reshape(newShape); @@ -234,8 +234,7 @@ export class Ops { x = x.transpose(permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, x.rank); } - const res = - ENV.engine.executeKernel('Min', {inputs: {x}, args: {axes}}) as Tensor; + const res = ENV.engine.runKernel(backend => backend.min(x, axes)); if (keepDims) { const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); return res.reshape(newShape) as T; @@ -281,8 +280,7 @@ export class Ops { x = x.transpose(permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, x.rank); } - const res = - ENV.engine.executeKernel('Max', {inputs: {x}, args: {axes}}) as Tensor; + const res = ENV.engine.runKernel(backend => backend.max(x, axes)); if (keepDims) { const newShape = axis_util.expandShapeToKeepDim(res.shape, origAxes); return res.reshape(newShape) as T; @@ -323,7 +321,7 @@ export class Ops { x = x.transpose(permutedAxes); axes = axis_util.getInnerMostAxes(axes.length, x.rank); } - return ENV.engine.executeKernel('ArgMin', {inputs: {x}, args: {axes}}) as T; + return ENV.engine.runKernel(backend => backend.argMin(x, axes)) as T; } /** @@ -359,7 +357,7 @@ export class Ops { axes = axis_util.getInnerMostAxes(axes.length, x.rank); } - return ENV.engine.executeKernel('ArgMax', {inputs: {x}, args: {axes}}) as T; + return ENV.engine.runKernel(backend => backend.argMax(x, axes)) as T; } /** diff --git a/src/ops/reverse.ts b/src/ops/reverse.ts index 8a8cfdbf75..dd29b0cc52 100644 --- a/src/ops/reverse.ts +++ b/src/ops/reverse.ts @@ -109,8 +109,8 @@ export class Ops { } else { throw new Error(`Reverse for rank ${x.rank} is not yet implemented`); } - const res = ENV.engine.executeKernel( - 'Reverse4D', {inputs: {x: x4d}, args: {axis: axisCleaned}}); + const res = + ENV.engine.runKernel(backend => backend.reverse4D(x4d, axisCleaned)); return res.reshapeAs(x); } } diff --git a/src/ops/slice.ts b/src/ops/slice.ts index 16be1a4746..e3a7d76eb5 100644 --- a/src/ops/slice.ts +++ b/src/ops/slice.ts @@ -34,8 +34,7 @@ export class Ops { @operation static slice1d(x: Tensor1D, begin: number, size: number): Tensor1D { slice_util.assertParamsValid(x, [begin], [size]); - return ENV.engine.executeKernel( - 'Slice1D', {inputs: {x}, args: {begin, size}}) as Tensor1D; + return ENV.engine.runKernel(backend => backend.slice1D(x, begin, size)); } /** @@ -50,8 +49,7 @@ export class Ops { static slice2d(x: Tensor2D, begin: [number, number], size: [number, number]): Tensor2D { slice_util.assertParamsValid(x, begin, size); - return ENV.engine.executeKernel( - 'Slice2D', {inputs: {x}, args: {begin, size}}) as Tensor2D; + return ENV.engine.runKernel(backend => backend.slice2D(x, begin, size)); } /** @@ -67,8 +65,7 @@ export class Ops { number, number, number ]): Tensor3D { slice_util.assertParamsValid(x, begin, size); - return ENV.engine.executeKernel( - 'Slice3D', {inputs: {x}, args: {begin, size}}) as Tensor3D; + return ENV.engine.runKernel(backend => backend.slice3D(x, begin, size)); } /** @@ -85,8 +82,7 @@ export class Ops { number, number, number, number ]): Tensor4D { slice_util.assertParamsValid(x, begin, size); - return ENV.engine.executeKernel( - 'Slice4D', {inputs: {x}, args: {begin, size}}) as Tensor4D; + return ENV.engine.runKernel(backend => backend.slice4D(x, begin, size)); } /** diff --git a/src/ops/transpose.ts b/src/ops/transpose.ts index eec05d093f..bba3d97af1 100644 --- a/src/ops/transpose.ts +++ b/src/ops/transpose.ts @@ -18,7 +18,6 @@ import {doc} from '../doc'; import {ENV} from '../environment'; import {Tensor} from '../tensor'; -import {Rank} from '../types'; import * as util from '../util'; import * as axis_util from './axis_util'; import {operation} from './operation'; @@ -43,20 +42,19 @@ export class Ops { */ @doc({heading: 'Operations', subheading: 'Matrices'}) @operation - static transpose(x: Tensor, perm?: number[]): Tensor { + static transpose(x: T, perm?: number[]): T { if (perm == null) { perm = x.shape.map((s, i) => i).reverse(); } - const der = (dy: Tensor) => { + const der = (dy: T) => { const undoPerm = axis_util.getUndoAxesPermutation(perm); - const derX = () => dy.transpose(undoPerm); - return {x: derX}; + return {x: () => dy.transpose(undoPerm)}; }; util.assert( x.rank === perm.length, `Error in transpose: rank of input ${x.rank} ` + `must match length of perm ${perm}.`); - return ENV.engine.executeKernel( - 'Transpose', {inputs: {x}, args: {perm}}, der) as Tensor; + return ENV.engine.runKernel( + backend => backend.transpose(x, perm), {x}, der); } } diff --git a/src/ops/unary_ops.ts b/src/ops/unary_ops.ts index 9f3c396612..1f5ad82762 100644 --- a/src/ops/unary_ops.ts +++ b/src/ops/unary_ops.ts @@ -39,9 +39,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static neg(x: T): T { - return ENV.engine.executeKernel('Neg', {inputs: {x}}, (dy: T, y: T) => { + const grad = (dy: T) => { return {x: () => dy.neg()}; - }) as T; + }; + return ENV.engine.runKernel(backend => backend.neg(x), {x}, grad); } /** @@ -58,10 +59,10 @@ export class Ops { @operation static ceil(x: T): T { // TODO(manrajgrover): Return null for gradients when backprop supports it. - const gradient = (dy: T, y: T) => { - return {x: () => ops.zeros(y.shape)}; + const grad = (dy: T) => { + return {x: () => ops.zerosLike(dy)}; }; - return ENV.engine.executeKernel('Ceil', {inputs: {x}}, gradient) as T; + return ENV.engine.runKernel(backend => backend.ceil(x), {x}, grad); } /** @@ -79,10 +80,10 @@ export class Ops { static floor(x: T): T { // TODO(nsthorat): Let gradients be null for cases where we want to stop // backpropgation. - const gradient = (dy: T, y: T) => { - return {x: () => ops.zeros(y.shape)}; + const grad = (dy: T) => { + return {x: () => ops.zerosLike(dy)}; }; - return ENV.engine.executeKernel('Floor', {inputs: {x}}, gradient) as T; + return ENV.engine.runKernel(backend => backend.floor(x), {x}, grad); } /** @@ -98,9 +99,12 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static exp(x: T): T { - return ENV.engine.executeKernel('Exp', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.mul(y)}; - }) as T; + const bck = (dy: T, saved: Tensor[]) => { + const [y] = saved; + return {x: () => dy.mulStrict(y as T)}; + }; + return ENV.engine.runKernel( + (backend, save) => save(backend.exp(x)), {x}, bck); } /** @@ -116,9 +120,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static log(x: T): T { - return ENV.engine.executeKernel('Log', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.div(x.toFloat())}; - }) as T; + const grad = (dy: T) => { + return {x: () => dy.divStrict(x.toFloat())}; + }; + return ENV.engine.runKernel(backend => backend.log(x), {x}, grad); } /** @@ -134,9 +139,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static sqrt(x: T): T { - return ENV.engine.executeKernel('Sqrt', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.div(x.toFloat().sqrt().mul(ops.scalar(2)))}; - }) as T; + const grad = (dy: T) => { + return {x: () => dy.divStrict(x.toFloat().sqrt().mul(ops.scalar(2)))}; + }; + return ENV.engine.runKernel(backend => backend.sqrt(x), {x}, grad); } /** @@ -152,9 +158,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static square(x: T): T { - return ENV.engine.executeKernel('Square', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.mul(x.toFloat().mul(ops.scalar(2)))}; - }) as T; + const grad = (dy: T) => { + return {x: () => dy.mulStrict(x.toFloat().mul(ops.scalar(2)))}; + }; + return ENV.engine.runKernel(backend => backend.square(x), {x}, grad); } /** @@ -170,9 +177,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static abs(x: T): T { - return ENV.engine.executeKernel('Abs', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.mul(x.toFloat().step(-1))}; - }) as T; + const grad = (dy: T) => { + return {x: () => dy.mulStrict(x.toFloat().step(-1))}; + }; + return ENV.engine.runKernel(backend => backend.abs(x), {x}, grad); } /** @@ -195,19 +203,18 @@ export class Ops { (clipValueMin <= clipValueMax), `Error in clip: min (${clipValueMin}) must be` + `less than or equal to max (${clipValueMax}).`); - return ENV.engine.executeKernel( - 'Clip', - {inputs: {x}, args: {min: clipValueMin, max: clipValueMax}}, - (dy: T, y: T) => { - return { - // TODO(cais): Fix gradients for the case where x = min or x - // = max. - x: () => dy.where( - x.greater(ops.scalar(clipValueMin)) - .logicalAnd(x.less(ops.scalar(clipValueMax))), - zerosLike(dy)), - }; - }) as T; + const grad = (dy: T) => { + return { + // TODO(cais): Fix gradients for the case where x = min or x + // = max. + x: () => dy.where( + x.greater(ops.scalar(clipValueMin)) + .logicalAnd(x.less(ops.scalar(clipValueMax))), + zerosLike(dy)) as T, + }; + }; + return ENV.engine.runKernel( + backend => backend.clip(x, clipValueMin, clipValueMax), {x}, grad); } /** @@ -223,10 +230,11 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static relu(x: T): T { - return ENV.engine.executeKernel('Relu', {inputs: {x}}, (dy: T, y: T) => { - const stepRes = x.step() as Tensor; - return {x: () => dy.mul(stepRes.toFloat())}; - }) as T; + const grad = (dy: T) => { + const stepRes = x.step(); + return {x: () => dy.mulStrict(stepRes.toFloat())}; + }; + return ENV.engine.runKernel(backend => backend.relu(x), {x}, grad); } /** @@ -242,17 +250,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static elu(x: T): T { - const der = (dy: Tensor) => { - return { - x: () => dy.mul(eluDer(x)), - alpha: () => { - throw new Error( - 'Derivative of prelu with respect to alpha is ' + - 'not implemented yet'); - } - }; + const grad = (dy: T) => { + return {x: () => dy.mulStrict(eluDer(x))}; }; - return ENV.engine.executeKernel('Elu', {inputs: {x}}, der) as T; + return ENV.engine.runKernel(backend => backend.elu(x), {x}, grad); } /** @@ -270,7 +271,7 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static selu(x: T): T { - const gradient = (dy: T, y: T) => { + const grad = (dy: T) => { return { x: () => { const mask = x.greater(ops.scalar(0)); @@ -281,13 +282,11 @@ export class Ops { const greaterThanZeroDer = dy.mul(scale); const lessEqualZeroDer = dy.mul(scaleAlpha).mul(x.toFloat().exp()); - const res = ops.where(mask, greaterThanZeroDer, lessEqualZeroDer); - - return res; + return ops.where(mask, greaterThanZeroDer, lessEqualZeroDer) as T; } }; }; - return ENV.engine.executeKernel('Selu', {inputs: {x}}, gradient) as T; + return ENV.engine.runKernel(backend => backend.selu(x), {x}, grad); } /** @@ -308,11 +307,11 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static leakyRelu(x: T, alpha = 0.2): T { - const gradient = (dy: T, y: T) => { - return {x: () => dy.mul(x.step(alpha))}; + const grad = (dy: T) => { + return {x: () => dy.mulStrict(x.step(alpha))}; }; - return ENV.engine.executeKernel( - 'LeakyRelu', {inputs: {x}, args: {alpha}}, gradient) as T; + return ENV.engine.runKernel( + backend => backend.leakyRelu(x, alpha), {x}, grad); } /** @@ -332,17 +331,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static prelu(x: T, alpha: T): T { - const der = (dy: Tensor) => { - return { - x: () => dy.mul(preluDer(x, alpha)), - alpha: () => { - throw new Error( - 'Derivative of prelu with respect to alpha is ' + - 'not implemented yet'); - } - }; + const grad = (dy: T) => { + return {x: () => dy.mulStrict(preluDer(x, alpha))}; }; - return ENV.engine.executeKernel('PReLU', {inputs: {x, alpha}}, der) as T; + return ENV.engine.runKernel(backend => backend.prelu(x, alpha), {x}, grad); } /** @@ -358,9 +350,12 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static sigmoid(x: T): T { - return ENV.engine.executeKernel('Sigmoid', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.mul(y.mul(ops.scalar(1).sub(y)))}; - }) as T; + const grad = (dy: T, saved: Tensor[]) => { + const [y] = saved; + return {x: () => dy.mulStrict(y.mul(ops.scalar(1).sub(y)))}; + }; + return ENV.engine.runKernel( + (backend, save) => save(backend.sigmoid(x)), {x}, grad); } /** @@ -376,9 +371,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static sin(x: T): T { - return ENV.engine.executeKernel('Sin', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => x.toFloat().cos().mul(dy)}; - }) as T; + const grad = (dy: T) => { + return {x: () => x.toFloat().cos().mulStrict(dy)}; + }; + return ENV.engine.runKernel(backend => backend.sin(x), {x}, grad); } /** @@ -394,9 +390,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static cos(x: T): T { - return ENV.engine.executeKernel('Cos', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => x.toFloat().sin().neg().mul(dy)}; - }) as T; + const grad = (dy: T) => { + return {x: () => x.toFloat().sin().neg().mulStrict(dy)}; + }; + return ENV.engine.runKernel(backend => backend.cos(x), {x}, grad); } /** @@ -412,9 +409,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static tan(x: T): T { - return ENV.engine.executeKernel('Tan', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.div(x.cos().square())}; - }) as T; + const grad = (dy: T) => { + return {x: () => dy.divStrict(x.cos().square())}; + }; + return ENV.engine.runKernel(backend => backend.tan(x), {x}, grad); } /** @@ -430,11 +428,12 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static asin(x: T): T { - return ENV.engine.executeKernel('Asin', {inputs: {x}}, (dy: T, y: T) => { + const grad = (dy: T) => { return { - x: () => dy.div(Ops.sqrt(ops.scalar(1).sub(x.toFloat().square()))) + x: () => dy.divStrict(Ops.sqrt(ops.scalar(1).sub(x.toFloat().square()))) }; - }) as T; + }; + return ENV.engine.runKernel(backend => backend.asin(x), {x}, grad); } /** @@ -450,11 +449,13 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static acos(x: T): T { - return ENV.engine.executeKernel('Acos', {inputs: {x}}, (dy: T, y: T) => { + const grad = (dy: T) => { return { - x: () => dy.div(Ops.sqrt(ops.scalar(1).sub(x.toFloat().square()))).neg() + x: () => dy.divStrict(Ops.sqrt(ops.scalar(1).sub(x.toFloat().square()))) + .neg() }; - }) as T; + }; + return ENV.engine.runKernel(backend => backend.acos(x), {x}, grad); } /** @@ -470,9 +471,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static atan(x: T): T { - return ENV.engine.executeKernel('Atan', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => dy.div(ops.scalar(1).add(x.toFloat().square()))}; - }) as T; + const grad = (dy: T) => { + return {x: () => dy.divStrict(ops.scalar(1).add(x.toFloat().square()))}; + }; + return ENV.engine.runKernel(backend => backend.atan(x), {x}, grad); } /** @@ -488,9 +490,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static sinh(x: T): T { - return ENV.engine.executeKernel('Sinh', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => x.toFloat().cosh().mul(dy)}; - }) as T; + const grad = (dy: T) => { + return {x: () => x.toFloat().cosh().mulStrict(dy)}; + }; + return ENV.engine.runKernel(backend => backend.sinh(x), {x}, grad); } /** @@ -506,9 +509,10 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static cosh(x: T): T { - return ENV.engine.executeKernel('Cosh', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => x.toFloat().sinh().mul(dy)}; - }) as T; + const grad = (dy: T) => { + return {x: () => x.toFloat().sinh().mulStrict(dy)}; + }; + return ENV.engine.runKernel(backend => backend.cosh(x), {x}, grad); } /** @@ -524,9 +528,12 @@ export class Ops { @doc({heading: 'Operations', subheading: 'Basic math'}) @operation static tanh(x: T): T { - return ENV.engine.executeKernel('Tanh', {inputs: {x}}, (dy: T, y: T) => { - return {x: () => ops.scalar(1).sub(y.square()).mul(dy)}; - }) as T; + const grad = (dy: T, saved: Tensor[]) => { + const [y] = saved; + return {x: () => ops.scalar(1).sub(y.square()).mulStrict(dy) as T}; + }; + return ENV.engine.runKernel( + (backend, save) => save(backend.tanh(x)), {x}, grad); } /** @@ -544,17 +551,17 @@ export class Ops { @operation static step(x: T, alpha = 0.0): T { // TODO(manrajgrover): Return null for gradients when backprop supports it. - return ENV.engine.executeKernel( - 'Step', {inputs: {x}, args: {alpha}}, (dy: T, y: T) => { - return {x: () => ops.zeros(y.shape)}; - }) as T; + const grad = (dy: T) => { + return {x: () => ops.zerosLike(dy)}; + }; + return ENV.engine.runKernel(backend => backend.step(x, alpha), {x}, grad); } } -function preluDer(x: Tensor, alpha: Tensor): Tensor { - return ENV.engine.executeKernel('PReLUDer', {inputs: {x, alpha}}) as Tensor; +function preluDer(x: T, alpha: T): T { + return ENV.engine.runKernel(backend => backend.preluDer(x, alpha)); } -function eluDer(x: Tensor): Tensor { - return ENV.engine.executeKernel('EluDer', {inputs: {x}}) as Tensor; +function eluDer(x: T): T { + return ENV.engine.runKernel(backend => backend.eluDer(x)); } diff --git a/src/profiler.ts b/src/profiler.ts index d6cb69a3c1..6b78c1c779 100644 --- a/src/profiler.ts +++ b/src/profiler.ts @@ -16,7 +16,6 @@ */ import {BackendTimer} from './kernels/backend'; -import {Kernel} from './kernels/kernel_registry'; import {Tensor} from './tensor'; import {TypedArray} from './types'; import * as util from './util'; @@ -28,7 +27,7 @@ export class Profiler { } } - profileKernel(kernelName: Kernel, f: () => T): T { + profileKernel(name: string, f: () => T): T { let result: T; const holdResultWrapperFn = () => { result = f(); @@ -36,10 +35,10 @@ export class Profiler { const timer = this.backendTimer.time(holdResultWrapperFn); const vals = result.dataSync(); - util.checkForNaN(vals, result.dtype, kernelName); + util.checkForNaN(vals, result.dtype, name); timer.then(timing => { - this.logger.logKernelProfile(kernelName, result, vals, timing.kernelMs); + this.logger.logKernelProfile(name, result, vals, timing.kernelMs); }); return result as T; @@ -48,9 +47,9 @@ export class Profiler { export class Logger { logKernelProfile( - kernelName: Kernel, result: Tensor, vals: TypedArray, timeMs: number) { + name: string, result: Tensor, vals: TypedArray, timeMs: number) { const time = util.rightPad(`${timeMs}ms`, 9); - const paddedName = util.rightPad(kernelName, 25); + const paddedName = util.rightPad(name, 25); const rank = result.rank; const size = result.size; const shape = util.rightPad(result.shape.toString(), 14); diff --git a/src/profiler_test.ts b/src/profiler_test.ts index 50fa077e69..9d3d368810 100644 --- a/src/profiler_test.ts +++ b/src/profiler_test.ts @@ -17,7 +17,6 @@ import * as dl from './index'; import {BackendTimer, BackendTimingInfo} from './kernels/backend'; -import {Kernel} from './kernels/kernel_registry'; import {TypedArray} from './kernels/webgl/tex_util'; import {Logger, Profiler} from './profiler'; import {Tensor} from './tensor'; @@ -37,7 +36,7 @@ class TestBackendTimer implements BackendTimer { class TestLogger extends Logger { logKernelProfile( - kernelName: Kernel, result: Tensor, vals: TypedArray, timeMs: number) {} + name: string, result: Tensor, vals: TypedArray, timeMs: number) {} } describe('profiler.Profiler', () => { diff --git a/src/tape_util.ts b/src/tape.ts similarity index 69% rename from src/tape_util.ts rename to src/tape.ts index e35df3bc24..a83d2a1256 100644 --- a/src/tape_util.ts +++ b/src/tape.ts @@ -15,12 +15,22 @@ * ============================================================================= */ -import * as util from './util'; import {Tensor} from './tensor'; import {NamedTensorMap, RegularArray} from './types'; +import * as util from './util'; -// tslint:disable-next-line:max-line-length -import {Tape, TapeNode, TapeNodeInputConfig, TapeNodeOutput} from './tape_types'; +export interface TapeNode { + id: number; + name: string; + output: Tensor; + // Optional params, defined only for ops with gradient impl. + inputs?: NamedTensorMap; + gradient?: (dy: Tensor|NamedTensorMap) => NamedGradientMap; +} + +export type NamedGradientMap = { + [inputName: string]: () => Tensor; +}; /** * Computes a list of TapeNodes that connect x to y, filtering everything else @@ -30,7 +40,7 @@ import {Tape, TapeNode, TapeNodeInputConfig, TapeNodeOutput} from './tape_types' * @param y The output Tensor. */ export function getFilteredNodesXToY( - tape: Tape, xs: Tensor[], y: Tensor): Tape { + tape: TapeNode[], xs: Tensor[], y: Tensor): TapeNode[] { // Forward pass to compute all the nodes and Tensors that are transitively a // function of x. const tensorsFromX: {[tensorId: number]: boolean} = {}; @@ -41,22 +51,19 @@ export function getFilteredNodesXToY( for (let i = 0; i < tape.length; i++) { const node = tape[i]; - const nodeInputs = node.inputAndArgs.inputs; - + const nodeInputs = node.inputs; + if (nodeInputs == null) { + throw new Error( + `${node.name} is missing gradient implementation. ` + + `Failed to back-propagate.`); + } for (const inputName in nodeInputs) { const input = nodeInputs[inputName]; let anyInputFromX = false; for (let j = 0; j < xs.length; j++) { if (tensorsFromX[input.id]) { - if (node.output instanceof Tensor) { - tensorsFromX[node.output.id] = true; - } else { - const keys = Object.keys(node.output); - for (const key of keys) { - tensorsFromX[node.output[key].id] = true; - } - } + tensorsFromX[node.output.id] = true; anyInputFromX = true; nodesFromX[node.id] = true; break; @@ -76,17 +83,10 @@ export function getFilteredNodesXToY( for (let i = tape.length - 1; i >= 0; i--) { const node = tape[i]; - const nodeInputs = node.inputAndArgs.inputs; + const nodeInputs = node.inputs; const outputs: Tensor[] = []; - if (node.output instanceof Tensor) { - outputs.push(node.output); - } else { - const keys = Object.keys(node.output); - for (const key of keys) { - outputs.push(node.output[key]); - } - } + outputs.push(node.output); // If any of the outputs lead to y, mark all of the inputs as leading to y. for (let j = 0; j < outputs.length; j++) { @@ -101,40 +101,24 @@ export function getFilteredNodesXToY( } // Return the paths that come from x and lead to y. - const filteredTape: Tape = []; + const filteredTape: TapeNode[] = []; for (let i = 0; i < tape.length; i++) { const node = tape[i]; if (nodesFromX[node.id] && nodesToY[node.id]) { // Prune the inputs from the node that aren't a function of x. const prunedInputs: {[inputName: string]: Tensor} = {}; - for (const inputName in node.inputAndArgs.inputs) { - const nodeInput = node.inputAndArgs.inputs[inputName]; + for (const inputName in node.inputs) { + const nodeInput = node.inputs[inputName]; if (tensorsFromX[nodeInput.id]) { prunedInputs[inputName] = nodeInput; } } - let prunedOutputs: Tensor|{[outputName: string]: Tensor}; - if (node.output instanceof Tensor) { - // Nothing to prune if the output is just a single Tensor since the - // node would have been pruned. - prunedOutputs = node.output; - } else { - // Prune the outputs from the node that don't lead to y. - prunedOutputs = {}; - for (const outputName in node.output) { - const output = node.output[outputName]; - if (tensorsLeadToY[output.id]) { - prunedOutputs[outputName] = node.output[outputName]; - } - } - } - // Copy the node and overwrite inputsAndArgs to the pruned version. - const prunedNode = Object.assign({}, node) as TapeNode; - prunedNode.inputAndArgs = {inputs: prunedInputs}; - prunedNode.output = prunedOutputs; + const prunedNode = Object.assign({}, node) as TapeNode; + prunedNode.inputs = prunedInputs; + prunedNode.output = node.output; filteredTape.push(prunedNode); } @@ -151,21 +135,12 @@ export function getFilteredNodesXToY( */ export function backpropagateGradients( tensorAccumulatedGradientMap: {[tensorId: number]: Tensor}, - filteredTape: Tape) { + filteredTape: TapeNode[]) { // Walk the tape backwards and keep a map of Tensor to its gradient. for (let i = filteredTape.length - 1; i >= 0; i--) { const node = filteredTape[i]; - let dy: Tensor|NamedTensorMap; - if (node.output instanceof Tensor) { - dy = tensorAccumulatedGradientMap[node.output.id]; - } else { - dy = {}; - const keys = Object.keys(node.output); - for (const key of keys) { - dy[key] = tensorAccumulatedGradientMap[node.output[key].id]; - } - } + const dy = tensorAccumulatedGradientMap[node.output.id]; if (node.gradient == null) { throw new Error( @@ -174,8 +149,8 @@ export function backpropagateGradients( } // Backprop dy through this node and accumulate gradients over the inputs. - const inputGradients = node.gradient(dy, node.output); - for (const inputName in node.inputAndArgs.inputs) { + const inputGradients = node.gradient(dy); + for (const inputName in node.inputs) { if (!(inputName in inputGradients)) { throw new Error( `Cannot backprop through input ${inputName}. ` + @@ -184,7 +159,7 @@ export function backpropagateGradients( // Call the gradient function. const dx = inputGradients[inputName](); - const x = node.inputAndArgs.inputs[inputName]; + const x = node.inputs[inputName]; if (!util.arraysEqual(dx.shape, x.shape)) { throw new Error( `Error in gradient for op ${node.name}. The gradient of input ` + @@ -228,14 +203,3 @@ export function extractTensorsFromScopeResult(result: ScopeResultImmediate): } return list; } - -export function stripUndefinedInputsFromInputConfig( - config: TapeNodeInputConfig): TapeNodeInputConfig { - const keys = Object.keys(config.inputs); - keys.forEach(key => { - if (config.inputs[key] == null) { - delete config.inputs[key]; - } - }); - return config; -} diff --git a/src/tape_test.ts b/src/tape_test.ts new file mode 100644 index 0000000000..347fa13203 --- /dev/null +++ b/src/tape_test.ts @@ -0,0 +1,344 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as dl from './index'; +import * as tape_util from './tape'; +import {TapeNode} from './tape'; +import {Scalar, Tensor} from './tensor'; +import {CPU_ENVS, describeWithFlags, expectArraysClose} from './test_util'; + +describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { + it('getFilteredNodesXToY no paths from x to y', () => { + const x = dl.scalar(1); + const intermediate1 = dl.scalar(0); + + const intermediate2 = dl.scalar(0); + const y = dl.scalar(2); + + const tape: TapeNode[] = [ + { + id: 0, + name: 'node0', + inputs: {x}, + output: intermediate1, + gradient: null + }, + { + id: 1, + name: 'node1', + inputs: {intermediate2}, + output: y, + gradient: null + } + ]; + + const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); + + expect(filteredTapeNodes.length).toBe(0); + expect(filteredTapeNodes).toEqual([]); + }); + + it('getFilteredNodesXToY one operation x => y', () => { + const x = dl.scalar(1); + const y = dl.scalar(2); + + const tape: TapeNode[] = + [{id: 0, name: 'node0', inputs: {x}, output: y, gradient: null}]; + + const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); + + expect(filteredTapeNodes.length).toBe(1); + expect(filteredTapeNodes).toEqual(tape); + }); + + it('getFilteredNodesXToY 1 operation [x0, x1] => y, all input paths', () => { + const x0 = dl.scalar(0); + const x1 = dl.scalar(1); + const y = dl.scalar(2); + + const tape: TapeNode[] = + [{id: 0, name: 'node0', inputs: {x0, x1}, output: y, gradient: null}]; + + const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x0, x1], y); + + expect(filteredTapeNodes.length).toBe(1); + expect(filteredTapeNodes).toEqual(tape); + }); + + it('getFilteredNodesXToY one operation [x0, x1] => y, one input paths', + () => { + const x0 = dl.scalar(0); + const x1 = dl.scalar(1); + const y = dl.scalar(2); + + const tape: TapeNode[] = [ + {id: 0, name: 'node0', inputs: {x0, x1}, output: y, gradient: null} + ]; + + const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x0], y); + + expect(filteredTapeNodes.length).toBe(1); + // x1 input should be pruned, we don't ask for the gradient of x1. + expect(filteredTapeNodes[0]) + .toEqual( + {id: 0, name: 'node0', inputs: {x0}, output: y, gradient: null}); + }); + + it('getFilteredNodesXToY two operations x => intermediate => y', () => { + const x = dl.scalar(1); + const intermediate = dl.scalar(0); + const y = dl.scalar(2); + + const tape: TapeNode[] = [ + {id: 0, name: 'node0', inputs: {x}, output: intermediate, gradient: null}, + { + id: 1, + name: 'node1', + inputs: {intermediate}, + output: y, + gradient: null + } + ]; + + const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); + + expect(filteredTapeNodes.length).toBe(2); + expect(filteredTapeNodes).toEqual(tape); + }); + + it('getFilteredNodesXToY two operations [x0, x1], [x2] => ' + + 'intermediate => y', + () => { + const x0 = dl.scalar(1); + const x1 = dl.scalar(2); + const x2 = dl.scalar(3); + const intermediate = dl.scalar(4); + const y = dl.scalar(2); + + const tape: TapeNode[] = [ + { + id: 0, + name: 'node0', + inputs: {x0, x1}, + output: intermediate, + gradient: null + }, + { + id: 1, + name: 'node1', + inputs: {x2, intermediate}, + output: y, + gradient: null + } + ]; + + const filteredTapeNodes = + tape_util.getFilteredNodesXToY(tape, [x0, x1, x2], y); + + expect(filteredTapeNodes.length).toBe(2); + expect(filteredTapeNodes).toEqual(tape); + }); + + it('getFilteredNodesXToY x => y and x => orphan', () => { + const x = dl.scalar(1); + const orphan = dl.scalar(0); + const y = dl.scalar(2); + + const tape: TapeNode[] = [ + {id: 0, name: 'node0', inputs: {x}, output: orphan, gradient: null}, + {id: 1, name: 'node1', inputs: {x}, output: y, gradient: null} + ]; + + const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); + + expect(filteredTapeNodes.length).toBe(1); + // The orphan should be removed. + expect(filteredTapeNodes[0]).toEqual(tape[1]); + }); + + it('getFilteredNodesXToY x => y and orphan => y', () => { + const x = dl.scalar(1); + const orphan = dl.scalar(0); + const y = dl.scalar(2); + + const tape: TapeNode[] = [ + {id: 0, name: 'node0', inputs: {x, orphan}, output: y, gradient: null} + ]; + + const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); + + expect(filteredTapeNodes.length).toBe(1); + // The orphan should be pruned from the node's input. + expect(filteredTapeNodes[0]) + .toEqual( + {id: 0, name: 'node0', inputs: {x}, output: y, gradient: null}); + }); +}); + +describeWithFlags('backpropagateGradients', CPU_ENVS, () => { + it('Throws if gradient is not defined', () => { + const x = dl.scalar(0); + const y = dl.scalar(1); + + const dy = dl.scalar(1); + + const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; + accumulatedGradientsMap[y.id] = dy; + + const tape: TapeNode[] = + [{id: 0, name: 'node0', inputs: {x}, output: y, gradient: null}]; + + expect( + () => tape_util.backpropagateGradients(accumulatedGradientsMap, tape)) + .toThrowError(); + }); + + it('basic backprop with 1 node', () => { + const x = dl.scalar(0); + const y = dl.scalar(1); + + const dy = dl.scalar(1); + + const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; + accumulatedGradientsMap[y.id] = dy; + + const tape: TapeNode[] = [{ + id: 0, + name: 'node0', + inputs: {x}, + output: y, + gradient: (dy: Scalar) => { + return {x: () => dy.add(dl.scalar(1))}; + } + }]; + + tape_util.backpropagateGradients(accumulatedGradientsMap, tape); + + expectArraysClose(accumulatedGradientsMap[x.id], [2]); + }); + + it('basic backprop with 2 nodes', () => { + const x = dl.scalar(0); + const intermediate = dl.scalar(1); + const y = dl.scalar(2); + + const dy = dl.scalar(1); + + const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; + accumulatedGradientsMap[y.id] = dy; + + const tape: TapeNode[] = [ + { + id: 0, + name: 'node0', + inputs: {x}, + output: intermediate, + gradient: (dy: Scalar) => { + return {x: () => dy.add(dl.scalar(1))}; + } + }, + { + id: 1, + name: 'node1', + inputs: {intermediate}, + output: y, + gradient: (dy: Scalar) => { + return {intermediate: () => dy.add(dl.scalar(1))}; + } + } + ]; + + tape_util.backpropagateGradients(accumulatedGradientsMap, tape); + + // dx = dy + 1 + 1 + expectArraysClose(accumulatedGradientsMap[x.id], [3]); + }); + + it('basic backprop with a split node accumulates gradients', () => { + const x = dl.scalar(0); + const intermediate1 = dl.scalar(1); + const intermediate2 = dl.scalar(2); + const y = dl.scalar(3); + + const dy = dl.scalar(1); + + const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; + accumulatedGradientsMap[y.id] = dy; + + const tape: TapeNode[] = [ + { + id: 0, + name: 'node0', + inputs: {x}, + output: intermediate1, + gradient: (dy: Scalar) => { + return {x: () => dy.add(dl.scalar(1))}; + } + }, + { + id: 1, + name: 'node1', + inputs: {x}, + output: intermediate2, + gradient: (dy: Scalar) => { + return {x: () => dy.add(dl.scalar(1))}; + } + }, + { + id: 2, + name: 'node2', + inputs: {intermediate1, intermediate2}, + output: y, + gradient: (dy: Scalar) => { + return { + intermediate1: () => dy.add(dl.scalar(1)), + intermediate2: () => dy.add(dl.scalar(1)) + }; + } + } + ]; + + tape_util.backpropagateGradients(accumulatedGradientsMap, tape); + + // dx = dy + 1 + 1 + 1 + 1 + 1 + expectArraysClose(accumulatedGradientsMap[x.id], [dy.dataSync()[0] + 5]); + }); +}); + +describeWithFlags('extractTensorsFromScopeResult', CPU_ENVS, () => { + it('null input returns empty tensor', () => { + const results = tape_util.extractTensorsFromScopeResult(null); + + expect(results).toEqual([]); + }); + + it('tensor input returns one element tensor', () => { + const x = dl.scalar(1); + const results = tape_util.extractTensorsFromScopeResult(x); + + expect(results).toEqual([x]); + }); + + it('name tensor map returns flattened tensor', () => { + const x1 = dl.scalar(1); + const x2 = dl.scalar(3); + const x3 = dl.scalar(4); + const results = tape_util.extractTensorsFromScopeResult({x1, x2, x3}); + + expect(results).toEqual([x1, x2, x3]); + }); +}); diff --git a/src/tape_types.ts b/src/tape_types.ts deleted file mode 100644 index 062f3b6376..0000000000 --- a/src/tape_types.ts +++ /dev/null @@ -1,54 +0,0 @@ -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {NamedTensorMap} from './types'; -import {Tensor} from './tensor'; -import {Rank} from './types'; -import {KernelConfigRegistry} from './kernels/kernel_registry'; - -export type Tape = Array>; -export type TapeNodeOutput = Tensor|NamedTensorMap; -export type TapeNodeType = 'kernel'|'customGradient'; - -export interface TapeNode { - id: number; - type: TapeNodeType; - name: string; - inputAndArgs: TapeNodeInputConfig; - - output: T; - gradient: (dy: Tensor|NamedTensorMap, y: T) => TapeNodeInputGradientTensors; -} - -export interface TapeNodeInputConfig { inputs: NamedTensorMap; } - -export type TapeNodeInputGradientTensors = { - [inputName: string]: () => Tensor; -}; - -// Kernel nodes -export interface KernelNode extends TapeNode { - kernel: keyof KernelConfigRegistry; - inputAndArgs: KernelInputConfig; - output: Tensor; -} - -export interface KernelInputConfig extends TapeNodeInputConfig { - inputs: NamedTensorMap; - // tslint:disable-next-line:no-any - args?: {[argName: string]: any}; -} diff --git a/src/tape_util_test.ts b/src/tape_util_test.ts deleted file mode 100644 index 33194829ab..0000000000 --- a/src/tape_util_test.ts +++ /dev/null @@ -1,666 +0,0 @@ - -/** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import * as dl from './index'; -import {NamedTensorMap} from './types'; -import {CPU_ENVS, describeWithFlags, expectArraysClose} from './test_util'; -import {Scalar, Tensor} from './tensor'; -// tslint:disable-next-line:max-line-length -import {Tape, TapeNode, TapeNodeInputConfig, TapeNodeOutput} from './tape_types'; -import * as tape_util from './tape_util'; - -describeWithFlags('getFilteredNodesXToY', CPU_ENVS, () => { - it('getFilteredNodesXToY no paths from x to y', () => { - const x = dl.scalar(1); - const intermediate1 = dl.scalar(0); - - const intermediate2 = dl.scalar(0); - const y = dl.scalar(2); - - const tape: Tape = [ - { - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: intermediate1, - gradient: null - }, - { - id: 1, - type: 'kernel', - name: 'node1', - inputAndArgs: { - inputs: {intermediate2}, - }, - output: y, - gradient: null - } - ]; - - const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); - - expect(filteredTapeNodes.length).toBe(0); - expect(filteredTapeNodes).toEqual([]); - }); - - it('getFilteredNodesXToY one operation x => y', () => { - const x = dl.scalar(1); - const y = dl.scalar(2); - - const tape: Tape = [{ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: y, - gradient: null - }]; - - const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); - - expect(filteredTapeNodes.length).toBe(1); - expect(filteredTapeNodes).toEqual(tape); - }); - - it('getFilteredNodesXToY 1 operation [x0, x1] => y, all input paths', () => { - const x0 = dl.scalar(0); - const x1 = dl.scalar(1); - const y = dl.scalar(2); - - const tape: Tape = [{ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x0, x1}, - }, - output: y, - gradient: null - }]; - - const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x0, x1], y); - - expect(filteredTapeNodes.length).toBe(1); - expect(filteredTapeNodes).toEqual(tape); - }); - - it('getFilteredNodesXToY one operation [x0, x1] => y, one input paths', - () => { - const x0 = dl.scalar(0); - const x1 = dl.scalar(1); - const y = dl.scalar(2); - - const tape: Tape = [{ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x0, x1}, - }, - output: y, - gradient: null - }]; - - const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x0], y); - - expect(filteredTapeNodes.length).toBe(1); - // x1 input should be pruned, we don't ask for the gradient of x1. - expect(filteredTapeNodes[0]).toEqual({ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x0}, - }, - output: y, - gradient: null - }); - }); - - it('getFilteredNodesXToY two operations x => intermediate => y', () => { - const x = dl.scalar(1); - const intermediate = dl.scalar(0); - const y = dl.scalar(2); - - const tape: Tape = [ - { - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: intermediate, - gradient: null - }, - { - id: 1, - type: 'kernel', - name: 'node1', - inputAndArgs: { - inputs: {intermediate}, - }, - output: y, - gradient: null - } - ]; - - const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); - - expect(filteredTapeNodes.length).toBe(2); - expect(filteredTapeNodes).toEqual(tape); - }); - - it('getFilteredNodesXToY two operations [x0, x1], [x2] => ' + - 'intermediate => y', - () => { - const x0 = dl.scalar(1); - const x1 = dl.scalar(2); - const x2 = dl.scalar(3); - const intermediate = dl.scalar(4); - const y = dl.scalar(2); - - const tape: Tape = [ - { - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x0, x1}, - }, - output: intermediate, - gradient: null - }, - { - id: 1, - type: 'kernel', - name: 'node1', - inputAndArgs: { - inputs: {x2, intermediate}, - }, - output: y, - gradient: null - } - ]; - - const filteredTapeNodes = - tape_util.getFilteredNodesXToY(tape, [x0, x1, x2], y); - - expect(filteredTapeNodes.length).toBe(2); - expect(filteredTapeNodes).toEqual(tape); - }); - - it('getFilteredNodesXToY x => y and x => orphan', () => { - const x = dl.scalar(1); - const orphan = dl.scalar(0); - const y = dl.scalar(2); - - const tape: Tape = [ - { - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: orphan, - gradient: null - }, - { - id: 1, - type: 'kernel', - name: 'node1', - inputAndArgs: { - inputs: {x}, - }, - output: y, - gradient: null - } - ]; - - const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); - - expect(filteredTapeNodes.length).toBe(1); - // The orphan should be removed. - expect(filteredTapeNodes[0]).toEqual(tape[1]); - }); - - it('getFilteredNodesXToY x => y and orphan => y', () => { - const x = dl.scalar(1); - const orphan = dl.scalar(0); - const y = dl.scalar(2); - - const tape: Tape = [{ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x, orphan}, - }, - output: y, - gradient: null - }]; - - const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); - - expect(filteredTapeNodes.length).toBe(1); - // The orphan should be pruned from the node's input. - expect(filteredTapeNodes[0]).toEqual({ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: y, - gradient: null - }); - }); - - it('getFilteredNodesXToY x => {intermediate, orphan1} and ' + - '{orphan2, intermediate} => {y, orphan3}', - () => { - const x = dl.scalar(1); - const intermediate = dl.scalar(5); - const orphan1 = dl.scalar(1); - const orphan2 = dl.scalar(2); - const orphan3 = dl.scalar(3); - const y = dl.scalar(2); - - const tape: Array> = [ - { - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: {orphan1, intermediate}, - gradient: null - }, - { - id: 1, - type: 'kernel', - name: 'node1', - inputAndArgs: { - inputs: {intermediate, orphan2}, - }, - output: {y, orphan3}, - gradient: null - } - ]; - - const filteredTapeNodes = tape_util.getFilteredNodesXToY(tape, [x], y); - - expect(filteredTapeNodes.length).toBe(2); - // The orphans should be pruned from inputs and outputs. - expect(filteredTapeNodes[0]).toEqual({ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: {intermediate}, - gradient: null - }); - expect(filteredTapeNodes[1]).toEqual({ - id: 1, - type: 'kernel', - name: 'node1', - inputAndArgs: { - inputs: {intermediate}, - }, - output: {y}, - gradient: null - }); - }); - - it('getFilteredNodesXToY x0 => orphan0, ' + - 'x0 => intermediate0, x0 => intermediate1, ' + - '[intermediate0, intermediate1, x1, orphan1] => {y, orphan2}', - () => { - const x0 = dl.scalar(1); - const orphan0 = dl.scalar(2); - - const intermediate0 = dl.scalar(3); - const intermediate1 = dl.scalar(4); - - const x1 = dl.scalar(5); - const orphan1 = dl.scalar(6); - const y = dl.scalar(7); - const orphan2 = dl.scalar(8); - - const tape: Tape = [ - { - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x0}, - }, - output: intermediate0, - gradient: null - }, - { - id: 1, - type: 'kernel', - name: 'node1', - inputAndArgs: { - inputs: {x0}, - }, - output: intermediate1, - gradient: null - }, - { - id: 2, - type: 'kernel', - name: 'node2', - inputAndArgs: { - inputs: {x0}, - }, - output: orphan0, - gradient: null - }, - { - id: 3, - type: 'kernel', - name: 'node3', - inputAndArgs: { - inputs: {intermediate0, intermediate1, x1, orphan1}, - }, - output: {y, orphan2}, - gradient: null - } - ]; - - const filteredTapeNodes = - tape_util.getFilteredNodesXToY(tape, [x0, x1], y); - - expect(filteredTapeNodes.length).toBe(3); - expect(filteredTapeNodes[0]).toEqual(tape[0]); - expect(filteredTapeNodes[1]).toEqual(tape[1]); - // The orphans should be removed and the orphan1 should be pruned from - // inputs. - expect(filteredTapeNodes[2]).toEqual({ - id: 3, - type: 'kernel', - name: 'node3', - inputAndArgs: { - inputs: {intermediate0, intermediate1, x1}, - }, - output: {y}, - gradient: null - }); - }); -}); - -describeWithFlags('backpropagateGradients', CPU_ENVS, () => { - it('Throws if gradient is not defined', () => { - const x = dl.scalar(0); - const y = dl.scalar(1); - - const dy = dl.scalar(1); - - const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; - accumulatedGradientsMap[y.id] = dy; - - const tape: Tape = [{ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: y, - gradient: null - }]; - - expect( - () => tape_util.backpropagateGradients(accumulatedGradientsMap, tape)) - .toThrowError(); - }); - - it('basic backprop with 1 node', () => { - const x = dl.scalar(0); - const y = dl.scalar(1); - - const dy = dl.scalar(1); - - const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; - accumulatedGradientsMap[y.id] = dy; - - const tape: Tape = [{ - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: y, - gradient: (dy: Scalar, y: Scalar) => { - return {x: () => dy.add(dl.scalar(1))}; - } - }]; - - tape_util.backpropagateGradients(accumulatedGradientsMap, tape); - - expectArraysClose(accumulatedGradientsMap[x.id], [2]); - }); - - it('basic backprop with 2 nodes', () => { - const x = dl.scalar(0); - const intermediate = dl.scalar(1); - const y = dl.scalar(2); - - const dy = dl.scalar(1); - - const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; - accumulatedGradientsMap[y.id] = dy; - - const tape: Tape = [ - { - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: intermediate, - gradient: (dy: Scalar, y: Scalar) => { - return {x: () => dy.add(dl.scalar(1))}; - } - }, - { - id: 1, - type: 'kernel', - name: 'node1', - inputAndArgs: { - inputs: {intermediate}, - }, - output: y, - gradient: (dy: Scalar, y: Scalar) => { - return {intermediate: () => dy.add(dl.scalar(1))}; - } - } - ]; - - tape_util.backpropagateGradients(accumulatedGradientsMap, tape); - - // dx = dy + 1 + 1 - expectArraysClose(accumulatedGradientsMap[x.id], [3]); - }); - - it('basic backprop with a split node accumulates gradients', () => { - const x = dl.scalar(0); - const intermediate1 = dl.scalar(1); - const intermediate2 = dl.scalar(2); - const y = dl.scalar(3); - - const dy = dl.scalar(1); - - const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; - accumulatedGradientsMap[y.id] = dy; - - const tape: Tape = [ - { - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: intermediate1, - gradient: (dy: Scalar, y: Scalar) => { - return {x: () => dy.add(dl.scalar(1))}; - } - }, - { - id: 1, - type: 'kernel', - name: 'node1', - inputAndArgs: { - inputs: {x}, - }, - output: intermediate2, - gradient: (dy: Scalar, y: Scalar) => { - return {x: () => dy.add(dl.scalar(1))}; - } - }, - { - id: 2, - type: 'kernel', - name: 'node2', - inputAndArgs: { - inputs: {intermediate1, intermediate2}, - }, - output: y, - gradient: (dy: Scalar, y: Scalar) => { - return { - intermediate1: () => dy.add(dl.scalar(1)), - intermediate2: () => dy.add(dl.scalar(1)) - }; - } - } - ]; - - tape_util.backpropagateGradients(accumulatedGradientsMap, tape); - - // dx = dy + 1 + 1 + 1 + 1 + 1 - expectArraysClose(accumulatedGradientsMap[x.id], [dy.dataSync()[0] + 5]); - }); - - it('basic backprop with a multi-output split node accumulates gradients', - () => { - const x = dl.scalar(0); - const intermediate1 = dl.scalar(1); - const intermediate2 = dl.scalar(2); - const y = dl.scalar(3); - - const dy = dl.scalar(1); - - const accumulatedGradientsMap: {[tensorId: number]: Tensor} = {}; - accumulatedGradientsMap[y.id] = dy; - - const tape: Array> = [ - { - id: 0, - type: 'kernel', - name: 'node0', - inputAndArgs: { - inputs: {x}, - }, - output: {intermediate1, intermediate2}, - gradient: (dy: NamedTensorMap, y: NamedTensorMap) => { - return {x: () => dy['intermediate1'].mul(dy['intermediate2'])}; - } - }, - { - id: 1, - type: 'kernel', - name: 'node1', - inputAndArgs: { - inputs: {intermediate1, intermediate2}, - }, - output: y, - gradient: (dy: Scalar, y: Scalar) => { - return { - intermediate1: () => dy.add(dl.scalar(2)), - intermediate2: () => dy.add(dl.scalar(3)) - }; - } - } - ]; - - tape_util.backpropagateGradients(accumulatedGradientsMap, tape); - - expectArraysClose( - accumulatedGradientsMap[x.id], [(dy.get() + 2) * (dy.get() + 3)]); - }); -}); - -describeWithFlags('extractTensorsFromScopeResult', CPU_ENVS, () => { - it('null input returns empty tensor', () => { - const results = tape_util.extractTensorsFromScopeResult(null); - - expect(results).toEqual([]); - }); - - it('tensor input returns one element tensor', () => { - const x = dl.scalar(1); - const results = tape_util.extractTensorsFromScopeResult(x); - - expect(results).toEqual([x]); - }); - - it('name tensor map returns flattened tensor', () => { - const x1 = dl.scalar(1); - const x2 = dl.scalar(3); - const x3 = dl.scalar(4); - const results = tape_util.extractTensorsFromScopeResult({x1, x2, x3}); - - expect(results).toEqual([x1, x2, x3]); - }); - - it('pass through when all inputs are defined', () => { - const x1 = dl.scalar(1); - const x2 = dl.scalar(2); - const config: TapeNodeInputConfig = { - inputs: {x1, x2}, - }; - expect(tape_util.stripUndefinedInputsFromInputConfig(config)).toEqual({ - inputs: {x1, x2} - }); - }); - - it('strips undefined inputs', () => { - const x1 = dl.scalar(1); - const x4 = dl.scalar(2); - const config: TapeNodeInputConfig = { - inputs: {x1, x2: undefined, x3: undefined, x4}, - }; - expect(tape_util.stripUndefinedInputsFromInputConfig(config)).toEqual({ - inputs: {x1, x4} - }); - }); -}); diff --git a/src/tensor.ts b/src/tensor.ts index ce1c6d0ddd..46647724b4 100644 --- a/src/tensor.ts +++ b/src/tensor.ts @@ -602,7 +602,7 @@ export class Tensor { this.throwIfDisposed(); return ops.subStrict(this, x); } - pow(exp: Tensor): T { + pow(this: T, exp: Tensor): T { this.throwIfDisposed(); return ops.pow(this, exp); } @@ -642,7 +642,7 @@ export class Tensor { this.throwIfDisposed(); return ops.maximumStrict(this, x); } - transpose(perm?: number[]): Tensor { + transpose(this: T, perm?: number[]): T { this.throwIfDisposed(); return ops.transpose(this, perm); } @@ -717,35 +717,35 @@ export class Tensor { } // Unary ops. - neg(): Tensor { + neg(this: T): T { this.throwIfDisposed(); return ops.neg(this); } - ceil(): Tensor { + ceil(this: T): T { this.throwIfDisposed(); return ops.ceil(this); } - floor(): Tensor { + floor(this: T): T { this.throwIfDisposed(); return ops.floor(this); } - exp(): Tensor { + exp(this: T): T { this.throwIfDisposed(); return ops.exp(this); } - log(): Tensor { + log(this: T): T { this.throwIfDisposed(); return ops.log(this); } - sqrt(): Tensor { + sqrt(this: T): T { this.throwIfDisposed(); return ops.sqrt(this); } - square(): Tensor { + square(this: T): T { this.throwIfDisposed(); return ops.square(this); } - abs(): Tensor { + abs(this: T): T { this.throwIfDisposed(); return ops.abs(this); } @@ -753,15 +753,15 @@ export class Tensor { this.throwIfDisposed(); return ops.clipByValue(this, min, max); } - relu(): Tensor { + relu(this: T): T { this.throwIfDisposed(); return ops.relu(this); } - elu(): Tensor { + elu(this: T): T { this.throwIfDisposed(); return ops.elu(this); } - selu(): Tensor { + selu(this: T): T { this.throwIfDisposed(); return ops.selu(this); } @@ -773,47 +773,47 @@ export class Tensor { this.throwIfDisposed(); return ops.prelu(this, alpha); } - sigmoid(): Tensor { + sigmoid(this: T): T { this.throwIfDisposed(); return ops.sigmoid(this); } - sin(): Tensor { + sin(this: T): T { this.throwIfDisposed(); return ops.sin(this); } - cos(): Tensor { + cos(this: T): T { this.throwIfDisposed(); return ops.cos(this); } - tan(): Tensor { + tan(this: T): T { this.throwIfDisposed(); return ops.tan(this); } - asin(): Tensor { + asin(this: T): T { this.throwIfDisposed(); return ops.asin(this); } - acos(): Tensor { + acos(this: T): T { this.throwIfDisposed(); return ops.acos(this); } - atan(): Tensor { + atan(this: T): T { this.throwIfDisposed(); return ops.atan(this); } - sinh(): Tensor { + sinh(this: T): T { this.throwIfDisposed(); return ops.sinh(this); } - cosh(): Tensor { + cosh(this: T): T { this.throwIfDisposed(); return ops.cosh(this); } - tanh(): Tensor { + tanh(this: T): T { this.throwIfDisposed(); return ops.tanh(this); } - step(alpha = 0.0): Tensor { + step(this: T, alpha = 0.0): T { this.throwIfDisposed(); return ops.step(this, alpha); } diff --git a/src/tracking.ts b/src/tracking.ts index fb383f7388..5ae889adc7 100644 --- a/src/tracking.ts +++ b/src/tracking.ts @@ -19,7 +19,7 @@ import {doc} from './doc'; import {TimingInfo} from './engine'; import {ENV} from './environment'; // tslint:disable-next-line:max-line-length -import {ScopeFn, ScopeResult, ScopeResultImmediate} from './tape_util'; +import {ScopeFn, ScopeResult, ScopeResultImmediate} from './tape'; import {Tensor} from './tensor'; export class Tracking { @@ -64,13 +64,13 @@ export class Tracking { @doc({heading: 'Performance', subheading: 'Memory'}) static tidy( nameOrFn: string|ScopeFn, fn?: ScopeFn, gradMode = false): T { + let name = null; if (fn == null) { // Called with only 1 argument. if (typeof nameOrFn !== 'function') { throw new Error('Please provide a function to dl.tidy()'); } fn = nameOrFn; - nameOrFn = ''; } else { // Called with 2 arguments. if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) { @@ -83,9 +83,10 @@ export class Tracking { 'When calling with two arguments, the 2nd argument ' + 'to dl.tidy() must be a function'); } + name = nameOrFn as string; // TODO(nsthorat,smilkov): Do operation logging and performance profiling. } - ENV.engine.startScope(gradMode); + ENV.engine.startScope(name, gradMode); const result = fn(); if (result instanceof Promise) {