Skip to content

Commit

Permalink
[Feat] find_or_insert with return values' address
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed May 10, 2023
1 parent 6aba056 commit 03be26a
Show file tree
Hide file tree
Showing 6 changed files with 3,072 additions and 14 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,8 @@ add_executable(group_lock_test tests/group_lock_test.cc)
target_compile_features(group_lock_test PUBLIC cxx_std_14)
set_target_properties(group_lock_test PROPERTIES CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(group_lock_test gtest_main)

add_executable(find_or_insert_ptr_test tests/find_or_insert_ptr_test.cc.cu)
target_compile_features(find_or_insert_ptr_test PUBLIC cxx_std_14)
set_target_properties(find_or_insert_ptr_test PROPERTIES CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(find_or_insert_ptr_test gtest_main)
38 changes: 31 additions & 7 deletions benchmark/merlin_hashtable_benchmark.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ enum class API_Select {
assign = 3,
insert_and_evict = 4,
find_ptr = 5,
find_or_insert_ptr = 6,
};

enum class Hit_Mode {
Expand Down Expand Up @@ -329,6 +330,23 @@ float test_one_api(const API_Select api, const size_t dim,
CUDA_CHECK(cudaFree(d_vectors_ptr));
break;
}
case API_Select::find_or_insert_ptr: {
V** d_vectors_ptr = nullptr;
bool* d_found;
CUDA_CHECK(cudaMalloc(&d_found, key_num_per_op * sizeof(bool)));
CUDA_CHECK(cudaMalloc(&d_vectors_ptr, key_num_per_op * sizeof(V*)));
benchmark::array2ptr(d_vectors_ptr, d_vectors, options.dim,
key_num_per_op, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
timer.start();
table->find_or_insert(key_num_per_op, d_keys, d_vectors_ptr, d_found,
d_metas, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
timer.end();
CUDA_CHECK(cudaFree(d_vectors_ptr));
CUDA_CHECK(cudaFree(d_found));
break;
}
default: {
std::cout << "[Unsupport API]\n";
}
Expand Down Expand Up @@ -365,7 +383,8 @@ void print_title() {
<< "| find "
<< "| find_or_insert "
<< "| assign "
<< "| find* ";
<< "| find* "
<< "| find_or_insert* ";
if (Test_Mode::pure_hbm == test_mode) {
cout << "| insert_and_evict ";
}
Expand All @@ -374,18 +393,20 @@ void print_title() {
//<< "| load_factor "
cout << "|------------:"
//<< "| insert_or_assign "
<< "|:----------------:"
<< "|-----------------:"
//<< "| find "
<< "|-------:"
//<< "| find_or_insert "
<< "|:--------------:"
<< "|---------------:"
//<< "| assign "
<< "|-------:"
//<< "| find* "
<< "|-------:";
<< "|-------:"
//<< "| find_or_insert* "
<< "|----------------:";
if (Test_Mode::pure_hbm == test_mode) {
//<< "| insert_and_evict "
cout << "|:----------------:";
cout << "|-----------------:";
}
cout << "|\n";
}
Expand All @@ -398,8 +419,7 @@ void test_main(const size_t dim,
std::vector<API_Select> apis{
API_Select::insert_or_assign, API_Select::find,
API_Select::find_or_insert, API_Select::assign,
API_Select::find_ptr,
};
API_Select::find_ptr, API_Select::find_or_insert_ptr};
if (Test_Mode::pure_hbm == test_mode) {
apis.push_back(API_Select::insert_and_evict);
}
Expand Down Expand Up @@ -437,6 +457,10 @@ void test_main(const size_t dim,
std::cout << rep(2);
break;
}
case API_Select::find_or_insert_ptr: {
std::cout << rep(11);
break;
}
default: {
std::cout << "[Unsupport API]";
}
Expand Down
130 changes: 125 additions & 5 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1585,11 +1585,12 @@ __global__ void lookup_kernel(const Table<K, V, M>* __restrict table,
* usually used for the pure HBM mode for better performance.
*/
template <class K, class V, class M, uint32_t TILE_SIZE = 4>
__global__ void lookup_ptr_kernel_with_io(
const Table<K, V, M>* __restrict table, const size_t bucket_max_size,
const size_t buckets_num, const size_t dim, const K* __restrict keys,
V** __restrict values, M* __restrict metas, bool* __restrict found,
size_t N) {
__global__ void lookup_ptr_kernel(const Table<K, V, M>* __restrict table,
const size_t bucket_max_size,
const size_t buckets_num, const size_t dim,
const K* __restrict keys,
V** __restrict values, M* __restrict metas,
bool* __restrict found, size_t N) {
int* buckets_size = table->buckets_size;
Bucket<K, V, M>* buckets = table->buckets;

Expand Down Expand Up @@ -2186,6 +2187,125 @@ __global__ void find_or_insert_kernel(
}
}

/* find or insert with the end-user specified meta.
*/
template <class K, class V, class M, uint32_t TILE_SIZE = 4>
__global__ void find_ptr_or_insert_kernel(
const Table<K, V, M>* __restrict table, const size_t bucket_max_size,
const size_t buckets_num, const size_t dim, const K* __restrict keys,
V** __restrict vectors, M* __restrict metas, bool* __restrict found,
const size_t N) {
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());
int* buckets_size = table->buckets_size;

for (size_t t = (blockIdx.x * blockDim.x) + threadIdx.x; t < N;
t += blockDim.x * gridDim.x) {
int key_pos = -1;
size_t key_idx = t / TILE_SIZE;

const K find_or_insert_key = keys[key_idx];

if (IS_RESERVED_KEY(find_or_insert_key)) continue;

const M find_or_insert_meta =
metas != nullptr ? metas[key_idx] : static_cast<M>(MAX_META);

size_t bkt_idx = 0;
size_t start_idx = 0;
int src_lane = -1;
K evicted_key;

Bucket<K, V, M>* bucket =
get_key_position<K>(table->buckets, find_or_insert_key, bkt_idx,
start_idx, buckets_num, bucket_max_size);

OccupyResult occupy_result{OccupyResult::INITIAL};
const int bucket_size = buckets_size[bkt_idx];
do {
if (bucket_size < bucket_max_size) {
occupy_result = find_and_lock_when_vacant<K, V, M, TILE_SIZE>(
g, bucket, find_or_insert_key, find_or_insert_meta, evicted_key,
start_idx, key_pos, src_lane, bucket_max_size);
} else {
start_idx = (start_idx / TILE_SIZE) * TILE_SIZE;
occupy_result = find_and_lock_when_full<K, V, M, TILE_SIZE>(
g, bucket, find_or_insert_key, find_or_insert_meta, evicted_key,
start_idx, key_pos, src_lane, bucket_max_size);
}

occupy_result = g.shfl(occupy_result, src_lane);
} while (occupy_result == OccupyResult::CONTINUE);

if (occupy_result == OccupyResult::REFUSED) continue;

if ((occupy_result == OccupyResult::OCCUPIED_EMPTY ||
occupy_result == OccupyResult::OCCUPIED_RECLAIMED) &&
g.thread_rank() == src_lane) {
atomicAdd(&(buckets_size[bkt_idx]), 1);
}

if (occupy_result == OccupyResult::DUPLICATE) {
if (g.thread_rank() == src_lane) {
*(vectors + key_idx) = (bucket->vectors + key_pos * dim);
*(found + key_idx) = true;
if (metas != nullptr) {
*(metas + key_idx) =
bucket->metas(key_pos)->load(cuda::std::memory_order_relaxed);
}
}
} else {
if (g.thread_rank() == src_lane) {
*(vectors + key_idx) = (bucket->vectors + key_pos * dim);
*(found + key_idx) = false;
update_meta(bucket, key_pos, metas, key_idx);
}
}

if (g.thread_rank() == src_lane) {
(bucket->keys(key_pos))
->store(find_or_insert_key, cuda::std::memory_order_relaxed);
}
}
}

template <typename K, typename V, typename M>
struct SelectFindOrInsertPtrKernel {
static void execute_kernel(const float& load_factor, const int& block_size,
const size_t bucket_max_size,
const size_t buckets_num, const size_t dim,
cudaStream_t& stream, const size_t& n,
const Table<K, V, M>* __restrict table,
const K* __restrict keys, V** __restrict values,
M* __restrict metas, bool* __restrict found) {
if (load_factor <= 0.5) {
const unsigned int tile_size = 4;
const size_t N = n * tile_size;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
find_ptr_or_insert_kernel<K, V, M, tile_size>
<<<grid_size, block_size, 0, stream>>>(table, bucket_max_size,
buckets_num, dim, keys, values,
metas, found, N);
} else if (load_factor <= 0.875) {
const unsigned int tile_size = 8;
const size_t N = n * tile_size;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
find_ptr_or_insert_kernel<K, V, M, tile_size>
<<<grid_size, block_size, 0, stream>>>(table, bucket_max_size,
buckets_num, dim, keys, values,
metas, found, N);
} else {
const unsigned int tile_size = 32;
const size_t N = n * tile_size;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
find_ptr_or_insert_kernel<K, V, M, tile_size>
<<<grid_size, block_size, 0, stream>>>(table, bucket_max_size,
buckets_num, dim, keys, values,
metas, found, N);
}
return;
}
};

/* Read the data from address of table_value_addrs to corresponding position
in param_value if mask[i] is true, otherwise write data to table_value_addrs
form param_value,
Expand Down
59 changes: 57 additions & 2 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,62 @@ class HashTable {
CudaCheckError();
}

/**
* @brief Searches the hash table for the specified keys and returns address
* of the values. When a key is missing, the value in @p values and @p metas
* will be inserted.
*
* @warning This API returns internal addresses for high-performance but
* thread-unsafe. The caller is responsible for guaranteeing data consistency.
*
* @param n The number of key-value-meta tuples to search or insert.
* @param keys The keys to search on GPU-accessible memory with shape (n).
* @param values The addresses of values to search on GPU-accessible memory
* with shape (n).
* @param founds The status that indicates if the keys are found on
* @param metas The metas to search on GPU-accessible memory with shape (n).
* @parblock
* If @p metas is `nullptr`, the meta for each key will not be returned.
* @endparblock
* @param stream The CUDA stream that is used to execute the operation.
*
*/
void find_or_insert(const size_type n, const key_type* keys, // (n)
value_type** values, // (n)
bool* founds, // (n)
meta_type* metas = nullptr, // (n)
cudaStream_t stream = 0,
bool ignore_evict_strategy = false) {
if (n == 0) {
return;
}

while (!reach_max_capacity_ &&
fast_load_factor(n, stream) > options_.max_load_factor) {
reserve(capacity() * 2, stream);
}

if (!ignore_evict_strategy) {
check_evict_strategy(metas);
}

writer_shared_lock lock(mutex_);

using Selector =
SelectFindOrInsertPtrKernel<key_type, value_type, meta_type>;
static thread_local int step_counter = 0;
static thread_local float load_factor = 0.0;

if (((step_counter++) % kernel_select_interval_) == 0) {
load_factor = fast_load_factor(0, stream, false);
}
Selector::execute_kernel(load_factor, options_.block_size,
options_.max_bucket_size, table_->buckets_num,
options_.dim, stream, n, d_table_, keys, values,
metas, founds);

CudaCheckError();
}
/**
* @brief Assign new key-value-meta tuples into the hash table.
* If the key doesn't exist, the operation on the key will be ignored.
Expand Down Expand Up @@ -918,8 +974,7 @@ class HashTable {

reader_shared_lock lock(mutex_);

using Selector =
SelectLookupPtrKernel<key_type, value_type, meta_type>;
using Selector = SelectLookupPtrKernel<key_type, value_type, meta_type>;
static thread_local int step_counter = 0;
static thread_local float load_factor = 0.0;

Expand Down
Loading

0 comments on commit 03be26a

Please sign in to comment.