Skip to content

Commit

Permalink
Improve: Read benchmark duration from CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Mar 19, 2024
1 parent 8f262a0 commit dc243be
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 21 deletions.
35 changes: 20 additions & 15 deletions scripts/bench.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@

#include "test.hpp" // `read_file`

#if SZ_DEBUG // Make debugging faster
#define default_seconds_m 10
#else
#define default_seconds_m 30
#endif

namespace sz = ashvardanian::stringzilla;

namespace ashvardanian {
Expand Down Expand Up @@ -162,6 +156,8 @@ inline std::vector<result_string_type> filter_by_length(std::vector<from_string_
return result;
}

inline static std::size_t seconds_per_benchmark = 5;

struct dataset_t {
std::string text;
std::vector<std::string_view> tokens;
Expand Down Expand Up @@ -206,9 +202,20 @@ inline dataset_t make_dataset_from_path(std::string path) {
/**
* @brief Loads a dataset, depending on the passed CLI arguments.
*/
inline dataset_t make_dataset(int argc, char const *argv[]) {
if (argc != 2) { throw std::runtime_error("Usage: " + std::string(argv[0]) + " <path>"); }
return make_dataset_from_path(argv[1]);
inline dataset_t prepare_benchmark_environment(int argc, char const *argv[]) {
if (argc < 2 || argc > 3)
throw std::runtime_error("Usage: " + std::string(argv[0]) + " <path> [seconds_per_benchmark]");

dataset_t data = make_dataset_from_path(argv[1]);

// If the seconds_per_benchmark argument is provided, update the value in the dataset
if (argc == 3) {
seconds_per_benchmark = std::stoi(argv[2]);
if (seconds_per_benchmark == 0)
throw std::invalid_argument("The number of seconds per task must be greater than 0.");
}

return data;
}

inline sz_string_view_t to_c(std::string_view str) noexcept { return {str.data(), str.size()}; }
Expand All @@ -224,8 +231,7 @@ inline sz_string_view_t to_c(sz_string_view_t str) noexcept { return str; }
* @return Number of seconds per iteration.
*/
template <typename strings_type, typename function_type>
benchmark_result_t bench_on_tokens(strings_type &&strings, function_type &&function,
seconds_t max_time = default_seconds_m) {
benchmark_result_t bench_on_tokens(strings_type &&strings, function_type &&function) {

namespace stdc = std::chrono;
using stdcc = stdc::high_resolution_clock;
Expand All @@ -245,7 +251,7 @@ benchmark_result_t bench_on_tokens(strings_type &&strings, function_type &&funct

stdcc::time_point t2 = stdcc::now();
result.seconds = stdc::duration_cast<stdc::nanoseconds>(t2 - t1).count() / 1.e9;
if (result.seconds > max_time) break;
if (result.seconds > seconds_per_benchmark) break;
}

return result;
Expand All @@ -259,8 +265,7 @@ benchmark_result_t bench_on_tokens(strings_type &&strings, function_type &&funct
* @return Number of seconds per iteration.
*/
template <typename strings_type, typename function_type>
benchmark_result_t bench_on_token_pairs(strings_type &&strings, function_type &&function,
seconds_t max_time = default_seconds_m) {
benchmark_result_t bench_on_token_pairs(strings_type &&strings, function_type &&function) {

namespace stdc = std::chrono;
using stdcc = stdc::high_resolution_clock;
Expand All @@ -282,7 +287,7 @@ benchmark_result_t bench_on_token_pairs(strings_type &&strings, function_type &&

stdcc::time_point t2 = stdcc::now();
result.seconds = stdc::duration_cast<stdc::nanoseconds>(t2 - t1).count() / 1.e9;
if (result.seconds > max_time) break;
if (result.seconds > seconds_per_benchmark) break;
}

return result;
Expand Down
2 changes: 1 addition & 1 deletion scripts/bench_container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void bench_tokens(strings_type const &strings) {
int main(int argc, char const **argv) {
std::printf("StringZilla. Starting search benchmarks.\n");

dataset_t dataset = make_dataset(argc, argv);
dataset_t dataset = prepare_benchmark_environment(argc, argv);

// Baseline benchmarks for real words, coming in all lengths
std::printf("Benchmarking on real words:\n");
Expand Down
8 changes: 6 additions & 2 deletions scripts/bench_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,19 +292,23 @@ void bench_search(std::string const &haystack, std::vector<std::string> const &s
int main(int argc, char const **argv) {
std::printf("StringZilla. Starting search benchmarks.\n");

dataset_t dataset = make_dataset(argc, argv);
dataset_t dataset = prepare_benchmark_environment(argc, argv);

// Splitting by new lines
std::printf("Benchmarking for a newline symbol:\n");
bench_finds(dataset.text, {"\n"}, find_functions());
bench_rfinds(dataset.text, {"\n"}, rfind_functions());

std::printf("Benchmarking for one whitespace:\n");
bench_finds(dataset.text, {" "}, find_functions());
bench_rfinds(dataset.text, {" "}, rfind_functions());

std::printf("Benchmarking for an [\\n\\r\\v\\f] RegEx:\n");
bench_finds(dataset.text, {"\n\r\v\f"}, find_charset_functions());
bench_rfinds(dataset.text, {"\n\r\v\f"}, rfind_charset_functions());

// Typical ASCII tokenization and validation benchmarks
std::printf("Benchmarking for whitespaces:\n");
std::printf("Benchmarking for all whitespaces:\n");
bench_finds(dataset.text, {{sz::whitespaces(), sizeof(sz::whitespaces())}}, find_charset_functions());
bench_rfinds(dataset.text, {{sz::whitespaces(), sizeof(sz::whitespaces())}}, rfind_charset_functions());

Expand Down
2 changes: 1 addition & 1 deletion scripts/bench_similarity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ void bench_similarity_on_bio_data() {

void bench_similarity_on_input_data(int argc, char const **argv) {

dataset_t dataset = make_dataset(argc, argv);
dataset_t dataset = prepare_benchmark_environment(argc, argv);

// Baseline benchmarks for real words, coming in all lengths
std::printf("Benchmarking on real words:\n");
Expand Down
2 changes: 1 addition & 1 deletion scripts/bench_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void bench_permute(char const *name, strings_t &strings, permute_t &permute, alg

int main(int argc, char const **argv) {
std::printf("StringZilla. Starting sorting benchmarks.\n");
dataset_t dataset = make_dataset(argc, argv);
dataset_t dataset = prepare_benchmark_environment(argc, argv);
strings_t strings {dataset.tokens.begin(), dataset.tokens.end()};

permute_t permute_base, permute_new;
Expand Down
2 changes: 1 addition & 1 deletion scripts/bench_token.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ void bench(strings_type &&strings) {
}

void bench_on_input_data(int argc, char const **argv) {
dataset_t dataset = make_dataset(argc, argv);
dataset_t dataset = prepare_benchmark_environment(argc, argv);

std::printf("Benchmarking on the entire dataset:\n");
bench_unary_functions(dataset.tokens, random_generation_functions(100));
Expand Down

0 comments on commit dc243be

Please sign in to comment.