Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static_map::insert_if. #118

Merged
merged 8 commits into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions include/cuco/detail/static_map.inl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,39 @@ void static_map<Key, Value, Scope, Allocator>::insert(InputIt first,
size_ += h_num_successes;
}

template <typename Key, typename Value, cuda::thread_scope Scope, typename Allocator>
template <typename InputIt, typename StencilIt, typename Predicate, typename Hash, typename KeyEqual>
void static_map<Key, Value, Scope, Allocator>::insert_if(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
Hash hash,
KeyEqual key_equal)
{
auto num_keys = std::distance(first, last);
if (num_keys == 0) { return; }

auto const block_size = 128;
auto const stride = 1;
auto const tile_size = 4;
auto const grid_size = (tile_size * num_keys + stride * block_size - 1) / (stride * block_size);
auto view = get_device_mutable_view();

// TODO: memset an atomic variable is unsafe
static_assert(sizeof(std::size_t) == sizeof(atomic_ctr_type));
CUCO_CUDA_TRY(cudaMemsetAsync(num_successes_, 0, sizeof(atomic_ctr_type)));
std::size_t h_num_successes;

// TODO: Should I specialize the version with a tile size?
detail::insert_if<block_size>
<<<grid_size, block_size>>>(first, first + num_keys, num_successes_, view, stencil, pred, hash, key_equal);
CUCO_CUDA_TRY(cudaMemcpyAsync(
&h_num_successes, num_successes_, sizeof(atomic_ctr_type), cudaMemcpyDeviceToHost));
CUCO_CUDA_TRY(cudaDeviceSynchronize());

size_ += h_num_successes;
}

template <typename Key, typename Value, cuda::thread_scope Scope, typename Allocator>
template <typename InputIt, typename OutputIt, typename Hash, typename KeyEqual>
void static_map<Key, Value, Scope, Allocator>::find(
Expand Down
54 changes: 54 additions & 0 deletions include/cuco/detail/static_map_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,60 @@ __global__ void insert(
if (threadIdx.x == 0) { *num_successes += block_num_successes; }
}

/**
* @brief Inserts all key/value pairs in the range `[first, last)`.
*
* If multiple keys in `[first, last)` compare equal, it is unspecified which
* element is inserted.
*
* @tparam block_size
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the map's `value_type`
* @tparam atomicT Type of atomic storage
* @tparam viewT Type of device view allowing access of hash map storage
* @tparam Hash Unary callable type
* @tparam KeyEqual Binary callable type
* @param first Beginning of the sequence of key/value pairs
* @param last End of the sequence of key/value pairs
* @param num_successes The number of successfully inserted key/value pairs
* @param view Mutable device view used to access the hash map's slot storage
* @param hash The unary function to apply to hash each key
* @param key_equal The binary function used to compare two keys for equality
*/
template <std::size_t block_size,
typename InputIt,
typename atomicT,
typename viewT,
typename StencilIt,
typename Predicate,
typename Hash,
typename KeyEqual>
__global__ void insert_if(
InputIt first, InputIt last, atomicT* num_successes, viewT view, StencilIt stencil, Predicate pred, Hash hash, KeyEqual key_equal)
{
typedef cub::BlockReduce<std::size_t, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
std::size_t thread_num_successes = 0;

auto tid = block_size * blockIdx.x + threadIdx.x;
auto it = first + tid;
auto i = tid;

while (it < last) {
if (pred(*(stencil + i))) {
typename viewT::value_type const insert_pair{*it};
if (view.insert(insert_pair, hash, key_equal)) { thread_num_successes++; }
}
it += gridDim.x * block_size;
i += gridDim.x * block_size;
}

// compute number of successfully inserted elements for each block
// and atomically add to the grand total
std::size_t block_num_successes = BlockReduce(temp_storage).Sum(thread_num_successes);
if (threadIdx.x == 0) { *num_successes += block_num_successes; }
}

/**
* @brief Finds the values corresponding to all keys in the range `[first, last)`.
*
Expand Down
29 changes: 29 additions & 0 deletions include/cuco/static_map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,35 @@ class static_map {
typename KeyEqual = thrust::equal_to<key_type>>
void insert(InputIt first, InputIt last, Hash hash = Hash{}, KeyEqual key_equal = KeyEqual{});

/**
* @brief Inserts key/value pairs in the range `[first, last)` if `pred`
* of the corresponding stencil returns true.
*
* The key/value pair `*(first + i)` is inserted if `pred( *(stencil + i) )` returns true.
*
* @tparam InputIt Device accessible random access iterator whose `value_type` is
* convertible to the map's `value_type`
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from `std::iterator_traits<StencilIt>::value_type`.
* @tparam Hash Unary callable type
* @tparam KeyEqual Binary callable type
* @param first Beginning of the sequence of key/value pairs
* @param last End of the sequence of key/value pairs
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[stencil, stencil +
* std::distance(first, last))`
* @param stream CUDA stream used for insert
*/
template <typename InputIt,
typename StencilIt,
typename Predicate,
typename Hash = cuco::detail::MurmurHash3_32<key_type>,
typename KeyEqual = thrust::equal_to<key_type>>
void insert_if(
InputIt first, InputIt last, StencilIt stencil, Predicate pred, Hash hash = Hash{}, KeyEqual key_equal = KeyEqual{});

/**
* @brief Finds the values corresponding to all keys in the range `[first, last)`.
*
Expand Down
15 changes: 15 additions & 0 deletions tests/static_map/static_map_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,21 @@ TEST_CASE("User defined key and value type", "")
REQUIRE(all_of(contained.begin(), contained.end(), [] __device__(bool const& b) { return b; }));
}

SECTION("All conditionally inserted keys-value pairs should be contained")
{
thrust::device_vector<bool> contained(num_pairs);
map.insert_if(insert_pairs, insert_pairs + num_pairs, thrust::counting_iterator<int>(0),
[] __device__(auto const& key) { return (key % 2) == 0; }, hash_key_pair{}, key_pair_equals{});
map.contains(insert_keys.begin(),
insert_keys.end(),
contained.begin(),
hash_key_pair{},
key_pair_equals{});

REQUIRE(thrust::equal(thrust::device, contained.begin(), contained.end(), thrust::counting_iterator<int>(0),
[] __device__(auto const& idx_contained, auto const& idx) { return ((idx % 2) == 0) == idx_contained; }));
}

SECTION("Non-inserted keys-value pairs should not be contained")
{
thrust::device_vector<bool> contained(num_pairs);
Expand Down