From 1945e0ca63f145b7f252eb0c341fc46bd2a992b3 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 13 Mar 2024 12:48:14 -0400 Subject: [PATCH] [Fix] PagedKVCache fetching compute stream when copy stream is needed This PR fixes an issue in PagedKVCache, where a compute stream will always be fetched. For backends like WebGPU, the `GetCurrentStream` function is not implemented, which leads to an error when fetching the compute stream. --- src/runtime/relax_vm/paged_kv_cache.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index fb22d20fcfc7..651fd4964c47 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -439,12 +439,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { free_page_ids_.push_back(page_id); } - // The compute stream is the default stream. // If the device is CUDA/ROCm, we create a standalone copy stream, in // purpose to hide the latency of auxiliary stream copy. - compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device); if (device.device_type == DLDeviceType::kDLCUDA || device.device_type == DLDeviceType::kDLROCM) { + // The compute stream is the default stream. + compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device); copy_stream_ = DeviceAPI::Get(device)->CreateStream(device); } }