Skip to content

Commit

Permalink
Remove kernel_registry and call backends directly. (tensorflow#683)
Browse files Browse the repository at this point in the history
- Remove kernel_registry and call backends directly
- Allow for forward function to save tensors that will be needed by the backend function. We pass `dy` and the `saved` tensors to the backprop function.
- Remove logic in tape to handle multiple outputs. We can bring it back if needed.
- consolidate `tape_types.ts` and `tape_util.ts` into a single file `tape.ts`
- in debug mode, log the name of the `activeScope`, which should point to the last operation which called the kernel. The aggregation of these and showing hierarchical debug view is still a TODO.
  • Loading branch information
dsmilkov authored Feb 19, 2018
1 parent dc7984a commit 461d4b9
Show file tree
Hide file tree
Showing 59 changed files with 742 additions and 2,506 deletions.
96 changes: 52 additions & 44 deletions src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,29 @@
import {ENV} from './environment';
import {tidy} from './globals';
import {BackendTimingInfo, KernelBackend} from './kernels/backend';
import * as kernel_registry from './kernels/kernel_registry';
import {KernelConfigRegistry} from './kernels/kernel_registry';
import * as ops from './ops/ops';
import {Profiler} from './profiler';
// tslint:disable-next-line:max-line-length
import {KernelNode, Tape, TapeNode, TapeNodeInputGradientTensors} from './tape_types';
import * as tape_util from './tape_util';
import {ScopeResultImmediate} from './tape_util';
import {backpropagateGradients, extractTensorsFromScopeResult, getFilteredNodesXToY} from './tape';
import {NamedGradientMap, TapeNode} from './tape';
import {ScopeResultImmediate} from './tape';
import {DataId, Tensor, Tensor3D, Variable} from './tensor';
import {NamedTensorMap, NamedVariableMap, TypedArray} from './types';
import {Rank} from './types';
import * as util from './util';

interface ScopeState {
keep: Tensor[];
track: Tensor[];
name?: string;
}

/**
* A function that computes an output. The save function is for saving tensors
* computed in the forward pass, that we need in the backwards pass.
*/
export type ForwardFunc<T extends Tensor> =
(backend: KernelBackend, save?: <S extends Tensor>(tensor: S) => S) => T;

/**
* @docalias (a: Tensor, b: Tensor,...) => {
* value: Tensor,
Expand Down Expand Up @@ -70,7 +75,7 @@ export class Engine implements TensorManager {
private numTensors = 0;
private numDataBuffers = 0;

private activeTape: Tape;
private activeTape: TapeNode[];
private gradientScopeCount = 0;
private customGradientDepth = 0;

Expand All @@ -88,39 +93,42 @@ export class Engine implements TensorManager {
this.profiler = new Profiler(backend);
}

executeKernel<R extends Rank, K extends keyof KernelConfigRegistry<R>, C
extends KernelConfigRegistry<R>[K]['inputAndArgs']>(
kernelName: K, config: C, grad?: KernelConfigRegistry<R>[K]['gradient']):
KernelConfigRegistry<R>[K]['output'] {
let result: KernelConfigRegistry<R>[K]['output'];
runKernel<T extends Tensor, I extends NamedTensorMap>(
forwardFunc: ForwardFunc<T>,
inputs?: I,
backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]},
): T {
let result: T;
const saved: Tensor[] = [];
const saveFunc = <T extends Tensor>(x: T): T => {
saved.push(x);
return x;
};
const scopeName = this.activeScope.name;

if (!ENV.get('DEBUG')) {
// NOTE: This isn't pulled out into a separate function to so that we
// keep a shallow stack trace.
result = kernel_registry.executeKernel(this.backend, kernelName, config);
result = forwardFunc(this.backend, saveFunc);
} else {
result = this.profiler.profileKernel(
kernelName,
() =>
kernel_registry.executeKernel(this.backend, kernelName, config));
scopeName, () => forwardFunc(this.backend, saveFunc));
}

const recordKernel =
this.activeTape != null && this.customGradientDepth === 0;
if (recordKernel) {
config = tape_util.stripUndefinedInputsFromInputConfig(config) as C;

const evaluatedNode: KernelNode = {
const tapeNode: TapeNode = {
id: this.nextTapeNodeId++,
type: 'kernel',
name: `kernel: ${kernelName}`,
kernel: kernelName,
inputAndArgs: config,
output: result,
gradient: grad
name: scopeName,
output: result
};
this.activeTape.push(evaluatedNode);
if (inputs != null) {
tapeNode.inputs = inputs;
}
if (backwardsFunc != null) {
tapeNode.gradient = (dy: T) => backwardsFunc(dy, saved);
}
this.activeTape.push(tapeNode);
}

return result;
}

Expand Down Expand Up @@ -191,22 +199,21 @@ export class Engine implements TensorManager {

const gradient = (dy: Tensor) => {
const res = gradientsFunc(dy);
const resMap: TapeNodeInputGradientTensors = {};
const resMap: NamedGradientMap = {};
res.forEach((r, idx) => {
resMap[idx] = () => r;
});
return resMap;
};

const evaluatedNode: TapeNode<Tensor> = {
const tapeNode: TapeNode = {
id: this.nextTapeNodeId++,
type: 'customGradient',
name,
inputAndArgs: {inputs: inputsMap},
name: this.activeScope.name,
inputs: inputsMap,
output: result,
gradient
};
this.activeTape.push(evaluatedNode);
this.activeTape.push(tapeNode);
}

keep<T extends Tensor>(result: T): T {
Expand All @@ -223,17 +230,20 @@ export class Engine implements TensorManager {
* Start a scope. Use this with endScope() to achieve the same functionality
* as scope() without the need for a function closure.
*/
startScope(gradientsMode = false) {
startScope(name?: string, gradientsMode = false) {
if (gradientsMode && this.gradientScopeCount === 0) {
this.activeTape = [];
}
if (gradientsMode) {
this.gradientScopeCount++;
}

const newScopeArrays: ScopeState = {keep: [], track: []};
this.scopeStack.push(newScopeArrays);
this.activeScope = newScopeArrays;
const scopeInfo: ScopeState = {keep: [], track: []};
if (name) {
scopeInfo.name = name;
}
this.scopeStack.push(scopeInfo);
this.activeScope = scopeInfo;
}

/**
Expand All @@ -249,8 +259,7 @@ export class Engine implements TensorManager {
}

let tensorsToKeep = this.activeScope.keep;
const tensorsToTrackInParent =
tape_util.extractTensorsFromScopeResult(result);
const tensorsToTrackInParent = extractTensorsFromScopeResult(result);
tensorsToKeep = tensorsToKeep.concat(tensorsToTrackInParent);

// Dispose the arrays tracked in this scope.
Expand Down Expand Up @@ -301,8 +310,7 @@ export class Engine implements TensorManager {
y instanceof Tensor,
'The result y returned by f() must be a tensor.');
// Filter out the nodes that don't connect x => y.
const filteredTape =
tape_util.getFilteredNodesXToY(this.activeTape, xs, y);
const filteredTape = getFilteredNodesXToY(this.activeTape, xs, y);
if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
throw new Error(
'Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
Expand All @@ -314,7 +322,7 @@ export class Engine implements TensorManager {
accumulatedGradientMap[y.id] = (dy == null) ? ops.onesLike(y) : dy;

// Backprop gradients through the filtered nodes.
tape_util.backpropagateGradients(accumulatedGradientMap, filteredTape);
backpropagateGradients(accumulatedGradientMap, filteredTape);

const grads = xs.map(x => accumulatedGradientMap[x.id]);
return {value: y, grads};
Expand Down
2 changes: 1 addition & 1 deletion src/gradients.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {doc} from './doc';
import {CustomGradientFunc} from './engine';
import {ENV} from './environment';
import {tidy} from './globals';
import {ScopeFn, ScopeResult} from './tape_util';
import {ScopeFn, ScopeResult} from './tape';
import {Scalar, Tensor, Variable} from './tensor';
import {NamedTensorMap} from './types';
import * as util from './util';
Expand Down
2 changes: 1 addition & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ export {CostReduction, FeedEntry, Session} from './graph/session';
export {MathBackendCPU, NDArrayMathCPU} from './kernels/backend_cpu';
// tslint:disable-next-line:max-line-length
export {MathBackendWebGL, NDArrayMathGPU, WebGLTimingInfo} from './kernels/backend_webgl';
export {MatrixOrientation} from './kernels/types/matmul';
export {GPGPUContext} from './kernels/webgl/gpgpu_context';
export {NDArrayMath} from './math';
export {Model} from './model';
export {LSTMCell} from './ops/lstm';
export {MatrixOrientation} from './ops/matmul';
export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer';
export {AdagradOptimizer} from './optimizers/adagrad_optimizer';
export {AdamOptimizer} from './optimizers/adam_optimizer';
Expand Down
10 changes: 5 additions & 5 deletions src/kernels/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import {Conv2DInfo} from '../ops/conv_util';
// tslint:disable-next-line:max-line-length
import {DataId, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor';
import {DataType, Rank, TypedArray} from '../types';
import {DataType, TypedArray} from '../types';

// Required information for all backends.
export interface BackendTimingInfo { kernelMs: number; }
Expand Down Expand Up @@ -95,7 +95,7 @@ export interface KernelBackend extends TensorStorage, BackendTimer {
greater(a: Tensor, b: Tensor): Tensor;
greaterEqual(a: Tensor, b: Tensor): Tensor;

logicalNot(a: Tensor): Tensor;
logicalNot<T extends Tensor>(a: T): T;
logicalAnd(a: Tensor, b: Tensor): Tensor;
logicalOr(a: Tensor, b: Tensor): Tensor;
logicalXor(a: Tensor, b: Tensor): Tensor;
Expand Down Expand Up @@ -128,7 +128,7 @@ export interface KernelBackend extends TensorStorage, BackendTimer {
leakyRelu<T extends Tensor>(x: T, alpha: number): T;
prelu<T extends Tensor>(x: T, alpha: T): T;
preluDer<T extends Tensor>(x: T, alpha: T): T;
int<R extends Rank>(x: Tensor<R>): Tensor<R>;
int<T extends Tensor>(x: T): T;

clip<T extends Tensor>(x: T, min: number, max: number): T;

Expand Down Expand Up @@ -183,8 +183,8 @@ export interface KernelBackend extends TensorStorage, BackendTimer {

batchNormalization4D(
x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D,
varianceEpsilon: number, scale?: Tensor4D|Tensor1D,
offset?: Tensor4D|Tensor1D): Tensor4D;
varianceEpsilon: number, scale: Tensor4D|Tensor1D,
offset: Tensor4D|Tensor1D): Tensor4D;

localResponseNormalization4D(
x: Tensor4D, radius: number, bias: number, alpha: number, beta: number,
Expand Down
4 changes: 2 additions & 2 deletions src/kernels/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -790,13 +790,13 @@ export class MathBackendCPU implements KernelBackend {
return Tensor.make(x.shape, {values: resultValues}) as T;
}

int<R extends Rank>(x: Tensor<R>): Tensor<R> {
int<T extends Tensor>(x: T): T {
const resultValues = new Int32Array(x.size);
const values = x.dataSync();
for (let i = 0; i < values.length; ++i) {
resultValues[i] = values[i];
}
return Tensor.make(x.shape, {values: resultValues}, 'int32') as Tensor<R>;
return Tensor.make(x.shape, {values: resultValues}, 'int32');
}

sigmoid<T extends Tensor>(x: T): T {
Expand Down
8 changes: 3 additions & 5 deletions src/kernels/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ import * as reduce_util from '../ops/reduce_util';
// tslint:disable-next-line:max-line-length
import {DataId, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor';
import * as types from '../types';
// tslint:disable-next-line:max-line-length
import {DataType, DataTypeMap, Rank, RecursiveArray, TypedArray} from '../types';
import {DataType, DataTypeMap, RecursiveArray, TypedArray} from '../types';
import * as util from '../util';

import {KernelBackend} from './backend';
import {ArgMinMaxProgram} from './webgl/argminmax_gpu';
import {AvgPool2DBackpropProgram} from './webgl/avg_pool_backprop_gpu';
Expand Down Expand Up @@ -705,10 +703,10 @@ export class MathBackendWebGL implements KernelBackend {
return this.compileAndRun(program, [a, b]) as T;
}

int<R extends Rank>(x: Tensor<R>): Tensor<R> {
int<T extends Tensor>(x: T): T {
const program = new UnaryOpProgram(x.shape, unary_op.TO_INT);
const output = this.makeOutputArray(program.outputShape, 'int32');
return this.compileAndRun(program, [x], output) as Tensor<R>;
return this.compileAndRun(program, [x], output) as T;
}

clip<T extends Tensor>(x: T, min: number, max: number): T {
Expand Down
Loading

0 comments on commit 461d4b9

Please sign in to comment.