From 7405aaca69b6cc06b2114ad8e0ea39e8c185b414 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Mon, 28 Mar 2022 06:03:56 -0700 Subject: [PATCH 01/19] Squashed with initial test set --- cpp/src/io/fst/logical_stack.cuh | 492 +++++++++++++++++++++++++ cpp/tests/CMakeLists.txt | 1 + cpp/tests/io/fst/logical_stack_test.cu | 275 ++++++++++++++ 3 files changed, 768 insertions(+) create mode 100644 cpp/src/io/fst/logical_stack.cuh create mode 100644 cpp/tests/io/fst/logical_stack_test.cu diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh new file mode 100644 index 00000000000..7069ee3b404 --- /dev/null +++ b/cpp/src/io/fst/logical_stack.cuh @@ -0,0 +1,492 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace cudf { +namespace io { +namespace fst { + +/** + * @brief Describes the kind of stack operation. + */ +enum class stack_op_type : int32_t { + READ = 0, ///< Operation reading what is currently on top of the stack + PUSH = 1, ///< Operation pushing a new item on top of the stack + POP = 2 ///< Operation popping the item currently on top of the stack +}; + +namespace detail { + +/** + * @brief A convenience struct that represents a stack opepration as a key-value pair, where the key + * represents the stack's level and the value represents the stack symbol. + * + * @tparam KeyT The key type sufficient to cover all stack levels. Must be signed type as any + * subsequence of stack operations must be able to be covered. E.g., consider the first 10 + * operations are all push and the last 10 operations are all pop operations, we need to be able to + * represent a partial aggregate of the first ten items, which is '+10', just as well as a partial + * aggregate of the last ten items, which is '-10'. + * @tparam ValueT The value type that corresponds to the stack symbols (i.e., covers the stack + * alphabet). + */ +template +struct KeyValueOp { + KeyT key; + ValueT value; +}; + +/** + * @brief Helper class to assist with radix sorting KeyValueOp instances by key. + * + * @tparam BYTE_SIZE The size of the KeyValueOp. + */ +template +struct KeyValueOpToUnsigned { +}; + +template <> +struct KeyValueOpToUnsigned<1U> { + using UnsignedT = uint8_t; +}; + +template <> +struct KeyValueOpToUnsigned<2U> { + using UnsignedT = uint16_t; +}; + +template <> +struct KeyValueOpToUnsigned<4U> { + using UnsignedT = uint32_t; +}; + +template <> +struct KeyValueOpToUnsigned<8U> { + using UnsignedT = uint64_t; +}; + +/** + * @brief Alias template to retrieve an unsigned bit-representation that can be used for radix + * sorting the key of a KeyValueOp. + * + * @tparam KeyValueOpT The KeyValueOp class template instance for which to get an unsigned + * bit-representation + */ +template +using UnsignedKeyValueOpType = typename KeyValueOpToUnsigned::UnsignedT; + +/** + * @brief Function object class template used for converting a stack operation to a key-value store + * operation, where the key corresponds to the stack level being accessed. + * + * @tparam KeyValueOpT + * @tparam StackSymbolToStackOpTypeT + */ +template +struct StackSymbolToKVOp { + template + __host__ __device__ __forceinline__ KeyValueOpT operator()(StackSymbolT const& stack_symbol) const + { + stack_op_type stack_op = symbol_to_stack_op_type(stack_symbol); + // PUSH => +1, POP => -1, READ => 0 + int32_t level_delta = stack_op == stack_op_type::PUSH ? 1 + : stack_op == stack_op_type::POP ? -1 + : 0; + return KeyValueOpT{static_cast(level_delta), stack_symbol}; + } + + /// Function object returning a stack operation type for a given stack symbol + StackSymbolToStackOpTypeT symbol_to_stack_op_type; +}; + +/** + * @brief Binary reduction operator to compute the absolute stack level from relative stack levels + * (i.e., +1 for a PUSH, -1 for a POP operation). + */ +struct AddStackLevelFromKVOp { + template + __host__ __device__ __forceinline__ KeyValueOp operator()( + KeyValueOp const& lhs, KeyValueOp const& rhs) const + { + KeyT new_level = lhs.key + rhs.key; + return KeyValueOp{new_level, rhs.value}; + } +}; + +/** + * @brief Binary reduction operator that propagates a write operation for a specific key to all + * reads of that same key. That is, if the key of LHS compares equal to the key of the RHS and if + * the RHS is a read and the LHS is a write operation type, then we return LHS, otherwise we return + * the RHS. + */ +template +struct PopulatePopWithPush { + template + __host__ __device__ __forceinline__ KeyValueOp operator()( + KeyValueOp const& lhs, KeyValueOp const& rhs) const + { + // If RHS is a read, then we need to figure out whether we can propagate the value from the LHS + bool is_rhs_read = symbol_to_stack_op_type(rhs.value) != stack_op_type::PUSH; + + // Whether LHS is a matching write (i.e., the push operation that is on top of the stack for the + // RHS's read) + bool is_lhs_matching_write = + (lhs.key == rhs.key) && symbol_to_stack_op_type(lhs.value) == stack_op_type::PUSH; + + return (is_rhs_read && is_lhs_matching_write) ? lhs : rhs; + } + + /// Function object returning a stack operation type for a given stack symbol + StackSymbolToStackOpTypeT symbol_to_stack_op_type; +}; + +/** + * @brief Binary reduction operator that is used to replace each read_symbol occurance with the last + * non-read_symbol that precedes such read_symbol. + */ +template +struct PropagateLastWrite { + __host__ __device__ __forceinline__ StackSymbolT operator()(StackSymbolT const& lhs, + StackSymbolT const& rhs) const + { + // If RHS is a yet-to-be-propagated, then we need to check whether we can use the LHS to fill + bool is_rhs_read = (rhs == read_symbol); + + // We propagate the write from the LHS if it's a write + bool is_lhs_write = (lhs != read_symbol); + + return (is_rhs_read && is_lhs_write) ? lhs : rhs; + } + + /// The read_symbol that is supposed to be replaced + StackSymbolT read_symbol; +}; + +/** + * @brief Helper function object class to convert a KeyValueOp to the stack symbol of that + * KeyValueOp. + */ +struct KVOpToStackSymbol { + template + __host__ __device__ __forceinline__ ValueT operator()(KeyValueOp const& kv_op) const + { + return kv_op.value; + } +}; + +/** + * @brief Replaces all operations that apply to stack level '0' with the empty stack symbol + */ +template +struct RemapEmptyStack { + __host__ __device__ __forceinline__ KeyValueOpT operator()(KeyValueOpT const& kv_op) const + { + return kv_op.key == 0 ? empty_stack_symbol : kv_op; + } + KeyValueOpT empty_stack_symbol; +}; + +} // namespace detail + +/** + * @brief Takes a sparse representation of a sequence of stack operations that either push something + * onto the stack or pop something from the stack and resolves the symbol that is on top of the + * stack. + * + * @tparam StackLevelT Signed integer type that must be sufficient to cover [-max_stack_level, + * max_stack_level] for the given sequence of stack operations. Must be signed as it needs to cover + * the stack level of any arbitrary subsequence of stack operations. + * @tparam StackSymbolItT An input iterator type that provides the sequence of symbols that + * represent stack operations + * @tparam SymbolPositionT The index that this stack operation is supposed to apply to + * @tparam StackSymbolToStackOpT Function object class to transform items from StackSymbolItT to + * stack_op_type + * @tparam TopOfStackOutItT Output iterator type to which StackSymbolT are being assigned + * @tparam StackSymbolT The internal type being used (usually corresponding to StackSymbolItT's + * value_type) + * @tparam OffsetT Signed or unsigned integer type large enough to index into both the sparse input + * sequence and the top-of-stack output sequence + * @param[in] d_symbols Sequence of symbols that represent stack operations. Memory may alias with + * \p d_top_of_stack + * @param[in,out] d_symbol_positions Sequence of symbol positions (for a sparse representation), + * sequence must be ordered in ascending order. Note, the memory of this array is repurposed for + * double-buffering. + * @param[in] symbol_to_stack_op Function object that returns a stack operation type (push, pop, or + * read) for a given symbol from \p d_symbols + * @param[out] d_top_of_stack A random access output iterator that will be populated with + * what-is-on-top-of-the-stack for the given sequence of stack operations \p d_symbols + * @param[in] empty_stack_symbol The symbol that will be written to top_of_stack whenever the stack + * was empty + * @param[in] read_symbol A symbol that may not be confused for a symbol that would push to the + * stack + * @param[in] num_symbols_in The number of symbols in the sparse representation + * @param[in] num_symbols_out The number of symbols that are supposed to be filled with + * what-is-on-top-of-the-stack + * @param[in] stream The cuda stream to which to dispatch the work + */ +template +void SparseStackOpToTopOfStack(void* d_temp_storage, + size_t& temp_storage_bytes, + StackSymbolItT d_symbols, + SymbolPositionT* d_symbol_positions, + StackSymbolToStackOpT symbol_to_stack_op, + TopOfStackOutItT d_top_of_stack, + StackSymbolT empty_stack_symbol, + StackSymbolT read_symbol, + OffsetT num_symbols_in, + OffsetT num_symbols_out, + cudaStream_t stream = nullptr) +{ + // Type used to hold key-value pairs (key being the stack level and the value being the stack + // symbol) + using KeyValueOpT = detail::KeyValueOp; + + // The unsigned integer type that we use for radix sorting items of type KeyValueOpT + using KVOpUnsignedT = detail::UnsignedKeyValueOpType; + + // Transforming sequence of stack symbols to key-value store operations, where the key corresponds + // to the stack level of a given stack operation and the value corresponds to the stack symbol of + // that operation + using StackSymbolToKVOpT = detail::StackSymbolToKVOp; + + // TransformInputIterator converting stack symbols to key-value store operations + using TransformInputItT = + cub::TransformInputIterator; + + // Converting a stack symbol that may either push or pop to a key-value store operation: + // stack_symbol -> ([+1,0,-1], stack_symbol) + StackSymbolToKVOpT stack_sym_to_kv_op{symbol_to_stack_op}; + TransformInputItT stack_symbols_in(d_symbols, stack_sym_to_kv_op); + + // Double-buffer for sorting along the given sequence of symbol positions (the sparse + // representation) + cub::DoubleBuffer d_symbol_positions_db{nullptr, nullptr}; + + // Double-buffer for sorting the key-value store operations + cub::DoubleBuffer d_kv_operations{nullptr, nullptr}; + + // A double-buffer that aliases memory from d_kv_operations but offset by one item (to discard the + // exclusive scans first item) + cub::DoubleBuffer d_kv_operations_offset{nullptr, nullptr}; + + // A double-buffer that aliases memory from d_kv_operations_offset with unsigned types in order to + // be able to perform a radix sort + cub::DoubleBuffer d_kv_operations_unsigned{nullptr, nullptr}; + + constexpr std::size_t bits_per_byte = 8; + constexpr std::size_t begin_bit = offsetof(KeyValueOpT, key) * bits_per_byte; + constexpr std::size_t end_bit = begin_bit + (sizeof(KeyValueOpT::key) * bits_per_byte); + + // The key-value store operation that makes sure that reads for stack level '0' will be populated + // with the empty_stack_symbol + KeyValueOpT const empty_stack{0, empty_stack_symbol}; + + cub::TransformInputIterator, KeyValueOpT*> + kv_ops_scan_in(nullptr, detail::RemapEmptyStack{empty_stack}); + KeyValueOpT* kv_ops_scan_out = nullptr; + + //------------------------------------------------------------------------------ + // MEMORY REQUIREMENTS + //------------------------------------------------------------------------------ + enum mem_alloc_id { + temp_storage = 0, + symbol_position_alt, + kv_ops_current, + kv_ops_alt, + num_allocations + }; + + void* allocations[mem_alloc_id::num_allocations] = {nullptr}; + std::size_t allocation_sizes[mem_alloc_id::num_allocations] = {0}; + + std::size_t stack_level_scan_bytes = 0; + std::size_t stack_level_sort_bytes = 0; + std::size_t match_level_scan_bytes = 0; + std::size_t propagate_writes_scan_bytes = 0; + + // Getting temporary storage requirements for the prefix sum of the stack level after each + // operation + CUDA_TRY(cub::DeviceScan::InclusiveScan(nullptr, + stack_level_scan_bytes, + stack_symbols_in, + d_kv_operations_offset.Current(), + detail::AddStackLevelFromKVOp{}, + num_symbols_in, + stream)); + + // Getting temporary storage requirements for the stable radix sort (sorting by stack level of the + // operations) + CUDA_TRY(cub::DeviceRadixSort::SortPairs(nullptr, + stack_level_sort_bytes, + d_kv_operations_unsigned, + d_symbol_positions_db, + num_symbols_in, + begin_bit, + end_bit, + stream)); + + // Getting temporary storage requirements for the scan to match pop operations with the latest + // push of the same level + CUDA_TRY(cub::DeviceScan::InclusiveScan( + nullptr, + match_level_scan_bytes, + kv_ops_scan_in, + kv_ops_scan_out, + detail::PopulatePopWithPush{symbol_to_stack_op}, + num_symbols_in, + stream)); + + // Getting temporary storage requirements for the scan to propagate top-of-stack for spots that + // didn't push or pop + CUDA_TRY(cub::DeviceScan::ExclusiveScan(nullptr, + propagate_writes_scan_bytes, + d_top_of_stack, + d_top_of_stack, + detail::PropagateLastWrite{read_symbol}, + empty_stack_symbol, + num_symbols_out, + stream)); + + // Scratch memory required by the algorithms + allocation_sizes[mem_alloc_id::temp_storage] = std::max({stack_level_scan_bytes, + stack_level_sort_bytes, + match_level_scan_bytes, + propagate_writes_scan_bytes}); + + // Memory requirements by auxiliary buffers + constexpr std::size_t extra_overlap_bytes = 2U; + allocation_sizes[mem_alloc_id::symbol_position_alt] = num_symbols_in * sizeof(SymbolPositionT); + allocation_sizes[mem_alloc_id::kv_ops_current] = + (num_symbols_in + extra_overlap_bytes) * sizeof(KeyValueOpT); + allocation_sizes[mem_alloc_id::kv_ops_alt] = + (num_symbols_in + extra_overlap_bytes) * sizeof(KeyValueOpT); + + // Try to alias into the user-provided temporary storage memory blob + CUDA_TRY(cub::AliasTemporaries( + d_temp_storage, temp_storage_bytes, allocations, allocation_sizes)); + + // If this call was just to retrieve auxiliary memory requirements or not sufficient memory was + // provided + if (!d_temp_storage) { return; } + + //------------------------------------------------------------------------------ + // ALGORITHM + //------------------------------------------------------------------------------ + // Amount of temp storage available to CUB algorithms + std::size_t cub_temp_storage_bytes = allocation_sizes[mem_alloc_id::temp_storage]; + + // Temp storage for CUB algorithms + void* d_cub_temp_storage = allocations[mem_alloc_id::temp_storage]; + + // Initialize double-buffer for sorting the indexes of the sequence of sparse stack operations + d_symbol_positions_db = cub::DoubleBuffer{ + d_symbol_positions, + reinterpret_cast(allocations[mem_alloc_id::symbol_position_alt])}; + + // Initialize double-buffer for sorting the indexes of the sequence of sparse stack operations + d_kv_operations = cub::DoubleBuffer{ + reinterpret_cast(allocations[mem_alloc_id::kv_ops_current]), + reinterpret_cast(allocations[mem_alloc_id::kv_ops_alt])}; + + d_kv_operations_offset = + cub::DoubleBuffer{d_kv_operations.Current(), d_kv_operations.Alternate()}; + + // Compute prefix sum of the stack level after each operation + CUDA_TRY(cub::DeviceScan::InclusiveScan(d_cub_temp_storage, + cub_temp_storage_bytes, + stack_symbols_in, + d_kv_operations_offset.Current(), + detail::AddStackLevelFromKVOp{}, + num_symbols_in, + stream)); + + // Stable radix sort, sorting by stack level of the operations + d_kv_operations_unsigned = cub::DoubleBuffer{ + reinterpret_cast(d_kv_operations_offset.Current()), + reinterpret_cast(d_kv_operations_offset.Alternate())}; + CUDA_TRY(cub::DeviceRadixSort::SortPairs(d_cub_temp_storage, + cub_temp_storage_bytes, + d_kv_operations_unsigned, + d_symbol_positions_db, + num_symbols_in, + begin_bit, + end_bit, + stream)); + + // TransformInputIterator that remaps all operations on stack level 0 to the empty stack symbol + kv_ops_scan_in = {reinterpret_cast(d_kv_operations_unsigned.Current()), + detail::RemapEmptyStack{empty_stack}}; + kv_ops_scan_out = reinterpret_cast(d_kv_operations_unsigned.Alternate()); + + // Exclusive scan to match pop operations with the latest push operation of that level + CUDA_TRY(cub::DeviceScan::InclusiveScan( + d_cub_temp_storage, + cub_temp_storage_bytes, + kv_ops_scan_in, + kv_ops_scan_out, + detail::PopulatePopWithPush{symbol_to_stack_op}, + num_symbols_in, + stream)); + + // Fill the output tape with read-symbol + thrust::fill(thrust::cuda::par.on(stream), + thrust::device_ptr{d_top_of_stack}, + thrust::device_ptr{d_top_of_stack + num_symbols_out}, + read_symbol); + + // Transform the key-value operations to the stack symbol they represent + cub::TransformInputIterator + kv_op_to_stack_sym_it(kv_ops_scan_out, detail::KVOpToStackSymbol{}); + + // Scatter the stack symbols to the output tape (spots that are not scattered to have been + // pre-filled with the read-symbol) + thrust::scatter(thrust::cuda::par.on(stream), + kv_op_to_stack_sym_it, + kv_op_to_stack_sym_it + num_symbols_in, + d_symbol_positions_db.Current(), + d_top_of_stack); + + // We perform an exclusive scan in order to fill the items at the very left that may + // be reading the empty stack before there's the first push occurance in the sequence. + // Also, we're interested in the top-of-the-stack symbol before the operation was applied. + CUDA_TRY(cub::DeviceScan::ExclusiveScan(d_cub_temp_storage, + cub_temp_storage_bytes, + d_top_of_stack, + d_top_of_stack, + detail::PropagateLastWrite{read_symbol}, + empty_stack_symbol, + num_symbols_out, + stream)); +} + +} // namespace fst +} // namespace io +} // namespace cudf diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 816c5a1c59c..f140413157a 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -224,6 +224,7 @@ ConfigureTest(PARQUET_TEST io/parquet_test.cpp) ConfigureTest(JSON_TEST io/json_test.cpp) ConfigureTest(ARROW_IO_SOURCE_TEST io/arrow_io_source_test.cpp) ConfigureTest(MULTIBYTE_SPLIT_TEST io/text/multibyte_split_test.cpp) +ConfigureTest(LOGICAL_STACK_TEST io/fst/logical_stack_test.cu) if(CUDF_ENABLE_ARROW_S3) target_compile_definitions(ARROW_IO_SOURCE_TEST PRIVATE "S3_ENABLED") endif() diff --git a/cpp/tests/io/fst/logical_stack_test.cu b/cpp/tests/io/fst/logical_stack_test.cu new file mode 100644 index 00000000000..d2144226457 --- /dev/null +++ b/cpp/tests/io/fst/logical_stack_test.cu @@ -0,0 +1,275 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace { +namespace fst = cudf::io::fst; + +/** + * @brief Generates the sparse representation of stack operations to feed into the logical + * stack + * + * @param begin Forward input iterator to the first item of symbols that are checked for whether + * they push or pop + * @param end Forward input iterator to one one past the last item of symbols that are checked for + * whether they push or pop + * @param to_stack_op A function object that takes an instance of InputItT's value type and + * returns the kind of stack operation such item represents (i.e., of type stack_op_type) + * @param stack_symbol_out Forward output iterator to which symbols that either push or pop are + * assigned + * @param stack_op_index_out Forward output iterator to which the indexes of symbols that either + * push or pop are assigned + * @return Pair of iterators to one past the last item of the items written to \p stack_symbol_out + * and \p stack_op_index_out, respectively + */ +template +std::pair to_sparse_stack_symbols( + InputItT begin, + InputItT end, + ToStackOpTypeT to_stack_op, + StackSymbolOutItT stack_symbol_out, + StackOpIndexOutItT stack_op_index_out) +{ + std::size_t index = 0; + for (auto it = begin; it < end; it++) { + fst::stack_op_type op_type = to_stack_op(*it); + if (op_type == fst::stack_op_type::PUSH || op_type == fst::stack_op_type::POP) { + *stack_symbol_out = *it; + *stack_op_index_out = index; + stack_symbol_out++; + stack_op_index_out++; + } + index++; + } + return std::make_pair(stack_symbol_out, stack_op_index_out); +} + +/** + * @brief Reads in a sequence of items that represent stack operations, applies these operations to + * a stack, and, for every oepration being read in, outputs what was the symbol on top of the stack + * before the operations was applied. In case the stack is empty before any operation, + * \p empty_stack will be output instead. + * + * @tparam InputItT Forward input iterator type to items representing stack operations + * @tparam ToStackOpTypeT A transform function object class that maps an item representing a stack + * oepration to the stack_op_type of such item + * @tparam StackSymbolT Type representing items being pushed onto the stack + * @tparam TopOfStackOutItT A forward output iterator type being assigned items of StackSymbolT + * @param[in] begin Forward iterator to the beginning of the items representing stack operations + * @param[in] end Iterator to one past the last item representing the stack operation + * @param[in] to_stack_op A function object that takes an instance of InputItT's value type and + * returns the kind of stack operation such item represents (i.e., of type stack_op_type) + * @param[in] empty_stack A symbol that will be written to top_of_stack whenever the stack was empty + * @param[out] top_of_stack The output iterator to which the item will be written to + * @return TopOfStackOutItT Iterators to one past the last element that was written + */ +template +TopOfStackOutItT to_top_of_stack(InputItT begin, + InputItT end, + ToStackOpTypeT to_stack_op, + StackSymbolT empty_stack, + TopOfStackOutItT top_of_stack) +{ + std::stack stack; + for (auto it = begin; it < end; it++) { + // Write what is currently on top of the stack when reading in the current symbol + *top_of_stack = stack.empty() ? empty_stack : stack.top(); + top_of_stack++; + + auto const& current = *it; + fst::stack_op_type op_type = to_stack_op(current); + + // Check whether this symbol corresponds to a push or pop operation and modify the stack + // accordingly + if (op_type == fst::stack_op_type::PUSH) { + stack.push(current); + } else if (op_type == fst::stack_op_type::POP) { + stack.pop(); + } + } + return top_of_stack; +} + +/** + * @brief Funciton object used to filter for brackets and braces that represent push and pop + * operations + * + */ +struct JSONToStackOp { + template + __host__ __device__ __forceinline__ fst::stack_op_type operator()( + StackSymbolT const& stack_symbol) const + { + return (stack_symbol == '{' || stack_symbol == '[') ? fst::stack_op_type::PUSH + : (stack_symbol == '}' || stack_symbol == ']') ? fst::stack_op_type::POP + : fst::stack_op_type::READ; + } +}; +} // namespace + + +// Base test fixture for tests +struct LogicalStackTest : public cudf::test::BaseFixture { +}; + +TEST_F(LogicalStackTest, GroundTruth) +{ + // Type sufficient to cover any stack level (must be a signed type) + using StackLevelT = int8_t; + using SymbolT = char; + using SymbolOffsetT = uint32_t; + + // The stack symbol that we'll fill everywhere where there's nothing on the stack + constexpr SymbolT empty_stack_symbol = '_'; + + // This just has to be a stack symbol that may not be confused with a symbol that would push or + // pop + constexpr SymbolT read_symbol = 'x'; + + // Prepare cuda stream for data transfers & kernels + cudaStream_t stream = nullptr; + cudaStreamCreate(&stream); + rmm::cuda_stream_view stream_view(stream); + + // Test input, + std::string input = R"( { +"category": "reference", +"index:" [4,12,42], +"author": "Nigel Rees", +"title": "Sayings of the Century", +"price": 8.95 +} +{ +"category": "reference", +"index:" [4,{},null,{"a":[]}], +"author": "Nigel Rees", +"title": "Sayings of the Century", +"price": 8.95 +} )"; + + // Repeat input sample 1024x + for (std::size_t i = 0; i < 10; i++) + input += input; + + // Getting the symbols that actually modify the stack (i.e., symbols that push or pop) + std::string stack_symbols = ""; + std::vector stack_op_indexes; + stack_op_indexes.reserve(input.size()); + + // Get the sparse representation of stack operations + to_sparse_stack_symbols(std::cbegin(input), + std::cend(input), + JSONToStackOp{}, + std::back_inserter(stack_symbols), + std::back_inserter(stack_op_indexes)); + + // Prepare sparse stack ops + std::size_t num_stack_ops = stack_symbols.size(); + + rmm::device_uvector d_stack_ops(stack_symbols.size(), stream_view); + rmm::device_uvector d_stack_op_indexes(stack_op_indexes.size(), stream_view); + auto top_of_stack_gpu = hostdevice_vector(input.size(), stream_view); + + cudaMemcpyAsync(d_stack_ops.data(), + stack_symbols.data(), + stack_symbols.size() * sizeof(SymbolT), + cudaMemcpyHostToDevice, + stream); + + cudaMemcpyAsync(d_stack_op_indexes.data(), + stack_op_indexes.data(), + stack_op_indexes.size() * sizeof(SymbolOffsetT), + cudaMemcpyHostToDevice, + stream); + + // Prepare output + std::size_t string_size = input.size(); + SymbolT* d_top_of_stack = nullptr; + cudaMalloc(&d_top_of_stack, string_size + 1); + + // Request temporary storage requirements + std::size_t temp_storage_bytes = 0; + fst::SparseStackOpToTopOfStack(nullptr, + temp_storage_bytes, + d_stack_ops.data(), + d_stack_op_indexes.data(), + JSONToStackOp{}, + d_top_of_stack, + empty_stack_symbol, + read_symbol, + num_stack_ops, + string_size, + stream); + + // Allocate temporary storage required by the get-top-of-the-stack algorithm + rmm::device_buffer d_temp_storage(temp_storage_bytes, stream_view); + + // Run algorithm + fst::SparseStackOpToTopOfStack(d_temp_storage.data(), + temp_storage_bytes, + d_stack_ops.data(), + d_stack_op_indexes.data(), + JSONToStackOp{}, + top_of_stack_gpu.device_ptr(), + empty_stack_symbol, + read_symbol, + num_stack_ops, + string_size, + stream); + + // Async copy results from device to host + top_of_stack_gpu.device_to_host(stream_view); + + // Get CPU-side results for verification + std::string top_of_stack_cpu{}; + top_of_stack_cpu.reserve(input.size()); + to_top_of_stack(std::cbegin(input), + std::cend(input), + JSONToStackOp{}, + empty_stack_symbol, + std::back_inserter(top_of_stack_cpu)); + + // Make sure results have been copied back to host + cudaStreamSynchronize(stream); + + // Verify results + ASSERT_EQ(input.size(), top_of_stack_cpu.size()); + for (size_t i = 0; i < input.size() && i < top_of_stack_cpu.size(); i++) { + ASSERT_EQ(top_of_stack_gpu.host_ptr()[i], top_of_stack_cpu[i]) << "Mismatch at index #" << i; + } +} + +CUDF_TEST_PROGRAM_MAIN() From 2fc22d03206a68e5c4817508590619b934e41639 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Tue, 29 Mar 2022 07:31:44 -0700 Subject: [PATCH 02/19] style fix & additional test scenario --- cpp/tests/io/fst/logical_stack_test.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/tests/io/fst/logical_stack_test.cu b/cpp/tests/io/fst/logical_stack_test.cu index d2144226457..389ac73e533 100644 --- a/cpp/tests/io/fst/logical_stack_test.cu +++ b/cpp/tests/io/fst/logical_stack_test.cu @@ -140,7 +140,6 @@ struct JSONToStackOp { }; } // namespace - // Base test fixture for tests struct LogicalStackTest : public cudf::test::BaseFixture { }; @@ -178,9 +177,9 @@ TEST_F(LogicalStackTest, GroundTruth) "author": "Nigel Rees", "title": "Sayings of the Century", "price": 8.95 -} )"; +} {} [] [ ])"; - // Repeat input sample 1024x + // Repeat input sample 1024x for (std::size_t i = 0; i < 10; i++) input += input; From 34df49241515cb837c4c728507317cf0ea77f6c3 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Tue, 29 Mar 2022 11:25:20 -0700 Subject: [PATCH 03/19] removed forceinline --- cpp/src/io/fst/logical_stack.cuh | 18 +++++++++--------- cpp/tests/io/fst/logical_stack_test.cu | 3 +-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index 7069ee3b404..b68d53d742d 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -107,7 +107,7 @@ using UnsignedKeyValueOpType = typename KeyValueOpToUnsigned struct StackSymbolToKVOp { template - __host__ __device__ __forceinline__ KeyValueOpT operator()(StackSymbolT const& stack_symbol) const + __host__ __device__ KeyValueOpT operator()(StackSymbolT const& stack_symbol) const { stack_op_type stack_op = symbol_to_stack_op_type(stack_symbol); // PUSH => +1, POP => -1, READ => 0 @@ -127,8 +127,8 @@ struct StackSymbolToKVOp { */ struct AddStackLevelFromKVOp { template - __host__ __device__ __forceinline__ KeyValueOp operator()( - KeyValueOp const& lhs, KeyValueOp const& rhs) const + __host__ __device__ KeyValueOp operator()(KeyValueOp const& lhs, + KeyValueOp const& rhs) const { KeyT new_level = lhs.key + rhs.key; return KeyValueOp{new_level, rhs.value}; @@ -144,8 +144,8 @@ struct AddStackLevelFromKVOp { template struct PopulatePopWithPush { template - __host__ __device__ __forceinline__ KeyValueOp operator()( - KeyValueOp const& lhs, KeyValueOp const& rhs) const + __host__ __device__ KeyValueOp operator()(KeyValueOp const& lhs, + KeyValueOp const& rhs) const { // If RHS is a read, then we need to figure out whether we can propagate the value from the LHS bool is_rhs_read = symbol_to_stack_op_type(rhs.value) != stack_op_type::PUSH; @@ -168,8 +168,8 @@ struct PopulatePopWithPush { */ template struct PropagateLastWrite { - __host__ __device__ __forceinline__ StackSymbolT operator()(StackSymbolT const& lhs, - StackSymbolT const& rhs) const + __host__ __device__ StackSymbolT operator()(StackSymbolT const& lhs, + StackSymbolT const& rhs) const { // If RHS is a yet-to-be-propagated, then we need to check whether we can use the LHS to fill bool is_rhs_read = (rhs == read_symbol); @@ -190,7 +190,7 @@ struct PropagateLastWrite { */ struct KVOpToStackSymbol { template - __host__ __device__ __forceinline__ ValueT operator()(KeyValueOp const& kv_op) const + __host__ __device__ ValueT operator()(KeyValueOp const& kv_op) const { return kv_op.value; } @@ -201,7 +201,7 @@ struct KVOpToStackSymbol { */ template struct RemapEmptyStack { - __host__ __device__ __forceinline__ KeyValueOpT operator()(KeyValueOpT const& kv_op) const + __host__ __device__ KeyValueOpT operator()(KeyValueOpT const& kv_op) const { return kv_op.key == 0 ? empty_stack_symbol : kv_op; } diff --git a/cpp/tests/io/fst/logical_stack_test.cu b/cpp/tests/io/fst/logical_stack_test.cu index 389ac73e533..9eb1a90c5fe 100644 --- a/cpp/tests/io/fst/logical_stack_test.cu +++ b/cpp/tests/io/fst/logical_stack_test.cu @@ -130,8 +130,7 @@ TopOfStackOutItT to_top_of_stack(InputItT begin, */ struct JSONToStackOp { template - __host__ __device__ __forceinline__ fst::stack_op_type operator()( - StackSymbolT const& stack_symbol) const + __host__ __device__ fst::stack_op_type operator()(StackSymbolT const& stack_symbol) const { return (stack_symbol == '{' || stack_symbol == '[') ? fst::stack_op_type::PUSH : (stack_symbol == '}' || stack_symbol == ']') ? fst::stack_op_type::POP From 4b5b91728577a6dda01a33690433ec3e438a470c Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Wed, 30 Mar 2022 03:23:48 -0700 Subject: [PATCH 04/19] tagging host device function --- cpp/src/io/fst/logical_stack.cuh | 13 +++++++------ cpp/tests/io/fst/logical_stack_test.cu | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index b68d53d742d..f3595504245 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -24,6 +24,7 @@ #include #include +#include namespace cudf { namespace io { @@ -107,7 +108,7 @@ using UnsignedKeyValueOpType = typename KeyValueOpToUnsigned struct StackSymbolToKVOp { template - __host__ __device__ KeyValueOpT operator()(StackSymbolT const& stack_symbol) const + constexpr CUDF_HOST_DEVICE KeyValueOpT operator()(StackSymbolT const& stack_symbol) const { stack_op_type stack_op = symbol_to_stack_op_type(stack_symbol); // PUSH => +1, POP => -1, READ => 0 @@ -127,7 +128,7 @@ struct StackSymbolToKVOp { */ struct AddStackLevelFromKVOp { template - __host__ __device__ KeyValueOp operator()(KeyValueOp const& lhs, + constexpr CUDF_HOST_DEVICE KeyValueOp operator()(KeyValueOp const& lhs, KeyValueOp const& rhs) const { KeyT new_level = lhs.key + rhs.key; @@ -144,7 +145,7 @@ struct AddStackLevelFromKVOp { template struct PopulatePopWithPush { template - __host__ __device__ KeyValueOp operator()(KeyValueOp const& lhs, + constexpr CUDF_HOST_DEVICE KeyValueOp operator()(KeyValueOp const& lhs, KeyValueOp const& rhs) const { // If RHS is a read, then we need to figure out whether we can propagate the value from the LHS @@ -168,7 +169,7 @@ struct PopulatePopWithPush { */ template struct PropagateLastWrite { - __host__ __device__ StackSymbolT operator()(StackSymbolT const& lhs, + constexpr CUDF_HOST_DEVICE StackSymbolT operator()(StackSymbolT const& lhs, StackSymbolT const& rhs) const { // If RHS is a yet-to-be-propagated, then we need to check whether we can use the LHS to fill @@ -190,7 +191,7 @@ struct PropagateLastWrite { */ struct KVOpToStackSymbol { template - __host__ __device__ ValueT operator()(KeyValueOp const& kv_op) const + constexpr CUDF_HOST_DEVICE ValueT operator()(KeyValueOp const& kv_op) const { return kv_op.value; } @@ -201,7 +202,7 @@ struct KVOpToStackSymbol { */ template struct RemapEmptyStack { - __host__ __device__ KeyValueOpT operator()(KeyValueOpT const& kv_op) const + constexpr CUDF_HOST_DEVICE KeyValueOpT operator()(KeyValueOpT const& kv_op) const { return kv_op.key == 0 ? empty_stack_symbol : kv_op; } diff --git a/cpp/tests/io/fst/logical_stack_test.cu b/cpp/tests/io/fst/logical_stack_test.cu index 9eb1a90c5fe..7d4564d3204 100644 --- a/cpp/tests/io/fst/logical_stack_test.cu +++ b/cpp/tests/io/fst/logical_stack_test.cu @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -130,7 +131,7 @@ TopOfStackOutItT to_top_of_stack(InputItT begin, */ struct JSONToStackOp { template - __host__ __device__ fst::stack_op_type operator()(StackSymbolT const& stack_symbol) const + constexpr CUDF_HOST_DEVICE fst::stack_op_type operator()(StackSymbolT const& stack_symbol) const { return (stack_symbol == '{' || stack_symbol == '[') ? fst::stack_op_type::PUSH : (stack_symbol == '}' || stack_symbol == ']') ? fst::stack_op_type::POP From ac1e48ca823f41ce5a0f4292c21509b90166c07c Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Thu, 31 Mar 2022 04:11:44 -0700 Subject: [PATCH 05/19] Added utility to debug print & instrumented code to use it --- cpp/include/cudf_test/print_utilities.cuh | 129 ++++++++++++++++++++++ cpp/src/io/fst/logical_stack.cuh | 103 +++++++++++++---- 2 files changed, 213 insertions(+), 19 deletions(-) create mode 100644 cpp/include/cudf_test/print_utilities.cuh diff --git a/cpp/include/cudf_test/print_utilities.cuh b/cpp/include/cudf_test/print_utilities.cuh new file mode 100644 index 00000000000..04a8d8c9bea --- /dev/null +++ b/cpp/include/cudf_test/print_utilities.cuh @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cub/util_type.cuh" +#include +#include +#include + +namespace cudf { +namespace test { +namespace print { + +constexpr int32_t hex_tag = 0; + +template +struct TaggedType { + T v; +}; + +template +using hex_t = TaggedType; + +template +struct ToTaggedType { + template + CUDF_HOST_DEVICE WrappedTypeT operator()(T const& v) const + { + return WrappedTypeT{v}; + } +}; + +template +auto hex(InItT it) +{ + using value_t = typename std::iterator_traits::value_type; + using tagged_t = hex_t; + return cub::TransformInputIterator, InItT>( + it, ToTaggedType{}); +} + +template && std::is_signed_v>* = nullptr> +CUDF_HOST_DEVICE void print_value(int32_t width, T arg) +{ + printf("%*d", width, arg); +} + +template && std::is_unsigned_v>* = nullptr> +CUDF_HOST_DEVICE void print_value(int32_t width, T arg) +{ + printf("%*d", width, arg); +} + +CUDF_HOST_DEVICE void print_value(int32_t width, char arg) { printf("%*c", width, arg); } + +template +CUDF_HOST_DEVICE void print_value(int32_t width, hex_t arg) +{ + printf("%*X", width, arg.v); +} + +namespace detail +{ +template +CUDF_HOST_DEVICE void print_line(int32_t width, char delimiter, T arg) +{ + print_value(width, arg); +} + +template +CUDF_HOST_DEVICE void print_line(int32_t width, char delimiter, T arg, Ts... args) +{ + print_value(width, arg); + if (delimiter) printf("%c", delimiter); + print_line(width, delimiter, args...); +} + +template +__global__ void print_array_kernel(std::size_t count, int32_t width, char delimiter, Ts... args) +{ + if (threadIdx.x == 0 && blockIdx.x == 0) { + for (std::size_t i = 0; i < count; i++) { + printf("%6lu: ", i); + print_line(width, delimiter, args[i]...); + printf("\n"); + } + } +} +} + +/** + * @brief Prints \p count elements from each of the given device-accessible iterators. + * + * @param count The number of items to print from each device-accessible iterator + * @param stream The cuda stream to which the printing kernel shall be dispatched + * @param args List of iterators to be printed + */ +template +void print_array(std::size_t count, cudaStream_t stream, Ts... args) +{ + // The width to pad printed numbers to + constexpr int32_t width = 6; + + // Delimiter used for separating values from subsequent iterators + constexpr char delimiter = ','; + + // TODO we want this to compile to nothing dependnig on compiler flag, rather than runtime + if (std::getenv("CUDA_DBG_DUMP") != nullptr) { + detail::print_array_kernel<<<1, 1, 0, stream>>>(count, width, delimiter, args...); + } +} + +} // namespace print +} // namespace test +} // namespace cudf diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index f3595504245..b725f8fed3f 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -23,8 +23,9 @@ #include #include -#include #include +#include +#include namespace cudf { namespace io { @@ -128,8 +129,8 @@ struct StackSymbolToKVOp { */ struct AddStackLevelFromKVOp { template - constexpr CUDF_HOST_DEVICE KeyValueOp operator()(KeyValueOp const& lhs, - KeyValueOp const& rhs) const + constexpr CUDF_HOST_DEVICE KeyValueOp operator()( + KeyValueOp const& lhs, KeyValueOp const& rhs) const { KeyT new_level = lhs.key + rhs.key; return KeyValueOp{new_level, rhs.value}; @@ -145,8 +146,8 @@ struct AddStackLevelFromKVOp { template struct PopulatePopWithPush { template - constexpr CUDF_HOST_DEVICE KeyValueOp operator()(KeyValueOp const& lhs, - KeyValueOp const& rhs) const + constexpr CUDF_HOST_DEVICE KeyValueOp operator()( + KeyValueOp const& lhs, KeyValueOp const& rhs) const { // If RHS is a read, then we need to figure out whether we can propagate the value from the LHS bool is_rhs_read = symbol_to_stack_op_type(rhs.value) != stack_op_type::PUSH; @@ -170,7 +171,7 @@ struct PopulatePopWithPush { template struct PropagateLastWrite { constexpr CUDF_HOST_DEVICE StackSymbolT operator()(StackSymbolT const& lhs, - StackSymbolT const& rhs) const + StackSymbolT const& rhs) const { // If RHS is a yet-to-be-propagated, then we need to check whether we can use the LHS to fill bool is_rhs_read = (rhs == read_symbol); @@ -209,6 +210,46 @@ struct RemapEmptyStack { KeyValueOpT empty_stack_symbol; }; +/** + * @brief Function object to return only the key part from a KeyValueOp instance. + */ +struct KVOpToKey { + template + constexpr CUDF_HOST_DEVICE KeyT operator()(KeyValueOp const& kv_op) const + { + return kv_op.key; + } +}; + +/** + * @brief Function object to return only the value part from a KeyValueOp instance. + */ +struct KVOpToValue { + template + constexpr CUDF_HOST_DEVICE ValueT operator()(KeyValueOp const& kv_op) const + { + return kv_op.value; + } +}; + +/** + * @brief Retrieves an iterator that returns only the `key` part from a KeyValueOp iterator. + */ +template +auto get_key_it(KeyValueOpItT it) +{ + return thrust::make_transform_iterator(it, KVOpToKey{}); +} + +/** + * @brief Retrieves an iterator that returns only the `value` part from a KeyValueOp iterator. + */ +template +auto get_value_it(KeyValueOpItT it) +{ + return thrust::make_transform_iterator(it, KVOpToValue{}); +} + } // namespace detail /** @@ -294,11 +335,7 @@ void SparseStackOpToTopOfStack(void* d_temp_storage, // Double-buffer for sorting the key-value store operations cub::DoubleBuffer d_kv_operations{nullptr, nullptr}; - // A double-buffer that aliases memory from d_kv_operations but offset by one item (to discard the - // exclusive scans first item) - cub::DoubleBuffer d_kv_operations_offset{nullptr, nullptr}; - - // A double-buffer that aliases memory from d_kv_operations_offset with unsigned types in order to + // A double-buffer that aliases memory from d_kv_operations with unsigned types in order to // be able to perform a radix sort cub::DoubleBuffer d_kv_operations_unsigned{nullptr, nullptr}; @@ -338,7 +375,7 @@ void SparseStackOpToTopOfStack(void* d_temp_storage, CUDA_TRY(cub::DeviceScan::InclusiveScan(nullptr, stack_level_scan_bytes, stack_symbols_in, - d_kv_operations_offset.Current(), + d_kv_operations.Current(), detail::AddStackLevelFromKVOp{}, num_symbols_in, stream)); @@ -417,22 +454,27 @@ void SparseStackOpToTopOfStack(void* d_temp_storage, reinterpret_cast(allocations[mem_alloc_id::kv_ops_current]), reinterpret_cast(allocations[mem_alloc_id::kv_ops_alt])}; - d_kv_operations_offset = - cub::DoubleBuffer{d_kv_operations.Current(), d_kv_operations.Alternate()}; - // Compute prefix sum of the stack level after each operation CUDA_TRY(cub::DeviceScan::InclusiveScan(d_cub_temp_storage, cub_temp_storage_bytes, stack_symbols_in, - d_kv_operations_offset.Current(), + d_kv_operations.Current(), detail::AddStackLevelFromKVOp{}, num_symbols_in, stream)); + // Dump info on stack operations: (stack level change + symbol) -> (absolute stack level + symbol) + test::print::print_array(num_symbols_in, + stream, + get_key_it(stack_symbols_in), + get_value_it(stack_symbols_in), + get_key_it(d_kv_operations.Current()), + get_value_it(d_kv_operations.Current())); + // Stable radix sort, sorting by stack level of the operations - d_kv_operations_unsigned = cub::DoubleBuffer{ - reinterpret_cast(d_kv_operations_offset.Current()), - reinterpret_cast(d_kv_operations_offset.Alternate())}; + d_kv_operations_unsigned = + cub::DoubleBuffer{reinterpret_cast(d_kv_operations.Current()), + reinterpret_cast(d_kv_operations.Alternate())}; CUDA_TRY(cub::DeviceRadixSort::SortPairs(d_cub_temp_storage, cub_temp_storage_bytes, d_kv_operations_unsigned, @@ -447,6 +489,11 @@ void SparseStackOpToTopOfStack(void* d_temp_storage, detail::RemapEmptyStack{empty_stack}}; kv_ops_scan_out = reinterpret_cast(d_kv_operations_unsigned.Alternate()); + // Dump info on stack operations sorted by their stack level (i.e. stack level after applying + // operation) + test::print::print_array( + num_symbols_in, stream, get_key_it(kv_ops_scan_in), get_value_it(kv_ops_scan_in)); + // Exclusive scan to match pop operations with the latest push operation of that level CUDA_TRY(cub::DeviceScan::InclusiveScan( d_cub_temp_storage, @@ -457,6 +504,15 @@ void SparseStackOpToTopOfStack(void* d_temp_storage, num_symbols_in, stream)); + // Dump info on stack operations sorted by their stack level (i.e. stack level after applying + // operation) + test::print::print_array(num_symbols_in, + stream, + get_key_it(kv_ops_scan_in), + get_value_it(kv_ops_scan_in), + get_key_it(kv_ops_scan_out), + get_value_it(kv_ops_scan_out)); + // Fill the output tape with read-symbol thrust::fill(thrust::cuda::par.on(stream), thrust::device_ptr{d_top_of_stack}, @@ -475,6 +531,11 @@ void SparseStackOpToTopOfStack(void* d_temp_storage, d_symbol_positions_db.Current(), d_top_of_stack); + // Dump the output tape that has many yet-to-be-filled spots (i.e., all spots that were not given + // in the sparse representation) + test::print::print_array( + std::min(num_symbols_in, static_cast(10000)), stream, d_top_of_stack); + // We perform an exclusive scan in order to fill the items at the very left that may // be reading the empty stack before there's the first push occurance in the sequence. // Also, we're interested in the top-of-the-stack symbol before the operation was applied. @@ -486,6 +547,10 @@ void SparseStackOpToTopOfStack(void* d_temp_storage, empty_stack_symbol, num_symbols_out, stream)); + + // Dump the final output + test::print::print_array( + std::min(num_symbols_in, static_cast(10000)), stream, d_top_of_stack); } } // namespace fst From 053bb31fc4e3d9b933d03ddf2d725c9c674f75ff Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Thu, 31 Mar 2022 05:28:17 -0700 Subject: [PATCH 06/19] switched to using rmm also inside algorithm --- cpp/src/io/fst/logical_stack.cuh | 85 +++++++++----------------- cpp/tests/io/fst/logical_stack_test.cu | 22 ++----- 2 files changed, 33 insertions(+), 74 deletions(-) diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index b725f8fed3f..9550798aeaf 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -27,6 +27,9 @@ #include #include +#include +#include + namespace cudf { namespace io { namespace fst { @@ -295,8 +298,7 @@ template -void SparseStackOpToTopOfStack(void* d_temp_storage, - size_t& temp_storage_bytes, +void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, StackSymbolItT d_symbols, SymbolPositionT* d_symbol_positions, StackSymbolToStackOpT symbol_to_stack_op, @@ -351,20 +353,6 @@ void SparseStackOpToTopOfStack(void* d_temp_storage, kv_ops_scan_in(nullptr, detail::RemapEmptyStack{empty_stack}); KeyValueOpT* kv_ops_scan_out = nullptr; - //------------------------------------------------------------------------------ - // MEMORY REQUIREMENTS - //------------------------------------------------------------------------------ - enum mem_alloc_id { - temp_storage = 0, - symbol_position_alt, - kv_ops_current, - kv_ops_alt, - num_allocations - }; - - void* allocations[mem_alloc_id::num_allocations] = {nullptr}; - std::size_t allocation_sizes[mem_alloc_id::num_allocations] = {0}; - std::size_t stack_level_scan_bytes = 0; std::size_t stack_level_sort_bytes = 0; std::size_t match_level_scan_bytes = 0; @@ -414,49 +402,34 @@ void SparseStackOpToTopOfStack(void* d_temp_storage, stream)); // Scratch memory required by the algorithms - allocation_sizes[mem_alloc_id::temp_storage] = std::max({stack_level_scan_bytes, - stack_level_sort_bytes, - match_level_scan_bytes, - propagate_writes_scan_bytes}); - - // Memory requirements by auxiliary buffers - constexpr std::size_t extra_overlap_bytes = 2U; - allocation_sizes[mem_alloc_id::symbol_position_alt] = num_symbols_in * sizeof(SymbolPositionT); - allocation_sizes[mem_alloc_id::kv_ops_current] = - (num_symbols_in + extra_overlap_bytes) * sizeof(KeyValueOpT); - allocation_sizes[mem_alloc_id::kv_ops_alt] = - (num_symbols_in + extra_overlap_bytes) * sizeof(KeyValueOpT); - - // Try to alias into the user-provided temporary storage memory blob - CUDA_TRY(cub::AliasTemporaries( - d_temp_storage, temp_storage_bytes, allocations, allocation_sizes)); - - // If this call was just to retrieve auxiliary memory requirements or not sufficient memory was - // provided - if (!d_temp_storage) { return; } + auto total_temp_storage_bytes = std::max({stack_level_scan_bytes, + stack_level_sort_bytes, + match_level_scan_bytes, + propagate_writes_scan_bytes}); + + if (temp_storage.size() < total_temp_storage_bytes) { + temp_storage.resize(total_temp_storage_bytes, stream); + } + // Actual device buffer size, as we need to pass in an lvalue-ref to cub algorithms as temp_storage_bytes + total_temp_storage_bytes = temp_storage.size(); + + rmm::device_uvector d_symbol_position_alt{num_symbols_in, stream}; + rmm::device_uvector d_kv_ops_current{num_symbols_in, stream}; + rmm::device_uvector d_kv_ops_alt{num_symbols_in, stream}; //------------------------------------------------------------------------------ // ALGORITHM //------------------------------------------------------------------------------ - // Amount of temp storage available to CUB algorithms - std::size_t cub_temp_storage_bytes = allocation_sizes[mem_alloc_id::temp_storage]; - - // Temp storage for CUB algorithms - void* d_cub_temp_storage = allocations[mem_alloc_id::temp_storage]; - // Initialize double-buffer for sorting the indexes of the sequence of sparse stack operations - d_symbol_positions_db = cub::DoubleBuffer{ - d_symbol_positions, - reinterpret_cast(allocations[mem_alloc_id::symbol_position_alt])}; + d_symbol_positions_db = + cub::DoubleBuffer{d_symbol_positions, d_symbol_position_alt.data()}; // Initialize double-buffer for sorting the indexes of the sequence of sparse stack operations - d_kv_operations = cub::DoubleBuffer{ - reinterpret_cast(allocations[mem_alloc_id::kv_ops_current]), - reinterpret_cast(allocations[mem_alloc_id::kv_ops_alt])}; + d_kv_operations = cub::DoubleBuffer{d_kv_ops_current.data(), d_kv_ops_alt.data()}; // Compute prefix sum of the stack level after each operation - CUDA_TRY(cub::DeviceScan::InclusiveScan(d_cub_temp_storage, - cub_temp_storage_bytes, + CUDA_TRY(cub::DeviceScan::InclusiveScan(temp_storage.data(), + total_temp_storage_bytes, stack_symbols_in, d_kv_operations.Current(), detail::AddStackLevelFromKVOp{}, @@ -475,8 +448,8 @@ void SparseStackOpToTopOfStack(void* d_temp_storage, d_kv_operations_unsigned = cub::DoubleBuffer{reinterpret_cast(d_kv_operations.Current()), reinterpret_cast(d_kv_operations.Alternate())}; - CUDA_TRY(cub::DeviceRadixSort::SortPairs(d_cub_temp_storage, - cub_temp_storage_bytes, + CUDA_TRY(cub::DeviceRadixSort::SortPairs(temp_storage.data(), + total_temp_storage_bytes, d_kv_operations_unsigned, d_symbol_positions_db, num_symbols_in, @@ -496,8 +469,8 @@ void SparseStackOpToTopOfStack(void* d_temp_storage, // Exclusive scan to match pop operations with the latest push operation of that level CUDA_TRY(cub::DeviceScan::InclusiveScan( - d_cub_temp_storage, - cub_temp_storage_bytes, + temp_storage.data(), + total_temp_storage_bytes, kv_ops_scan_in, kv_ops_scan_out, detail::PopulatePopWithPush{symbol_to_stack_op}, @@ -539,8 +512,8 @@ void SparseStackOpToTopOfStack(void* d_temp_storage, // We perform an exclusive scan in order to fill the items at the very left that may // be reading the empty stack before there's the first push occurance in the sequence. // Also, we're interested in the top-of-the-stack symbol before the operation was applied. - CUDA_TRY(cub::DeviceScan::ExclusiveScan(d_cub_temp_storage, - cub_temp_storage_bytes, + CUDA_TRY(cub::DeviceScan::ExclusiveScan(temp_storage.data(), + total_temp_storage_bytes, d_top_of_stack, d_top_of_stack, detail::PropagateLastWrite{read_symbol}, diff --git a/cpp/tests/io/fst/logical_stack_test.cu b/cpp/tests/io/fst/logical_stack_test.cu index 7d4564d3204..3b860867cf2 100644 --- a/cpp/tests/io/fst/logical_stack_test.cu +++ b/cpp/tests/io/fst/logical_stack_test.cu @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -217,28 +218,13 @@ TEST_F(LogicalStackTest, GroundTruth) // Prepare output std::size_t string_size = input.size(); SymbolT* d_top_of_stack = nullptr; - cudaMalloc(&d_top_of_stack, string_size + 1); - - // Request temporary storage requirements - std::size_t temp_storage_bytes = 0; - fst::SparseStackOpToTopOfStack(nullptr, - temp_storage_bytes, - d_stack_ops.data(), - d_stack_op_indexes.data(), - JSONToStackOp{}, - d_top_of_stack, - empty_stack_symbol, - read_symbol, - num_stack_ops, - string_size, - stream); + cudaMalloc(&d_top_of_stack, string_size); // Allocate temporary storage required by the get-top-of-the-stack algorithm - rmm::device_buffer d_temp_storage(temp_storage_bytes, stream_view); + rmm::device_buffer d_temp_storage{}; // Run algorithm - fst::SparseStackOpToTopOfStack(d_temp_storage.data(), - temp_storage_bytes, + fst::SparseStackOpToTopOfStack(d_temp_storage, d_stack_ops.data(), d_stack_op_indexes.data(), JSONToStackOp{}, From 6bbcd32c1f00b69e5c990f3ad8bc6d5ff634de16 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Thu, 31 Mar 2022 09:34:41 -0700 Subject: [PATCH 07/19] header include order & SFINAE macro --- cpp/include/cudf_test/print_utilities.cuh | 18 ++++++++++-------- cpp/src/io/fst/logical_stack.cuh | 20 +++++++++++--------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/cpp/include/cudf_test/print_utilities.cuh b/cpp/include/cudf_test/print_utilities.cuh index 04a8d8c9bea..6e6fff84cec 100644 --- a/cpp/include/cudf_test/print_utilities.cuh +++ b/cpp/include/cudf_test/print_utilities.cuh @@ -16,9 +16,12 @@ #pragma once +#include +#include + #include "cub/util_type.cuh" #include -#include + #include namespace cudf { @@ -53,13 +56,13 @@ auto hex(InItT it) it, ToTaggedType{}); } -template && std::is_signed_v>* = nullptr> +template && std::is_signed_v)> CUDF_HOST_DEVICE void print_value(int32_t width, T arg) { printf("%*d", width, arg); } -template && std::is_unsigned_v>* = nullptr> +template && std::is_unsigned_v)> CUDF_HOST_DEVICE void print_value(int32_t width, T arg) { printf("%*d", width, arg); @@ -73,8 +76,7 @@ CUDF_HOST_DEVICE void print_value(int32_t width, hex_t arg) printf("%*X", width, arg.v); } -namespace detail -{ +namespace detail { template CUDF_HOST_DEVICE void print_line(int32_t width, char delimiter, T arg) { @@ -100,11 +102,11 @@ __global__ void print_array_kernel(std::size_t count, int32_t width, char delimi } } } -} +} // namespace detail /** * @brief Prints \p count elements from each of the given device-accessible iterators. - * + * * @param count The number of items to print from each device-accessible iterator * @param stream The cuda stream to which the printing kernel shall be dispatched * @param args List of iterators to be printed @@ -113,7 +115,7 @@ template void print_array(std::size_t count, cudaStream_t stream, Ts... args) { // The width to pad printed numbers to - constexpr int32_t width = 6; + constexpr int32_t width = 6; // Delimiter used for separating values from subsequent iterators constexpr char delimiter = ','; diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index 9550798aeaf..ced1a712d6a 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -15,14 +15,6 @@ */ #pragma once -#include -#include -#include -#include -#include -#include -#include - #include #include #include @@ -30,6 +22,16 @@ #include #include +#include +#include +#include +#include + +#include + +#include +#include + namespace cudf { namespace io { namespace fst { @@ -307,7 +309,7 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, StackSymbolT read_symbol, OffsetT num_symbols_in, OffsetT num_symbols_out, - cudaStream_t stream = nullptr) + rmm::cuda_stream_view stream = rmm::cuda_stream_default) { // Type used to hold key-value pairs (key being the stack level and the value being the stack // symbol) From 9af2138bd0e3d7c16fcd2426bdcd2b457b8429b5 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Sun, 3 Apr 2022 23:46:00 -0700 Subject: [PATCH 08/19] debug print cleanups --- cpp/include/cudf_test/print_utilities.cuh | 31 ++++++++++++++++------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/cpp/include/cudf_test/print_utilities.cuh b/cpp/include/cudf_test/print_utilities.cuh index 6e6fff84cec..1da7b9836b1 100644 --- a/cpp/include/cudf_test/print_utilities.cuh +++ b/cpp/include/cudf_test/print_utilities.cuh @@ -38,15 +38,28 @@ struct TaggedType { template using hex_t = TaggedType; -template +/** + * @brief Function object to transform a built-in type to a tagged type (e.g., in order to print + * values from an iterator returning uint32_t as hex values) + * + * @tparam TaggedTypeT A TaggedType template specialisation + */ +template struct ToTaggedType { template - CUDF_HOST_DEVICE WrappedTypeT operator()(T const& v) const + CUDF_HOST_DEVICE TaggedTypeT operator()(T const& v) const { - return WrappedTypeT{v}; + return TaggedTypeT{v}; } }; +/** + * @brief Returns an iterator that causes the values from \p it to be printed as hex values. + * + * @tparam InItT A random-access input iterator type + * @param it A random-access input iterator t + * @return + */ template auto hex(InItT it) { @@ -56,13 +69,13 @@ auto hex(InItT it) it, ToTaggedType{}); } -template && std::is_signed_v)> +template && std::is_signed_v)> CUDF_HOST_DEVICE void print_value(int32_t width, T arg) { printf("%*d", width, arg); } -template && std::is_unsigned_v)> +template && std::is_unsigned_v)> CUDF_HOST_DEVICE void print_value(int32_t width, T arg) { printf("%*d", width, arg); @@ -78,17 +91,17 @@ CUDF_HOST_DEVICE void print_value(int32_t width, hex_t arg) namespace detail { template -CUDF_HOST_DEVICE void print_line(int32_t width, char delimiter, T arg) +CUDF_HOST_DEVICE void print_values(int32_t width, char delimiter, T arg) { print_value(width, arg); } template -CUDF_HOST_DEVICE void print_line(int32_t width, char delimiter, T arg, Ts... args) +CUDF_HOST_DEVICE void print_values(int32_t width, char delimiter, T arg, Ts... args) { print_value(width, arg); if (delimiter) printf("%c", delimiter); - print_line(width, delimiter, args...); + print_values(width, delimiter, args...); } template @@ -97,7 +110,7 @@ __global__ void print_array_kernel(std::size_t count, int32_t width, char delimi if (threadIdx.x == 0 && blockIdx.x == 0) { for (std::size_t i = 0; i < count; i++) { printf("%6lu: ", i); - print_line(width, delimiter, args[i]...); + print_values(width, delimiter, args[i]...); printf("\n"); } } From 1df7dcf4d9f67ccc45804ac14ea1209a5dd8b5c0 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Mon, 4 Apr 2022 02:28:30 -0700 Subject: [PATCH 09/19] renaming key-value store op to stack_op --- cpp/src/io/fst/logical_stack.cuh | 216 +++++++++++++++---------------- 1 file changed, 101 insertions(+), 115 deletions(-) diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index ced1a712d6a..bce362beff9 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -48,10 +48,10 @@ enum class stack_op_type : int32_t { namespace detail { /** - * @brief A convenience struct that represents a stack opepration as a key-value pair, where the key + * @brief A convenience struct that represents a stack opepration as a pair, where the stack_level * represents the stack's level and the value represents the stack symbol. * - * @tparam KeyT The key type sufficient to cover all stack levels. Must be signed type as any + * @tparam StackLevelT The stack level type sufficient to cover all stack levels. Must be signed type as any * subsequence of stack operations must be able to be covered. E.g., consider the first 10 * operations are all push and the last 10 operations are all pop operations, we need to be able to * represent a partial aggregate of the first ten items, which is '+10', just as well as a partial @@ -59,69 +59,69 @@ namespace detail { * @tparam ValueT The value type that corresponds to the stack symbols (i.e., covers the stack * alphabet). */ -template -struct KeyValueOp { - KeyT key; +template +struct StackOp { + StackLevelT stack_level; ValueT value; }; /** - * @brief Helper class to assist with radix sorting KeyValueOp instances by key. + * @brief Helper class to assist with radix sorting StackOp instances by stack level. * - * @tparam BYTE_SIZE The size of the KeyValueOp. + * @tparam BYTE_SIZE The size of the StackOp. */ template -struct KeyValueOpToUnsigned { +struct StackOpToUnsigned { }; template <> -struct KeyValueOpToUnsigned<1U> { +struct StackOpToUnsigned<1U> { using UnsignedT = uint8_t; }; template <> -struct KeyValueOpToUnsigned<2U> { +struct StackOpToUnsigned<2U> { using UnsignedT = uint16_t; }; template <> -struct KeyValueOpToUnsigned<4U> { +struct StackOpToUnsigned<4U> { using UnsignedT = uint32_t; }; template <> -struct KeyValueOpToUnsigned<8U> { +struct StackOpToUnsigned<8U> { using UnsignedT = uint64_t; }; /** * @brief Alias template to retrieve an unsigned bit-representation that can be used for radix - * sorting the key of a KeyValueOp. + * sorting the stack level of a StackOp. * - * @tparam KeyValueOpT The KeyValueOp class template instance for which to get an unsigned + * @tparam StackOpT The StackOp class template instance for which to get an unsigned * bit-representation */ -template -using UnsignedKeyValueOpType = typename KeyValueOpToUnsigned::UnsignedT; +template +using UnsignedStackOpType = typename StackOpToUnsigned::UnsignedT; /** - * @brief Function object class template used for converting a stack operation to a key-value store - * operation, where the key corresponds to the stack level being accessed. + * @brief Function object class template used for converting a stack symbol to a stack + * operation that has a stack level to which an operation applies. * - * @tparam KeyValueOpT + * @tparam StackOpT * @tparam StackSymbolToStackOpTypeT */ -template -struct StackSymbolToKVOp { +template +struct StackSymbolToStackOp { template - constexpr CUDF_HOST_DEVICE KeyValueOpT operator()(StackSymbolT const& stack_symbol) const + constexpr CUDF_HOST_DEVICE StackOpT operator()(StackSymbolT const& stack_symbol) const { stack_op_type stack_op = symbol_to_stack_op_type(stack_symbol); // PUSH => +1, POP => -1, READ => 0 int32_t level_delta = stack_op == stack_op_type::PUSH ? 1 : stack_op == stack_op_type::POP ? -1 : 0; - return KeyValueOpT{static_cast(level_delta), stack_symbol}; + return StackOpT{static_cast(level_delta), stack_symbol}; } /// Function object returning a stack operation type for a given stack symbol @@ -132,27 +132,27 @@ struct StackSymbolToKVOp { * @brief Binary reduction operator to compute the absolute stack level from relative stack levels * (i.e., +1 for a PUSH, -1 for a POP operation). */ -struct AddStackLevelFromKVOp { - template - constexpr CUDF_HOST_DEVICE KeyValueOp operator()( - KeyValueOp const& lhs, KeyValueOp const& rhs) const +struct AddStackLevelFromStackOp { + template + constexpr CUDF_HOST_DEVICE StackOp operator()( + StackOp const& lhs, StackOp const& rhs) const { - KeyT new_level = lhs.key + rhs.key; - return KeyValueOp{new_level, rhs.value}; + StackLevelT new_level = lhs.stack_level + rhs.stack_level; + return StackOp{new_level, rhs.value}; } }; /** - * @brief Binary reduction operator that propagates a write operation for a specific key to all - * reads of that same key. That is, if the key of LHS compares equal to the key of the RHS and if + * @brief Binary reduction operator that propagates a write operation for a specific stack level to all + * reads of that same stack level. That is, if the stack level of LHS compares equal to the stack level of the RHS and if * the RHS is a read and the LHS is a write operation type, then we return LHS, otherwise we return * the RHS. */ template struct PopulatePopWithPush { - template - constexpr CUDF_HOST_DEVICE KeyValueOp operator()( - KeyValueOp const& lhs, KeyValueOp const& rhs) const + template + constexpr CUDF_HOST_DEVICE StackOp operator()( + StackOp const& lhs, StackOp const& rhs) const { // If RHS is a read, then we need to figure out whether we can propagate the value from the LHS bool is_rhs_read = symbol_to_stack_op_type(rhs.value) != stack_op_type::PUSH; @@ -160,7 +160,7 @@ struct PopulatePopWithPush { // Whether LHS is a matching write (i.e., the push operation that is on top of the stack for the // RHS's read) bool is_lhs_matching_write = - (lhs.key == rhs.key) && symbol_to_stack_op_type(lhs.value) == stack_op_type::PUSH; + (lhs.stack_level == rhs.stack_level) && symbol_to_stack_op_type(lhs.value) == stack_op_type::PUSH; return (is_rhs_read && is_lhs_matching_write) ? lhs : rhs; } @@ -192,12 +192,12 @@ struct PropagateLastWrite { }; /** - * @brief Helper function object class to convert a KeyValueOp to the stack symbol of that - * KeyValueOp. + * @brief Helper function object class to convert a StackOp to the stack symbol of that + * StackOp. */ -struct KVOpToStackSymbol { - template - constexpr CUDF_HOST_DEVICE ValueT operator()(KeyValueOp const& kv_op) const +struct StackOpToStackSymbol { + template + constexpr CUDF_HOST_DEVICE ValueT operator()(StackOp const& kv_op) const { return kv_op.value; } @@ -206,53 +206,42 @@ struct KVOpToStackSymbol { /** * @brief Replaces all operations that apply to stack level '0' with the empty stack symbol */ -template +template struct RemapEmptyStack { - constexpr CUDF_HOST_DEVICE KeyValueOpT operator()(KeyValueOpT const& kv_op) const + constexpr CUDF_HOST_DEVICE StackOpT operator()(StackOpT const& kv_op) const { - return kv_op.key == 0 ? empty_stack_symbol : kv_op; + return kv_op.stack_level == 0 ? empty_stack_symbol : kv_op; } - KeyValueOpT empty_stack_symbol; + StackOpT empty_stack_symbol; }; /** - * @brief Function object to return only the key part from a KeyValueOp instance. + * @brief Function object to return only the stack_level part from a StackOp instance. */ -struct KVOpToKey { - template - constexpr CUDF_HOST_DEVICE KeyT operator()(KeyValueOp const& kv_op) const +struct StackOpToStackLevel { + template + constexpr CUDF_HOST_DEVICE StackLevelT operator()(StackOp const& kv_op) const { - return kv_op.key; + return kv_op.stack_level; } }; /** - * @brief Function object to return only the value part from a KeyValueOp instance. + * @brief Retrieves an iterator that returns only the `stack_level` part from a StackOp iterator. */ -struct KVOpToValue { - template - constexpr CUDF_HOST_DEVICE ValueT operator()(KeyValueOp const& kv_op) const - { - return kv_op.value; - } -}; - -/** - * @brief Retrieves an iterator that returns only the `key` part from a KeyValueOp iterator. - */ -template -auto get_key_it(KeyValueOpItT it) +template +auto get_stack_level_it(StackOpItT it) { - return thrust::make_transform_iterator(it, KVOpToKey{}); + return thrust::make_transform_iterator(it, StackOpToStackLevel{}); } /** - * @brief Retrieves an iterator that returns only the `value` part from a KeyValueOp iterator. + * @brief Retrieves an iterator that returns only the `value` part from a StackOp iterator. */ -template -auto get_value_it(KeyValueOpItT it) +template +auto get_value_it(StackOpItT it) { - return thrust::make_transform_iterator(it, KVOpToValue{}); + return thrust::make_transform_iterator(it, StackOpToStackSymbol{}); } } // namespace detail @@ -268,7 +257,7 @@ auto get_value_it(KeyValueOpItT it) * @tparam StackSymbolItT An input iterator type that provides the sequence of symbols that * represent stack operations * @tparam SymbolPositionT The index that this stack operation is supposed to apply to - * @tparam StackSymbolToStackOpT Function object class to transform items from StackSymbolItT to + * @tparam StackSymbolToStackOpTypeT Function object class to transform items from StackSymbolItT to * stack_op_type * @tparam TopOfStackOutItT Output iterator type to which StackSymbolT are being assigned * @tparam StackSymbolT The internal type being used (usually corresponding to StackSymbolItT's @@ -296,14 +285,14 @@ auto get_value_it(KeyValueOpItT it) template void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, StackSymbolItT d_symbols, SymbolPositionT* d_symbol_positions, - StackSymbolToStackOpT symbol_to_stack_op, + StackSymbolToStackOpTypeT symbol_to_stack_op, TopOfStackOutItT d_top_of_stack, StackSymbolT empty_stack_symbol, StackSymbolT read_symbol, @@ -311,49 +300,46 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, OffsetT num_symbols_out, rmm::cuda_stream_view stream = rmm::cuda_stream_default) { - // Type used to hold key-value pairs (key being the stack level and the value being the stack - // symbol) - using KeyValueOpT = detail::KeyValueOp; + // Type used to hold pairs of (stack_level, value) pairs + using StackOpT = detail::StackOp; - // The unsigned integer type that we use for radix sorting items of type KeyValueOpT - using KVOpUnsignedT = detail::UnsignedKeyValueOpType; + // The unsigned integer type that we use for radix sorting items of type StackOpT + using StackOpUnsignedT = detail::UnsignedStackOpType; - // Transforming sequence of stack symbols to key-value store operations, where the key corresponds - // to the stack level of a given stack operation and the value corresponds to the stack symbol of - // that operation - using StackSymbolToKVOpT = detail::StackSymbolToKVOp; + // Transforming sequence of stack symbols to stack operations + using StackSymbolToStackOpT = detail::StackSymbolToStackOp; - // TransformInputIterator converting stack symbols to key-value store operations + // TransformInputIterator converting stack symbols to stack operations using TransformInputItT = - cub::TransformInputIterator; + cub::TransformInputIterator; - // Converting a stack symbol that may either push or pop to a key-value store operation: + // Converting a stack symbol that may either push or pop to a stack operation: // stack_symbol -> ([+1,0,-1], stack_symbol) - StackSymbolToKVOpT stack_sym_to_kv_op{symbol_to_stack_op}; + StackSymbolToStackOpT stack_sym_to_kv_op{symbol_to_stack_op}; TransformInputItT stack_symbols_in(d_symbols, stack_sym_to_kv_op); // Double-buffer for sorting along the given sequence of symbol positions (the sparse // representation) cub::DoubleBuffer d_symbol_positions_db{nullptr, nullptr}; - // Double-buffer for sorting the key-value store operations - cub::DoubleBuffer d_kv_operations{nullptr, nullptr}; + // Double-buffer for sorting the stack operations by the stack level to which such operation applies + cub::DoubleBuffer d_kv_operations{nullptr, nullptr}; // A double-buffer that aliases memory from d_kv_operations with unsigned types in order to // be able to perform a radix sort - cub::DoubleBuffer d_kv_operations_unsigned{nullptr, nullptr}; + cub::DoubleBuffer d_kv_operations_unsigned{nullptr, nullptr}; constexpr std::size_t bits_per_byte = 8; - constexpr std::size_t begin_bit = offsetof(KeyValueOpT, key) * bits_per_byte; - constexpr std::size_t end_bit = begin_bit + (sizeof(KeyValueOpT::key) * bits_per_byte); + constexpr std::size_t begin_bit = offsetof(StackOpT, stack_level) * bits_per_byte; + constexpr std::size_t end_bit = begin_bit + (sizeof(StackOpT::stack_level) * bits_per_byte); - // The key-value store operation that makes sure that reads for stack level '0' will be populated + // The stack operation that makes sure that reads for stack level '0' will be populated // with the empty_stack_symbol - KeyValueOpT const empty_stack{0, empty_stack_symbol}; + StackOpT const empty_stack{0, empty_stack_symbol}; - cub::TransformInputIterator, KeyValueOpT*> - kv_ops_scan_in(nullptr, detail::RemapEmptyStack{empty_stack}); - KeyValueOpT* kv_ops_scan_out = nullptr; + cub::TransformInputIterator, StackOpT*> + kv_ops_scan_in(nullptr, detail::RemapEmptyStack{empty_stack}); + StackOpT* kv_ops_scan_out = nullptr; std::size_t stack_level_scan_bytes = 0; std::size_t stack_level_sort_bytes = 0; @@ -366,7 +352,7 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, stack_level_scan_bytes, stack_symbols_in, d_kv_operations.Current(), - detail::AddStackLevelFromKVOp{}, + detail::AddStackLevelFromStackOp{}, num_symbols_in, stream)); @@ -388,7 +374,7 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, match_level_scan_bytes, kv_ops_scan_in, kv_ops_scan_out, - detail::PopulatePopWithPush{symbol_to_stack_op}, + detail::PopulatePopWithPush{symbol_to_stack_op}, num_symbols_in, stream)); @@ -416,8 +402,8 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, total_temp_storage_bytes = temp_storage.size(); rmm::device_uvector d_symbol_position_alt{num_symbols_in, stream}; - rmm::device_uvector d_kv_ops_current{num_symbols_in, stream}; - rmm::device_uvector d_kv_ops_alt{num_symbols_in, stream}; + rmm::device_uvector d_kv_ops_current{num_symbols_in, stream}; + rmm::device_uvector d_kv_ops_alt{num_symbols_in, stream}; //------------------------------------------------------------------------------ // ALGORITHM @@ -427,29 +413,29 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, cub::DoubleBuffer{d_symbol_positions, d_symbol_position_alt.data()}; // Initialize double-buffer for sorting the indexes of the sequence of sparse stack operations - d_kv_operations = cub::DoubleBuffer{d_kv_ops_current.data(), d_kv_ops_alt.data()}; + d_kv_operations = cub::DoubleBuffer{d_kv_ops_current.data(), d_kv_ops_alt.data()}; // Compute prefix sum of the stack level after each operation CUDA_TRY(cub::DeviceScan::InclusiveScan(temp_storage.data(), total_temp_storage_bytes, stack_symbols_in, d_kv_operations.Current(), - detail::AddStackLevelFromKVOp{}, + detail::AddStackLevelFromStackOp{}, num_symbols_in, stream)); // Dump info on stack operations: (stack level change + symbol) -> (absolute stack level + symbol) test::print::print_array(num_symbols_in, stream, - get_key_it(stack_symbols_in), + get_stack_level_it(stack_symbols_in), get_value_it(stack_symbols_in), - get_key_it(d_kv_operations.Current()), + get_stack_level_it(d_kv_operations.Current()), get_value_it(d_kv_operations.Current())); // Stable radix sort, sorting by stack level of the operations d_kv_operations_unsigned = - cub::DoubleBuffer{reinterpret_cast(d_kv_operations.Current()), - reinterpret_cast(d_kv_operations.Alternate())}; + cub::DoubleBuffer{reinterpret_cast(d_kv_operations.Current()), + reinterpret_cast(d_kv_operations.Alternate())}; CUDA_TRY(cub::DeviceRadixSort::SortPairs(temp_storage.data(), total_temp_storage_bytes, d_kv_operations_unsigned, @@ -460,22 +446,22 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, stream)); // TransformInputIterator that remaps all operations on stack level 0 to the empty stack symbol - kv_ops_scan_in = {reinterpret_cast(d_kv_operations_unsigned.Current()), - detail::RemapEmptyStack{empty_stack}}; - kv_ops_scan_out = reinterpret_cast(d_kv_operations_unsigned.Alternate()); + kv_ops_scan_in = {reinterpret_cast(d_kv_operations_unsigned.Current()), + detail::RemapEmptyStack{empty_stack}}; + kv_ops_scan_out = reinterpret_cast(d_kv_operations_unsigned.Alternate()); // Dump info on stack operations sorted by their stack level (i.e. stack level after applying // operation) test::print::print_array( - num_symbols_in, stream, get_key_it(kv_ops_scan_in), get_value_it(kv_ops_scan_in)); + num_symbols_in, stream, get_stack_level_it(kv_ops_scan_in), get_value_it(kv_ops_scan_in)); - // Exclusive scan to match pop operations with the latest push operation of that level + // Inclusive scan to match pop operations with the latest push operation of that level CUDA_TRY(cub::DeviceScan::InclusiveScan( temp_storage.data(), total_temp_storage_bytes, kv_ops_scan_in, kv_ops_scan_out, - detail::PopulatePopWithPush{symbol_to_stack_op}, + detail::PopulatePopWithPush{symbol_to_stack_op}, num_symbols_in, stream)); @@ -483,9 +469,9 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, // operation) test::print::print_array(num_symbols_in, stream, - get_key_it(kv_ops_scan_in), + get_stack_level_it(kv_ops_scan_in), get_value_it(kv_ops_scan_in), - get_key_it(kv_ops_scan_out), + get_stack_level_it(kv_ops_scan_out), get_value_it(kv_ops_scan_out)); // Fill the output tape with read-symbol @@ -494,9 +480,9 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, thrust::device_ptr{d_top_of_stack + num_symbols_out}, read_symbol); - // Transform the key-value operations to the stack symbol they represent - cub::TransformInputIterator - kv_op_to_stack_sym_it(kv_ops_scan_out, detail::KVOpToStackSymbol{}); + // Transform the stack operations to the stack symbol they represent + cub::TransformInputIterator + kv_op_to_stack_sym_it(kv_ops_scan_out, detail::StackOpToStackSymbol{}); // Scatter the stack symbols to the output tape (spots that are not scattered to have been // pre-filled with the read-symbol) From edb14e2e9c59e83cbf783300d9ca65a441a6c45e Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Mon, 4 Apr 2022 07:35:33 -0700 Subject: [PATCH 10/19] device_span --- cpp/include/cudf_test/print_utilities.cuh | 6 +-- cpp/src/io/fst/logical_stack.cuh | 56 ++++++++++++----------- cpp/tests/io/fst/logical_stack_test.cu | 23 +++++----- 3 files changed, 44 insertions(+), 41 deletions(-) diff --git a/cpp/include/cudf_test/print_utilities.cuh b/cpp/include/cudf_test/print_utilities.cuh index 1da7b9836b1..5c5a42249ac 100644 --- a/cpp/include/cudf_test/print_utilities.cuh +++ b/cpp/include/cudf_test/print_utilities.cuh @@ -42,7 +42,7 @@ using hex_t = TaggedType; * @brief Function object to transform a built-in type to a tagged type (e.g., in order to print * values from an iterator returning uint32_t as hex values) * - * @tparam TaggedTypeT A TaggedType template specialisation + * @tparam TaggedTypeT A TaggedType template specialisation */ template struct ToTaggedType { @@ -55,10 +55,10 @@ struct ToTaggedType { /** * @brief Returns an iterator that causes the values from \p it to be printed as hex values. - * + * * @tparam InItT A random-access input iterator type * @param it A random-access input iterator t - * @return + * @return */ template auto hex(InItT it) diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index bce362beff9..3584f6665c4 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -17,10 +17,11 @@ #include #include +#include #include -#include #include +#include #include #include @@ -51,11 +52,11 @@ namespace detail { * @brief A convenience struct that represents a stack opepration as a pair, where the stack_level * represents the stack's level and the value represents the stack symbol. * - * @tparam StackLevelT The stack level type sufficient to cover all stack levels. Must be signed type as any - * subsequence of stack operations must be able to be covered. E.g., consider the first 10 - * operations are all push and the last 10 operations are all pop operations, we need to be able to - * represent a partial aggregate of the first ten items, which is '+10', just as well as a partial - * aggregate of the last ten items, which is '-10'. + * @tparam StackLevelT The stack level type sufficient to cover all stack levels. Must be signed + * type as any subsequence of stack operations must be able to be covered. E.g., consider the first + * 10 operations are all push and the last 10 operations are all pop operations, we need to be able + * to represent a partial aggregate of the first ten items, which is '+10', just as well as a + * partial aggregate of the last ten items, which is '-10'. * @tparam ValueT The value type that corresponds to the stack symbols (i.e., covers the stack * alphabet). */ @@ -143,10 +144,10 @@ struct AddStackLevelFromStackOp { }; /** - * @brief Binary reduction operator that propagates a write operation for a specific stack level to all - * reads of that same stack level. That is, if the stack level of LHS compares equal to the stack level of the RHS and if - * the RHS is a read and the LHS is a write operation type, then we return LHS, otherwise we return - * the RHS. + * @brief Binary reduction operator that propagates a write operation for a specific stack level to + * all reads of that same stack level. That is, if the stack level of LHS compares equal to the + * stack level of the RHS and if the RHS is a read and the LHS is a write operation type, then we + * return LHS, otherwise we return the RHS. */ template struct PopulatePopWithPush { @@ -159,8 +160,8 @@ struct PopulatePopWithPush { // Whether LHS is a matching write (i.e., the push operation that is on top of the stack for the // RHS's read) - bool is_lhs_matching_write = - (lhs.stack_level == rhs.stack_level) && symbol_to_stack_op_type(lhs.value) == stack_op_type::PUSH; + bool is_lhs_matching_write = (lhs.stack_level == rhs.stack_level) && + symbol_to_stack_op_type(lhs.value) == stack_op_type::PUSH; return (is_rhs_read && is_lhs_matching_write) ? lhs : rhs; } @@ -277,7 +278,6 @@ auto get_value_it(StackOpItT it) * was empty * @param[in] read_symbol A symbol that may not be confused for a symbol that would push to the * stack - * @param[in] num_symbols_in The number of symbols in the sparse representation * @param[in] num_symbols_out The number of symbols that are supposed to be filled with * what-is-on-top-of-the-stack * @param[in] stream The cuda stream to which to dispatch the work @@ -287,17 +287,15 @@ template + typename StackSymbolT> void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, StackSymbolItT d_symbols, - SymbolPositionT* d_symbol_positions, + device_span d_symbol_positions, StackSymbolToStackOpTypeT symbol_to_stack_op, TopOfStackOutItT d_top_of_stack, StackSymbolT empty_stack_symbol, StackSymbolT read_symbol, - OffsetT num_symbols_in, - OffsetT num_symbols_out, + std::size_t num_symbols_out, rmm::cuda_stream_view stream = rmm::cuda_stream_default) { // Type used to hold pairs of (stack_level, value) pairs @@ -313,6 +311,8 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, using TransformInputItT = cub::TransformInputIterator; + auto const num_symbols_in = d_symbol_positions.size(); + // Converting a stack symbol that may either push or pop to a stack operation: // stack_symbol -> ([+1,0,-1], stack_symbol) StackSymbolToStackOpT stack_sym_to_kv_op{symbol_to_stack_op}; @@ -322,7 +322,8 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, // representation) cub::DoubleBuffer d_symbol_positions_db{nullptr, nullptr}; - // Double-buffer for sorting the stack operations by the stack level to which such operation applies + // Double-buffer for sorting the stack operations by the stack level to which such operation + // applies cub::DoubleBuffer d_kv_operations{nullptr, nullptr}; // A double-buffer that aliases memory from d_kv_operations with unsigned types in order to @@ -391,14 +392,15 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, // Scratch memory required by the algorithms auto total_temp_storage_bytes = std::max({stack_level_scan_bytes, - stack_level_sort_bytes, - match_level_scan_bytes, - propagate_writes_scan_bytes}); + stack_level_sort_bytes, + match_level_scan_bytes, + propagate_writes_scan_bytes}); if (temp_storage.size() < total_temp_storage_bytes) { temp_storage.resize(total_temp_storage_bytes, stream); } - // Actual device buffer size, as we need to pass in an lvalue-ref to cub algorithms as temp_storage_bytes + // Actual device buffer size, as we need to pass in an lvalue-ref to cub algorithms as + // temp_storage_bytes total_temp_storage_bytes = temp_storage.size(); rmm::device_uvector d_symbol_position_alt{num_symbols_in, stream}; @@ -410,7 +412,7 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, //------------------------------------------------------------------------------ // Initialize double-buffer for sorting the indexes of the sequence of sparse stack operations d_symbol_positions_db = - cub::DoubleBuffer{d_symbol_positions, d_symbol_position_alt.data()}; + cub::DoubleBuffer{d_symbol_positions.data(), d_symbol_position_alt.data()}; // Initialize double-buffer for sorting the indexes of the sequence of sparse stack operations d_kv_operations = cub::DoubleBuffer{d_kv_ops_current.data(), d_kv_ops_alt.data()}; @@ -433,9 +435,9 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, get_value_it(d_kv_operations.Current())); // Stable radix sort, sorting by stack level of the operations - d_kv_operations_unsigned = - cub::DoubleBuffer{reinterpret_cast(d_kv_operations.Current()), - reinterpret_cast(d_kv_operations.Alternate())}; + d_kv_operations_unsigned = cub::DoubleBuffer{ + reinterpret_cast(d_kv_operations.Current()), + reinterpret_cast(d_kv_operations.Alternate())}; CUDA_TRY(cub::DeviceRadixSort::SortPairs(temp_storage.data(), total_temp_storage_bytes, d_kv_operations_unsigned, diff --git a/cpp/tests/io/fst/logical_stack_test.cu b/cpp/tests/io/fst/logical_stack_test.cu index 3b860867cf2..7f7d72c0db3 100644 --- a/cpp/tests/io/fst/logical_stack_test.cu +++ b/cpp/tests/io/fst/logical_stack_test.cu @@ -22,8 +22,8 @@ #include #include -#include #include +#include #include #include @@ -202,6 +202,7 @@ TEST_F(LogicalStackTest, GroundTruth) rmm::device_uvector d_stack_ops(stack_symbols.size(), stream_view); rmm::device_uvector d_stack_op_indexes(stack_op_indexes.size(), stream_view); auto top_of_stack_gpu = hostdevice_vector(input.size(), stream_view); + cudf::device_span d_stack_op_idx_span{d_stack_op_indexes.data(), d_stack_op_indexes.size()}; cudaMemcpyAsync(d_stack_ops.data(), stack_symbols.data(), @@ -224,16 +225,16 @@ TEST_F(LogicalStackTest, GroundTruth) rmm::device_buffer d_temp_storage{}; // Run algorithm - fst::SparseStackOpToTopOfStack(d_temp_storage, - d_stack_ops.data(), - d_stack_op_indexes.data(), - JSONToStackOp{}, - top_of_stack_gpu.device_ptr(), - empty_stack_symbol, - read_symbol, - num_stack_ops, - string_size, - stream); + fst::SparseStackOpToTopOfStack( + d_temp_storage, + d_stack_ops.data(), + d_stack_op_idx_span, + JSONToStackOp{}, + top_of_stack_gpu.device_ptr(), + empty_stack_symbol, + read_symbol, + string_size, + stream); // Async copy results from device to host top_of_stack_gpu.device_to_host(stream_view); From 56f64cbf482fcb9cf4461f244eadc03763b1ca35 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Wed, 6 Apr 2022 07:44:44 -0700 Subject: [PATCH 11/19] addressing review comments & minor cleanups --- cpp/src/io/fst/logical_stack.cuh | 5 --- cpp/tests/io/fst/logical_stack_test.cu | 58 +++++++++++++------------- 2 files changed, 29 insertions(+), 34 deletions(-) diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index 3584f6665c4..93f1a9ac09f 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -75,11 +75,6 @@ template struct StackOpToUnsigned { }; -template <> -struct StackOpToUnsigned<1U> { - using UnsignedT = uint8_t; -}; - template <> struct StackOpToUnsigned<2U> { using UnsignedT = uint16_t; diff --git a/cpp/tests/io/fst/logical_stack_test.cu b/cpp/tests/io/fst/logical_stack_test.cu index 7f7d72c0db3..6f0535a6c77 100644 --- a/cpp/tests/io/fst/logical_stack_test.cu +++ b/cpp/tests/io/fst/logical_stack_test.cu @@ -78,7 +78,7 @@ std::pair to_sparse_stack_symbols( /** * @brief Reads in a sequence of items that represent stack operations, applies these operations to - * a stack, and, for every oepration being read in, outputs what was the symbol on top of the stack + * a stack, and, for every operation being read in, outputs what was the symbol on top of the stack * before the operations was applied. In case the stack is empty before any operation, * \p empty_stack will be output instead. * @@ -91,7 +91,8 @@ std::pair to_sparse_stack_symbols( * @param[in] end Iterator to one past the last item representing the stack operation * @param[in] to_stack_op A function object that takes an instance of InputItT's value type and * returns the kind of stack operation such item represents (i.e., of type stack_op_type) - * @param[in] empty_stack A symbol that will be written to top_of_stack whenever the stack was empty + * @param[in] empty_stack A symbol that will be written to top_of_stack_out_it whenever the stack + * was empty * @param[out] top_of_stack The output iterator to which the item will be written to * @return TopOfStackOutItT Iterators to one past the last element that was written */ @@ -103,13 +104,15 @@ TopOfStackOutItT to_top_of_stack(InputItT begin, InputItT end, ToStackOpTypeT to_stack_op, StackSymbolT empty_stack, - TopOfStackOutItT top_of_stack) + TopOfStackOutItT top_of_stack_out_it) { - std::stack stack; + // This is the data structure that keeps track of the full stack state for each input symbol + std::stack stack_state; + for (auto it = begin; it < end; it++) { // Write what is currently on top of the stack when reading in the current symbol - *top_of_stack = stack.empty() ? empty_stack : stack.top(); - top_of_stack++; + *top_of_stack_out_it = stack_state.empty() ? empty_stack : stack_state.top(); + top_of_stack_out_it++; auto const& current = *it; fst::stack_op_type op_type = to_stack_op(current); @@ -117,12 +120,12 @@ TopOfStackOutItT to_top_of_stack(InputItT begin, // Check whether this symbol corresponds to a push or pop operation and modify the stack // accordingly if (op_type == fst::stack_op_type::PUSH) { - stack.push(current); + stack_state.push(current); } else if (op_type == fst::stack_op_type::POP) { - stack.pop(); + stack_state.pop(); } } - return top_of_stack; + return top_of_stack_out_it; } /** @@ -155,8 +158,7 @@ TEST_F(LogicalStackTest, GroundTruth) // The stack symbol that we'll fill everywhere where there's nothing on the stack constexpr SymbolT empty_stack_symbol = '_'; - // This just has to be a stack symbol that may not be confused with a symbol that would push or - // pop + // This just has to be a stack symbol that may not be confused with a symbol that would push constexpr SymbolT read_symbol = 'x'; // Prepare cuda stream for data transfers & kernels @@ -185,7 +187,7 @@ TEST_F(LogicalStackTest, GroundTruth) input += input; // Getting the symbols that actually modify the stack (i.e., symbols that push or pop) - std::string stack_symbols = ""; + std::string stack_symbols{}; std::vector stack_op_indexes; stack_op_indexes.reserve(input.size()); @@ -196,13 +198,11 @@ TEST_F(LogicalStackTest, GroundTruth) std::back_inserter(stack_symbols), std::back_inserter(stack_op_indexes)); - // Prepare sparse stack ops - std::size_t num_stack_ops = stack_symbols.size(); - - rmm::device_uvector d_stack_ops(stack_symbols.size(), stream_view); - rmm::device_uvector d_stack_op_indexes(stack_op_indexes.size(), stream_view); - auto top_of_stack_gpu = hostdevice_vector(input.size(), stream_view); - cudf::device_span d_stack_op_idx_span{d_stack_op_indexes.data(), d_stack_op_indexes.size()}; + rmm::device_uvector d_stack_ops{stack_symbols.size(), stream_view}; + rmm::device_uvector d_stack_op_indexes{stack_op_indexes.size(), stream_view}; + hostdevice_vector top_of_stack_gpu{input.size(), stream_view}; + cudf::device_span d_stack_op_idx_span{d_stack_op_indexes.data(), + d_stack_op_indexes.size()}; cudaMemcpyAsync(d_stack_ops.data(), stack_symbols.data(), @@ -225,16 +225,15 @@ TEST_F(LogicalStackTest, GroundTruth) rmm::device_buffer d_temp_storage{}; // Run algorithm - fst::SparseStackOpToTopOfStack( - d_temp_storage, - d_stack_ops.data(), - d_stack_op_idx_span, - JSONToStackOp{}, - top_of_stack_gpu.device_ptr(), - empty_stack_symbol, - read_symbol, - string_size, - stream); + fst::SparseStackOpToTopOfStack(d_temp_storage, + d_stack_ops.data(), + d_stack_op_idx_span, + JSONToStackOp{}, + top_of_stack_gpu.device_ptr(), + empty_stack_symbol, + read_symbol, + string_size, + stream); // Async copy results from device to host top_of_stack_gpu.device_to_host(stream_view); @@ -253,6 +252,7 @@ TEST_F(LogicalStackTest, GroundTruth) // Verify results ASSERT_EQ(input.size(), top_of_stack_cpu.size()); + ASSERT_EQ(top_of_stack_gpu.size(), top_of_stack_cpu.size()); for (size_t i = 0; i < input.size() && i < top_of_stack_cpu.size(); i++) { ASSERT_EQ(top_of_stack_gpu.host_ptr()[i], top_of_stack_cpu[i]) << "Mismatch at index #" << i; } From 829ee1b0f709d4177fed3ea79ee9cb25e1e36556 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Wed, 6 Apr 2022 22:19:03 -0700 Subject: [PATCH 12/19] error on unsupported unsigned_t and fixed typos --- cpp/src/io/fst/logical_stack.cuh | 10 ++++++---- cpp/tests/io/fst/logical_stack_test.cu | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index 93f1a9ac09f..d84a8c8fc80 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -49,7 +49,7 @@ enum class stack_op_type : int32_t { namespace detail { /** - * @brief A convenience struct that represents a stack opepration as a pair, where the stack_level + * @brief A convenience struct that represents a stack operation as a pair, where the stack_level * represents the stack's level and the value represents the stack symbol. * * @tparam StackLevelT The stack level type sufficient to cover all stack levels. Must be signed @@ -73,6 +73,7 @@ struct StackOp { */ template struct StackOpToUnsigned { + using UnsignedT = void; }; template <> @@ -166,8 +167,8 @@ struct PopulatePopWithPush { }; /** - * @brief Binary reduction operator that is used to replace each read_symbol occurance with the last - * non-read_symbol that precedes such read_symbol. + * @brief Binary reduction operator that is used to replace each read_symbol occurrence with the + * last non-read_symbol that precedes such read_symbol. */ template struct PropagateLastWrite { @@ -298,6 +299,7 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, // The unsigned integer type that we use for radix sorting items of type StackOpT using StackOpUnsignedT = detail::UnsignedStackOpType; + static_assert(!std::is_void(), "unsupported StackOpT size"); // Transforming sequence of stack symbols to stack operations using StackSymbolToStackOpT = detail::StackSymbolToStackOp; @@ -495,7 +497,7 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, std::min(num_symbols_in, static_cast(10000)), stream, d_top_of_stack); // We perform an exclusive scan in order to fill the items at the very left that may - // be reading the empty stack before there's the first push occurance in the sequence. + // be reading the empty stack before there's the first push occurrence in the sequence. // Also, we're interested in the top-of-the-stack symbol before the operation was applied. CUDA_TRY(cub::DeviceScan::ExclusiveScan(temp_storage.data(), total_temp_storage_bytes, diff --git a/cpp/tests/io/fst/logical_stack_test.cu b/cpp/tests/io/fst/logical_stack_test.cu index 6f0535a6c77..f690a8497df 100644 --- a/cpp/tests/io/fst/logical_stack_test.cu +++ b/cpp/tests/io/fst/logical_stack_test.cu @@ -84,7 +84,7 @@ std::pair to_sparse_stack_symbols( * * @tparam InputItT Forward input iterator type to items representing stack operations * @tparam ToStackOpTypeT A transform function object class that maps an item representing a stack - * oepration to the stack_op_type of such item + * operation to the stack_op_type of such item * @tparam StackSymbolT Type representing items being pushed onto the stack * @tparam TopOfStackOutItT A forward output iterator type being assigned items of StackSymbolT * @param[in] begin Forward iterator to the beginning of the items representing stack operations @@ -129,7 +129,7 @@ TopOfStackOutItT to_top_of_stack(InputItT begin, } /** - * @brief Funciton object used to filter for brackets and braces that represent push and pop + * @brief Function object used to filter for brackets and braces that represent push and pop * operations * */ From 508497430675d81cde12e0b57f30c1f4975a0d8f Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Tue, 12 Apr 2022 22:55:00 -0700 Subject: [PATCH 13/19] minor style changes addressing review comments --- cpp/src/io/fst/logical_stack.cuh | 143 +++++++++++++------------ cpp/tests/io/fst/logical_stack_test.cu | 48 ++++----- 2 files changed, 96 insertions(+), 95 deletions(-) diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index d84a8c8fc80..412e85204fe 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -227,7 +227,7 @@ struct StackOpToStackLevel { * @brief Retrieves an iterator that returns only the `stack_level` part from a StackOp iterator. */ template -auto get_stack_level_it(StackOpItT it) +auto get_stack_level_iterator(StackOpItT it) { return thrust::make_transform_iterator(it, StackOpToStackLevel{}); } @@ -236,7 +236,7 @@ auto get_stack_level_it(StackOpItT it) * @brief Retrieves an iterator that returns only the `value` part from a StackOp iterator. */ template -auto get_value_it(StackOpItT it) +auto get_value_iterator(StackOpItT it) { return thrust::make_transform_iterator(it, StackOpToStackSymbol{}); } @@ -284,16 +284,17 @@ template -void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, - StackSymbolItT d_symbols, - device_span d_symbol_positions, - StackSymbolToStackOpTypeT symbol_to_stack_op, - TopOfStackOutItT d_top_of_stack, - StackSymbolT empty_stack_symbol, - StackSymbolT read_symbol, - std::size_t num_symbols_out, - rmm::cuda_stream_view stream = rmm::cuda_stream_default) +void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols, + device_span d_symbol_positions, + StackSymbolToStackOpTypeT symbol_to_stack_op, + TopOfStackOutItT d_top_of_stack, + StackSymbolT const empty_stack_symbol, + StackSymbolT const read_symbol, + std::size_t const num_symbols_out, + rmm::cuda_stream_view stream = rmm::cuda_stream_default) { + rmm::device_buffer temp_storage{}; + // Type used to hold pairs of (stack_level, value) pairs using StackOpT = detail::StackOp; @@ -346,28 +347,28 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, // Getting temporary storage requirements for the prefix sum of the stack level after each // operation - CUDA_TRY(cub::DeviceScan::InclusiveScan(nullptr, - stack_level_scan_bytes, - stack_symbols_in, - d_kv_operations.Current(), - detail::AddStackLevelFromStackOp{}, - num_symbols_in, - stream)); + CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScan(nullptr, + stack_level_scan_bytes, + stack_symbols_in, + d_kv_operations.Current(), + detail::AddStackLevelFromStackOp{}, + num_symbols_in, + stream)); // Getting temporary storage requirements for the stable radix sort (sorting by stack level of the // operations) - CUDA_TRY(cub::DeviceRadixSort::SortPairs(nullptr, - stack_level_sort_bytes, - d_kv_operations_unsigned, - d_symbol_positions_db, - num_symbols_in, - begin_bit, - end_bit, - stream)); + CUDF_CUDA_TRY(cub::DeviceRadixSort::SortPairs(nullptr, + stack_level_sort_bytes, + d_kv_operations_unsigned, + d_symbol_positions_db, + num_symbols_in, + begin_bit, + end_bit, + stream)); // Getting temporary storage requirements for the scan to match pop operations with the latest // push of the same level - CUDA_TRY(cub::DeviceScan::InclusiveScan( + CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScan( nullptr, match_level_scan_bytes, kv_ops_scan_in, @@ -378,14 +379,15 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, // Getting temporary storage requirements for the scan to propagate top-of-stack for spots that // didn't push or pop - CUDA_TRY(cub::DeviceScan::ExclusiveScan(nullptr, - propagate_writes_scan_bytes, - d_top_of_stack, - d_top_of_stack, - detail::PropagateLastWrite{read_symbol}, - empty_stack_symbol, - num_symbols_out, - stream)); + CUDF_CUDA_TRY( + cub::DeviceScan::ExclusiveScan(nullptr, + propagate_writes_scan_bytes, + d_top_of_stack, + d_top_of_stack, + detail::PropagateLastWrite{read_symbol}, + empty_stack_symbol, + num_symbols_out, + stream)); // Scratch memory required by the algorithms auto total_temp_storage_bytes = std::max({stack_level_scan_bytes, @@ -415,34 +417,34 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, d_kv_operations = cub::DoubleBuffer{d_kv_ops_current.data(), d_kv_ops_alt.data()}; // Compute prefix sum of the stack level after each operation - CUDA_TRY(cub::DeviceScan::InclusiveScan(temp_storage.data(), - total_temp_storage_bytes, - stack_symbols_in, - d_kv_operations.Current(), - detail::AddStackLevelFromStackOp{}, - num_symbols_in, - stream)); + CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScan(temp_storage.data(), + total_temp_storage_bytes, + stack_symbols_in, + d_kv_operations.Current(), + detail::AddStackLevelFromStackOp{}, + num_symbols_in, + stream)); // Dump info on stack operations: (stack level change + symbol) -> (absolute stack level + symbol) test::print::print_array(num_symbols_in, stream, - get_stack_level_it(stack_symbols_in), - get_value_it(stack_symbols_in), - get_stack_level_it(d_kv_operations.Current()), - get_value_it(d_kv_operations.Current())); + get_stack_level_iterator(stack_symbols_in), + get_value_iterator(stack_symbols_in), + get_stack_level_iterator(d_kv_operations.Current()), + get_value_iterator(d_kv_operations.Current())); // Stable radix sort, sorting by stack level of the operations d_kv_operations_unsigned = cub::DoubleBuffer{ reinterpret_cast(d_kv_operations.Current()), reinterpret_cast(d_kv_operations.Alternate())}; - CUDA_TRY(cub::DeviceRadixSort::SortPairs(temp_storage.data(), - total_temp_storage_bytes, - d_kv_operations_unsigned, - d_symbol_positions_db, - num_symbols_in, - begin_bit, - end_bit, - stream)); + CUDF_CUDA_TRY(cub::DeviceRadixSort::SortPairs(temp_storage.data(), + total_temp_storage_bytes, + d_kv_operations_unsigned, + d_symbol_positions_db, + num_symbols_in, + begin_bit, + end_bit, + stream)); // TransformInputIterator that remaps all operations on stack level 0 to the empty stack symbol kv_ops_scan_in = {reinterpret_cast(d_kv_operations_unsigned.Current()), @@ -451,11 +453,13 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, // Dump info on stack operations sorted by their stack level (i.e. stack level after applying // operation) - test::print::print_array( - num_symbols_in, stream, get_stack_level_it(kv_ops_scan_in), get_value_it(kv_ops_scan_in)); + test::print::print_array(num_symbols_in, + stream, + get_stack_level_iterator(kv_ops_scan_in), + get_value_iterator(kv_ops_scan_in)); // Inclusive scan to match pop operations with the latest push operation of that level - CUDA_TRY(cub::DeviceScan::InclusiveScan( + CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScan( temp_storage.data(), total_temp_storage_bytes, kv_ops_scan_in, @@ -468,10 +472,10 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, // operation) test::print::print_array(num_symbols_in, stream, - get_stack_level_it(kv_ops_scan_in), - get_value_it(kv_ops_scan_in), - get_stack_level_it(kv_ops_scan_out), - get_value_it(kv_ops_scan_out)); + get_stack_level_iterator(kv_ops_scan_in), + get_value_iterator(kv_ops_scan_in), + get_stack_level_iterator(kv_ops_scan_out), + get_value_iterator(kv_ops_scan_out)); // Fill the output tape with read-symbol thrust::fill(thrust::cuda::par.on(stream), @@ -499,14 +503,15 @@ void SparseStackOpToTopOfStack(rmm::device_buffer& temp_storage, // We perform an exclusive scan in order to fill the items at the very left that may // be reading the empty stack before there's the first push occurrence in the sequence. // Also, we're interested in the top-of-the-stack symbol before the operation was applied. - CUDA_TRY(cub::DeviceScan::ExclusiveScan(temp_storage.data(), - total_temp_storage_bytes, - d_top_of_stack, - d_top_of_stack, - detail::PropagateLastWrite{read_symbol}, - empty_stack_symbol, - num_symbols_out, - stream)); + CUDF_CUDA_TRY( + cub::DeviceScan::ExclusiveScan(temp_storage.data(), + total_temp_storage_bytes, + d_top_of_stack, + d_top_of_stack, + detail::PropagateLastWrite{read_symbol}, + empty_stack_symbol, + num_symbols_out, + stream)); // Dump the final output test::print::print_array( diff --git a/cpp/tests/io/fst/logical_stack_test.cu b/cpp/tests/io/fst/logical_stack_test.cu index f690a8497df..87a5bb69b2c 100644 --- a/cpp/tests/io/fst/logical_stack_test.cu +++ b/cpp/tests/io/fst/logical_stack_test.cu @@ -167,20 +167,20 @@ TEST_F(LogicalStackTest, GroundTruth) rmm::cuda_stream_view stream_view(stream); // Test input, - std::string input = R"( { -"category": "reference", -"index:" [4,12,42], -"author": "Nigel Rees", -"title": "Sayings of the Century", -"price": 8.95 -} -{ -"category": "reference", -"index:" [4,{},null,{"a":[]}], -"author": "Nigel Rees", -"title": "Sayings of the Century", -"price": 8.95 -} {} [] [ ])"; + std::string input = R"( {)" + R"(category": "reference",)" + R"("index:" [4,12,42],)" + R"("author": "Nigel Rees",)" + R"("title": "Sayings of the Century",)" + R"("price": 8.95)" + R"(} )" + R"({)" + R"("category": "reference",)" + R"("index:" [4,{},null,{"a":[]}],)" + R"("author": "Nigel Rees",)" + R"("title": "Sayings of the Century",)" + R"("price": 8.95)" + R"(} {} [] [ ])"; // Repeat input sample 1024x for (std::size_t i = 0; i < 10; i++) @@ -221,19 +221,15 @@ TEST_F(LogicalStackTest, GroundTruth) SymbolT* d_top_of_stack = nullptr; cudaMalloc(&d_top_of_stack, string_size); - // Allocate temporary storage required by the get-top-of-the-stack algorithm - rmm::device_buffer d_temp_storage{}; - // Run algorithm - fst::SparseStackOpToTopOfStack(d_temp_storage, - d_stack_ops.data(), - d_stack_op_idx_span, - JSONToStackOp{}, - top_of_stack_gpu.device_ptr(), - empty_stack_symbol, - read_symbol, - string_size, - stream); + fst::sparse_stack_op_to_top_of_stack(d_stack_ops.data(), + d_stack_op_idx_span, + JSONToStackOp{}, + top_of_stack_gpu.device_ptr(), + empty_stack_symbol, + read_symbol, + string_size, + stream); // Async copy results from device to host top_of_stack_gpu.device_to_host(stream_view); From 8d986056ee9457c6fe3596b39ea0439c96d33dbf Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Wed, 8 Jun 2022 13:48:19 -0700 Subject: [PATCH 14/19] addresses review comments on print utils --- cpp/include/cudf_test/print_utilities.cuh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/include/cudf_test/print_utilities.cuh b/cpp/include/cudf_test/print_utilities.cuh index 5c5a42249ac..2dc9bdbc3a3 100644 --- a/cpp/include/cudf_test/print_utilities.cuh +++ b/cpp/include/cudf_test/print_utilities.cuh @@ -19,8 +19,9 @@ #include #include -#include "cub/util_type.cuh" -#include +#include + +#include #include @@ -65,8 +66,7 @@ auto hex(InItT it) { using value_t = typename std::iterator_traits::value_type; using tagged_t = hex_t; - return cub::TransformInputIterator, InItT>( - it, ToTaggedType{}); + return thrust::make_transform_iterator(it, ToTaggedType{}); } template && std::is_signed_v)> @@ -125,7 +125,7 @@ __global__ void print_array_kernel(std::size_t count, int32_t width, char delimi * @param args List of iterators to be printed */ template -void print_array(std::size_t count, cudaStream_t stream, Ts... args) +void print_array(std::size_t count, rmm::cuda_stream_view stream, Ts... args) { // The width to pad printed numbers to constexpr int32_t width = 6; @@ -135,7 +135,7 @@ void print_array(std::size_t count, cudaStream_t stream, Ts... args) // TODO we want this to compile to nothing dependnig on compiler flag, rather than runtime if (std::getenv("CUDA_DBG_DUMP") != nullptr) { - detail::print_array_kernel<<<1, 1, 0, stream>>>(count, width, delimiter, args...); + detail::print_array_kernel<<<1, 1, 0, stream.value()>>>(count, width, delimiter, args...); } } From f721b8c3dff1514594af4954b235d3598ac514c5 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Wed, 8 Jun 2022 13:49:19 -0700 Subject: [PATCH 15/19] addresses review comments on logical stack --- cpp/src/io/fst/logical_stack.cuh | 36 +++++++++++++++++++------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index 412e85204fe..145441afea8 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -16,12 +16,14 @@ #pragma once #include +#include #include #include #include #include #include +#include #include #include @@ -32,6 +34,7 @@ #include #include +#include namespace cudf { namespace io { @@ -62,6 +65,9 @@ namespace detail { */ template struct StackOp { + // Must be signed type as any subsequence of stack operations must be able to be covered. + static_assert(std::is_signed_v, "StackLevelT has to be a signed type"); + StackLevelT stack_level; ValueT value; }; @@ -227,7 +233,7 @@ struct StackOpToStackLevel { * @brief Retrieves an iterator that returns only the `stack_level` part from a StackOp iterator. */ template -auto get_stack_level_iterator(StackOpItT it) +auto make_stack_level_iterator(StackOpItT it) { return thrust::make_transform_iterator(it, StackOpToStackLevel{}); } @@ -236,7 +242,7 @@ auto get_stack_level_iterator(StackOpItT it) * @brief Retrieves an iterator that returns only the `value` part from a StackOp iterator. */ template -auto get_value_iterator(StackOpItT it) +auto make_value_iterator(StackOpItT it) { return thrust::make_transform_iterator(it, StackOpToStackSymbol{}); } @@ -291,7 +297,7 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols, StackSymbolT const empty_stack_symbol, StackSymbolT const read_symbol, std::size_t const num_symbols_out, - rmm::cuda_stream_view stream = rmm::cuda_stream_default) + rmm::cuda_stream_view stream = cudf::default_stream_value) { rmm::device_buffer temp_storage{}; @@ -428,10 +434,10 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols, // Dump info on stack operations: (stack level change + symbol) -> (absolute stack level + symbol) test::print::print_array(num_symbols_in, stream, - get_stack_level_iterator(stack_symbols_in), - get_value_iterator(stack_symbols_in), - get_stack_level_iterator(d_kv_operations.Current()), - get_value_iterator(d_kv_operations.Current())); + make_stack_level_iterator(stack_symbols_in), + make_value_iterator(stack_symbols_in), + make_stack_level_iterator(d_kv_operations.Current()), + make_value_iterator(d_kv_operations.Current())); // Stable radix sort, sorting by stack level of the operations d_kv_operations_unsigned = cub::DoubleBuffer{ @@ -455,8 +461,8 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols, // operation) test::print::print_array(num_symbols_in, stream, - get_stack_level_iterator(kv_ops_scan_in), - get_value_iterator(kv_ops_scan_in)); + make_stack_level_iterator(kv_ops_scan_in), + make_value_iterator(kv_ops_scan_in)); // Inclusive scan to match pop operations with the latest push operation of that level CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScan( @@ -472,13 +478,13 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols, // operation) test::print::print_array(num_symbols_in, stream, - get_stack_level_iterator(kv_ops_scan_in), - get_value_iterator(kv_ops_scan_in), - get_stack_level_iterator(kv_ops_scan_out), - get_value_iterator(kv_ops_scan_out)); + make_stack_level_iterator(kv_ops_scan_in), + make_value_iterator(kv_ops_scan_in), + make_stack_level_iterator(kv_ops_scan_out), + make_value_iterator(kv_ops_scan_out)); // Fill the output tape with read-symbol - thrust::fill(thrust::cuda::par.on(stream), + thrust::fill(rmm::exec_policy(stream), thrust::device_ptr{d_top_of_stack}, thrust::device_ptr{d_top_of_stack + num_symbols_out}, read_symbol); @@ -489,7 +495,7 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols, // Scatter the stack symbols to the output tape (spots that are not scattered to have been // pre-filled with the read-symbol) - thrust::scatter(thrust::cuda::par.on(stream), + thrust::scatter(rmm::exec_policy(stream), kv_op_to_stack_sym_it, kv_op_to_stack_sym_it + num_symbols_in, d_symbol_positions_db.Current(), From 461660ba4f463bbf75339e8ddefe755d97327983 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Wed, 8 Jun 2022 14:00:20 -0700 Subject: [PATCH 16/19] adds empty line in doxygen after tparam --- cpp/src/io/fst/logical_stack.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index 145441afea8..cac407bf0e6 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -267,6 +267,7 @@ auto make_value_iterator(StackOpItT it) * value_type) * @tparam OffsetT Signed or unsigned integer type large enough to index into both the sparse input * sequence and the top-of-stack output sequence + * * @param[in] d_symbols Sequence of symbols that represent stack operations. Memory may alias with * \p d_top_of_stack * @param[in,out] d_symbol_positions Sequence of symbol positions (for a sparse representation), From bd6fa967fb6eed8819b00a68cb721e48830bfc88 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Wed, 15 Jun 2022 01:52:30 -0700 Subject: [PATCH 17/19] addresses gh review comments --- cpp/src/io/fst/logical_stack.cuh | 2 +- cpp/tests/io/fst/logical_stack_test.cu | 34 ++++++++++++-------------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index cac407bf0e6..f61b8cbbecf 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -43,7 +43,7 @@ namespace fst { /** * @brief Describes the kind of stack operation. */ -enum class stack_op_type : int32_t { +enum class stack_op_type : int8_t { READ = 0, ///< Operation reading what is currently on top of the stack PUSH = 1, ///< Operation pushing a new item on top of the stack POP = 2 ///< Operation popping the item currently on top of the stack diff --git a/cpp/tests/io/fst/logical_stack_test.cu b/cpp/tests/io/fst/logical_stack_test.cu index 87a5bb69b2c..3c2cdd7fb5c 100644 --- a/cpp/tests/io/fst/logical_stack_test.cu +++ b/cpp/tests/io/fst/logical_stack_test.cu @@ -15,12 +15,12 @@ */ #include -#include #include #include #include +#include #include #include #include @@ -162,8 +162,7 @@ TEST_F(LogicalStackTest, GroundTruth) constexpr SymbolT read_symbol = 'x'; // Prepare cuda stream for data transfers & kernels - cudaStream_t stream = nullptr; - cudaStreamCreate(&stream); + rmm::cuda_stream stream{}; rmm::cuda_stream_view stream_view(stream); // Test input, @@ -186,10 +185,13 @@ TEST_F(LogicalStackTest, GroundTruth) for (std::size_t i = 0; i < 10; i++) input += input; + // Input's size + std::size_t string_size = input.size(); + // Getting the symbols that actually modify the stack (i.e., symbols that push or pop) std::string stack_symbols{}; std::vector stack_op_indexes; - stack_op_indexes.reserve(input.size()); + stack_op_indexes.reserve(string_size); // Get the sparse representation of stack operations to_sparse_stack_symbols(std::cbegin(input), @@ -200,26 +202,20 @@ TEST_F(LogicalStackTest, GroundTruth) rmm::device_uvector d_stack_ops{stack_symbols.size(), stream_view}; rmm::device_uvector d_stack_op_indexes{stack_op_indexes.size(), stream_view}; - hostdevice_vector top_of_stack_gpu{input.size(), stream_view}; - cudf::device_span d_stack_op_idx_span{d_stack_op_indexes.data(), - d_stack_op_indexes.size()}; + hostdevice_vector top_of_stack_gpu{string_size, stream_view}; + cudf::device_span d_stack_op_idx_span{d_stack_op_indexes}; cudaMemcpyAsync(d_stack_ops.data(), stack_symbols.data(), stack_symbols.size() * sizeof(SymbolT), cudaMemcpyHostToDevice, - stream); + stream.value()); cudaMemcpyAsync(d_stack_op_indexes.data(), stack_op_indexes.data(), stack_op_indexes.size() * sizeof(SymbolOffsetT), cudaMemcpyHostToDevice, - stream); - - // Prepare output - std::size_t string_size = input.size(); - SymbolT* d_top_of_stack = nullptr; - cudaMalloc(&d_top_of_stack, string_size); + stream.value()); // Run algorithm fst::sparse_stack_op_to_top_of_stack(d_stack_ops.data(), @@ -229,14 +225,14 @@ TEST_F(LogicalStackTest, GroundTruth) empty_stack_symbol, read_symbol, string_size, - stream); + stream.value()); // Async copy results from device to host top_of_stack_gpu.device_to_host(stream_view); // Get CPU-side results for verification std::string top_of_stack_cpu{}; - top_of_stack_cpu.reserve(input.size()); + top_of_stack_cpu.reserve(string_size); to_top_of_stack(std::cbegin(input), std::cend(input), JSONToStackOp{}, @@ -244,12 +240,12 @@ TEST_F(LogicalStackTest, GroundTruth) std::back_inserter(top_of_stack_cpu)); // Make sure results have been copied back to host - cudaStreamSynchronize(stream); + stream.synchronize(); // Verify results - ASSERT_EQ(input.size(), top_of_stack_cpu.size()); + ASSERT_EQ(string_size, top_of_stack_cpu.size()); ASSERT_EQ(top_of_stack_gpu.size(), top_of_stack_cpu.size()); - for (size_t i = 0; i < input.size() && i < top_of_stack_cpu.size(); i++) { + for (size_t i = 0; i < string_size && i < top_of_stack_cpu.size(); i++) { ASSERT_EQ(top_of_stack_gpu.host_ptr()[i], top_of_stack_cpu[i]) << "Mismatch at index #" << i; } } From f5fb111760a6f8493f5d4ca1a5bf93bf5e40106e Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Tue, 5 Jul 2022 10:01:50 -0700 Subject: [PATCH 18/19] removes debug print and uses c++17 namespace style --- cpp/include/cudf_test/print_utilities.cuh | 8 ++--- cpp/src/io/fst/logical_stack.cuh | 41 ++--------------------- 2 files changed, 4 insertions(+), 45 deletions(-) diff --git a/cpp/include/cudf_test/print_utilities.cuh b/cpp/include/cudf_test/print_utilities.cuh index 2dc9bdbc3a3..37ffcd401fc 100644 --- a/cpp/include/cudf_test/print_utilities.cuh +++ b/cpp/include/cudf_test/print_utilities.cuh @@ -25,9 +25,7 @@ #include -namespace cudf { -namespace test { -namespace print { +namespace cudf::test::print { constexpr int32_t hex_tag = 0; @@ -139,6 +137,4 @@ void print_array(std::size_t count, rmm::cuda_stream_view stream, Ts... args) } } -} // namespace print -} // namespace test -} // namespace cudf +} // namespace cudf::test::print diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index f61b8cbbecf..64e08d9edfa 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -36,9 +36,7 @@ #include #include -namespace cudf { -namespace io { -namespace fst { +namespace cudf::io::fst { /** * @brief Describes the kind of stack operation. @@ -432,14 +430,6 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols, num_symbols_in, stream)); - // Dump info on stack operations: (stack level change + symbol) -> (absolute stack level + symbol) - test::print::print_array(num_symbols_in, - stream, - make_stack_level_iterator(stack_symbols_in), - make_value_iterator(stack_symbols_in), - make_stack_level_iterator(d_kv_operations.Current()), - make_value_iterator(d_kv_operations.Current())); - // Stable radix sort, sorting by stack level of the operations d_kv_operations_unsigned = cub::DoubleBuffer{ reinterpret_cast(d_kv_operations.Current()), @@ -458,13 +448,6 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols, detail::RemapEmptyStack{empty_stack}}; kv_ops_scan_out = reinterpret_cast(d_kv_operations_unsigned.Alternate()); - // Dump info on stack operations sorted by their stack level (i.e. stack level after applying - // operation) - test::print::print_array(num_symbols_in, - stream, - make_stack_level_iterator(kv_ops_scan_in), - make_value_iterator(kv_ops_scan_in)); - // Inclusive scan to match pop operations with the latest push operation of that level CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScan( temp_storage.data(), @@ -475,15 +458,6 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols, num_symbols_in, stream)); - // Dump info on stack operations sorted by their stack level (i.e. stack level after applying - // operation) - test::print::print_array(num_symbols_in, - stream, - make_stack_level_iterator(kv_ops_scan_in), - make_value_iterator(kv_ops_scan_in), - make_stack_level_iterator(kv_ops_scan_out), - make_value_iterator(kv_ops_scan_out)); - // Fill the output tape with read-symbol thrust::fill(rmm::exec_policy(stream), thrust::device_ptr{d_top_of_stack}, @@ -502,11 +476,6 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols, d_symbol_positions_db.Current(), d_top_of_stack); - // Dump the output tape that has many yet-to-be-filled spots (i.e., all spots that were not given - // in the sparse representation) - test::print::print_array( - std::min(num_symbols_in, static_cast(10000)), stream, d_top_of_stack); - // We perform an exclusive scan in order to fill the items at the very left that may // be reading the empty stack before there's the first push occurrence in the sequence. // Also, we're interested in the top-of-the-stack symbol before the operation was applied. @@ -519,12 +488,6 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols, empty_stack_symbol, num_symbols_out, stream)); - - // Dump the final output - test::print::print_array( - std::min(num_symbols_in, static_cast(10000)), stream, d_top_of_stack); } -} // namespace fst -} // namespace io -} // namespace cudf +} // namespace cudf::io::fst From ac768b6ba37ecacfa365d48f2bce6fc4d23ecc45 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Mon, 11 Jul 2022 12:06:36 -0700 Subject: [PATCH 19/19] removes unused functions --- cpp/src/io/fst/logical_stack.cuh | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/cpp/src/io/fst/logical_stack.cuh b/cpp/src/io/fst/logical_stack.cuh index 64e08d9edfa..9502922a379 100644 --- a/cpp/src/io/fst/logical_stack.cuh +++ b/cpp/src/io/fst/logical_stack.cuh @@ -216,35 +216,6 @@ struct RemapEmptyStack { StackOpT empty_stack_symbol; }; -/** - * @brief Function object to return only the stack_level part from a StackOp instance. - */ -struct StackOpToStackLevel { - template - constexpr CUDF_HOST_DEVICE StackLevelT operator()(StackOp const& kv_op) const - { - return kv_op.stack_level; - } -}; - -/** - * @brief Retrieves an iterator that returns only the `stack_level` part from a StackOp iterator. - */ -template -auto make_stack_level_iterator(StackOpItT it) -{ - return thrust::make_transform_iterator(it, StackOpToStackLevel{}); -} - -/** - * @brief Retrieves an iterator that returns only the `value` part from a StackOp iterator. - */ -template -auto make_value_iterator(StackOpItT it) -{ - return thrust::make_transform_iterator(it, StackOpToStackSymbol{}); -} - } // namespace detail /**