Skip to content

Commit

Permalink
[Fix] more restrictive memory order for EpochLfu
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Aug 16, 2023
1 parent 9a038d5 commit 0816cf3
Show file tree
Hide file tree
Showing 2 changed files with 325 additions and 4 deletions.
8 changes: 4 additions & 4 deletions include/merlin/core_kernels/kernel_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,9 @@ struct ScoreFunctor<K, V, S, EvictStrategyInternal::kEpochLru> {
template <class K, class V, class S>
struct ScoreFunctor<K, V, S, EvictStrategyInternal::kEpochLfu> {
static constexpr cuda::std::memory_order LOCK_MEM_ORDER =
cuda::std::memory_order_relaxed;
cuda::std::memory_order_acquire;
static constexpr cuda::std::memory_order UNLOCK_MEM_ORDER =
cuda::std::memory_order_relaxed;
cuda::std::memory_order_release;

__forceinline__ __device__ static S desired_when_missed(
const S* __restrict const input_scores, const int key_idx,
Expand Down Expand Up @@ -331,9 +331,9 @@ struct ScoreFunctor<K, V, S, EvictStrategyInternal::kEpochLfu> {
template <class K, class V, class S>
struct ScoreFunctor<K, V, S, EvictStrategyInternal::kCustomized> {
static constexpr cuda::std::memory_order LOCK_MEM_ORDER =
cuda::std::memory_order_relaxed;
cuda::std::memory_order_acquire;
static constexpr cuda::std::memory_order UNLOCK_MEM_ORDER =
cuda::std::memory_order_relaxed;
cuda::std::memory_order_release;

__forceinline__ __device__ static S desired_when_missed(
const S* __restrict const input_scores, const int key_idx,
Expand Down
321 changes: 321 additions & 0 deletions tests/insert_and_evict_test.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,323 @@ void test_insert_and_evict_run_with_batch_find() {
CUDA_CHECK(cudaStreamDestroy(find_stream));
}

__global__ void k_count_missing(int64_t total_keys, const bool* d_founds,
int* d_missing_count) {
int idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx < total_keys && d_founds[idx] == 0) {
atomicAdd(d_missing_count, 1);
}
}

void test_multi_stream_multi_threads() {
using namespace nv::merlin;
using namespace std;
using K = uint64_t;
using S = uint64_t;
using V = float;
#define DIM 20
TableOptions options;
options.init_capacity = 1 * 1024 * 1024UL;
options.max_capacity = 16 * 1024 * 1024UL;
options.dim = DIM;
options.max_hbm_for_vectors = nv::merlin::GB(40);
// options.evict_strategy = EvictStrategy::kLru;

int QUEUE_SIZE = 4;
int LAG_STEP = QUEUE_SIZE - 1;
int TEST_TIMES = 500;
size_t hot_key_length = 8 * 1024 * 1024UL;
size_t cold_key_length = 64 * 1024 * 1024UL;
size_t input_key_length = 1 * 1024 * 1024UL;

// MerlinHashTable table;

using Table = nv::merlin::HashTable<K, V, S, EvictStrategy::kEpochLfu>;
std::shared_ptr<Table> table = std::make_shared<Table>();
table->init(options);

std::random_device dev;
std::mt19937 mt(dev());
std::uniform_int_distribution<uint64_t> dist(0, 1ul << 60);

// gen hot keys
K* h_hot_keys;
K* d_hot_keys;
S* h_hot_scores;
S* d_hot_scores;
V* h_hot_values;
V* d_hot_values;

h_hot_keys = static_cast<K*>(std::malloc(hot_key_length * sizeof(K)));
h_hot_scores = static_cast<S*>(std::malloc(hot_key_length * sizeof(S)));
h_hot_values = static_cast<V*>(std::malloc(hot_key_length * sizeof(V) * DIM));
CUDA_CHECK(cudaMalloc(&d_hot_keys, hot_key_length * sizeof(K)));
CUDA_CHECK(cudaMalloc(&d_hot_scores, hot_key_length * sizeof(S)));
CUDA_CHECK(cudaMalloc(&d_hot_values, hot_key_length * sizeof(V) * DIM));
for (int i = 0; i < hot_key_length; i++) {
h_hot_keys[i] = dist(mt) + (2ul << 60);
h_hot_scores[i] = dist(mt) % 1000 + 1000;
for (int j = 0; j < DIM; j++) {
h_hot_values[i * DIM + j] =
static_cast<float>(rand()) / static_cast<float>(RAND_MAX);
}
}
cudaMemcpy(d_hot_keys, h_hot_keys, hot_key_length * sizeof(K),
cudaMemcpyHostToDevice);
cudaMemcpy(d_hot_scores, h_hot_scores, hot_key_length * sizeof(S),
cudaMemcpyHostToDevice);
cudaMemcpy(d_hot_values, h_hot_values, hot_key_length * sizeof(V) * DIM,
cudaMemcpyHostToDevice);

// gen cold keys
K* h_cold_keys;
K* d_cold_keys;
S* h_cold_scores;
S* d_cold_scores;
V* h_cold_values;
V* d_cold_values;

h_cold_keys = static_cast<K*>(std::malloc(cold_key_length * sizeof(K)));
h_cold_scores = static_cast<S*>(std::malloc(cold_key_length * sizeof(S)));
h_cold_values =
static_cast<V*>(std::malloc(cold_key_length * sizeof(V) * DIM));
CUDA_CHECK(cudaMalloc(&d_cold_keys, cold_key_length * sizeof(K)));
CUDA_CHECK(cudaMalloc(&d_cold_scores, cold_key_length * sizeof(S)));
CUDA_CHECK(cudaMalloc(&d_cold_values, cold_key_length * sizeof(V) * DIM));
for (int i = 0; i < cold_key_length; i++) {
h_cold_keys[i] = dist(mt) + (3ul << 60);
h_cold_scores[i] = dist(mt) % 1000 + 0;
for (int j = 0; j < DIM; j++) {
h_cold_values[i * DIM + j] =
static_cast<float>(rand()) / static_cast<float>(RAND_MAX);
}
}
cudaMemcpy(d_cold_keys, h_cold_keys, cold_key_length * sizeof(K),
cudaMemcpyHostToDevice);
cudaMemcpy(d_cold_scores, h_cold_scores, cold_key_length * sizeof(S),
cudaMemcpyHostToDevice);
cudaMemcpy(d_cold_values, h_cold_values, cold_key_length * sizeof(V) * DIM,
cudaMemcpyHostToDevice);

// gen input and output
K* d_input_keys_list[QUEUE_SIZE];
S* d_input_scores_list[QUEUE_SIZE];
V* d_input_values_list[QUEUE_SIZE];
S* d_output_scores_list[QUEUE_SIZE];
V* d_output_values_list[QUEUE_SIZE];
bool* d_founds_list[QUEUE_SIZE];

K* d_evicted_keys_list[QUEUE_SIZE];
S* d_evicted_scores_list[QUEUE_SIZE];
V* d_evicted_values_list[QUEUE_SIZE];

for (int i = 0; i < QUEUE_SIZE; i++) {
CUDA_CHECK(cudaMalloc(&d_input_keys_list[i], input_key_length * sizeof(K)));
CUDA_CHECK(
cudaMalloc(&d_input_scores_list[i], input_key_length * sizeof(S)));
CUDA_CHECK(cudaMalloc(&d_input_values_list[i],
input_key_length * DIM * sizeof(V)));
CUDA_CHECK(
cudaMalloc(&d_output_scores_list[i], input_key_length * sizeof(S)));
CUDA_CHECK(cudaMalloc(&d_output_values_list[i],
input_key_length * DIM * sizeof(V)));
CUDA_CHECK(cudaMalloc(&d_founds_list[i], input_key_length * sizeof(bool)));

CUDA_CHECK(
cudaMalloc(&d_evicted_keys_list[i], input_key_length * sizeof(K)));
CUDA_CHECK(
cudaMalloc(&d_evicted_scores_list[i], input_key_length * sizeof(S)));
CUDA_CHECK(cudaMalloc(&d_evicted_values_list[i],
input_key_length * DIM * sizeof(V)));

CUDA_CHECK(
cudaMemset(d_founds_list[i], 0, input_key_length * sizeof(bool)));
CUDA_CHECK(cudaMemset(d_input_values_list[i], 0,
input_key_length * DIM * sizeof(V)));
CUDA_CHECK(cudaMemset(d_evicted_values_list[i], 0,
input_key_length * DIM * sizeof(V)));
}
cudaDeviceSynchronize();

std::vector<std::atomic<int>> tokens(QUEUE_SIZE);
for (auto& t : tokens) {
t.store(0);
}
std::atomic<int> global_epoch(0);
std::atomic<int> global_index(-1);
std::mutex table_lock;

auto lookup_func = [&]() {
cudaStream_t stream;
cudaStreamCreate(&stream);
cudaStreamSynchronize(stream);

int* d_missing_count;
int h_missing_count;
CUDA_CHECK(cudaMalloc(&d_missing_count, sizeof(int)));
int step = 0;
int index = 0;
cout << "lookup_func thread " << index << " start" << endl;
while (true) {
while (global_index.load() != -1) {
std::this_thread::yield();
}
index = step % QUEUE_SIZE;
auto* d_input_keys = d_input_keys_list[index];
auto* d_input_values = d_input_values_list[index];
auto* d_input_scores = d_input_scores_list[index];
auto* d_output_values = d_input_values_list[index];
auto* d_founds = d_founds_list[index];
auto* d_evicted_keys = d_evicted_keys_list[index];
auto* d_evicted_values = d_evicted_values_list[index];
auto* d_evicted_scores = d_evicted_scores_list[index];

global_epoch++;

int hot_length = 900 * 1024;
int cold_length = input_key_length - hot_length;
size_t hot_start_index = dist(mt) % (hot_key_length - hot_length);
cudaMemcpyAsync(d_input_keys, d_hot_keys + hot_start_index,
hot_length * sizeof(K), cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(d_input_scores, d_hot_scores + hot_start_index,
hot_length * sizeof(S), cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(d_input_values, d_hot_values + hot_start_index * DIM,
hot_length * sizeof(V) * DIM, cudaMemcpyDeviceToDevice,
stream);
size_t cold_start_index = dist(mt) % (cold_key_length - hot_length);
cudaMemcpyAsync(d_input_keys + hot_length, d_cold_keys + cold_start_index,
cold_length * sizeof(K), cudaMemcpyDeviceToDevice,
stream);
cudaMemcpyAsync(d_input_scores + hot_length,
d_cold_scores + cold_start_index, cold_length * sizeof(S),
cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(d_input_values + hot_length * DIM,
d_cold_values + cold_start_index * DIM,
cold_length * sizeof(V) * DIM, cudaMemcpyDeviceToDevice,
stream);

cudaMemsetAsync(d_founds, 0, input_key_length * sizeof(bool));
cudaMemsetAsync(d_missing_count, 0, sizeof(int));

table_lock.lock();
table->find(input_key_length, d_input_keys, d_output_values, d_founds,
nullptr, stream);
size_t evict_count = table->insert_and_evict(
input_key_length, d_input_keys, d_input_values, d_input_scores,
d_evicted_keys, d_evicted_values, d_evicted_scores, stream,
global_epoch, false);
table_lock.unlock();

k_count_missing<<<1024, 1024>>>(input_key_length, d_founds,
d_missing_count);
cudaMemcpyAsync(&h_missing_count, d_missing_count, sizeof(int),
cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);

cout << "lookup_func step " << step << " key_epoch: " << global_epoch
<< ", h_missing_count: " << h_missing_count
<< " evict_count: " << evict_count
<< ", table size: " << table->size(stream) << endl;

if (step >= LAG_STEP) {
while (global_index.load() != -1) {
std::this_thread::yield();
}
global_index.store((step - LAG_STEP) % QUEUE_SIZE);
}
step++;
if (step >= (TEST_TIMES + LAG_STEP)) break;
}
cudaStreamDestroy(stream);
};

auto assign_func = [&]() {
cudaStream_t stream;
cudaStreamCreate(&stream);
cudaStreamSynchronize(stream);

int* d_missing_count;
int h_missing_count;
CUDA_CHECK(cudaMalloc(&d_missing_count, sizeof(int)));

K* h_keys = static_cast<K*>(std::malloc(input_key_length * sizeof(K)));
S* h_scores = static_cast<S*>(std::malloc(input_key_length * sizeof(S)));
bool* h_founds =
static_cast<bool*>(std::malloc(input_key_length * sizeof(bool)));
int index = 0;
int step = 0;
cout << "assign_func thread " << index << " start" << endl;
while (true) {
while (global_index.load() == -1) {
std::this_thread::yield();
}
index = global_index.load();
auto* d_input_keys = d_input_keys_list[index];
auto* d_input_values = d_input_values_list[index];
auto* d_input_scores = d_input_scores_list[index];
auto* d_output_values = d_input_values_list[index];
auto* d_founds = d_founds_list[index];
cudaMemsetAsync(d_founds, 0, input_key_length * sizeof(bool));
cudaMemsetAsync(d_missing_count, 0, sizeof(int));
table_lock.lock();
table->find(input_key_length, d_input_keys, d_output_values, d_founds,
nullptr, stream);
table->assign(input_key_length, d_input_keys, d_input_values,
d_input_scores, stream, global_epoch);
table_lock.unlock();

k_count_missing<<<1024, 1024>>>(input_key_length, d_founds,
d_missing_count);
cudaMemcpyAsync(&h_missing_count, d_missing_count, sizeof(int),
cudaMemcpyDeviceToHost, stream);

cudaMemcpyAsync(h_keys, d_input_keys, input_key_length * sizeof(K),
cudaMemcpyDeviceToHost, stream);
cudaMemcpyAsync(h_scores, d_input_scores, input_key_length * sizeof(S),
cudaMemcpyDeviceToHost, stream);
cudaMemcpyAsync(h_founds, d_founds, input_key_length * sizeof(bool),
cudaMemcpyDeviceToHost, stream);

cudaStreamSynchronize(stream);
for (int i = 0; i < input_key_length; i++) {
if (!h_founds[i]) {
cout << "key:" << h_keys[i] << ", score:" << h_scores[i] << endl;
}
}
ASSERT_EQ(h_missing_count, 0);

cout << "assign_func step " << step << " key_epoch: " << global_epoch
<< " h_missing_count: " << h_missing_count
<< ", table size: " << table->size(stream) << endl;
usleep(5e5);
global_index.store(-1);
step++;
if (step >= TEST_TIMES) break;
}
free(h_keys);
free(h_scores);
free(h_founds);
CUDA_CHECK(cudaFree(d_missing_count));
cudaStreamDestroy(stream);
};

cout << "start." << endl;
std::vector<std::thread> threads;
threads.emplace_back(lookup_func);
threads.emplace_back(assign_func);
for (int i = 0; i < threads.size(); i++) {
threads[i].join();
}

free(h_cold_keys);
free(h_cold_scores);
free(h_cold_values);
CUDA_CHECK(cudaFree(d_cold_keys));
CUDA_CHECK(cudaFree(d_cold_scores));
CUDA_CHECK(cudaFree(d_cold_values));
cout << "end." << endl;
}

TEST(InsertAndEvictTest, test_insert_and_evict_basic) {
test_insert_and_evict_basic();
}
Expand Down Expand Up @@ -1675,3 +1992,7 @@ TEST(InsertAndEvictTest, test_insert_and_evict_with_export_batch) {
TEST(InsertAndEvictTest, test_insert_and_evict_run_with_batch_find) {
test_insert_and_evict_run_with_batch_find();
}

TEST(InsertAndEvictTest, test_multi_stream_multi_threads) {
test_multi_stream_multi_threads();
}

0 comments on commit 0816cf3

Please sign in to comment.