Skip to content

Commit

Permalink
[js/web] WebAssembly profiling (#8932)
Browse files Browse the repository at this point in the history
* add p50 in test

* Preallocate WebAssembly worker threads to minimize worker creation time

* WebAssembly profiling

* merge master

* merge with proxy changes

* disable profiling tests from WebAssembly build

* fix e2e test failure

Co-authored-by: Yulong Wang <[email protected]>
  • Loading branch information
hanbitmyths and fs-eire authored Sep 8, 2021
1 parent 0193490 commit 4505243
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 11 deletions.
2 changes: 2 additions & 0 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ export interface OrtWasmModule extends EmscriptenModule {
_OrtCreateRunOptions(logSeverityLevel: number, logVerbosityLevel: number, terminate: boolean, tag: number): number;
_OrtAddRunConfigEntry(runOptionsHandle: number, configKey: number, configValue: number): number;
_OrtReleaseRunOptions(runOptionsHandle: number): void;

_OrtEndProfiling(sessionHandle: number): number;
//#endregion

//#region config
Expand Down
8 changes: 7 additions & 1 deletion js/web/lib/wasm/proxy-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,10 @@ interface MessageRun extends MessageError {
out?: SerializableTensor[];
}

export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSession|MessageReleaseSession|MessageRun;
interface MesssageEndProfiling extends MessageError {
type: 'end-profiling';
in ?: number;
}

export type OrtWasmMessage =
MessageInitWasm|MessageInitOrt|MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling;
11 changes: 10 additions & 1 deletion js/web/lib/wasm/proxy-worker/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
/// <reference lib="webworker" />

import {OrtWasmMessage} from '../proxy-messages';
import {createSession, extractTransferableBuffers, initOrt, releaseSession, run} from '../wasm-core-impl';
import {createSession, endProfiling, extractTransferableBuffers, initOrt, releaseSession, run} from '../wasm-core-impl';
import {initializeWebAssembly} from '../wasm-factory';

self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
Expand Down Expand Up @@ -51,6 +51,15 @@ self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
postMessage({type: 'run', err} as OrtWasmMessage);
}
break;
case 'end-profiling':
try {
const handler = ev.data.in!;
endProfiling(handler);
postMessage({type: 'end-profiling'} as OrtWasmMessage);
} catch (err) {
postMessage({type: 'end-profiling', err} as OrtWasmMessage);
}
break;
default:
}
};
21 changes: 21 additions & 0 deletions js/web/lib/wasm/proxy-wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ let initOrtCallbacks: PromiseCallbacks;
const createSessionCallbacks: Array<PromiseCallbacks<SerializableSessionMetadata>> = [];
const releaseSessionCallbacks: Array<PromiseCallbacks<void>> = [];
const runCallbacks: Array<PromiseCallbacks<SerializableTensor[]>> = [];
const endProfilingCallbacks: Array<PromiseCallbacks<void>> = [];

const ensureWorker = (): void => {
if (initializing || !initialized || aborted || !proxyWorker) {
Expand Down Expand Up @@ -67,6 +68,13 @@ const onProxyWorkerMessage = (ev: MessageEvent<OrtWasmMessage>): void => {
runCallbacks.shift()![0](ev.data.out!);
}
break;
case 'end-profiling':
if (ev.data.err) {
endProfilingCallbacks.shift()![1](ev.data.err);
} else {
endProfilingCallbacks.shift()![0]();
}
break;
default:
}
};
Expand Down Expand Up @@ -163,3 +171,16 @@ export const run = async(
return core.run(sessionId, inputIndices, inputs, outputIndices, options);
}
};

export const endProfiling = async(sessionId: number): Promise<void> => {
if (isProxy()) {
ensureWorker();
return new Promise<void>((resolve, reject) => {
endProfilingCallbacks.push([resolve, reject]);
const message: OrtWasmMessage = {type: 'end-profiling', in : sessionId};
proxyWorker!.postMessage(message);
});
} else {
core.endProfiling(sessionId);
}
};
4 changes: 2 additions & 2 deletions js/web/lib/wasm/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import {env, InferenceSession, SessionHandler, Tensor} from 'onnxruntime-common';

import {createSession, initOrt, releaseSession, run} from './proxy-wrapper';
import {createSession, endProfiling, initOrt, releaseSession, run} from './proxy-wrapper';

let ortInit: boolean;

Expand Down Expand Up @@ -85,6 +85,6 @@ export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler {
}

endProfiling(): void {
// TODO: implement profiling
void endProfiling(this.sessionId);
}
}
7 changes: 4 additions & 3 deletions js/web/lib/wasm/session-options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,13 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n
throw new Error(`log verbosity level is not valid: ${options.logVerbosityLevel}`);
}

// TODO: Support profiling
sessionOptions.enableProfiling = false;
if (options?.enableProfiling === undefined) {
sessionOptions.enableProfiling = false;
}

sessionOptionsHandle = wasm._OrtCreateSessionOptions(
graphOptimizationLevel, !!sessionOptions.enableCpuMemArena!, !!sessionOptions.enableMemPattern!, executionMode,
sessionOptions.enableProfiling, 0, logIdDataOffset, sessionOptions.logSeverityLevel!,
!!sessionOptions.enableProfiling!, 0, logIdDataOffset, sessionOptions.logSeverityLevel!,
sessionOptions.logVerbosityLevel!);
if (sessionOptionsHandle === 0) {
throw new Error('Can\'t create session options');
Expand Down
19 changes: 19 additions & 0 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,25 @@ export const run =
}
};

/**
* end profiling
*/
export const endProfiling = (sessionId: number): void => {
const wasm = getInstance();
const session = activeSessions[sessionId];
if (!session) {
throw new Error('invalid session id');
}
const sessionHandle = session[0];

// profile file name is not used yet, but it must be freed.
const profileFileName = wasm._OrtEndProfiling(sessionHandle);
if (profileFileName === 0) {
throw new Error('Can\'t get an profile file name');
}
wasm._OrtFree(profileFileName);
};

export const extractTransferableBuffers = (tensors: readonly SerializableTensor[]): ArrayBufferLike[] => {
const buffers: ArrayBufferLike[] = [];
for (const tensor of tensors) {
Expand Down
2 changes: 1 addition & 1 deletion js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async function initializeSession(
preloadModelData ? ` [preloaded(${preloadModelData.byteLength})]` : ''}`);

const profilerConfig = profile ? {maxNumberEvents: 65536} : undefined;
const sessionConfig = {executionProviders: [backendHint], profiler: profilerConfig};
const sessionConfig = {executionProviders: [backendHint], profiler: profilerConfig, enableProfiling: profile};
let session: ort.InferenceSession;

try {
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/common/profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ void Profiler::StartProfiling(const logging::Logger* custom_logger) {
template <typename T>
void Profiler::StartProfiling(const std::basic_string<T>& file_name) {
enabled_ = true;
#if !defined(__wasm__)
profile_stream_.open(file_name, std::ios::out | std::ios::trunc);
#endif
profile_stream_file_ = ToMBString(file_name);
profiling_start_time_ = StartTime();
}
Expand Down Expand Up @@ -129,7 +131,9 @@ std::string Profiler::EndProfiling() {
}
}
profile_stream_ << "]\n";
#if !defined(__wasm__)
profile_stream_.close();
#endif
enabled_ = false; // will not collect profile after writing.
return profile_stream_file_;
}
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/common/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,16 @@ class Profiler {
// Mutex controlling access to profiler data
OrtMutex mutex_;
bool enabled_{false};
#if defined(__wasm__)
/*
* The simplest way to emit profiling data in WebAssembly is to print out to console,
* since browsers can't access to a file system directly.
* TODO: Consider MEMFS or IndexedDB instead of console.
*/
std::ostream& profile_stream_{std::cout};
#else
std::ofstream profile_stream_;
#endif
std::string profile_stream_file_;
const logging::Logger* session_logger_{nullptr};
const logging::Logger* custom_logger_{nullptr};
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/test/framework/inference_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,8 @@ TEST(InferenceSessionTests, CheckRunLogger) {
#endif
}

// WebAssembly will emit profiling data into console
#if !defined(__wasm__)
TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions) {
SessionOptions so;

Expand Down Expand Up @@ -686,6 +688,7 @@ TEST(InferenceSessionTests, CheckRunProfilerWithStartProfile) {
count++;
}
}
#endif // __wasm__

TEST(InferenceSessionTests, CheckRunProfilerStartTime) {
// Test whether the InferenceSession can access the profiler's start time
Expand Down
19 changes: 17 additions & 2 deletions onnxruntime/wasm/api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ OrtSessionOptions* OrtCreateSessionOptions(size_t graph_optimization_level,
bool enable_cpu_mem_arena,
bool enable_mem_pattern,
size_t execution_mode,
bool /* enable_profiling */,
const char* /* profile_file_prefix */,
bool enable_profiling,
const char* /*profile_file_prefix*/,
const char* log_id,
size_t log_severity_level,
size_t log_verbosity_level) {
Expand Down Expand Up @@ -93,6 +93,11 @@ OrtSessionOptions* OrtCreateSessionOptions(size_t graph_optimization_level,
RETURN_NULLPTR_IF_ERROR(SetSessionExecutionMode, session_options, static_cast<ExecutionMode>(execution_mode));

// TODO: support profling
if (enable_profiling) {
RETURN_NULLPTR_IF_ERROR(EnableProfiling, session_options, "");
} else {
RETURN_NULLPTR_IF_ERROR(DisableProfiling, session_options);
}

if (log_id != nullptr) {
RETURN_NULLPTR_IF_ERROR(SetSessionLogId, session_options, log_id);
Expand Down Expand Up @@ -345,3 +350,13 @@ int OrtRun(OrtSession* session,
OrtRunOptions* run_options) {
return CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs);
}

char* OrtEndProfiling(ort_session_handle_t session) {
OrtAllocator* allocator = nullptr;
RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator);

char* file_name = nullptr;
return (CHECK_STATUS(SessionEndProfiling, session, allocator, &file_name) == ORT_OK)
? file_name
: nullptr;
}
10 changes: 9 additions & 1 deletion onnxruntime/wasm/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ int EMSCRIPTEN_KEEPALIVE OrtInit(int num_threads, int logging_level);
* @param enable_cpu_mem_arena enable or disable cpu memory arena
* @param enable_mem_pattern enable or disable memory pattern
* @param execution_mode sequential or parallel execution mode
* @param enable_profiling enable or disable profiling. it's a no-op and for a future use.
* @param enable_profiling enable or disable profiling.
* @param profile_file_prefix file prefix for profiling data. it's a no-op and for a future use.
* @param log_id logger id for session output
* @param log_severity_level verbose, info, warning, error or fatal
Expand Down Expand Up @@ -185,4 +185,12 @@ int EMSCRIPTEN_KEEPALIVE OrtRun(ort_session_handle_t session,
size_t output_count,
ort_tensor_handle_t* outputs,
ort_run_options_handle_t run_options);

/**
* end profiling.
* @param session handle of the specified session
* @returns a pointer to a buffer which contains C-style string of profile filename.
* Caller must release the C style string after use by calling OrtFree().
*/
char* EMSCRIPTEN_KEEPALIVE OrtEndProfiling(ort_session_handle_t session);
};

0 comments on commit 4505243

Please sign in to comment.