Skip to content

Commit

Permalink
Fix memory management bugs with tidy(). (tensorflow#1080)
Browse files Browse the repository at this point in the history
BUG
  • Loading branch information
Nikhil Thorat authored Jun 5, 2018
1 parent 32cface commit 4b72d94
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 194 deletions.
30 changes: 17 additions & 13 deletions src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import {NamedTensorMap, NamedVariableMap, TensorContainer, TypedArray} from './t
import * as util from './util';

interface ScopeState {
keep: Tensor[];
track: Tensor[];
name?: string;
}
Expand Down Expand Up @@ -62,7 +61,9 @@ export type MemoryInfo = {
unreliable?: boolean;
};

export interface TimingInfo extends BackendTimingInfo { wallMs: number; }
export interface TimingInfo extends BackendTimingInfo {
wallMs: number;
}

export class Engine implements TensorManager {
// Public since optimizers will use it.
Expand All @@ -81,11 +82,12 @@ export class Engine implements TensorManager {
// Keep Tensors that parallel the tapes.
private activeScope: ScopeState;
private scopeStack: ScopeState[];
private keepTensors: Set<number> = new Set();
private profiler: Profiler;

constructor(private backend: KernelBackend, public safeMode: boolean) {
// Create a default outer scope.
this.activeScope = {keep: [], track: []};
this.activeScope = {track: []};
this.scopeStack = [this.activeScope];
this.profiler = new Profiler(backend);
}
Expand Down Expand Up @@ -228,7 +230,7 @@ export class Engine implements TensorManager {
'Safe mode is ON. Enclose all tensor operations inside tf.tidy(): ' +
'tf.tidy(() => {...}) to avoid memory leaks.');
}
this.activeScope.keep.push(result);
this.keepTensors.add(result.id);
return result;
}

Expand All @@ -244,7 +246,7 @@ export class Engine implements TensorManager {
this.gradientScopeCount++;
}

const scopeInfo: ScopeState = {keep: [], track: []};
const scopeInfo: ScopeState = {track: []};
if (name) {
scopeInfo.name = name;
}
Expand All @@ -264,14 +266,15 @@ export class Engine implements TensorManager {
}
}

let tensorsToKeep = this.activeScope.keep;
const tensorsToKeep = new Set(this.keepTensors);

const tensorsToTrackInParent = util.extractTensorsFromContainer(result);
tensorsToKeep = tensorsToKeep.concat(tensorsToTrackInParent);
tensorsToTrackInParent.forEach(tensor => tensorsToKeep.add(tensor.id));

// Dispose the arrays tracked in this scope.
for (let i = 0; i < this.activeScope.track.length; i++) {
const tensor = this.activeScope.track[i];
if (util.isTensorInList(tensor, tensorsToKeep)) {
if (tensorsToKeep.has(tensor.id)) {
continue;
}

Expand All @@ -282,21 +285,22 @@ export class Engine implements TensorManager {
}
}

this.scopeStack.pop();
const oldScope = this.scopeStack.pop();
this.activeScope = this.scopeStack.length === 0 ?
{keep: [], track: []} :
{track: []} :
this.scopeStack[this.scopeStack.length - 1];

// Track the current result in the parent scope.
tensorsToTrackInParent.forEach(tensor => {
if (!util.isTensorInList(tensor, this.activeScope.keep)) {
// Only track the tensor if was allocated in the inner scope and is not
// globally kept.
if (!this.keepTensors.has(tensor.id) &&
util.isTensorInList(tensor, oldScope.track)) {
this.track(tensor);
}
});
}

dispose() {}

/**
* Returns gradients of `f` with respect to each of the `xs`. The gradients
* returned are of the same length as `xs`, but some might be null if `f` was
Expand Down
179 changes: 5 additions & 174 deletions src/engine_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,178 +16,9 @@
*/

import * as tf from './index';
import {describeWithFlags} from './jasmine_util';
// tslint:disable-next-line:max-line-length
import {ALL_ENVS, expectArraysClose, expectArraysEqual, expectNumbersClose, WEBGL_ENVS} from './test_util';
import {describeWithFlags} from './jasmine_util';

describeWithFlags('tidy', ALL_ENVS, () => {
it('returns Tensor', () => {
tf.tidy(() => {
const a = tf.tensor1d([1, 2, 3]);
let b = tf.tensor1d([0, 0, 0]);

expect(tf.memory().numTensors).toBe(2);
tf.tidy(() => {
const result = tf.tidy(() => {
b = tf.addStrict(a, b);
b = tf.addStrict(a, b);
b = tf.addStrict(a, b);
return tf.add(a, b);
});

// result is new. All intermediates should be disposed.
expect(tf.memory().numTensors).toBe(2 + 1);
expectArraysClose(result, [4, 8, 12]);
});

// a, b are still here, result should be disposed.
expect(tf.memory().numTensors).toBe(2);
});

expect(tf.memory().numTensors).toBe(0);
});

it('multiple disposes does not affect num arrays', () => {
expect(tf.memory().numTensors).toBe(0);
const a = tf.tensor1d([1, 2, 3]);
const b = tf.tensor1d([1, 2, 3]);
expect(tf.memory().numTensors).toBe(2);
a.dispose();
a.dispose();
expect(tf.memory().numTensors).toBe(1);
b.dispose();
expect(tf.memory().numTensors).toBe(0);
});

it('allows primitive types', () => {
const a = tf.tidy(() => 5);
expect(a).toBe(5);

const b = tf.tidy(() => 'hello');
expect(b).toBe('hello');
});

it('allows complex types', () => {
const res = tf.tidy(() => {
return {a: tf.scalar(1), b: 'hello', c: [tf.scalar(2), 'world']};
});
expectArraysClose(res.a, [1]);
expectArraysClose(res.c[0] as tf.Scalar, [2]);
});

it('returns Tensor[]', () => {
const a = tf.tensor1d([1, 2, 3]);
const b = tf.tensor1d([0, -1, 1]);
expect(tf.memory().numTensors).toBe(2);

tf.tidy(() => {
const result = tf.tidy(() => {
tf.add(a, b);
return [tf.add(a, b), tf.sub(a, b)];
});

// the 2 results are new. All intermediates should be disposed.
expect(tf.memory().numTensors).toBe(4);
expectArraysClose(result[0], [1, 1, 4]);
expectArraysClose(result[1], [1, 3, 2]);
expect(tf.memory().numTensors).toBe(4);
});

// the 2 results should be disposed.
expect(tf.memory().numTensors).toBe(2);
a.dispose();
b.dispose();
expect(tf.memory().numTensors).toBe(0);
});

it('basic usage without return', () => {
const a = tf.tensor1d([1, 2, 3]);
let b = tf.tensor1d([0, 0, 0]);

expect(tf.memory().numTensors).toBe(2);

tf.tidy(() => {
b = tf.addStrict(a, b);
b = tf.addStrict(a, b);
b = tf.addStrict(a, b);
tf.add(a, b);
});

// all intermediates should be disposed.
expect(tf.memory().numTensors).toBe(2);
});

it('nested usage', () => {
const a = tf.tensor1d([1, 2, 3]);
let b = tf.tensor1d([0, 0, 0]);

expect(tf.memory().numTensors).toBe(2);

tf.tidy(() => {
const result = tf.tidy(() => {
b = tf.addStrict(a, b);
b = tf.tidy(() => {
b = tf.tidy(() => {
return tf.addStrict(a, b);
});
// original a, b, and two intermediates.
expect(tf.memory().numTensors).toBe(4);

tf.tidy(() => {
tf.addStrict(a, b);
});
// All the intermediates should be cleaned up.
expect(tf.memory().numTensors).toBe(4);

return tf.addStrict(a, b);
});
expect(tf.memory().numTensors).toBe(4);

return tf.addStrict(a, b);
});

expect(tf.memory().numTensors).toBe(3);
expectArraysClose(result, [4, 8, 12]);
});
expect(tf.memory().numTensors).toBe(2);
});

it('single argument', () => {
let hasRan = false;
tf.tidy(() => {
hasRan = true;
});
expect(hasRan).toBe(true);
});

it('single argument, but not a function throws error', () => {
expect(() => {
tf.tidy('asdf');
}).toThrowError();
});

it('2 arguments, first is string', () => {
let hasRan = false;
tf.tidy('name', () => {
hasRan = true;
});
expect(hasRan).toBe(true);
});

it('2 arguments, but first is not string throws error', () => {
expect(() => {
// tslint:disable-next-line:no-any
tf.tidy(4 as any, () => {});
}).toThrowError();
});

it('2 arguments, but second is not a function throws error', () => {
expect(() => {
// tslint:disable-next-line:no-any
tf.tidy('name', 'another name' as any);
}).toThrowError();
});
});

describeWithFlags('fromPixels + regular math op', WEBGL_ENVS, () => {
it('debug mode does not error when no nans', () => {
Expand Down Expand Up @@ -532,13 +363,13 @@ describeWithFlags('memory', ALL_ENVS, () => {

describeWithFlags('disposeVariables', ALL_ENVS, () => {
it('reuse same name variable', () => {
tf.tensor1d([1,2,3]).variable(true, 'v1');
tf.tensor1d([1,2,3]).variable(true, 'v2');
tf.tensor1d([1, 2, 3]).variable(true, 'v1');
tf.tensor1d([1, 2, 3]).variable(true, 'v2');
expect(() => {
tf.tensor1d([1, 2, 3]).variable(true, 'v1');
}).toThrowError();
tf.disposeVariables();
tf.tensor1d([1,2,3]).variable(true, 'v1');
tf.tensor1d([1,2,3]).variable(true, 'v2');
tf.tensor1d([1, 2, 3]).variable(true, 'v1');
tf.tensor1d([1, 2, 3]).variable(true, 'v2');
});
});
4 changes: 0 additions & 4 deletions src/environment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -344,16 +344,12 @@ export class Environment {
reset() {
this.features = getFeaturesFromURL();
if (this.globalEngine != null) {
this.globalEngine.dispose();
this.globalEngine = null;
}
}

private initBackend(backendType?: string, safeMode = false) {
this.currentBackend = backendType;
if (this.globalEngine != null) {
this.globalEngine.dispose();
}
const backend = ENV.findBackend(backendType);
this.globalEngine = new Engine(backend, safeMode);
}
Expand Down
8 changes: 6 additions & 2 deletions src/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,15 @@ export class Tensor<R extends Rank = Rank> {
if (this.isDisposed) {
return;
}
this.isDisposed = true;
this.isDisposedInternal = true;
ENV.engine.disposeTensor(this);
}

private isDisposed = false;
private isDisposedInternal = false;
get isDisposed(): boolean {
return this.isDisposedInternal;
}

private throwIfDisposed() {
if (this.isDisposed) {
throw new Error(`Tensor is disposed.`);
Expand Down
Loading

0 comments on commit 4b72d94

Please sign in to comment.