-
Notifications
You must be signed in to change notification settings - Fork 3k
/
Copy pathtensor-impl.ts
182 lines (168 loc) · 6.66 KB
/
tensor-impl.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {Tensor as TensorInterface} from './tensor';
type TensorType = TensorInterface.Type;
type TensorDataType = TensorInterface.DataType;
type SupportedTypedArrayConstructors = Float32ArrayConstructor|Uint8ArrayConstructor|Int8ArrayConstructor|
Uint16ArrayConstructor|Int16ArrayConstructor|Int32ArrayConstructor|BigInt64ArrayConstructor|Uint8ArrayConstructor|
Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor;
type SupportedTypedArray = InstanceType<SupportedTypedArrayConstructors>;
const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function';
const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function';
// a runtime map that maps type string to TypedArray constructor. Should match Tensor.DataTypeMap.
const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map<string, SupportedTypedArrayConstructors>([
['float32', Float32Array],
['uint8', Uint8Array],
['int8', Int8Array],
['uint16', Uint16Array],
['int16', Int16Array],
['int32', Int32Array],
['bool', Uint8Array],
['float64', Float64Array],
['uint32', Uint32Array],
]);
// a runtime map that maps type string to TypedArray constructor. Should match Tensor.DataTypeMap.
const NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP = new Map<SupportedTypedArrayConstructors, TensorType>([
[Float32Array, 'float32'],
[Uint8Array, 'uint8'],
[Int8Array, 'int8'],
[Uint16Array, 'uint16'],
[Int16Array, 'int16'],
[Int32Array, 'int32'],
[Float64Array, 'float64'],
[Uint32Array, 'uint32'],
]);
if (isBigInt64ArrayAvailable) {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array);
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigInt64Array, 'int64');
}
if (isBigUint64ArrayAvailable) {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array);
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64');
}
/**
* calculate size from dims.
*
* @param dims the dims array. May be an illegal input.
*/
const calculateSize = (dims: readonly unknown[]): number => {
let size = 1;
for (let i = 0; i < dims.length; i++) {
const dim = dims[i];
if (typeof dim !== 'number' || !Number.isSafeInteger(dim)) {
throw new TypeError(`dims[${i}] must be an integer, got: ${dim}`);
}
if (dim < 0) {
throw new RangeError(`dims[${i}] must be a non-negative integer, got: ${dim}`);
}
size *= dim;
}
return size;
};
export class Tensor implements TensorInterface {
//#region constructors
constructor(type: TensorType, data: TensorDataType|readonly number[]|readonly boolean[], dims?: readonly number[]);
constructor(data: TensorDataType|readonly boolean[], dims?: readonly number[]);
constructor(
arg0: TensorType|TensorDataType|readonly boolean[], arg1?: TensorDataType|readonly number[]|readonly boolean[],
arg2?: readonly number[]) {
let type: TensorType;
let data: TensorDataType;
let dims: typeof arg1|typeof arg2;
// check whether arg0 is type or data
if (typeof arg0 === 'string') {
//
// Override: constructor(type, data, ...)
//
type = arg0;
dims = arg2;
if (arg0 === 'string') {
// string tensor
if (!Array.isArray(arg1)) {
throw new TypeError('A string tensor\'s data must be a string array.');
}
// we don't check whether every element in the array is string; this is too slow. we assume it's correct and
// error will be populated at inference
data = arg1;
} else {
// numeric tensor
const typedArrayConstructor = NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.get(arg0);
if (typedArrayConstructor === undefined) {
throw new TypeError(`Unsupported tensor type: ${arg0}.`);
}
if (Array.isArray(arg1)) {
// use 'as any' here because TypeScript's check on type of 'SupportedTypedArrayConstructors.from()' produces
// incorrect results.
// 'typedArrayConstructor' should be one of the typed array prototype objects.
// eslint-disable-next-line @typescript-eslint/no-explicit-any
data = (typedArrayConstructor as any).from(arg1);
} else if (arg1 instanceof typedArrayConstructor) {
data = arg1;
} else {
throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`);
}
}
} else {
//
// Override: constructor(data, ...)
//
dims = arg1;
if (Array.isArray(arg0)) {
// only boolean[] and string[] is supported
if (arg0.length === 0) {
throw new TypeError('Tensor type cannot be inferred from an empty array.');
}
const firstElementType = typeof arg0[0];
if (firstElementType === 'string') {
type = 'string';
data = arg0;
} else if (firstElementType === 'boolean') {
type = 'bool';
// 'arg0' is of type 'boolean[]'. Uint8Array.from(boolean[]) actually works, but typescript thinks this is
// wrong type. We use 'as any' to make it happy.
// eslint-disable-next-line @typescript-eslint/no-explicit-any
data = Uint8Array.from(arg0 as any[]);
} else {
throw new TypeError(`Invalid element type of data array: ${firstElementType}.`);
}
} else {
// get tensor type from TypedArray
const mappedType =
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.get(arg0.constructor as SupportedTypedArrayConstructors);
if (mappedType === undefined) {
throw new TypeError(`Unsupported type for tensor data: ${arg0.constructor}.`);
}
type = mappedType;
data = arg0 as SupportedTypedArray;
}
}
// type and data is processed, now processing dims
if (dims === undefined) {
// assume 1-D tensor if dims omitted
dims = [data.length];
} else if (!Array.isArray(dims)) {
throw new TypeError('A tensor\'s dims must be a number array');
}
// perform check
const size = calculateSize(dims);
if (size !== data.length) {
throw new Error(`Tensor's size(${size}) does not match data length(${data.length}).`);
}
this.dims = dims as readonly number[];
this.type = type;
this.data = data;
this.size = size;
}
//#endregion
//#region fields
readonly dims: readonly number[];
readonly type: TensorType;
readonly data: TensorDataType;
readonly size: number;
//#endregion
//#region tensor utilities
reshape(dims: readonly number[]): Tensor {
return new Tensor(this.type, this.data, dims);
}
//#endregion
}