From 7cc43a7da96d7ce27040742d4bf1563148d461db Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 15 Jan 2025 21:06:10 +0800 Subject: [PATCH] feat: support float16/float64 for multivector (#3387) Signed-off-by: BubbleCal --- rust/lance-linalg/src/distance.rs | 77 +++++++++++++++++++++++-------- 1 file changed, 59 insertions(+), 18 deletions(-) diff --git a/rust/lance-linalg/src/distance.rs b/rust/lance-linalg/src/distance.rs index 90f4486676..a5575ec061 100644 --- a/rust/lance-linalg/src/distance.rs +++ b/rust/lance-linalg/src/distance.rs @@ -12,8 +12,8 @@ use std::sync::Arc; use arrow_array::cast::AsArray; -use arrow_array::types::{Float32Type, UInt8Type}; -use arrow_array::{Array, FixedSizeListArray, Float32Array, ListArray}; +use arrow_array::types::{Float16Type, Float32Type, Float64Type, UInt8Type}; +use arrow_array::{Array, ArrowPrimitiveType, FixedSizeListArray, Float32Array, ListArray}; use arrow_schema::{ArrowError, DataType}; pub mod cosine; @@ -117,6 +117,17 @@ pub fn multivec_distance( )); }; + // check the query vectors type first + // because we don't want to check the vectors type for each vector + match query.data_type() { + DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::UInt8 => {} + _ => { + return Err(ArrowError::InvalidArgumentError( + "query must be a float array or binary array".to_string(), + )); + } + } + let dists = vectors .iter() .map(|v| { @@ -139,22 +150,27 @@ pub fn multivec_distance( }) .sum() } - _ => { - let query = query.as_primitive::().values(); - query - .chunks_exact(dim) - .map(|q| { - multivector - .values() - .as_primitive::() - .values() - .chunks_exact(dim) - .map(|v| distance_type.func()(q, v)) - .min_by(|a, b| a.partial_cmp(b).unwrap()) - .unwrap() - }) - .sum() - } + _ => match query.data_type() { + DataType::Float16 => multivec_distance_impl::( + query, + multivector, + dim, + distance_type, + ), + DataType::Float32 => multivec_distance_impl::( + query, + multivector, + dim, + distance_type, + ), + DataType::Float64 => multivec_distance_impl::( + query, + multivector, + dim, + distance_type, + ), + _ => unreachable!("missed to check query type"), + }, } }) .unwrap_or(f32::NAN) @@ -162,3 +178,28 @@ pub fn multivec_distance( .collect(); Ok(dists) } + +fn multivec_distance_impl( + query: &dyn Array, + multivector: &FixedSizeListArray, + dim: usize, + distance_type: DistanceType, +) -> f32 +where + T::Native: L2 + Cosine + Dot, +{ + let query = query.as_primitive::().values(); + query + .chunks_exact(dim) + .map(|q| { + multivector + .values() + .as_primitive::() + .values() + .chunks_exact(dim) + .map(|v| distance_type.func()(q, v)) + .min_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap() + }) + .sum() +}