Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Generalize the tensor container to arbitrary depth #1083

Merged
merged 4 commits into from
Jun 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ export class Engine implements TensorManager {

const tensorsToKeep = new Set(this.keepTensors);

const tensorsToTrackInParent = util.extractTensorsFromContainer(result);
const tensorsToTrackInParent = util.getTensorsInContainer(result);
tensorsToTrackInParent.forEach(tensor => tensorsToKeep.add(tensor.id));

// Dispose the arrays tracked in this scope.
Expand Down
10 changes: 5 additions & 5 deletions src/tracking.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {ScopeFn, TimingInfo} from './engine';
import {ENV} from './environment';
import {Tensor} from './tensor';
import {TensorContainer} from './types';
import {extractTensorsFromAny} from './util';
import {getTensorsInContainer} from './util';

export class Tracking {
/**
Expand Down Expand Up @@ -101,17 +101,17 @@ export class Tracking {
}

/**
* Disposes any `Tensor`s found within the provided object up to depth 1.
* Disposes any `Tensor`s found within the provided object.
*
* @param container an object that may be a `Tensor` or may directly contain
* `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If the
* object is not a `Tensor` or does not contain `Tensors`, nothing
* happens. In general it is safe to pass any object here, except that
* `Promise`s are not supported.
*/
// tslint:disable-next-line:no-any
static dispose(container: any) {
const tensors = extractTensorsFromAny(container);
@doc({heading: 'Performance', subheading: 'Memory'})
static dispose(container: TensorContainer) {
const tensors = getTensorsInContainer(container);
tensors.forEach(tensor => tensor.dispose());
}

Expand Down
20 changes: 19 additions & 1 deletion src/tracking_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

import * as tf from './index';
import {describeWithFlags} from './jasmine_util';
import {ALL_ENVS, CPU_ENVS, expectArraysClose, WEBGL_ENVS} from './test_util';
// tslint:disable-next-line:max-line-length
import {ALL_ENVS, CPU_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from './test_util';

describeWithFlags('time webgl', WEBGL_ENVS, () => {
it('upload + compute', async () => {
Expand Down Expand Up @@ -262,4 +263,21 @@ describeWithFlags('tidy', ALL_ENVS, () => {
tf.tidy('name', 'another name' as any);
}).toThrowError();
});

it('works with arbitrary depth of result', () => {
tf.tidy(() => {
const res = tf.tidy(() => {
return [tf.scalar(1), [[tf.scalar(2)]], {list: [tf.scalar(3)]}];
});
expectArraysEqual(res[0] as tf.Tensor, [1]);
// tslint:disable-next-line:no-any
expectArraysEqual((res[1] as any)[0][0], [2]);
// tslint:disable-next-line:no-any
expectArraysEqual((res[2] as any).list[0], [3]);
expect(tf.memory().numTensors).toBe(3);
return res[0];
});
// Everything but scalar(1) got disposed.
expect(tf.memory().numTensors).toBe(1);
});
});
52 changes: 26 additions & 26 deletions src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
import {Tensor} from './tensor';
// tslint:disable-next-line:max-line-length
import {DataType, DataTypeMap, FlatVector, NamedTensorMap, RecursiveArray, RegularArray, TensorContainer, TypedArray} from './types';
import {DataType, DataTypeMap, FlatVector, NamedTensorMap, RecursiveArray, RegularArray, TensorContainer, TensorContainerArray, TypedArray} from './types';

function assertArgumentIsTensor(
x: Tensor, argName: string, functionName: string) {
Expand Down Expand Up @@ -440,46 +440,46 @@ export function isFunction(f: Function) {
return !!(f && f.constructor && f.call && f.apply);
}

export function extractTensorsFromContainer(result: TensorContainer): Tensor[] {
return extractTensorsFromAny(result);
}

/**
* Extracts any `Tensor`s found within the provided object up to depth 1.
* Extracts any `Tensor`s found within the provided object.
*
* @param container an object that may be a `Tensor` or may directly contain
* `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it
* is safe to pass any object here, except that `Promise`s are not
* supported.
* @returns An array of `Tensors` found within the passed object. If the
* argument is simply a `Tensor', a list containing that `Tensor` is
* returned. If the argument directly contains `Tensor`s, a list of them
* will be returned. `Tensor`s nested more deeply within the argument will
* however not be found. If the object is not a `Tensor` or does not
* returned. If the object is not a `Tensor` or does not
* contain `Tensors`, an empty list is returned.
*/
// tslint:disable-next-line:no-any
export function extractTensorsFromAny(result: any): Tensor[] {
if (result == null) {
return [];
export function getTensorsInContainer(result: TensorContainer): Tensor[] {
const list: Tensor[] = [];
const seen = new Set<{}|void>();
walkTensorContainer(result, list, seen);
return list;
}

function walkTensorContainer(
container: TensorContainer, list: Tensor[], seen: Set<{}|void>): void {
if (container == null) {
return;
}
if (result instanceof Tensor) {
return [result];
if (container instanceof Tensor) {
list.push(container);
return;
}

const list: Tensor[] = [];
// tslint:disable-next-line:no-any
const resultObj = result as {[key: string]: any};
if (!isIterable(resultObj)) {
return [];
if (!isIterable(container)) {
return;
}

// Iteration over keys works also for arrays.
for (const k in resultObj) {
const sublist = flatten(resultObj[k]).filter(x => x instanceof Tensor);
list.push(...sublist);
const iterable = container as TensorContainerArray;
for (const k in iterable) {
const val = iterable[k];
if (!seen.has(val)) {
seen.add(val);
walkTensorContainer(val, list, seen);
}
}
return list;
}

// tslint:disable-next-line:no-any
Expand Down
28 changes: 23 additions & 5 deletions src/util_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
*/

import * as tf from './index';
import {describeWithFlags} from './jasmine_util';
import {Tensor} from './tensor';
import {CPU_ENVS} from './test_util';
import {describeWithFlags} from './jasmine_util';
import {NamedTensorMap} from './types';
import * as util from './util';

Expand Down Expand Up @@ -297,16 +297,16 @@ describe('util.hasEncodingLoss', () => {
});
});

describeWithFlags('extractTensorsFromAny', CPU_ENVS, () => {
describeWithFlags('getTensorsInContainer', CPU_ENVS, () => {
it('null input returns empty tensor', () => {
const results = util.extractTensorsFromAny(null);
const results = util.getTensorsInContainer(null);

expect(results).toEqual([]);
});

it('tensor input returns one element tensor', () => {
const x = tf.scalar(1);
const results = util.extractTensorsFromAny(x);
const results = util.getTensorsInContainer(x);

expect(results).toEqual([x]);
});
Expand All @@ -315,8 +315,26 @@ describeWithFlags('extractTensorsFromAny', CPU_ENVS, () => {
const x1 = tf.scalar(1);
const x2 = tf.scalar(3);
const x3 = tf.scalar(4);
const results = util.extractTensorsFromAny({x1, x2, x3});
const results = util.getTensorsInContainer({x1, x2, x3});

expect(results).toEqual([x1, x2, x3]);
});

it('can extract from arbitrary depth', () => {
const container = [
{x: tf.scalar(1), y: tf.scalar(2)},
[[[tf.scalar(3)]], {z: tf.scalar(4)}]
];
const results = util.getTensorsInContainer(container);
expect(results.length).toBe(4);
});

it('works with loops in container', () => {
const container = [tf.scalar(1), tf.scalar(2), [tf.scalar(3)]];
const innerContainer = [container];
// tslint:disable-next-line:no-any
container.push(innerContainer as any);
const results = util.getTensorsInContainer(container);
expect(results.length).toBe(3);
});
});