Skip to content

Commit

Permalink
[Feat] Support three new evict strategy(lfu, epoch_lfu, epoch_lru)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Jul 5, 2023
1 parent e69234b commit 931e10f
Show file tree
Hide file tree
Showing 16 changed files with 1,291 additions and 234 deletions.
23 changes: 12 additions & 11 deletions benchmark/merlin_hashtable_benchmark.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@ std::string rep(int n) { return std::string(n, ' '); }
using K = uint64_t;
using S = uint64_t;
using V = float;
using MerlinHashTable = nv::merlin::HashTable<K, V, S>;
using EvictStrategy = nv::merlin::EvictStrategy;
using TableOptions = nv::merlin::HashTableOptions;

float test_one_api(std::shared_ptr<MerlinHashTable>& table,
const API_Select api, const size_t dim,
const size_t init_capacity, const size_t key_num_per_op,
const float load_factor, const float hitrate = 0.6f) {
template <class Table>
float test_one_api(std::shared_ptr<Table>& table, const API_Select api,
const size_t dim, const size_t init_capacity,
const size_t key_num_per_op, const float load_factor,
const float hitrate = 0.6f) {
K* h_keys;
S* h_scores;
V* h_vectors;
Expand Down Expand Up @@ -458,9 +459,9 @@ void test_main(std::vector<API_Select>& apis, const size_t dim,
options.dim = dim;
options.max_hbm_for_vectors = nv::merlin::GB(hbm4values);
options.io_by_cpu = io_by_cpu;
options.evict_strategy = EvictStrategy::kLru;
using Table = nv::merlin::HashTable<K, V, S, EvictStrategy::kLru>;

std::shared_ptr<MerlinHashTable> table = std::make_shared<MerlinHashTable>();
std::shared_ptr<Table> table = std::make_shared<Table>();
table->init(options);

for (float load_factor : load_factors) {
Expand All @@ -472,10 +473,10 @@ void test_main(std::vector<API_Select>& apis, const size_t dim,
CUDA_CHECK(cudaDeviceSynchronize());
// There is a sampling of load_factor after several times call to target
// API. Two consecutive calls can avoid the impact of sampling.
auto res1 = test_one_api(table, api, dim, init_capacity, key_num_per_op,
load_factor);
auto res2 = test_one_api(table, api, dim, init_capacity, key_num_per_op,
load_factor);
auto res1 = test_one_api<Table>(table, api, dim, init_capacity,
key_num_per_op, load_factor);
auto res2 = test_one_api<Table>(table, api, dim, init_capacity,
key_num_per_op, load_factor);
auto res = std::max(res1, res2);
std::cout << "|";
switch (api) {
Expand Down
12 changes: 8 additions & 4 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -656,19 +656,21 @@ __global__ void read_kernel(const V** __restrict src, V* __restrict dst,

/* Accum kernel with customized scores.
*/
template <class K, class V, class S, uint32_t TILE_SIZE = 4>
template <class K, class V, class S, int Strategy, uint32_t TILE_SIZE = 4>
__global__ void accum_kernel(
const Table<K, V, S>* __restrict table, const K* __restrict keys,
V** __restrict vectors, const S* __restrict scores,
const bool* __restrict existed, Bucket<K, V, S>* __restrict buckets,
int* __restrict buckets_size, const size_t bucket_max_size,
const size_t buckets_num, int* __restrict src_offset,
bool* __restrict status, size_t N) {
bool* __restrict status, const S global_epoch, size_t N) {
const size_t dim = table->dim;
size_t tid = (blockIdx.x * blockDim.x) + threadIdx.x;
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());
int rank = g.thread_rank();

using ScoreFunctor = ScoreFunctor<K, V, S, Strategy>;

for (size_t t = tid; t < N; t += blockDim.x * gridDim.x) {
int key_pos = -1;
int local_size = 0;
Expand Down Expand Up @@ -719,7 +721,8 @@ __global__ void accum_kernel(
local_size++;
}
*(vectors + key_idx) = (bucket->vectors + key_pos * dim);
update_score(bucket, key_pos, scores, key_idx);
ScoreFunctor::update_without_missed(bucket, key_pos, scores,
key_idx, global_epoch);
}
}
local_size = g.shfl(local_size, src_lane);
Expand All @@ -736,7 +739,8 @@ __global__ void accum_kernel(
(bucket->keys(key_pos))
->store(insert_key, cuda::std::memory_order_relaxed);
*(vectors + key_idx) = (bucket->vectors + key_pos * dim);
update_score(bucket, key_pos, scores, key_idx);
ScoreFunctor::update_without_missed(bucket, key_pos, scores, key_idx,
global_epoch);
}
refresh_bucket_score<K, V, S, TILE_SIZE>(g, bucket, bucket_max_size);
}
Expand Down
37 changes: 23 additions & 14 deletions include/merlin/core_kernels/find_or_insert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@ namespace merlin {
* find or insert with IO operation. This kernel is
* usually used for the pure HBM mode for better performance.
*/
template <class K, class V, class S, uint32_t TILE_SIZE = 4>
template <class K, class V, class S, int Strategy, uint32_t TILE_SIZE = 4>
__global__ void find_or_insert_kernel_with_io(
const Table<K, V, S>* __restrict table, const size_t bucket_max_size,
const size_t buckets_num, const size_t dim, const K* __restrict keys,
V* __restrict values, S* __restrict scores, const size_t N) {
V* __restrict values, S* __restrict scores, const S global_epoch,
const size_t N) {
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());
int* buckets_size = table->buckets_size;

using ScoreFunctor = ScoreFunctor<K, V, S, Strategy>;

for (size_t t = (blockIdx.x * blockDim.x) + threadIdx.x; t < N;
t += blockDim.x * gridDim.x) {
int key_pos = -1;
Expand All @@ -43,7 +46,7 @@ __global__ void find_or_insert_kernel_with_io(
if (IS_RESERVED_KEY(find_or_insert_key)) continue;

const S find_or_insert_score =
scores != nullptr ? scores[key_idx] : static_cast<S>(MAX_SCORE);
ScoreFunctor::desired_when_missed(scores, key_idx, global_epoch);
V* find_or_insert_value = values + key_idx * dim;

size_t bkt_idx = 0;
Expand Down Expand Up @@ -91,7 +94,9 @@ __global__ void find_or_insert_kernel_with_io(
copy_vector<V, TILE_SIZE>(g, find_or_insert_value,
bucket->vectors + key_pos * dim, dim);
if (g.thread_rank() == src_lane) {
update_score(bucket, key_pos, scores, key_idx);
ScoreFunctor::update(bucket, key_pos, scores, key_idx,
find_or_insert_score,
(occupy_result != OccupyResult::DUPLICATE));
}
}

Expand All @@ -103,47 +108,49 @@ __global__ void find_or_insert_kernel_with_io(
}
}

template <typename K, typename V, typename S>
template <typename K, typename V, typename S, int Strategy>
struct SelectFindOrInsertKernelWithIO {
static void execute_kernel(const float& load_factor, const int& block_size,
const size_t bucket_max_size,
const size_t buckets_num, const size_t dim,
cudaStream_t& stream, const size_t& n,
const Table<K, V, S>* __restrict table,
const K* __restrict keys, V* __restrict values,
S* __restrict scores) {
S* __restrict scores, const S global_epoch) {
if (load_factor <= 0.75) {
const unsigned int tile_size = 4;
const size_t N = n * tile_size;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
find_or_insert_kernel_with_io<K, V, S, tile_size>
find_or_insert_kernel_with_io<K, V, S, Strategy, tile_size>
<<<grid_size, block_size, 0, stream>>>(table, bucket_max_size,
buckets_num, dim, keys, values,
scores, N);
scores, global_epoch, N);
} else {
const unsigned int tile_size = 32;
const size_t N = n * tile_size;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
find_or_insert_kernel_with_io<K, V, S, tile_size>
find_or_insert_kernel_with_io<K, V, S, Strategy, tile_size>
<<<grid_size, block_size, 0, stream>>>(table, bucket_max_size,
buckets_num, dim, keys, values,
scores, N);
scores, global_epoch, N);
}
return;
}
};

/* find or insert with the end-user specified score.
*/
template <class K, class V, class S, uint32_t TILE_SIZE = 4>
template <class K, class V, class S, int Strategy, uint32_t TILE_SIZE = 4>
__global__ void find_or_insert_kernel(
const Table<K, V, S>* __restrict table, const size_t bucket_max_size,
const size_t buckets_num, const size_t dim, const K* __restrict keys,
V** __restrict vectors, S* __restrict scores, bool* __restrict found,
int* __restrict keys_index, const size_t N) {
int* __restrict keys_index, const S global_epoch, const size_t N) {
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());
int* buckets_size = table->buckets_size;

using ScoreFunctor = ScoreFunctor<K, V, S, Strategy>;

for (size_t t = (blockIdx.x * blockDim.x) + threadIdx.x; t < N;
t += blockDim.x * gridDim.x) {
int key_pos = -1;
Expand All @@ -154,7 +161,7 @@ __global__ void find_or_insert_kernel(
if (IS_RESERVED_KEY(find_or_insert_key)) continue;

const S find_or_insert_score =
scores != nullptr ? scores[key_idx] : static_cast<S>(MAX_SCORE);
ScoreFunctor::desired_when_missed(scores, key_idx, global_epoch);

size_t bkt_idx = 0;
size_t start_idx = 0;
Expand Down Expand Up @@ -210,7 +217,9 @@ __global__ void find_or_insert_kernel(
} else {
if (g.thread_rank() == src_lane) {
*(vectors + key_idx) = (bucket->vectors + key_pos * dim);
update_score(bucket, key_pos, scores, key_idx);
ScoreFunctor::update(bucket, key_pos, scores, key_idx,
find_or_insert_score,
(occupy_result != OccupyResult::DUPLICATE));
}
}

Expand Down
41 changes: 23 additions & 18 deletions include/merlin/core_kernels/find_ptr_or_insert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ namespace merlin {

/* find or insert with the end-user specified score.
*/
template <class K, class V, class S, uint32_t TILE_SIZE = 4>
template <class K, class V, class S, int Strategy, uint32_t TILE_SIZE = 4>
__global__ void find_ptr_or_insert_kernel(
const Table<K, V, S>* __restrict table, const size_t bucket_max_size,
const size_t buckets_num, const size_t dim, const K* __restrict keys,
V** __restrict vectors, S* __restrict scores, bool* __restrict found,
const size_t N) {
const S global_epoch, const size_t N) {
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());
int* buckets_size = table->buckets_size;

using ScoreFunctor = ScoreFunctor<K, V, S, Strategy>;

for (size_t t = (blockIdx.x * blockDim.x) + threadIdx.x; t < N;
t += blockDim.x * gridDim.x) {
int key_pos = -1;
Expand All @@ -42,7 +44,7 @@ __global__ void find_ptr_or_insert_kernel(
if (IS_RESERVED_KEY(find_or_insert_key)) continue;

const S find_or_insert_score =
scores != nullptr ? scores[key_idx] : static_cast<S>(MAX_SCORE);
ScoreFunctor::desired_when_missed(scores, key_idx, global_epoch);

size_t bkt_idx = 0;
size_t start_idx = 0;
Expand Down Expand Up @@ -91,7 +93,9 @@ __global__ void find_ptr_or_insert_kernel(
if (g.thread_rank() == src_lane) {
*(vectors + key_idx) = (bucket->vectors + key_pos * dim);
*(found + key_idx) = false;
update_score(bucket, key_pos, scores, key_idx);
ScoreFunctor::update(bucket, key_pos, scores, key_idx,
find_or_insert_score,
(occupy_result != OccupyResult::DUPLICATE));
}
}

Expand All @@ -103,39 +107,40 @@ __global__ void find_ptr_or_insert_kernel(
}
}

template <typename K, typename V, typename S>
template <typename K, typename V, typename S, int Strategy>
struct SelectFindOrInsertPtrKernel {
static void execute_kernel(const float& load_factor, const int& block_size,
const size_t bucket_max_size,
const size_t buckets_num, const size_t dim,
cudaStream_t& stream, const size_t& n,
const Table<K, V, S>* __restrict table,
const K* __restrict keys, V** __restrict values,
S* __restrict scores, bool* __restrict found) {
S* __restrict scores, bool* __restrict found,
const S global_epoch) {
if (load_factor <= 0.5) {
const unsigned int tile_size = 4;
const size_t N = n * tile_size;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
find_ptr_or_insert_kernel<K, V, S, tile_size>
<<<grid_size, block_size, 0, stream>>>(table, bucket_max_size,
buckets_num, dim, keys, values,
scores, found, N);
find_ptr_or_insert_kernel<K, V, S, Strategy, tile_size>
<<<grid_size, block_size, 0, stream>>>(
table, bucket_max_size, buckets_num, dim, keys, values, scores,
found, global_epoch, N);
} else if (load_factor <= 0.875) {
const unsigned int tile_size = 8;
const size_t N = n * tile_size;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
find_ptr_or_insert_kernel<K, V, S, tile_size>
<<<grid_size, block_size, 0, stream>>>(table, bucket_max_size,
buckets_num, dim, keys, values,
scores, found, N);
find_ptr_or_insert_kernel<K, V, S, Strategy, tile_size>
<<<grid_size, block_size, 0, stream>>>(
table, bucket_max_size, buckets_num, dim, keys, values, scores,
found, global_epoch, N);
} else {
const unsigned int tile_size = 32;
const size_t N = n * tile_size;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
find_ptr_or_insert_kernel<K, V, S, tile_size>
<<<grid_size, block_size, 0, stream>>>(table, bucket_max_size,
buckets_num, dim, keys, values,
scores, found, N);
find_ptr_or_insert_kernel<K, V, S, Strategy, tile_size>
<<<grid_size, block_size, 0, stream>>>(
table, bucket_max_size, buckets_num, dim, keys, values, scores,
found, global_epoch, N);
}
return;
}
Expand Down
Loading

0 comments on commit 931e10f

Please sign in to comment.