Skip to content

Commit

Permalink
[WebGPU] Use a per-context staging buffer
Browse files Browse the repository at this point in the history
This fixes the generator_aot_gpu_multi_context_threaded tests.
  • Loading branch information
jrprice committed Mar 1, 2023
1 parent 26ef6d7 commit 2e88fe2
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 54 deletions.
86 changes: 45 additions & 41 deletions src/runtime/webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ WEAK WGPUDevice global_device = nullptr;
// Lock to synchronize access to the global WebGPU context.
volatile ScopedSpinLock::AtomicFlag WEAK context_lock = 0;

// Size of the staging buffer used for host<->device copies.
constexpr int kWebGpuStagingBufferSize = 4 * 1024 * 1024;
// A staging buffer used for host<->device copies.
WGPUBuffer WEAK staging_buffer = nullptr;
WEAK WGPUBuffer global_staging_buffer = nullptr;

// A flag to signify that the WebGPU device was lost.
bool device_was_lost = false;
Expand Down Expand Up @@ -72,6 +70,7 @@ WEAK int halide_webgpu_acquire_context(void *user_context,
WGPUInstance *instance_ret,
WGPUAdapter *adapter_ret,
WGPUDevice *device_ret,
WGPUBuffer *staging_buffer_ret,
bool create = true) {
halide_abort_if_false(user_context, &context_lock != nullptr);
while (__atomic_test_and_set(&context_lock, __ATOMIC_ACQUIRE)) {
Expand All @@ -92,6 +91,7 @@ WEAK int halide_webgpu_acquire_context(void *user_context,
*instance_ret = global_instance;
*adapter_ret = global_adapter;
*device_ret = global_device;
*staging_buffer_ret = global_staging_buffer;

return halide_error_code_success;
}
Expand All @@ -117,12 +117,16 @@ class WgpuContext {
WGPUAdapter adapter = nullptr;
WGPUDevice device = nullptr;
WGPUQueue queue = nullptr;

// A staging buffer used for host<->device copies.
WGPUBuffer staging_buffer = nullptr;

int error_code = 0;

ALWAYS_INLINE WgpuContext(void *user_context)
: user_context(user_context) {
error_code = halide_webgpu_acquire_context(
user_context, &instance, &adapter, &device);
user_context, &instance, &adapter, &device, &staging_buffer);
if (error_code == halide_error_code_success) {
queue = wgpuDeviceGetQueue(device);
}
Expand Down Expand Up @@ -157,7 +161,7 @@ class ErrorScope {

// Wait for all error callbacks in this scope to fire.
// Returns the error code (or success).
int wait() {
halide_error_code_t wait() {
if (callbacks_remaining == 0) {
error(user_context) << "no outstanding error scopes\n";
return halide_error_code_internal_error;
Expand All @@ -180,7 +184,7 @@ class ErrorScope {
WGPUDevice device;

// The error code reported by the callback functions.
volatile int error_code;
volatile halide_error_code_t error_code;

// Used to track outstanding error callbacks.
volatile int callbacks_remaining = 0;
Expand Down Expand Up @@ -249,6 +253,24 @@ void request_device_callback(WGPURequestDeviceStatus status,
device_was_lost = false;
wgpuDeviceSetDeviceLostCallback(device, device_lost_callback, user_context);
global_device = device;

// Create a staging buffer for transfers.
constexpr int kStagingBufferSize = 4 * 1024 * 1024;
WGPUBufferDescriptor desc{};
desc.nextInChain = nullptr;
desc.label = nullptr;
desc.usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead;
desc.size = kStagingBufferSize;
desc.mappedAtCreation = false;

ErrorScope error_scope(user_context, global_device);
global_staging_buffer = wgpuDeviceCreateBuffer(device, &desc);

halide_error_code_t error_code = error_scope.wait();
if (error_code != halide_error_code_success) {
global_staging_buffer = nullptr;
init_error_code = error_code;
}
}

void request_adapter_callback(WGPURequestAdapterStatus status,
Expand Down Expand Up @@ -375,26 +397,6 @@ WEAK int halide_webgpu_device_malloc(void *user_context, halide_buffer_t *buf) {
return error_code;
}

if (staging_buffer == nullptr) {
ErrorScope error_scope(user_context, context.device);

// Create a staging buffer for transfers if we haven't already.
WGPUBufferDescriptor desc{};
desc.nextInChain = nullptr;
desc.label = nullptr;
desc.usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead;
desc.size = kWebGpuStagingBufferSize;
desc.mappedAtCreation = false;

staging_buffer = wgpuDeviceCreateBuffer(context.device, &desc);

int error_code = error_scope.wait();
if (error_code != halide_error_code_success) {
staging_buffer = nullptr;
return error_code;
}
}

buf->device = (uint64_t)device_handle;
buf->device_interface = &webgpu_device_interface;
buf->device_interface->impl->use_module();
Expand Down Expand Up @@ -482,8 +484,9 @@ WEAK int halide_webgpu_device_release(void *user_context) {
WGPUInstance instance;
WGPUAdapter adapter;
WGPUDevice device;
WGPUBuffer staging_buffer;
err = halide_webgpu_acquire_context(user_context,
&instance, &adapter, &device, false);
&instance, &adapter, &device, &staging_buffer, false);
if (err != halide_error_code_success) {
return err;
}
Expand All @@ -492,13 +495,13 @@ WEAK int halide_webgpu_device_release(void *user_context) {
shader_cache.delete_context(user_context, device,
wgpuShaderModuleRelease);

if (staging_buffer) {
wgpuBufferRelease(staging_buffer);
staging_buffer = nullptr;
}

// Release the device/adapter/instance, if we created them.
// Release the device/adapter/instance/staging_buffer, if we created them.
if (device == global_device) {
if (staging_buffer) {
wgpuBufferRelease(staging_buffer);
global_staging_buffer = nullptr;
}

wgpuDeviceSetDeviceLostCallback(device, nullptr, nullptr);
wgpuDeviceRelease(device);
global_device = nullptr;
Expand Down Expand Up @@ -538,9 +541,9 @@ namespace {
int do_copy_to_host(void *user_context, WgpuContext *context, uint8_t *dst,
WGPUBuffer src, int64_t src_offset, int64_t size) {
// Copy chunks via the staging buffer.
for (int64_t offset = 0; offset < size;
offset += kWebGpuStagingBufferSize) {
int64_t num_bytes = kWebGpuStagingBufferSize;
int64_t staging_buffer_size = wgpuBufferGetSize(context->staging_buffer);
for (int64_t offset = 0; offset < size; offset += staging_buffer_size) {
int64_t num_bytes = staging_buffer_size;
if (offset + num_bytes > size) {
num_bytes = size - offset;
}
Expand All @@ -549,7 +552,8 @@ int do_copy_to_host(void *user_context, WgpuContext *context, uint8_t *dst,
WGPUCommandEncoder encoder =
wgpuDeviceCreateCommandEncoder(context->device, nullptr);
wgpuCommandEncoderCopyBufferToBuffer(encoder, src, src_offset + offset,
staging_buffer, 0, num_bytes);
context->staging_buffer,
0, num_bytes);
WGPUCommandBuffer command_buffer =
wgpuCommandEncoderFinish(encoder, nullptr);
wgpuQueueSubmit(context->queue, 1, &command_buffer);
Expand All @@ -563,7 +567,7 @@ int do_copy_to_host(void *user_context, WgpuContext *context, uint8_t *dst,
// Map the staging buffer for reading.
__atomic_test_and_set(&result.map_complete, __ATOMIC_RELAXED);
wgpuBufferMapAsync(
staging_buffer, WGPUMapMode_Read, 0, num_bytes,
context->staging_buffer, WGPUMapMode_Read, 0, num_bytes,
[](WGPUBufferMapAsyncStatus status, void *userdata) {
BufferMapResult *result = (BufferMapResult *)userdata;
result->map_status = status;
Expand All @@ -581,10 +585,10 @@ int do_copy_to_host(void *user_context, WgpuContext *context, uint8_t *dst,
}

// Copy the data from the mapped staging buffer to the host allocation.
const void *src = wgpuBufferGetConstMappedRange(staging_buffer, 0,
num_bytes);
const void *src = wgpuBufferGetConstMappedRange(context->staging_buffer,
0, num_bytes);
memcpy(dst + offset, src, num_bytes);
wgpuBufferUnmap(staging_buffer);
wgpuBufferUnmap(context->staging_buffer);
}

return halide_error_code_success;
Expand Down
21 changes: 19 additions & 2 deletions test/common/gpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,12 @@ void emscripten_sleep(unsigned int ms);
#endif
}

inline bool create_webgpu_context(WGPUInstance *instance_out, WGPUAdapter *adapter_out, WGPUDevice *device_out) {
inline bool create_webgpu_context(WGPUInstance *instance_out, WGPUAdapter *adapter_out, WGPUDevice *device_out, WGPUBuffer *staging_buffer_out) {
struct Results {
WGPUInstance instance = nullptr;
WGPUAdapter adapter = nullptr;
WGPUDevice device = nullptr;
WGPUBuffer staging_buffer = nullptr;
bool success = true;
} results;

Expand Down Expand Up @@ -246,6 +247,20 @@ inline bool create_webgpu_context(WGPUInstance *instance_out, WGPUAdapter *adapt
abort();
};
wgpuDeviceSetDeviceLostCallback(device, device_lost_callback, userdata);

// Create a staging buffer for transfers.
constexpr int kStagingBufferSize = 4 * 1024 * 1024;
WGPUBufferDescriptor desc{};
desc.nextInChain = nullptr;
desc.label = nullptr;
desc.usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead;
desc.size = kStagingBufferSize;
desc.mappedAtCreation = false;
results->staging_buffer = wgpuDeviceCreateBuffer(device, &desc);
if (results->staging_buffer == nullptr) {
results->success = false;
return;
}
};

wgpuAdapterRequestDevice(adapter, &desc, request_device_callback, userdata);
Expand All @@ -267,11 +282,13 @@ inline bool create_webgpu_context(WGPUInstance *instance_out, WGPUAdapter *adapt
*instance_out = results.instance;
*adapter_out = results.adapter;
*device_out = results.device;
*staging_buffer_out = results.staging_buffer;
return results.success;
}

inline void destroy_webgpu_context(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device) {
inline void destroy_webgpu_context(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, WGPUBuffer staging_buffer) {
wgpuDeviceSetDeviceLostCallback(device, nullptr, nullptr);
wgpuBufferRelease(staging_buffer);
wgpuDeviceRelease(device);
wgpuAdapterRelease(adapter);

Expand Down
8 changes: 6 additions & 2 deletions test/generator/acquire_release_aottest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,27 +110,31 @@ struct gpu_context {
WGPUInstance instance = nullptr;
WGPUAdapter adapter = nullptr;
WGPUDevice device = nullptr;
WGPUBuffer staging_buffer = nullptr;
} webgpu_context;

bool init_context() {
return create_webgpu_context(&webgpu_context.instance, &webgpu_context.adapter, &webgpu_context.device);
return create_webgpu_context(&webgpu_context.instance, &webgpu_context.adapter, &webgpu_context.device, &webgpu_context.staging_buffer);
}

void destroy_context() {
destroy_webgpu_context(webgpu_context.instance, webgpu_context.adapter, webgpu_context.device);
destroy_webgpu_context(webgpu_context.instance, webgpu_context.adapter, webgpu_context.device, webgpu_context.staging_buffer);
webgpu_context.instance = nullptr;
webgpu_context.adapter = nullptr;
webgpu_context.device = nullptr;
webgpu_context.staging_buffer = nullptr;
}

extern "C" int halide_webgpu_acquire_context(void *user_context,
WGPUInstance *instance_ret,
WGPUAdapter *adapter_ret,
WGPUDevice *device_ret,
WGPUBuffer* staging_buffer_ret,
bool create) {
*instance_ret = webgpu_context.instance;
*adapter_ret = webgpu_context.adapter;
*device_ret = webgpu_context.device;
*staging_buffer_ret = webgpu_context.staging_buffer;
return 0;
}

Expand Down
16 changes: 7 additions & 9 deletions test/generator/gpu_multi_context_threaded_aottest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,50 +141,48 @@ struct gpu_context {
WGPUInstance instance = nullptr;
WGPUAdapter adapter = nullptr;
WGPUDevice device = nullptr;
WGPUBuffer staging_buffer = nullptr;
};

bool init_context(gpu_context &ctx) {
return create_webgpu_context(&ctx.instance, &ctx.adapter, &ctx.device);
return create_webgpu_context(&ctx.instance, &ctx.adapter, &ctx.device, &ctx.staging_buffer);
}

void destroy_context(gpu_context &ctx) {
destroy_webgpu_context(ctx.instance, ctx.adapter, ctx.device);
destroy_webgpu_context(ctx.instance, ctx.adapter, ctx.device, ctx.staging_buffer);
ctx.instance = nullptr;
ctx.adapter = nullptr;
ctx.device = nullptr;
ctx.staging_buffer = nullptr;
}

char context_lock = 0;

// These functions replace the acquire/release implementation in src/runtime/webgpu.cpp.
// Since we don't parallelize access to the GPU in the schedule, we don't need synchronization
// in our implementation of these functions.
extern "C" int halide_webgpu_acquire_context(void *user_context,
WGPUInstance *instance_ret,
WGPUAdapter *adapter_ret,
WGPUDevice *device_ret,
WGPUBuffer *staging_buffer_ret,
bool create) {
while (__atomic_test_and_set(&context_lock, __ATOMIC_ACQUIRE)) {
// nothing
}

if (user_context == nullptr) {
assert(!create);
*instance_ret = nullptr;
*adapter_ret = nullptr;
*device_ret = nullptr;
*staging_buffer_ret = nullptr;
return -1;
} else {
gpu_context *context = (gpu_context *)user_context;
*instance_ret = context->instance;
*adapter_ret = context->adapter;
*device_ret = context->device;
*staging_buffer_ret = context->staging_buffer;
}
return 0;
}

extern "C" int halide_webgpu_release_context(void *user_context) {
__atomic_clear(&context_lock, __ATOMIC_RELEASE);
return 0;
}

Expand Down

0 comments on commit 2e88fe2

Please sign in to comment.