Skip to content

Commit

Permalink
[Rust] Benchmark for dot product (#1142)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored Aug 17, 2023
1 parent 0a12b97 commit da41e49
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 25 deletions.
4 changes: 4 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ required-features = ["cli"]
name = "l2"
harness = false

[[bench]]
name = "dot"
harness = false

[[bench]]
name = "scan"
harness = false
Expand Down
9 changes: 5 additions & 4 deletions rust/benches/cosine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use arrow_array::Float32Array;
use criterion::{criterion_group, criterion_main, Criterion};

use lance::linalg::cosine::cosine_distance_batch;
Expand All @@ -24,19 +25,19 @@ fn bench_distance(c: &mut Criterion) {
const DIMENSION: usize = 1024;
const TOTAL: usize = 1024 * 1024; // 1M vectors

let key = generate_random_array_with_seed(DIMENSION, [0; 32]);
let key: Float32Array = generate_random_array_with_seed(DIMENSION, [0; 32]);
// 1M of 1024 D vectors. 4GB in memory.
let target = generate_random_array_with_seed(TOTAL * DIMENSION, [42; 32]);
let target: Float32Array = generate_random_array_with_seed(TOTAL * DIMENSION, [42; 32]);

c.bench_function("Cosine(simd)", |b| {
b.iter(|| {
cosine_distance_batch(key.values(), target.values(), DIMENSION);
})
});

let key = generate_random_array_with_seed(DIMENSION, [5; 32]);
let key: Float32Array = generate_random_array_with_seed(DIMENSION, [5; 32]);
// 1M of 1024 D vectors. 4GB in memory.
let target = generate_random_array_with_seed(TOTAL * DIMENSION, [7; 32]);
let target: Float32Array = generate_random_array_with_seed(TOTAL * DIMENSION, [7; 32]);

c.bench_function("Cosine(simd) second rng seed", |b| {
b.iter(|| {
Expand Down
121 changes: 121 additions & 0 deletions rust/benches/dot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::iter::Sum;

use arrow_arith::{aggregate::sum, arithmetic::multiply};
use arrow_array::{
types::{Float16Type, Float32Type, Float64Type},
ArrowNumericType, Float16Array, Float32Array, NativeAdapter, PrimitiveArray,
};
use criterion::{criterion_group, criterion_main, Criterion};
use num_traits::{real::Real, FromPrimitive};

#[cfg(target_os = "linux")]
use pprof::criterion::{Output, PProfProfiler};

use lance::linalg::dot::{dot, Dot};
use lance::utils::testing::generate_random_array_with_seed;

#[inline]
fn dot_arrow_artiy<T: ArrowNumericType>(x: &PrimitiveArray<T>, y: &PrimitiveArray<T>) -> T::Native {
let m = multiply(x, y).unwrap();
sum(&m).unwrap()
}

fn run_bench<T: ArrowNumericType>(c: &mut Criterion)
where
T::Native: Real + FromPrimitive + Sum,
NativeAdapter<T>: From<T::Native>,
{
const DIMENSION: usize = 1024;
const TOTAL: usize = 1024 * 1024; // 1M vectors

let key: PrimitiveArray<T> = generate_random_array_with_seed(DIMENSION, [0; 32]);
// 1M of 1024 D vectors
let target: PrimitiveArray<T> = generate_random_array_with_seed(TOTAL * DIMENSION, [42; 32]);

let type_name = std::any::type_name::<T::Native>();

c.bench_function(format!("Dot({type_name}, arrow_artiy)").as_str(), |b| {
b.iter(|| unsafe {
PrimitiveArray::<T>::from_trusted_len_iter((0..target.len() / DIMENSION).map(|idx| {
let arr = target.slice(idx * DIMENSION, DIMENSION);
Some(dot_arrow_artiy(&key, &arr))
}))
});
});

c.bench_function(format!("Dot({type_name})").as_str(), |b| {
let x = key.values();
b.iter(|| unsafe {
PrimitiveArray::<T>::from_trusted_len_iter((0..target.len() / DIMENSION).map(|idx| {
let y = target.values()[idx * DIMENSION..(idx + 1) * DIMENSION].as_ref();
Some(dot(x, y))
}))
});
});

// TODO: SIMD needs generic specialization
}

fn bench_distance(c: &mut Criterion) {
const DIMENSION: usize = 1024;
const TOTAL: usize = 1024 * 1024; // 1M vectors

run_bench::<Float16Type>(c);
c.bench_function("Dot(f16, SIMD)", |b| {
let key: Float16Array = generate_random_array_with_seed(DIMENSION, [0; 32]);
// 1M of 1024 D vectors
let target: Float16Array = generate_random_array_with_seed(TOTAL * DIMENSION, [42; 32]);
b.iter(|| unsafe {
let x = key.values().as_ref();
Float16Array::from_trusted_len_iter((0..target.len() / DIMENSION).map(|idx| {
let y = target.values()[idx * DIMENSION..(idx + 1) * DIMENSION].as_ref();
Some(x.dot(y))
}))
});
});
run_bench::<Float32Type>(c);

c.bench_function("Dot(f32, SIMD)", |b| {
let key: Float32Array = generate_random_array_with_seed(DIMENSION, [0; 32]);
// 1M of 1024 D vectors
let target: Float32Array = generate_random_array_with_seed(TOTAL * DIMENSION, [42; 32]);
b.iter(|| unsafe {
let x = key.values().as_ref();
Float32Array::from_trusted_len_iter((0..target.len() / DIMENSION).map(|idx| {
let y = target.values()[idx * DIMENSION..(idx + 1) * DIMENSION].as_ref();
Some(x.dot(y))
}))
});
});
run_bench::<Float64Type>(c);
}

#[cfg(target_os = "linux")]
criterion_group!(
name=benches;
config = Criterion::default().significance_level(0.1).sample_size(10)
.with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
targets = bench_distance);

// Non-linux version does not support pprof.
#[cfg(not(target_os = "linux"))]
criterion_group!(
name=benches;
config = Criterion::default().significance_level(0.1).sample_size(10);
targets = bench_distance);

criterion_main!(benches);
61 changes: 61 additions & 0 deletions rust/src/linalg/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ use arrow_array::Float32Array;
use half::{bf16, f16};
use num_traits::real::Real;

/// Naive implemenetation of dot product.
#[inline]
pub fn dot<T: Real + Sum>(from: &[T], to: &[T]) -> T {
from.iter().zip(to.iter()).map(|(x, y)| x.mul(*y)).sum()
}

/// Dot product
pub trait Dot {
type Output;

Expand Down Expand Up @@ -53,6 +55,13 @@ impl Dot for [f32] {
type Output = f32;

fn dot(&self, other: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("fma") {
return x86_64::avx::dot_f32(self, other);
}
}

dot(self, other)
}
}
Expand All @@ -78,3 +87,55 @@ pub fn dot_distance_batch(from: &[f32], to: &[f32], dimension: usize) -> Arc<Flo
pub fn dot_distance(from: &[f32], to: &[f32]) -> f32 {
from.dot(to)
}

#[cfg(target_arch = "x86_64")]
mod x86_64 {

pub mod avx {
use crate::linalg::x86_64::avx::*;
use std::arch::x86_64::*;

#[inline]
pub fn dot_f32(x: &[f32], y: &[f32]) -> f32 {
let len = x.len() / 8 * 8;
let mut sum = unsafe {
let mut sums = _mm256_setzero_ps();
x.chunks_exact(8).zip(y.chunks_exact(8)).for_each(|(a, b)| {
let x = _mm256_loadu_ps(a.as_ptr());
let y = _mm256_loadu_ps(b.as_ptr());
sums = _mm256_fmadd_ps(x, y, sums);
});
add_f32_register(sums)
};
sum += x[len..]
.iter()
.zip(y[len..].iter())
.map(|(a, b)| a * b)
.sum::<f32>();
sum
}
}
}

#[cfg(test)]
mod tests {

use super::*;
use num_traits::FromPrimitive;

#[test]
fn test_dot() {
let x: Vec<f32> = (0..20).map(|v| v as f32).collect();
let y: Vec<f32> = (100..120).map(|v| v as f32).collect();

assert_eq!(x.dot(&y), dot(&x, &y));

let x: Vec<f16> = (0..20).map(|v| f16::from_i32(v).unwrap()).collect();
let y: Vec<f16> = (100..120).map(|v| f16::from_i32(v).unwrap()).collect();
assert_eq!(x.dot(&y), dot(&x, &y));

let x: Vec<f64> = (20..40).map(|v| f64::from_i32(v).unwrap()).collect();
let y: Vec<f64> = (120..140).map(|v| f64::from_i32(v).unwrap()).collect();
assert_eq!(x.dot(&y), dot(&x, &y));
}
}
46 changes: 25 additions & 21 deletions rust/src/utils/testing.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,40 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
// Copyright 2023 Lance Developers.
//
// http://www.apache.org/licenses/LICENSE-2.0
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Testing utilities
use num_traits::real::Real;
use num_traits::FromPrimitive;

use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::iter::repeat_with;

use arrow_array::Float32Array;
use arrow_array::{ArrowNumericType, Float32Array, NativeAdapter, PrimitiveArray};

/// Create a random float32 array.
pub fn generate_random_array_with_seed(n: usize, seed: [u8; 32]) -> Float32Array {
pub fn generate_random_array_with_seed<T: ArrowNumericType>(
n: usize,
seed: [u8; 32],
) -> PrimitiveArray<T>
where
T::Native: Real + FromPrimitive,
NativeAdapter<T>: From<T::Native>,
{
let mut rng = StdRng::from_seed(seed);
Float32Array::from(
repeat_with(|| rng.gen::<f32>())
.take(n)
.collect::<Vec<f32>>(),
)

PrimitiveArray::<T>::from_iter(repeat_with(|| T::Native::from_f32(rng.gen::<f32>())).take(n))
}

/// Create a random float32 array.
Expand Down

0 comments on commit da41e49

Please sign in to comment.