Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance for cudf::strings::count_re #15578

Merged
merged 3 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/benchmarks/string/contains.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ std::unique_ptr<cudf::column> build_input_column(cudf::size_type n_rows,
}

// longer pattern lengths demand more working memory per string
std::string patterns[] = {"^\\d+ [a-z]+", "[A-Z ]+\\d+ +\\d+[A-Z]+\\d+$"};
std::string patterns[] = {"^\\d+ [a-z]+", "[A-Z ]+\\d+ +\\d+[A-Z]+\\d+$", "5W43"};

static void bench_contains(nvbench::state& state)
{
Expand Down Expand Up @@ -114,4 +114,4 @@ NVBENCH_BENCH(bench_contains)
.add_int64_axis("row_width", {32, 64, 128, 256, 512})
.add_int64_axis("num_rows", {32768, 262144, 2097152, 16777216})
.add_int64_axis("hit_rate", {50, 100}) // percentage
.add_int64_axis("pattern", {0, 1});
.add_int64_axis("pattern", {0, 1, 2});
12 changes: 8 additions & 4 deletions cpp/benchmarks/string/count.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@

#include <nvbench/nvbench.cuh>

static std::string patterns[] = {"\\d+", "a"};

static void bench_count(nvbench::state& state)
{
auto const num_rows = static_cast<cudf::size_type>(state.get_int64("num_rows"));
auto const row_width = static_cast<cudf::size_type>(state.get_int64("row_width"));
auto const num_rows = static_cast<cudf::size_type>(state.get_int64("num_rows"));
auto const row_width = static_cast<cudf::size_type>(state.get_int64("row_width"));
auto const pattern_index = static_cast<cudf::size_type>(state.get_int64("pattern"));

if (static_cast<std::size_t>(num_rows) * static_cast<std::size_t>(row_width) >=
static_cast<std::size_t>(std::numeric_limits<cudf::size_type>::max())) {
Expand All @@ -41,7 +44,7 @@ static void bench_count(nvbench::state& state)
create_random_table({cudf::type_id::STRING}, row_count{num_rows}, table_profile);
cudf::strings_column_view input(table->view().column(0));

std::string pattern = "\\d+";
auto const pattern = patterns[pattern_index];

auto prog = cudf::strings::regex_program::create(pattern);

Expand All @@ -59,4 +62,5 @@ static void bench_count(nvbench::state& state)
NVBENCH_BENCH(bench_count)
.set_name("count")
.add_int64_axis("row_width", {32, 64, 128, 256, 512, 1024, 2048})
.add_int64_axis("num_rows", {4096, 32768, 262144, 2097152, 16777216});
.add_int64_axis("num_rows", {4096, 32768, 262144, 2097152, 16777216})
.add_int64_axis("pattern", {0, 1});
19 changes: 14 additions & 5 deletions cpp/src/strings/regex/regex.inl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,15 @@ __device__ __forceinline__ reprog_device reprog_device::load(reprog_device const
: reinterpret_cast<reprog_device*>(buffer)[0];
}

__device__ __forceinline__ static string_view::const_iterator find_char(
cudf::char_utf8 chr, string_view const d_str, string_view::const_iterator itr)
{
while (itr.byte_offset() < d_str.size_bytes() && *itr != chr) {
++itr;
}
return itr;
}

/**
* @brief Evaluate a specific string against regex pattern compiled to this instance.
*
Expand Down Expand Up @@ -253,16 +262,16 @@ __device__ __forceinline__ match_result reprog_device::regexec(string_view const
case BOL:
if (pos == 0) break;
if (jnk.startchar != '^') { return thrust::nullopt; }
--pos;
--itr;
startchar = static_cast<char_utf8>('\n');
case CHAR: {
auto const fidx = dstr.find(startchar, pos);
if (fidx == string_view::npos) { return thrust::nullopt; }
pos = fidx + (jnk.starttype == BOL);
auto const find_itr = find_char(startchar, dstr, itr);
if (find_itr.byte_offset() >= dstr.size_bytes()) { return thrust::nullopt; }
itr = find_itr + (jnk.starttype == BOL);
pos = itr.position();
break;
}
}
itr += (pos - itr.position()); // faster to increment position
}

if (((eos < 0) || (pos < eos)) && match == 0) {
Expand Down
Loading