Skip to content

Commit

Permalink
Generalize the tensor container to arbitrary depth (tensorflow#1083)
Browse files Browse the repository at this point in the history
FEATURE

Currently, calling `tf.dispose(obj)` or `tf.tidy(() => obj)` results in walking the `obj` for depth of 1 and disposing of any tensors found.

With this change `obj` is walked to arbitrary depth with cycle-detection to defend against cyclical objects.

This makes `tf.tidy()` and `tf.dispose()` more useful in practice.

Also rename `util.extractTensorsFromContainer` to `util.getTensorsFromContainer` since we are not modifying (extracting from) the container.
  • Loading branch information
dsmilkov authored Jun 6, 2018
1 parent abcd97b commit e12aebd
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ export class Engine implements TensorManager {

const tensorsToKeep = new Set(this.keepTensors);

const tensorsToTrackInParent = util.extractTensorsFromContainer(result);
const tensorsToTrackInParent = util.getTensorsInContainer(result);
tensorsToTrackInParent.forEach(tensor => tensorsToKeep.add(tensor.id));

// Dispose the arrays tracked in this scope.
Expand Down
10 changes: 5 additions & 5 deletions src/tracking.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {ScopeFn, TimingInfo} from './engine';
import {ENV} from './environment';
import {Tensor} from './tensor';
import {TensorContainer} from './types';
import {extractTensorsFromAny} from './util';
import {getTensorsInContainer} from './util';

export class Tracking {
/**
Expand Down Expand Up @@ -101,17 +101,17 @@ export class Tracking {
}

/**
* Disposes any `Tensor`s found within the provided object up to depth 1.
* Disposes any `Tensor`s found within the provided object.
*
* @param container an object that may be a `Tensor` or may directly contain
* `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If the
* object is not a `Tensor` or does not contain `Tensors`, nothing
* happens. In general it is safe to pass any object here, except that
* `Promise`s are not supported.
*/
// tslint:disable-next-line:no-any
static dispose(container: any) {
const tensors = extractTensorsFromAny(container);
@doc({heading: 'Performance', subheading: 'Memory'})
static dispose(container: TensorContainer) {
const tensors = getTensorsInContainer(container);
tensors.forEach(tensor => tensor.dispose());
}

Expand Down
20 changes: 19 additions & 1 deletion src/tracking_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

import * as tf from './index';
import {describeWithFlags} from './jasmine_util';
import {ALL_ENVS, CPU_ENVS, expectArraysClose, WEBGL_ENVS} from './test_util';
// tslint:disable-next-line:max-line-length
import {ALL_ENVS, CPU_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from './test_util';

describeWithFlags('time webgl', WEBGL_ENVS, () => {
it('upload + compute', async () => {
Expand Down Expand Up @@ -262,4 +263,21 @@ describeWithFlags('tidy', ALL_ENVS, () => {
tf.tidy('name', 'another name' as any);
}).toThrowError();
});

it('works with arbitrary depth of result', () => {
tf.tidy(() => {
const res = tf.tidy(() => {
return [tf.scalar(1), [[tf.scalar(2)]], {list: [tf.scalar(3)]}];
});
expectArraysEqual(res[0] as tf.Tensor, [1]);
// tslint:disable-next-line:no-any
expectArraysEqual((res[1] as any)[0][0], [2]);
// tslint:disable-next-line:no-any
expectArraysEqual((res[2] as any).list[0], [3]);
expect(tf.memory().numTensors).toBe(3);
return res[0];
});
// Everything but scalar(1) got disposed.
expect(tf.memory().numTensors).toBe(1);
});
});
52 changes: 26 additions & 26 deletions src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
import {Tensor} from './tensor';
// tslint:disable-next-line:max-line-length
import {DataType, DataTypeMap, FlatVector, NamedTensorMap, RecursiveArray, RegularArray, TensorContainer, TypedArray} from './types';
import {DataType, DataTypeMap, FlatVector, NamedTensorMap, RecursiveArray, RegularArray, TensorContainer, TensorContainerArray, TypedArray} from './types';

function assertArgumentIsTensor(
x: Tensor, argName: string, functionName: string) {
Expand Down Expand Up @@ -440,46 +440,46 @@ export function isFunction(f: Function) {
return !!(f && f.constructor && f.call && f.apply);
}

export function extractTensorsFromContainer(result: TensorContainer): Tensor[] {
return extractTensorsFromAny(result);
}

/**
* Extracts any `Tensor`s found within the provided object up to depth 1.
* Extracts any `Tensor`s found within the provided object.
*
* @param container an object that may be a `Tensor` or may directly contain
* `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it
* is safe to pass any object here, except that `Promise`s are not
* supported.
* @returns An array of `Tensors` found within the passed object. If the
* argument is simply a `Tensor', a list containing that `Tensor` is
* returned. If the argument directly contains `Tensor`s, a list of them
* will be returned. `Tensor`s nested more deeply within the argument will
* however not be found. If the object is not a `Tensor` or does not
* returned. If the object is not a `Tensor` or does not
* contain `Tensors`, an empty list is returned.
*/
// tslint:disable-next-line:no-any
export function extractTensorsFromAny(result: any): Tensor[] {
if (result == null) {
return [];
export function getTensorsInContainer(result: TensorContainer): Tensor[] {
const list: Tensor[] = [];
const seen = new Set<{}|void>();
walkTensorContainer(result, list, seen);
return list;
}

function walkTensorContainer(
container: TensorContainer, list: Tensor[], seen: Set<{}|void>): void {
if (container == null) {
return;
}
if (result instanceof Tensor) {
return [result];
if (container instanceof Tensor) {
list.push(container);
return;
}

const list: Tensor[] = [];
// tslint:disable-next-line:no-any
const resultObj = result as {[key: string]: any};
if (!isIterable(resultObj)) {
return [];
if (!isIterable(container)) {
return;
}

// Iteration over keys works also for arrays.
for (const k in resultObj) {
const sublist = flatten(resultObj[k]).filter(x => x instanceof Tensor);
list.push(...sublist);
const iterable = container as TensorContainerArray;
for (const k in iterable) {
const val = iterable[k];
if (!seen.has(val)) {
seen.add(val);
walkTensorContainer(val, list, seen);
}
}
return list;
}

// tslint:disable-next-line:no-any
Expand Down
28 changes: 23 additions & 5 deletions src/util_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
*/

import * as tf from './index';
import {describeWithFlags} from './jasmine_util';
import {Tensor} from './tensor';
import {CPU_ENVS} from './test_util';
import {describeWithFlags} from './jasmine_util';
import {NamedTensorMap} from './types';
import * as util from './util';

Expand Down Expand Up @@ -297,16 +297,16 @@ describe('util.hasEncodingLoss', () => {
});
});

describeWithFlags('extractTensorsFromAny', CPU_ENVS, () => {
describeWithFlags('getTensorsInContainer', CPU_ENVS, () => {
it('null input returns empty tensor', () => {
const results = util.extractTensorsFromAny(null);
const results = util.getTensorsInContainer(null);

expect(results).toEqual([]);
});

it('tensor input returns one element tensor', () => {
const x = tf.scalar(1);
const results = util.extractTensorsFromAny(x);
const results = util.getTensorsInContainer(x);

expect(results).toEqual([x]);
});
Expand All @@ -315,8 +315,26 @@ describeWithFlags('extractTensorsFromAny', CPU_ENVS, () => {
const x1 = tf.scalar(1);
const x2 = tf.scalar(3);
const x3 = tf.scalar(4);
const results = util.extractTensorsFromAny({x1, x2, x3});
const results = util.getTensorsInContainer({x1, x2, x3});

expect(results).toEqual([x1, x2, x3]);
});

it('can extract from arbitrary depth', () => {
const container = [
{x: tf.scalar(1), y: tf.scalar(2)},
[[[tf.scalar(3)]], {z: tf.scalar(4)}]
];
const results = util.getTensorsInContainer(container);
expect(results.length).toBe(4);
});

it('works with loops in container', () => {
const container = [tf.scalar(1), tf.scalar(2), [tf.scalar(3)]];
const innerContainer = [container];
// tslint:disable-next-line:no-any
container.push(innerContainer as any);
const results = util.getTensorsInContainer(container);
expect(results.length).toBe(3);
});
});

0 comments on commit e12aebd

Please sign in to comment.