From 4ada624122d19cd3ec58d9eadb3fd42f3b74c744 Mon Sep 17 00:00:00 2001 From: rmeng Date: Wed, 31 May 2023 11:25:06 -0400 Subject: [PATCH] argmin benchmark --- rust/Cargo.toml | 4 +++ rust/benches/argmin.rs | 68 +++++++++++++++++++++++++++++++++++++++ rust/src/utils/testing.rs | 13 +++++++- 3 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 rust/benches/argmin.rs diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 9274298d0c..7b566542a4 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -112,6 +112,10 @@ harness = false name = "kmeans" harness = false +[[bench]] +name = "argmin" +harness = false + [profile.release] strip = true opt-level = "s" diff --git a/rust/benches/argmin.rs b/rust/benches/argmin.rs new file mode 100644 index 0000000000..35f2286b3c --- /dev/null +++ b/rust/benches/argmin.rs @@ -0,0 +1,68 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{sync::Arc, time::Duration}; + +use arrow_array::{Float32Array, UInt32Array}; +use criterion::{criterion_group, criterion_main, Criterion}; + +use lance::utils::testing::generate_random_array_with_seed; +#[cfg(target_os = "linux")] +use pprof::criterion::{Output, PProfProfiler}; + +use lance::arrow::argmin; + +#[inline] +fn argmin_arrow(x: &Float32Array) -> u32 { + argmin(x).unwrap() +} + +fn argmin_arrow_batch(x: &Float32Array, dimension: usize) -> Arc { + assert_eq!(x.len() % dimension, 0); + + let idxs = unsafe { + UInt32Array::from_trusted_len_iter( + (0..x.len()) + .step_by(dimension) + .into_iter() + .map(|start| Some(argmin_arrow(&x.slice(start, dimension)))), + ) + }; + Arc::new(idxs) +} + +fn bench_argmin(c: &mut Criterion) { + const DIMENSION: usize = 1024 * 8; + const TOTAL: usize = 1024; + const SEED: [u8; 32] = [42; 32]; + + let target = generate_random_array_with_seed(TOTAL * DIMENSION, SEED); + + c.bench_function("argmin(arrow)", |b| { + b.iter(|| { + argmin_arrow_batch(&target, DIMENSION); + }) + }); +} + +#[cfg(target_os = "linux")] +criterion_group!( + name=benches; + config = Criterion::default() + .measurement_time(Duration::from_secs(10)) + .sample_size(32) + .with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); + targets = bench_argmin); + +criterion_main!(benches); diff --git a/rust/src/utils/testing.rs b/rust/src/utils/testing.rs index 982eb5ab3c..68dbfd1e9a 100644 --- a/rust/src/utils/testing.rs +++ b/rust/src/utils/testing.rs @@ -17,11 +17,22 @@ //! Testing utilities -use rand::Rng; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; use std::iter::repeat_with; use arrow_array::Float32Array; +/// Create a random float32 array. +pub fn generate_random_array_with_seed(n: usize, seed: [u8; 32]) -> Float32Array { + let mut rng = StdRng::from_seed(seed); + Float32Array::from( + repeat_with(|| rng.gen::()) + .take(n) + .collect::>(), + ) +} + /// Create a random float32 array. pub fn generate_random_array(n: usize) -> Float32Array { let mut rng = rand::thread_rng();