forked from tensorflow/tfjs-core
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprofiler_test.ts
120 lines (101 loc) · 3.83 KB
/
profiler_test.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from './index';
import {BackendTimer, BackendTimingInfo} from './kernels/backend';
import {TypedArray} from './kernels/webgl/tex_util';
import {Logger, Profiler} from './profiler';
import {Tensor} from './tensor';
class TestBackendTimer implements BackendTimer {
private counter = 1;
constructor(private delayMs: number, private queryTimeMs: number) {}
async time(query: () => void): Promise<BackendTimingInfo> {
query();
const kernelMs = await new Promise<number>(
resolve => setTimeout(
resolve(this.queryTimeMs * this.counter++), this.delayMs));
return {kernelMs};
}
}
class TestLogger extends Logger {
logKernelProfile(
name: string, result: Tensor, vals: TypedArray, timeMs: number) {}
}
describe('profiler.Profiler', () => {
it('profiles simple function', doneFn => {
const delayMs = 5;
const queryTimeMs = 10;
const timer = new TestBackendTimer(delayMs, queryTimeMs);
const logger = new TestLogger();
const profiler = new Profiler(timer, logger);
spyOn(timer, 'time').and.callThrough();
spyOn(logger, 'logKernelProfile').and.callThrough();
const timeSpy = timer.time as jasmine.Spy;
const logKernelProfileSpy = logger.logKernelProfile as jasmine.Spy;
let kernelCalled = false;
const result = 1;
const resultScalar = tf.scalar(result);
profiler.profileKernel('MatMul', () => {
kernelCalled = true;
return resultScalar;
});
setTimeout(() => {
expect(timeSpy.calls.count()).toBe(1);
expect(logKernelProfileSpy.calls.count()).toBe(1);
expect(logKernelProfileSpy.calls.first().args).toEqual([
'MatMul', resultScalar, new Float32Array([result]), queryTimeMs
]);
expect(kernelCalled).toBe(true);
doneFn();
}, delayMs * 2);
});
it('profiles nested kernel', doneFn => {
const delayMs = 5;
const queryTimeMs = 10;
const timer = new TestBackendTimer(delayMs, queryTimeMs);
const logger = new TestLogger();
const profiler = new Profiler(timer, logger);
spyOn(timer, 'time').and.callThrough();
spyOn(logger, 'logKernelProfile').and.callThrough();
const timeSpy = timer.time as jasmine.Spy;
const logKernelProfileSpy = logger.logKernelProfile as jasmine.Spy;
let matmulKernelCalled = false;
let maxKernelCalled = false;
const result = 1;
const resultScalar = tf.scalar(result);
profiler.profileKernel('MatMul', () => {
const result = profiler.profileKernel('Max', () => {
maxKernelCalled = true;
return resultScalar;
});
matmulKernelCalled = true;
return result;
});
setTimeout(() => {
expect(timeSpy.calls.count()).toBe(2);
expect(logKernelProfileSpy.calls.count()).toBe(2);
expect(logKernelProfileSpy.calls.first().args).toEqual([
'Max', resultScalar, new Float32Array([result]), queryTimeMs
]);
expect(logKernelProfileSpy.calls.argsFor(1)).toEqual([
'MatMul', resultScalar, new Float32Array([result]), queryTimeMs * 2
]);
expect(matmulKernelCalled).toBe(true);
expect(maxKernelCalled).toBe(true);
doneFn();
}, delayMs * 2);
});
});