Skip to content

Commit

Permalink
[Fix&Opt] Make group lock mount to cuda stream.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Aug 24, 2023
1 parent 86d9305 commit 9799a9f
Show file tree
Hide file tree
Showing 7 changed files with 550 additions and 257 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ target_compile_features(dynamic_max_capacity_test PUBLIC cxx_std_14)
set_target_properties(dynamic_max_capacity_test PROPERTIES CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(dynamic_max_capacity_test gtest_main)

add_executable(group_lock_test tests/group_lock_test.cc)
add_executable(group_lock_test tests/group_lock_test.cc.cu)
target_compile_features(group_lock_test PUBLIC cxx_std_14)
set_target_properties(group_lock_test PROPERTIES CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(group_lock_test gtest_main)
Expand Down
2 changes: 1 addition & 1 deletion include/merlin/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ cuda_library(
"core_kernels.cuh",
"debug.hpp",
"flexible_buffer.cuh",
"group_lock.hpp",
"group_lock.cuh",
"initializers.cuh",
"memory_pool.cuh",
"optimizers.cuh",
Expand Down
125 changes: 125 additions & 0 deletions include/merlin/core_kernels/group_lock_kernels.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http:///www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <cuda/std/semaphore>

namespace nv {
namespace merlin {
namespace group_lock {

__global__ void init_kernel(
cuda::atomic<int, cuda::thread_scope_device>* writer_count,
cuda::atomic<int, cuda::thread_scope_device>* reader_count,
cuda::atomic<bool, cuda::thread_scope_device>* unique_flag) {
new (writer_count) cuda::atomic<int, cuda::thread_scope_device>{0};
new (reader_count) cuda::atomic<int, cuda::thread_scope_device>{0};
new (unique_flag) cuda::atomic<bool, cuda::thread_scope_device>{false};
}
__global__ void lock_read_kernel(
cuda::atomic<int, cuda::thread_scope_device>* writer_count,
cuda::atomic<int, cuda::thread_scope_device>* reader_count) {
for (;;) {
while (writer_count->load(cuda::std::memory_order_relaxed)) {
}
reader_count->fetch_add(1, cuda::std::memory_order_relaxed);
if (writer_count->load(cuda::std::memory_order_relaxed) == 0) {
break;
}
reader_count->fetch_sub(1, cuda::std::memory_order_relaxed);
}
}

__global__ void unlock_read_kernel(
cuda::atomic<int, cuda::thread_scope_device>* reader_count) {
reader_count->fetch_sub(1, cuda::std::memory_order_relaxed);
}

__global__ void lock_write_kernel(
cuda::atomic<int, cuda::thread_scope_device>* writer_count,
cuda::atomic<int, cuda::thread_scope_device>* reader_count) {
for (;;) {
while (reader_count->load(cuda::std::memory_order_relaxed)) {
}
writer_count->fetch_add(1, cuda::std::memory_order_relaxed);
if (reader_count->load(cuda::std::memory_order_relaxed) == 0) {
break;
}
writer_count->fetch_sub(1, cuda::std::memory_order_relaxed);
}
}

__global__ void unlock_write_kernel(
cuda::atomic<int, cuda::thread_scope_device>* writer_count) {
writer_count->fetch_sub(1, cuda::std::memory_order_relaxed);
}

__global__ void lock_write_read_kernel(
cuda::atomic<int, cuda::thread_scope_device>* writer_count,
cuda::atomic<int, cuda::thread_scope_device>* reader_count,
cuda::atomic<bool, cuda::thread_scope_device>* unique_flag) {
/* Lock unique flag */
bool expected = false;
while (!unique_flag->compare_exchange_weak(expected, true,
cuda::std::memory_order_relaxed)) {
expected = false;
}

/* Ban writer */
for (;;) {
while (writer_count->load(cuda::std::memory_order_relaxed)) {
}
reader_count->fetch_add(1, cuda::std::memory_order_relaxed);
if (writer_count->load(cuda::std::memory_order_relaxed) == 0) {
break;
}
reader_count->fetch_sub(1, cuda::std::memory_order_relaxed);
}

/* Ban reader */
for (;;) {
while (reader_count->load(cuda::std::memory_order_relaxed) > 1) {
}
writer_count->fetch_add(1, cuda::std::memory_order_relaxed);
if (reader_count->load(cuda::std::memory_order_relaxed) == 1) {
break;
}
writer_count->fetch_sub(1, cuda::std::memory_order_relaxed);
}
}

__global__ void unlock_write_read_kernel(
cuda::atomic<int, cuda::thread_scope_device>* writer_count,
cuda::atomic<int, cuda::thread_scope_device>* reader_count,
cuda::atomic<bool, cuda::thread_scope_device>* unique_flag) {
reader_count->fetch_sub(1, cuda::std::memory_order_relaxed);
writer_count->fetch_sub(1, cuda::std::memory_order_relaxed);
unique_flag->store(false, cuda::std::memory_order_relaxed);
}

__global__ void writer_count_kernel(
int* counter, cuda::atomic<int, cuda::thread_scope_device>* writer_count) {
*counter = writer_count->load(cuda::std::memory_order_relaxed);
}

__global__ void reader_count_kernel(
int* counter, cuda::atomic<int, cuda::thread_scope_device>* reader_count) {
*counter = reader_count->load(cuda::std::memory_order_relaxed);
}

} // namespace group_lock
} // namespace merlin
} // namespace nv
Loading

0 comments on commit 9799a9f

Please sign in to comment.