Skip to content

Commit

Permalink
[NPU] add support for launching host callback (#777)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 authored Sep 6, 2023
1 parent 9cd6f47 commit 0a9d29c
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
4 changes: 4 additions & 0 deletions backends/npu/runtime/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ void SecondaryStream::Create(aclrtStream aicore_stream) {

void SecondaryStream::Destroy(aclrtStream aicore_stream) {
RUN_CHECK(aicpu_streams.find(aicore_stream) != aicpu_streams.cend());
HostCallbackManager::Instance().ReleaseProcessWorker(
aicpu_streams[aicore_stream]);
ACL_CHECK(aclrtDestroyStream(aicpu_streams[aicore_stream]));
aicpu_streams.erase(aicore_stream);
}
Expand Down Expand Up @@ -246,6 +248,7 @@ C_Status ReleaseDevice(const C_Device device) {
}

C_Status Finalize() {
HostCallbackManager::Instance().ReleaseAllProcessWorkers();
if (global_allocator_list) {
delete global_allocator_list;
global_allocator_list = nullptr;
Expand Down Expand Up @@ -380,6 +383,7 @@ C_Status CreateStream(const C_Device device, C_Stream *stream) {
}

C_Status DestroyStream(const C_Device device, C_Stream stream) {
HostCallbackManager::Instance().ReleaseProcessWorker(stream);
ACL_CHECK(aclrtDestroyStream(reinterpret_cast<aclrtStream>(stream)));
SecondaryStream::Instance().Destroy(reinterpret_cast<aclrtStream>(stream));
return C_SUCCESS;
Expand Down
98 changes: 98 additions & 0 deletions backends/npu/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#include <hccl/hccl.h>
#include <hccl/hccl_types.h>

#include <functional>
#include <thread>

#include "paddle/phi/extension.h"

#define RUNTIME_CHECK(func, success) \
Expand Down Expand Up @@ -208,3 +211,98 @@ struct SecondaryStream {
SecondaryStream() = default;
std::unordered_map<aclrtStream, aclrtStream> aicpu_streams;
};

struct HostCallbackManager {
static HostCallbackManager &Instance() {
static HostCallbackManager ins;
return ins;
}

void ReleaseProcessWorker(aclrtStream stream) {
std::lock_guard<std::mutex> lock(g_stream_thread_mutex);
if (g_stream_thread_map.find(stream) != g_stream_thread_map.end()) {
std::ostringstream oss;
oss << g_stream_thread_map[stream].get_id();
uint64_t tid = std::stoull(oss.str());
g_stream_thread_is_running_map[stream] = false;
g_stream_thread_map[stream].join();
ACL_CHECK(aclrtUnSubscribeReport(tid, stream));
g_stream_thread_is_running_map.erase(stream);
g_stream_thread_map.erase(stream);
}
}

void ReleaseAllProcessWorkers() {
std::lock_guard<std::mutex> lock(g_stream_thread_mutex);
for (auto it = g_stream_thread_map.begin();
it != g_stream_thread_map.end();) {
std::ostringstream oss;
oss << it->second.get_id();
uint64_t tid = std::stoull(oss.str());
g_stream_thread_is_running_map[it->first] = false;
it->second.join();
ACL_CHECK(aclrtUnSubscribeReport(tid, it->first));
it = g_stream_thread_map.erase(it);
}
g_stream_thread_map.clear();
}

void Launch(
aclrtStream stream,
std::function<void()> callback,
size_t timeout_time = 100 /* ms */,
aclrtCallbackBlockType block_device =
ACL_CALLBACK_BLOCK /* 310 dose not support ACL_CALLBACK_NO_BLOCK */) {
InitProcessWorker(stream, timeout_time);
ACL_CHECK(aclrtLaunchCallback(
HostCallbackManager::CallbackWrapper,
reinterpret_cast<void *>(
new std::function<void()>([callback] { callback(); })),
block_device,
stream));
}

void LaunchNonBlockingDevice(aclrtStream stream,
std::function<void()> callback,
size_t timeout_time = 100) {
Launch(stream, callback, timeout_time, ACL_CALLBACK_NO_BLOCK);
}

private:
static void ProcessCallbackWorker(bool *is_running,
aclrtContext context,
size_t timeout_time) {
ACL_CHECK(aclrtSetCurrentContext(context));
while (*is_running) {
(void)aclrtProcessReport(timeout_time);
}
}

static void CallbackWrapper(void *user_func) {
std::unique_ptr<std::function<void()>> callback(
reinterpret_cast<std::function<void()> *>(user_func));
(*callback)();
}

void InitProcessWorker(aclrtStream stream, size_t timeout_time) {
std::lock_guard<std::mutex> lock(g_stream_thread_mutex);
if (g_stream_thread_map.find(stream) == g_stream_thread_map.end()) {
aclrtContext context;
g_stream_thread_is_running_map[stream] = true;
ACL_CHECK(aclrtGetCurrentContext(&context));
std::thread cb_thread(ProcessCallbackWorker,
&g_stream_thread_is_running_map[stream],
context,
timeout_time);
g_stream_thread_map[stream] = std::move(cb_thread);
std::ostringstream oss;
oss << g_stream_thread_map[stream].get_id();
uint64_t tid = std::stoull(oss.str());
ACL_CHECK(aclrtSubscribeReport(tid, stream));
}
}

std::mutex g_stream_thread_mutex;
std::unordered_map<aclrtStream, std::thread> g_stream_thread_map;
std::unordered_map<aclrtStream, bool> g_stream_thread_is_running_map;
};

0 comments on commit 0a9d29c

Please sign in to comment.