forked from tensorflow/tfjs-core
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenvironment.ts
471 lines (409 loc) · 14.8 KB
/
environment.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as device_util from './device_util';
import {doc} from './doc';
import {Engine, MemoryInfo} from './engine';
import {KernelBackend} from './kernels/backend';
import * as util from './util';
export enum Type {
NUMBER,
BOOLEAN,
STRING
}
export interface Features {
// Whether to enable debug mode.
'DEBUG'?: boolean;
// Whether we are in a browser (as versus, say, node.js) environment.
'IS_BROWSER'?: boolean;
// The disjoint_query_timer extension version.
// 0: disabled, 1: EXT_disjoint_timer_query, 2:
// EXT_disjoint_timer_query_webgl2.
// In Firefox with WebGL 2.0,
// EXT_disjoint_timer_query_webgl2 is not available, so we must use the
// WebGL 1.0 extension.
'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'?: number;
// Whether the timer object from the disjoint_query_timer extension gives
// timing information that is reliable.
'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE'?: boolean;
// 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0.
'WEBGL_VERSION'?: number;
// Whether writing & reading floating point textures is enabled. When
// false, fall back to using unsigned byte textures.
'WEBGL_FLOAT_TEXTURE_ENABLED'?: boolean;
// Whether WEBGL_get_buffer_sub_data_async is enabled.
'WEBGL_GET_BUFFER_SUB_DATA_ASYNC_EXTENSION_ENABLED'?: boolean;
'BACKEND'?: string;
}
export const URL_PROPERTIES: URLProperty[] = [
{name: 'DEBUG', type: Type.BOOLEAN}, {name: 'IS_BROWSER', type: Type.BOOLEAN},
{name: 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', type: Type.NUMBER},
{name: 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', type: Type.BOOLEAN},
{name: 'WEBGL_VERSION', type: Type.NUMBER},
{name: 'WEBGL_FLOAT_TEXTURE_ENABLED', type: Type.BOOLEAN}, {
name: 'WEBGL_GET_BUFFER_SUB_DATA_ASYNC_EXTENSION_ENABLED',
type: Type.BOOLEAN
},
{name: 'BACKEND', type: Type.STRING}
];
export interface URLProperty {
name: keyof Features;
type: Type;
}
function hasExtension(gl: WebGLRenderingContext, extensionName: string) {
const ext = gl.getExtension(extensionName);
return ext != null;
}
function getWebGLRenderingContext(webGLVersion: number): WebGLRenderingContext {
if (webGLVersion === 0) {
throw new Error('Cannot get WebGL rendering context, WebGL is disabled.');
}
const tempCanvas = document.createElement('canvas');
if (webGLVersion === 1) {
return (tempCanvas.getContext('webgl') ||
tempCanvas.getContext('experimental-webgl')) as
WebGLRenderingContext;
}
return tempCanvas.getContext('webgl2') as WebGLRenderingContext;
}
function loseContext(gl: WebGLRenderingContext) {
if (gl != null) {
const loseContextExtension = gl.getExtension('WEBGL_lose_context');
if (loseContextExtension == null) {
throw new Error(
'Extension WEBGL_lose_context not supported on this browser.');
}
loseContextExtension.loseContext();
}
}
function isWebGLVersionEnabled(webGLVersion: 1|2) {
const gl = getWebGLRenderingContext(webGLVersion);
if (gl != null) {
loseContext(gl);
return true;
}
return false;
}
function getWebGLDisjointQueryTimerVersion(webGLVersion: number): number {
if (webGLVersion === 0) {
return 0;
}
let queryTimerVersion: number;
const gl = getWebGLRenderingContext(webGLVersion);
if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') &&
webGLVersion === 2) {
queryTimerVersion = 2;
} else if (hasExtension(gl, 'EXT_disjoint_timer_query')) {
queryTimerVersion = 1;
} else {
queryTimerVersion = 0;
}
if (gl != null) {
loseContext(gl);
}
return queryTimerVersion;
}
function isFloatTextureReadPixelsEnabled(webGLVersion: number): boolean {
if (webGLVersion === 0) {
return false;
}
const gl = getWebGLRenderingContext(webGLVersion);
if (webGLVersion === 1) {
if (!hasExtension(gl, 'OES_texture_float')) {
return false;
}
} else {
if (!hasExtension(gl, 'EXT_color_buffer_float')) {
return false;
}
}
const frameBuffer = gl.createFramebuffer();
const texture = gl.createTexture();
gl.bindTexture(gl.TEXTURE_2D, texture);
// tslint:disable-next-line:no-any
const internalFormat = webGLVersion === 2 ? (gl as any).RGBA32F : gl.RGBA;
gl.texImage2D(
gl.TEXTURE_2D, 0, internalFormat, 1, 1, 0, gl.RGBA, gl.FLOAT, null);
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
gl.framebufferTexture2D(
gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
const frameBufferComplete =
(gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE);
gl.readPixels(0, 0, 1, 1, gl.RGBA, gl.FLOAT, new Float32Array(4));
const readPixelsNoError = gl.getError() === gl.NO_ERROR;
loseContext(gl);
return frameBufferComplete && readPixelsNoError;
}
function isWebGLGetBufferSubDataAsyncExtensionEnabled(webGLVersion: number) {
// TODO(nsthorat): Remove this once we fix
// https://github.com/tensorflow/tfjs/issues/137
if (webGLVersion > 0) {
return false;
}
if (webGLVersion !== 2) {
return false;
}
const gl = getWebGLRenderingContext(webGLVersion);
const isEnabled = hasExtension(gl, 'WEBGL_get_buffer_sub_data_async');
loseContext(gl);
return isEnabled;
}
export class Environment {
private features: Features = {};
private globalEngine: Engine;
private registry:
{[id: string]: {backend: KernelBackend, priority: number}} = {};
private currentBackend: string;
constructor(features?: Features) {
if (features != null) {
this.features = features;
}
if (this.get('DEBUG')) {
console.warn(
'Debugging mode is ON. The output of every math call will ' +
'be downloaded to CPU and checked for NaNs. ' +
'This significantly impacts performance.');
}
}
/**
* Sets the backend (cpu, webgl, etc) responsible for creating tensors and
* executing operations on those tensors.
*
* Note this disposes the current backend, if any, as well as any tensors
* associated with it. A new backend is initialized, even if it is of the
* same type as the previous one.
*
* @param backendType The backend type. Currently supports `'webgl'|'cpu'` in
* the browser, and `'tensorflow'` under node.js (requires tfjs-node).
* @param safeMode Defaults to false. In safe mode, you are forced to
* construct tensors and call math operations inside a `tidy()` which
* will automatically clean up intermediate tensors.
*/
@doc({heading: 'Environment'})
static setBackend(backendType: string, safeMode = false) {
if (!(backendType in ENV.registry)) {
throw new Error(`Backend type '${backendType}' not found in registry`);
}
ENV.initBackend(backendType, safeMode);
}
/**
* Returns the current backend (cpu, webgl, etc). The backend is responsible
* for creating tensors and executing operations on those tensors.
*/
@doc({heading: 'Environment'})
static getBackend(): string {
ENV.initDefaultBackend();
return ENV.currentBackend;
}
/**
* Dispose all variables kept in backend engine.
*/
@doc({heading: 'Environment'})
static disposeVariables(): void {
ENV.engine.disposeVariables();
}
/**
* Returns memory info at the current time in the program. The result is an
* object with the following properties:
*
* - `numBytes`: Number of bytes allocated (undisposed) at this time.
* - `numTensors`: Number of unique tensors allocated.
* - `numDataBuffers`: Number of unique data buffers allocated
* (undisposed) at this time, which is ≤ the number of tensors
* (e.g. `a.reshape(newShape)` makes a new Tensor that shares the same
* data buffer with `a`).
* - `unreliable`: `Optional` `boolean`:
* - On WebGL, not present (always reliable).
* - On CPU, true. Due to automatic garbage collection, these numbers
* represent undisposed tensors, i.e. not wrapped in `tidy()`, or
* lacking a call to `tensor.dispose()`.
*/
@doc({heading: 'Performance', subheading: 'Memory'})
static memory(): MemoryInfo {
return ENV.engine.memory();
}
get<K extends keyof Features>(feature: K): Features[K] {
if (feature in this.features) {
return this.features[feature];
}
this.features[feature] = this.evaluateFeature(feature);
return this.features[feature];
}
set<K extends keyof Features>(feature: K, value: Features[K]): void {
this.features[feature] = value;
}
getBestBackendType(): string {
if (Object.keys(this.registry).length === 0) {
throw new Error('No backend found in registry.');
}
const sortedBackends = Object.keys(this.registry)
.map(name => {
return {name, entry: this.registry[name]};
})
.sort((a, b) => {
// Highest priority comes first.
return b.entry.priority - a.entry.priority;
});
return sortedBackends[0].name;
}
private evaluateFeature<K extends keyof Features>(feature: K): Features[K] {
if (feature === 'DEBUG') {
return false;
} else if (feature === 'IS_BROWSER') {
return typeof window !== 'undefined';
} else if (feature === 'BACKEND') {
return this.getBestBackendType();
} else if (feature === 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') {
const webGLVersion = this.get('WEBGL_VERSION');
if (webGLVersion === 0) {
return 0;
}
return getWebGLDisjointQueryTimerVersion(webGLVersion);
} else if (feature === 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') {
return this.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 &&
!device_util.isMobile();
} else if (feature === 'WEBGL_VERSION') {
if (isWebGLVersionEnabled(2)) {
return 2;
} else if (isWebGLVersionEnabled(1)) {
return 1;
}
return 0;
} else if (feature === 'WEBGL_FLOAT_TEXTURE_ENABLED') {
return isFloatTextureReadPixelsEnabled(this.get('WEBGL_VERSION'));
} else if (
feature === 'WEBGL_GET_BUFFER_SUB_DATA_ASYNC_EXTENSION_ENABLED') {
return isWebGLGetBufferSubDataAsyncExtensionEnabled(
this.get('WEBGL_VERSION'));
}
throw new Error(`Unknown feature ${feature}.`);
}
setFeatures(features: Features) {
this.features = features;
}
reset() {
this.features = getFeaturesFromURL();
if (this.globalEngine != null) {
this.globalEngine = null;
}
}
private initBackend(backendType?: string, safeMode = false) {
this.currentBackend = backendType;
const backend = ENV.findBackend(backendType);
this.globalEngine = new Engine(backend, safeMode);
}
findBackend(name: string): KernelBackend {
if (!(name in this.registry)) {
return null;
}
return this.registry[name].backend;
}
/**
* Registers a global backend. The registration should happen when importing
* a module file (e.g. when importing `backend_webgl.ts`), and is used for
* modular builds (e.g. custom tfjs bundle with only webgl support).
*
* @param factory: The backend factory function. When called, it should
* return an instance of the backend.
* @param priority The priority of the backend (higher = more important).
* In case multiple backends are registered, `getBestBackendType` uses
* priority to find the best backend. Defaults to 1.
* @return False if the creation/registration failed. True otherwise.
*/
registerBackend(name: string, factory: () => KernelBackend, priority = 1):
boolean {
if (name in this.registry) {
console.warn(`${name} backend was already registered`);
}
try {
const backend = factory();
this.registry[name] = {backend, priority};
return true;
} catch (err) {
console.warn(err.message);
return false;
}
}
removeBackend(name: string): void {
if (!(name in this.registry)) {
throw new Error(`${name} backend not found in registry`);
}
this.registry[name].backend.dispose();
delete this.registry[name];
}
get engine(): Engine {
this.initDefaultBackend();
return this.globalEngine;
}
private initDefaultBackend() {
if (this.globalEngine == null) {
this.initBackend(ENV.get('BACKEND'), false /* safeMode */);
}
}
}
// Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true.
const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
function getFeaturesFromURL(): Features {
const features: Features = {};
if (typeof window === 'undefined' || typeof window.location === 'undefined') {
return features;
}
const urlParams = util.getQueryParams(window.location.search);
if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
const urlFlags: {[key: string]: string} = {};
const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
keyValues.forEach(keyValue => {
const [key, value] = keyValue.split(':') as [string, string];
urlFlags[key] = value;
});
URL_PROPERTIES.forEach(urlProperty => {
if (urlProperty.name in urlFlags) {
console.log(
`Setting feature override from URL ${urlProperty.name}: ` +
`${urlFlags[urlProperty.name]}`);
if (urlProperty.type === Type.NUMBER) {
features[urlProperty.name] = +urlFlags[urlProperty.name];
} else if (urlProperty.type === Type.BOOLEAN) {
features[urlProperty.name] = urlFlags[urlProperty.name] === 'true';
} else if (urlProperty.type === Type.STRING) {
// tslint:disable-next-line:no-any
features[urlProperty.name] = urlFlags[urlProperty.name] as any;
} else {
console.warn(`Unknown URL param: ${urlProperty.name}.`);
}
}
});
}
return features;
}
function getGlobalNamespace(): {ENV: Environment} {
// tslint:disable-next-line:no-any
let ns: any;
if (typeof (window) !== 'undefined') {
ns = window;
} else if (typeof (global) !== 'undefined') {
ns = global;
} else {
throw new Error('Could not find a global object');
}
return ns;
}
function getOrMakeEnvironment(): Environment {
const ns = getGlobalNamespace();
ns.ENV = ns.ENV || new Environment(getFeaturesFromURL());
return ns.ENV;
}
export let ENV = getOrMakeEnvironment();