Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

template the erase_if and export_batch_if API #132

Merged
merged 1 commit into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1774,16 +1774,18 @@ __global__ void remove_kernel(const Table<K, V, M>* __restrict table,
}

/* Remove specified keys which match the Predict. */
template <class K, class V, class M, uint32_t TILE_SIZE = 1>
template <class K, class V, class M,
template <typename, typename> class PredFunctor,
uint32_t TILE_SIZE = 1>
__global__ void remove_kernel(const Table<K, V, M>* __restrict table,
const EraseIfPredictInternal<K, M> pred,
const K pattern, const M threshold,
size_t* __restrict count,
Bucket<K, V, M>* __restrict buckets,
int* __restrict buckets_size,
const size_t bucket_max_size,
const size_t buckets_num, size_t N) {
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());
PredFunctor<K, M> pred;

for (size_t t = (blockIdx.x * blockDim.x) + threadIdx.x; t < N;
t += blockDim.x * gridDim.x) {
Expand Down Expand Up @@ -1891,9 +1893,9 @@ __global__ void dump_kernel(const Table<K, V, M>* __restrict table, K* d_key,
}

/* Dump with meta. */
template <class K, class V, class M>
template <class K, class V, class M,
template <typename, typename> class PredFunctor>
__global__ void dump_kernel(const Table<K, V, M>* __restrict table,
const EraseIfPredictInternal<K, M> pred,
const K pattern, const M threshold, K* d_key,
V* __restrict d_val, M* __restrict d_meta,
const size_t offset, const size_t search_length,
Expand All @@ -1907,6 +1909,7 @@ __global__ void dump_kernel(const Table<K, V, M>* __restrict table,
M* block_result_meta = (M*)&(block_result_val[blockDim.x * dim]);
__shared__ size_t block_acc;
__shared__ size_t global_acc;
PredFunctor<K, M> pred;

const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;

Expand Down
101 changes: 52 additions & 49 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,28 @@ struct HashTableOptions {
*
* ```
* template <class K, class M>
* __forceinline__ __device__ bool erase_if_pred(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
* return ((key & 0xFFFF000000000000 == pattern) &&
* (meta < threshold));
* }
* struct EraseIfPredFunctor {
* __forceinline__ __device__ bool operator()(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
* return ((key & 0xFFFF000000000000 == pattern) &&
* (meta < threshold));
* }
* };
* ```
*
* Example for export_batch_if:
* ```
* template <class K, class M>
* __forceinline__ __device__ bool export_if_pred(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
* return meta >= threshold;
* }
* struct ExportIfPredFunctor {
* __forceinline__ __device__ bool operator()(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
* return meta >= threshold;
* }
* };
* ```
*/
template <class K, class M>
Expand Down Expand Up @@ -1023,21 +1027,24 @@ class HashTable {
* @brief Erases all elements that satisfy the predicate @p pred from the
* hash table.
*
* The value for @p pred should be a function with type `Pred` defined like
* the following example:
* @tparam PredFunctor The predicate template <typename K, typename M>
* function with operator signature (bool*)(const K&, const M&, const K&,
* const threshold) that returns `true` if the element should be erased. The
* value for @p pred should be a function with type `Pred` defined like the
* following example:
*
* ```
* template <class K, class M>
* __forceinline__ __device__ bool erase_if_pred(const K& key,
* const M& meta,
* const K& pattern,
* const M& threshold) {
* return ((key & 0x1 == pattern) && (meta < threshold));
* }
* struct EraseIfPredFunctor {
* __forceinline__ __device__ bool operator()(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
* return ((key & 0x1 == pattern) && (meta < threshold));
* }
* };
* ```
*
* @param pred The predicate function with type Pred that returns `true` if
* the element should be erased.
* @param pattern The third user-defined argument to @p pred with key_type
* type.
* @param threshold The fourth user-defined argument to @p pred with meta_type
Expand All @@ -1047,27 +1054,24 @@ class HashTable {
* @return The number of elements removed.
*
*/
size_type erase_if(const Pred& pred, const key_type& pattern,
const meta_type& threshold, cudaStream_t stream = 0) {
template <template <typename, typename> class PredFunctor>
size_type erase_if(const key_type& pattern, const meta_type& threshold,
cudaStream_t stream = 0) {
write_read_lock lock(mutex_);

auto dev_ws{dev_mem_pool_->get_workspace<1>(sizeof(size_type), stream)};
auto d_count{dev_ws.get<size_type*>(0)};

CUDA_CHECK(cudaMemsetAsync(d_count, 0, sizeof(size_type), stream));

Pred h_pred;
CUDA_CHECK(cudaMemcpyFromSymbolAsync(&h_pred, pred, sizeof(Pred), 0,
cudaMemcpyDeviceToHost, stream));

{
const size_t block_size = options_.block_size;
const size_t N = table_->buckets_num;
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);

remove_kernel<key_type, value_type, meta_type>
remove_kernel<key_type, value_type, meta_type, PredFunctor>
<<<grid_size, block_size, 0, stream>>>(
table_, h_pred, pattern, threshold, d_count, table_->buckets,
table_, pattern, threshold, d_count, table_->buckets,
table_->buckets_size, table_->bucket_max_size,
table_->buckets_num, N);
}
Expand Down Expand Up @@ -1169,6 +1173,9 @@ class HashTable {

/**
* @brief Exports a certain number of the key-value-meta tuples which match
*
* @tparam PredFunctor A functor with template <K, M> defined an operator
* with signature: __device__ (bool*)(const K&, M&, const K&, const M&).
* specified condition from the hash table.
*
* @param n The maximum number of exported pairs.
Expand All @@ -1177,17 +1184,16 @@ class HashTable {
*
* ```
* template <class K, class M>
* __forceinline__ __device__ bool export_if_pred(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
*
* return meta > threshold;
* }
* struct ExportIfPredFunctor {
* __forceinline__ __device__ bool operator()(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
* return meta >= threshold;
* }
* };
* ```
*
* @param pred The predicate function with type Pred that returns `true` if
* the element should be exported.
* @param pattern The third user-defined argument to @p pred with key_type
* type.
* @param threshold The fourth user-defined argument to @p pred with meta_type
Expand All @@ -1209,9 +1215,10 @@ class HashTable {
* memory. Reducing the value for @p n is currently required if this exception
* occurs.
*/
void export_batch_if(Pred& pred, const key_type& pattern,
const meta_type& threshold, size_type n,
const size_type offset, size_type* d_counter,
template <template <typename, typename> class PredFunctor>
void export_batch_if(const key_type& pattern, const meta_type& threshold,
size_type n, const size_type offset,
size_type* d_counter,
key_type* keys, // (n)
value_type* values, // (n, DIM)
meta_type* metas = nullptr, // (n)
Expand All @@ -1235,13 +1242,9 @@ class HashTable {
const size_t shared_size = kvm_size * block_size;
const size_t grid_size = SAFE_GET_GRID_SIZE(n, block_size);

Pred h_pred;
CUDA_CHECK(cudaMemcpyFromSymbolAsync(&h_pred, pred, sizeof(Pred), 0,
cudaMemcpyDeviceToHost, stream));

dump_kernel<key_type, value_type, meta_type>
dump_kernel<key_type, value_type, meta_type, PredFunctor>
<<<grid_size, block_size, shared_size, stream>>>(
table_, h_pred, pattern, threshold, keys, values, metas, offset, n,
table_, pattern, threshold, keys, values, metas, offset, n,
d_counter);

CudaCheckError();
Expand Down
42 changes: 20 additions & 22 deletions tests/find_or_insert_ptr_test.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,22 @@ using Table = nv::merlin::HashTable<K, V, M>;
using TableOptions = nv::merlin::HashTableOptions;

template <class K, class M>
__forceinline__ __device__ bool erase_if_pred(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return ((key & 0x7f > pattern) && (meta > threshold));
}

template <class K, class M>
__device__ Table::Pred EraseIfPred = erase_if_pred<K, M>;

template <class K, class M>
__forceinline__ __device__ bool export_if_pred(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return meta > threshold;
}
struct EraseIfPredFunctor {
__forceinline__ __device__ bool operator()(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return ((key & 0x7f > pattern) && (meta > threshold));
}
};

template <class K, class M>
__device__ Table::Pred ExportIfPred = export_if_pred<K, M>;
struct ExportIfPredFunctor {
__forceinline__ __device__ bool operator()(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return meta > threshold;
}
};

void test_basic(size_t max_hbm_for_vectors) {
constexpr uint64_t INIT_CAPACITY = 64 * 1024 * 1024UL;
Expand Down Expand Up @@ -546,8 +544,8 @@ void test_erase_if_pred(size_t max_hbm_for_vectors) {

K pattern = 100;
M threshold = 0;
size_t erase_num =
table->erase_if(EraseIfPred<K, M>, pattern, threshold, stream);
size_t erase_num = table->template erase_if<EraseIfPredFunctor>(
pattern, threshold, stream);
total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
ASSERT_EQ((erase_num + total_size), BUCKET_MAX_SIZE);
Expand Down Expand Up @@ -1212,9 +1210,9 @@ void test_export_batch_if(size_t max_hbm_for_vectors) {
K pattern = 100;
M threshold = h_metas[size_t(KEY_NUM / 2)];

table->export_batch_if(ExportIfPred<K, M>, pattern, threshold,
table->capacity(), 0, d_dump_counter, d_keys,
d_vectors, d_metas, stream);
table->template export_batch_if<ExportIfPredFunctor>(
pattern, threshold, table->capacity(), 0, d_dump_counter, d_keys,
d_vectors, d_metas, stream);

CUDA_CHECK(cudaStreamSynchronize(stream));
CUDA_CHECK(cudaMemcpy(&h_dump_counter, d_dump_counter, sizeof(size_t),
Expand Down Expand Up @@ -2822,4 +2820,4 @@ TEST(FindOrInsertPtrTest, test_find_or_insert_values_check) {
test_find_or_insert_values_check(16);
// TODO(rhdong): Add back when diff error issue fixed in hybrid mode.
// test_insert_or_assign_values_check(0);
}
}
42 changes: 20 additions & 22 deletions tests/find_or_insert_test.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,22 @@ using Table = nv::merlin::HashTable<K, V, M>;
using TableOptions = nv::merlin::HashTableOptions;

template <class K, class M>
__forceinline__ __device__ bool erase_if_pred(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return ((key & 0x7f > pattern) && (meta > threshold));
}

template <class K, class M>
__device__ Table::Pred EraseIfPred = erase_if_pred<K, M>;

template <class K, class M>
__forceinline__ __device__ bool export_if_pred(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return meta > threshold;
}
struct EraseIfPredFunctor {
__forceinline__ __device__ bool operator()(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return ((key & 0x7f > pattern) && (meta > threshold));
}
};

template <class K, class M>
__device__ Table::Pred ExportIfPred = export_if_pred<K, M>;
struct ExportIfPredFunctor {
__forceinline__ __device__ bool operator()(const K& key, M& meta,
const K& pattern,
const M& threshold) {
return meta > threshold;
}
};

void test_basic(size_t max_hbm_for_vectors) {
constexpr uint64_t INIT_CAPACITY = 64 * 1024 * 1024UL;
Expand Down Expand Up @@ -470,8 +468,8 @@ void test_erase_if_pred(size_t max_hbm_for_vectors) {

K pattern = 100;
M threshold = 0;
size_t erase_num =
table->erase_if(EraseIfPred<K, M>, pattern, threshold, stream);
size_t erase_num = table->template erase_if<EraseIfPredFunctor>(
pattern, threshold, stream);
total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
ASSERT_EQ((erase_num + total_size), BUCKET_MAX_SIZE);
Expand Down Expand Up @@ -1061,9 +1059,9 @@ void test_export_batch_if(size_t max_hbm_for_vectors) {
K pattern = 100;
M threshold = h_metas[size_t(KEY_NUM / 2)];

table->export_batch_if(ExportIfPred<K, M>, pattern, threshold,
table->capacity(), 0, d_dump_counter, d_keys,
d_vectors, d_metas, stream);
table->template export_batch_if<ExportIfPredFunctor>(
pattern, threshold, table->capacity(), 0, d_dump_counter, d_keys,
d_vectors, d_metas, stream);

CUDA_CHECK(cudaStreamSynchronize(stream));
CUDA_CHECK(cudaMemcpy(&h_dump_counter, d_dump_counter, sizeof(size_t),
Expand Down Expand Up @@ -2499,4 +2497,4 @@ TEST(FindOrInsertTest, test_find_or_insert_values_check) {
test_find_or_insert_values_check(16);
// TODO(rhdong): Add back when diff error issue fixed in hybrid mode.
// test_insert_or_assign_values_check(0);
}
}
Loading