Skip to content

Commit

Permalink
[Wasm] Modularize Reshape and Cast (#2259)
Browse files Browse the repository at this point in the history
FEATURE

Modularize `Reshape` and `Cast`. Also add support for string tensors.
  • Loading branch information
dsmilkov authored Oct 22, 2019
1 parent 74b93f0 commit ad51fbd
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 39 deletions.
1 change: 1 addition & 0 deletions tfjs-backend-wasm/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ git_repository(
name = "xnnpack",
commit = "f6839e1355032ee6c78833e1693876d4cdb01436",
remote = "https://github.com/google/XNNPACK.git",
shallow_since = "1571339795 -0700",
)

# The libraries below are transitive dependencies of XNNPACK that we need to
Expand Down
57 changes: 24 additions & 33 deletions tfjs-backend-wasm/src/backend_wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {backend_util, DataStorage, DataType, engine, KernelBackend, Rank, registerBackend, ShapeMap, Tensor, TensorInfo, util} from '@tensorflow/tfjs-core';
import {backend_util, DataStorage, DataType, engine, KernelBackend, registerBackend, TensorInfo, util} from '@tensorflow/tfjs-core';

import wasmFactory from '../wasm-out/tfjs-backend-wasm';
import {BackendWasmModule} from '../wasm-out/tfjs-backend-wasm';
Expand All @@ -27,6 +27,8 @@ interface TensorData {
memoryOffset: number;
shape: number[];
dtype: DataType;
/** Only used for string tensors, storing encoded bytes. */
stringBytes?: Uint8Array[];
}

export type DataId = object; // object instead of {} to force non-primitive.
Expand All @@ -41,7 +43,7 @@ export class BackendWasm extends KernelBackend {
this.dataIdMap = new DataStorage(this, engine());
}

write(values: backend_util.TypedArray, shape: number[], dtype: DataType):
write(values: backend_util.BackendValues, shape: number[], dtype: DataType):
DataId {
const dataId = {};
this.move(dataId, values, shape, dtype);
Expand All @@ -53,17 +55,25 @@ export class BackendWasm extends KernelBackend {
}

move(
dataId: DataId, values: backend_util.TypedArray, shape: number[],
dataId: DataId, values: backend_util.BackendValues, shape: number[],
dtype: DataType): void {
const id = this.dataIdNextNumber++;
if (dtype === 'string') {
const stringBytes = values as Uint8Array[];
this.dataIdMap.set(
dataId, {id, stringBytes, shape, dtype, memoryOffset: null});
return;
}
const numBytes = util.sizeFromShape(shape) * util.bytesPerElement(dtype);
const memoryOffset = this.wasm._malloc(numBytes);
const id = this.dataIdNextNumber++;
this.dataIdMap.set(dataId, {id, memoryOffset, shape, dtype});
const shapeBytes = new Uint8Array(new Int32Array(shape).buffer);
this.wasm.tfjs.registerTensor(
id, shapeBytes, shape.length, dtypeToEnumValue(dtype), memoryOffset);
if (values != null) {
this.wasm.HEAPU8.set(new Uint8Array(values.buffer), memoryOffset);
this.wasm.HEAPU8.set(
new Uint8Array((values as backend_util.TypedArray).buffer),
memoryOffset);
}
}

Expand All @@ -72,7 +82,11 @@ export class BackendWasm extends KernelBackend {
}

readSync(dataId: DataId): backend_util.BackendValues {
const {memoryOffset, dtype, shape} = this.dataIdMap.get(dataId);
const {memoryOffset, dtype, shape, stringBytes} =
this.dataIdMap.get(dataId);
if (dtype === 'string') {
return stringBytes;
}
const bytes = this.wasm.HEAPU8.slice(
memoryOffset,
memoryOffset + util.sizeFromShape(shape) * util.bytesPerElement(dtype));
Expand Down Expand Up @@ -101,25 +115,12 @@ export class BackendWasm extends KernelBackend {
this.wasm = null;
}

// Kernels.

reshape<T extends Tensor, R extends Rank>(x: T, newShape: ShapeMap[R]):
Tensor<R> {
return engine().makeTensorFromDataId(x.dataId, newShape, x.dtype) as
Tensor<R>;
}

cast<T extends Tensor>(x: T, dtype: DataType): T {
const out = this.makeOutTensor(x.shape, dtype);
const {memoryOffset: inOffset} = this.dataIdMap.get(x.dataId);
const {memoryOffset: outOffset} = this.dataIdMap.get(out.dataId);
const inVals = this.typedArrayFromHeap(inOffset, x.dtype, x.size);
const outVals = this.typedArrayFromHeap(outOffset, dtype, out.size);
outVals.set(inVals);
return out as T;
makeOutput(shape: number[], dtype: DataType): TensorInfo {
const dataId = this.write(null /* values */, shape, dtype);
return {dataId, shape, dtype};
}

private typedArrayFromHeap(offset: number, dtype: DataType, size: number):
typedArrayFromHeap(offset: number, dtype: DataType, size: number):
backend_util.TypedArray {
const buffer = this.wasm.HEAPU8.buffer;
switch (dtype) {
Expand All @@ -133,16 +134,6 @@ export class BackendWasm extends KernelBackend {
throw new Error(`Uknown dtype ${dtype}`);
}
}

makeOutput(shape: number[], dtype: DataType): TensorInfo {
const dataId = this.write(null /* values */, shape, dtype);
return {dataId, shape, dtype};
}

private makeOutTensor(shape: number[], dtype: DataType): Tensor {
const dataId = this.write(null /* values */, shape, dtype);
return engine().makeTensorFromDataId(dataId, shape, dtype, this);
}
}

registerBackend('wasm', async () => {
Expand Down
48 changes: 48 additions & 0 deletions tfjs-backend-wasm/src/kernels/Cast.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {DataType, NamedAttrMap, NamedTensorInfoMap, registerKernel, util} from '@tensorflow/tfjs-core';
import {TensorInfo} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

interface CastInputs extends NamedTensorInfoMap {
x: TensorInfo;
}

interface CastAttrs extends NamedAttrMap {
dtype: DataType;
}

function cast(
args: {inputs: CastInputs, attrs: CastAttrs, backend: BackendWasm}) {
const {inputs: {x}, attrs: {dtype}, backend} = args;
const out = backend.makeOutput(x.shape, dtype);
const {memoryOffset: inOffset} = backend.dataIdMap.get(x.dataId);
const {memoryOffset: outOffset} = backend.dataIdMap.get(out.dataId);
const size = util.sizeFromShape(x.shape);
const inVals = backend.typedArrayFromHeap(inOffset, x.dtype, size);
const outVals = backend.typedArrayFromHeap(outOffset, dtype, size);
outVals.set(inVals);
return out;
}

registerKernel({
kernelName: 'Cast',
backendName: 'wasm',
kernelFunc: cast,
});
41 changes: 41 additions & 0 deletions tfjs-backend-wasm/src/kernels/Reshape.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {NamedAttrMap, NamedTensorInfoMap, registerKernel} from '@tensorflow/tfjs-core';
import {TensorInfo} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

interface ReshapeInputs extends NamedTensorInfoMap {
x: TensorInfo;
}

interface ReshapeAttrs extends NamedAttrMap {
shape: number[];
}

function reshape(
args: {inputs: ReshapeInputs, attrs: ReshapeAttrs, backend: BackendWasm}) {
const {inputs: {x}, attrs: {shape}} = args;
return {dataId: x.dataId, shape, dtype: x.dtype};
}

registerKernel({
kernelName: 'Reshape',
backendName: 'wasm',
kernelFunc: reshape,
});
2 changes: 2 additions & 0 deletions tfjs-backend-wasm/src/kernels/all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@
// the contents of this file and import only the kernels that are needed.
import './Add';
import './BatchMatMul';
import './Cast';
import './Prelu';
import './Reshape';
7 changes: 5 additions & 2 deletions tfjs-backend-wasm/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ const env = jasmine.getEnv();
const grepFilter = env.specFilter;

/** Tests that have these substrings in their name will be included. */
const INCLUDE_LIST: string[] = ['add ', 'matmul ', 'prelu '];
const INCLUDE_LIST: string[] = ['add ', 'matmul ', 'prelu ', ' cast'];
/** Tests that have these substrings in their name will be excluded. */
const EXCLUDE_LIST: string[] = [
'complex', // Complex numbers not yet implemented.
Expand All @@ -42,7 +42,10 @@ const EXCLUDE_LIST: string[] = [
'matmul followed by mul', // mul not supported yet

// prelu
'prelu test-wasm undefined derivative',
'prelu test-wasm undefined derivative', // Missing gradient.

// cast
'shallow slice an input that was cast', // Slice is not implemented.
];

/**
Expand Down
11 changes: 7 additions & 4 deletions tfjs-core/src/ops/array_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,11 @@ function reshape_<R2 extends Rank>(
() => 'new shape and old shape must have the same number of elements.');

const grad = (dy: Tensor<R2>) => {
return {$x: () => dy.reshape($x.shape)};
return {x: () => dy.reshape($x.shape)};
};
const attrs = {shape};
return ENGINE.runKernelFunc(
backend => backend.reshape($x, shape), {$x}, grad);
backend => backend.reshape($x, shape), {x: $x}, grad, 'Reshape', attrs);
}

/**
Expand Down Expand Up @@ -422,9 +423,11 @@ function cast_<T extends Tensor>(x: T|TensorLike, dtype: DataType): T {
}

const grad = (dy: T) => {
return {$x: () => dy.clone()};
return {x: () => dy.clone()};
};
return ENGINE.runKernelFunc(backend => backend.cast($x, dtype), {$x}, grad);
const attrs = {dtype};
return ENGINE.runKernelFunc(
backend => backend.cast($x, dtype), {x: $x}, grad, 'Cast', attrs);
}

/**
Expand Down

0 comments on commit ad51fbd

Please sign in to comment.