From ad51fbd476ffae082f910a0e56a805eb85b590ee Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Tue, 22 Oct 2019 14:48:39 -0400 Subject: [PATCH] [Wasm] Modularize `Reshape` and `Cast` (#2259) FEATURE Modularize `Reshape` and `Cast`. Also add support for string tensors. --- tfjs-backend-wasm/WORKSPACE | 1 + tfjs-backend-wasm/src/backend_wasm.ts | 57 +++++++++----------- tfjs-backend-wasm/src/kernels/Cast.ts | 48 +++++++++++++++++ tfjs-backend-wasm/src/kernels/Reshape.ts | 41 ++++++++++++++ tfjs-backend-wasm/src/kernels/all_kernels.ts | 2 + tfjs-backend-wasm/src/setup_test.ts | 7 ++- tfjs-core/src/ops/array_ops.ts | 11 ++-- 7 files changed, 128 insertions(+), 39 deletions(-) create mode 100644 tfjs-backend-wasm/src/kernels/Cast.ts create mode 100644 tfjs-backend-wasm/src/kernels/Reshape.ts diff --git a/tfjs-backend-wasm/WORKSPACE b/tfjs-backend-wasm/WORKSPACE index 97c68a77333..81a2601f4ab 100644 --- a/tfjs-backend-wasm/WORKSPACE +++ b/tfjs-backend-wasm/WORKSPACE @@ -10,6 +10,7 @@ git_repository( name = "xnnpack", commit = "f6839e1355032ee6c78833e1693876d4cdb01436", remote = "https://github.com/google/XNNPACK.git", + shallow_since = "1571339795 -0700", ) # The libraries below are transitive dependencies of XNNPACK that we need to diff --git a/tfjs-backend-wasm/src/backend_wasm.ts b/tfjs-backend-wasm/src/backend_wasm.ts index a254fdd6799..212f077a756 100644 --- a/tfjs-backend-wasm/src/backend_wasm.ts +++ b/tfjs-backend-wasm/src/backend_wasm.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {backend_util, DataStorage, DataType, engine, KernelBackend, Rank, registerBackend, ShapeMap, Tensor, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {backend_util, DataStorage, DataType, engine, KernelBackend, registerBackend, TensorInfo, util} from '@tensorflow/tfjs-core'; import wasmFactory from '../wasm-out/tfjs-backend-wasm'; import {BackendWasmModule} from '../wasm-out/tfjs-backend-wasm'; @@ -27,6 +27,8 @@ interface TensorData { memoryOffset: number; shape: number[]; dtype: DataType; + /** Only used for string tensors, storing encoded bytes. */ + stringBytes?: Uint8Array[]; } export type DataId = object; // object instead of {} to force non-primitive. @@ -41,7 +43,7 @@ export class BackendWasm extends KernelBackend { this.dataIdMap = new DataStorage(this, engine()); } - write(values: backend_util.TypedArray, shape: number[], dtype: DataType): + write(values: backend_util.BackendValues, shape: number[], dtype: DataType): DataId { const dataId = {}; this.move(dataId, values, shape, dtype); @@ -53,17 +55,25 @@ export class BackendWasm extends KernelBackend { } move( - dataId: DataId, values: backend_util.TypedArray, shape: number[], + dataId: DataId, values: backend_util.BackendValues, shape: number[], dtype: DataType): void { + const id = this.dataIdNextNumber++; + if (dtype === 'string') { + const stringBytes = values as Uint8Array[]; + this.dataIdMap.set( + dataId, {id, stringBytes, shape, dtype, memoryOffset: null}); + return; + } const numBytes = util.sizeFromShape(shape) * util.bytesPerElement(dtype); const memoryOffset = this.wasm._malloc(numBytes); - const id = this.dataIdNextNumber++; this.dataIdMap.set(dataId, {id, memoryOffset, shape, dtype}); const shapeBytes = new Uint8Array(new Int32Array(shape).buffer); this.wasm.tfjs.registerTensor( id, shapeBytes, shape.length, dtypeToEnumValue(dtype), memoryOffset); if (values != null) { - this.wasm.HEAPU8.set(new Uint8Array(values.buffer), memoryOffset); + this.wasm.HEAPU8.set( + new Uint8Array((values as backend_util.TypedArray).buffer), + memoryOffset); } } @@ -72,7 +82,11 @@ export class BackendWasm extends KernelBackend { } readSync(dataId: DataId): backend_util.BackendValues { - const {memoryOffset, dtype, shape} = this.dataIdMap.get(dataId); + const {memoryOffset, dtype, shape, stringBytes} = + this.dataIdMap.get(dataId); + if (dtype === 'string') { + return stringBytes; + } const bytes = this.wasm.HEAPU8.slice( memoryOffset, memoryOffset + util.sizeFromShape(shape) * util.bytesPerElement(dtype)); @@ -101,25 +115,12 @@ export class BackendWasm extends KernelBackend { this.wasm = null; } - // Kernels. - - reshape(x: T, newShape: ShapeMap[R]): - Tensor { - return engine().makeTensorFromDataId(x.dataId, newShape, x.dtype) as - Tensor; - } - - cast(x: T, dtype: DataType): T { - const out = this.makeOutTensor(x.shape, dtype); - const {memoryOffset: inOffset} = this.dataIdMap.get(x.dataId); - const {memoryOffset: outOffset} = this.dataIdMap.get(out.dataId); - const inVals = this.typedArrayFromHeap(inOffset, x.dtype, x.size); - const outVals = this.typedArrayFromHeap(outOffset, dtype, out.size); - outVals.set(inVals); - return out as T; + makeOutput(shape: number[], dtype: DataType): TensorInfo { + const dataId = this.write(null /* values */, shape, dtype); + return {dataId, shape, dtype}; } - private typedArrayFromHeap(offset: number, dtype: DataType, size: number): + typedArrayFromHeap(offset: number, dtype: DataType, size: number): backend_util.TypedArray { const buffer = this.wasm.HEAPU8.buffer; switch (dtype) { @@ -133,16 +134,6 @@ export class BackendWasm extends KernelBackend { throw new Error(`Uknown dtype ${dtype}`); } } - - makeOutput(shape: number[], dtype: DataType): TensorInfo { - const dataId = this.write(null /* values */, shape, dtype); - return {dataId, shape, dtype}; - } - - private makeOutTensor(shape: number[], dtype: DataType): Tensor { - const dataId = this.write(null /* values */, shape, dtype); - return engine().makeTensorFromDataId(dataId, shape, dtype, this); - } } registerBackend('wasm', async () => { diff --git a/tfjs-backend-wasm/src/kernels/Cast.ts b/tfjs-backend-wasm/src/kernels/Cast.ts new file mode 100644 index 00000000000..798375ff37c --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Cast.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {DataType, NamedAttrMap, NamedTensorInfoMap, registerKernel, util} from '@tensorflow/tfjs-core'; +import {TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface CastInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +interface CastAttrs extends NamedAttrMap { + dtype: DataType; +} + +function cast( + args: {inputs: CastInputs, attrs: CastAttrs, backend: BackendWasm}) { + const {inputs: {x}, attrs: {dtype}, backend} = args; + const out = backend.makeOutput(x.shape, dtype); + const {memoryOffset: inOffset} = backend.dataIdMap.get(x.dataId); + const {memoryOffset: outOffset} = backend.dataIdMap.get(out.dataId); + const size = util.sizeFromShape(x.shape); + const inVals = backend.typedArrayFromHeap(inOffset, x.dtype, size); + const outVals = backend.typedArrayFromHeap(outOffset, dtype, size); + outVals.set(inVals); + return out; +} + +registerKernel({ + kernelName: 'Cast', + backendName: 'wasm', + kernelFunc: cast, +}); diff --git a/tfjs-backend-wasm/src/kernels/Reshape.ts b/tfjs-backend-wasm/src/kernels/Reshape.ts new file mode 100644 index 00000000000..f1009c6ab38 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Reshape.ts @@ -0,0 +1,41 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel} from '@tensorflow/tfjs-core'; +import {TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface ReshapeInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +interface ReshapeAttrs extends NamedAttrMap { + shape: number[]; +} + +function reshape( + args: {inputs: ReshapeInputs, attrs: ReshapeAttrs, backend: BackendWasm}) { + const {inputs: {x}, attrs: {shape}} = args; + return {dataId: x.dataId, shape, dtype: x.dtype}; +} + +registerKernel({ + kernelName: 'Reshape', + backendName: 'wasm', + kernelFunc: reshape, +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 3cfca45d5eb..76a174c1501 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -20,4 +20,6 @@ // the contents of this file and import only the kernels that are needed. import './Add'; import './BatchMatMul'; +import './Cast'; import './Prelu'; +import './Reshape'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 5e8279b65b7..7a691c7911d 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -25,7 +25,7 @@ const env = jasmine.getEnv(); const grepFilter = env.specFilter; /** Tests that have these substrings in their name will be included. */ -const INCLUDE_LIST: string[] = ['add ', 'matmul ', 'prelu ']; +const INCLUDE_LIST: string[] = ['add ', 'matmul ', 'prelu ', ' cast']; /** Tests that have these substrings in their name will be excluded. */ const EXCLUDE_LIST: string[] = [ 'complex', // Complex numbers not yet implemented. @@ -42,7 +42,10 @@ const EXCLUDE_LIST: string[] = [ 'matmul followed by mul', // mul not supported yet // prelu - 'prelu test-wasm undefined derivative', + 'prelu test-wasm undefined derivative', // Missing gradient. + + // cast + 'shallow slice an input that was cast', // Slice is not implemented. ]; /** diff --git a/tfjs-core/src/ops/array_ops.ts b/tfjs-core/src/ops/array_ops.ts index 804e991b170..326cbf3fcb4 100644 --- a/tfjs-core/src/ops/array_ops.ts +++ b/tfjs-core/src/ops/array_ops.ts @@ -373,10 +373,11 @@ function reshape_( () => 'new shape and old shape must have the same number of elements.'); const grad = (dy: Tensor) => { - return {$x: () => dy.reshape($x.shape)}; + return {x: () => dy.reshape($x.shape)}; }; + const attrs = {shape}; return ENGINE.runKernelFunc( - backend => backend.reshape($x, shape), {$x}, grad); + backend => backend.reshape($x, shape), {x: $x}, grad, 'Reshape', attrs); } /** @@ -422,9 +423,11 @@ function cast_(x: T|TensorLike, dtype: DataType): T { } const grad = (dy: T) => { - return {$x: () => dy.clone()}; + return {x: () => dy.clone()}; }; - return ENGINE.runKernelFunc(backend => backend.cast($x, dtype), {$x}, grad); + const attrs = {dtype}; + return ENGINE.runKernelFunc( + backend => backend.cast($x, dtype), {x: $x}, grad, 'Cast', attrs); } /**