Skip to content

Commit

Permalink
fix!: low recall with cosine/dot on v3 index types (#3141)
Browse files Browse the repository at this point in the history
  • Loading branch information
BubbleCal authored Nov 20, 2024
1 parent 73cf23b commit 71f323a
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 41 deletions.
3 changes: 2 additions & 1 deletion java/core/lance-jni/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use arrow::array::Float32Array;
use jni::objects::{JMap, JObject, JString};
use jni::JNIEnv;
use lance::dataset::{WriteMode, WriteParams};
use lance::index::vector::{StageParams, VectorIndexParams};
use lance::index::vector::{IndexFileVersion, StageParams, VectorIndexParams};
use lance::io::ObjectStoreParams;
use lance_encoding::version::LanceFileVersion;
use lance_index::vector::hnsw::builder::HnswBuildParams;
Expand Down Expand Up @@ -265,6 +265,7 @@ pub fn get_index_params(
Some(VectorIndexParams {
metric_type: distance_type,
stages,
version: IndexFileVersion::V3,
})
} else {
None
Expand Down
27 changes: 14 additions & 13 deletions rust/lance-index/src/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ use arrow_array::{cast::AsArray, Array, FixedSizeListArray, UInt8Array};
use arrow_array::{ArrayRef, Float32Array, PrimitiveArray};
use arrow_schema::DataType;
use deepsize::DeepSizeOf;
use distance::build_distance_table_dot;
use lance_arrow::*;
use lance_core::{Error, Result};
use lance_linalg::distance::{dot_distance_batch, DistanceType, Dot, L2};
use lance_linalg::distance::{DistanceType, Dot, L2};
use lance_linalg::kmeans::compute_partition;
use num_traits::Float;
use prost::Message;
Expand Down Expand Up @@ -150,6 +151,12 @@ impl ProductQuantizer {
match self.distance_type {
DistanceType::L2 => self.l2_distances(query, code),
DistanceType::Cosine => {
// it seems we implemented cosine distance at some version,
// but from now on, we should use normalized L2 distance.
debug_assert!(
false,
"cosine distance should be converted to normalized L2 distance"
);
// L2 over normalized vectors: ||x - y|| = x^2 + y^2 - 2 * xy = 1 + 1 - 2 * xy = 2 * (1 - xy)
// Cosine distance: 1 - |xy| / (||x|| * ||y||) = 1 - xy / (x^2 * y^2) = 1 - xy / (1 * 1) = 1 - xy
// Therefore, Cosine = L2 / 2
Expand Down Expand Up @@ -211,18 +218,12 @@ impl ProductQuantizer {
where
T::Native: Dot,
{
let capacity = self.num_sub_vectors * num_centroids(self.num_bits);
let mut distance_table = Vec::with_capacity(capacity);

let sub_vector_length = self.dimension / self.num_sub_vectors;
key.values()
.chunks_exact(sub_vector_length)
.enumerate()
.for_each(|(sub_vec_id, sub_vec)| {
let subvec_centroids = self.centroids::<T>(sub_vec_id);
let distances = dot_distance_batch(sub_vec, subvec_centroids, sub_vector_length);
distance_table.extend(distances);
});
let distance_table = build_distance_table_dot(
self.codebook.values().as_primitive::<T>().values(),
self.num_bits,
self.num_sub_vectors,
key.values(),
);

let num_vectors = code.len() / self.num_sub_vectors;
let mut distances = vec![0.0; num_vectors];
Expand Down
1 change: 1 addition & 0 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5285,6 +5285,7 @@ mod test {
..Default::default()
}),
],
version: crate::index::vector::IndexFileVersion::Legacy,
},
false,
)
Expand Down
60 changes: 50 additions & 10 deletions rust/lance/src/index/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,40 @@ pub enum StageParams {
SQ(SQBuildParams),
}

// The version of the index file.
// `Legacy` is used for only IVF_PQ index, and is the default value.
// The other index types are using `V3`.
#[derive(Debug, Clone)]
pub enum IndexFileVersion {
Legacy,
V3,
}

/// The parameters to build vector index.
#[derive(Debug, Clone)]
pub struct VectorIndexParams {
pub stages: Vec<StageParams>,

/// Vector distance metrics type.
pub metric_type: MetricType,

/// The version of the index file.
pub version: IndexFileVersion,
}

impl VectorIndexParams {
pub fn version(&mut self, version: IndexFileVersion) -> &mut Self {
self.version = version;
self
}

pub fn ivf_flat(num_partitions: usize, metric_type: MetricType) -> Self {
let ivf_params = IvfBuildParams::new(num_partitions);
let stages = vec![StageParams::Ivf(ivf_params)];
Self {
stages,
metric_type,
version: IndexFileVersion::V3,
}
}

Expand Down Expand Up @@ -106,6 +124,7 @@ impl VectorIndexParams {
Self {
stages,
metric_type,
version: IndexFileVersion::Legacy,
}
}

Expand All @@ -119,6 +138,7 @@ impl VectorIndexParams {
Self {
stages,
metric_type,
version: IndexFileVersion::Legacy,
}
}

Expand All @@ -138,6 +158,7 @@ impl VectorIndexParams {
Self {
stages,
metric_type,
version: IndexFileVersion::V3,
}
}

Expand All @@ -157,6 +178,7 @@ impl VectorIndexParams {
Self {
stages,
metric_type,
version: IndexFileVersion::V3,
}
}
}
Expand Down Expand Up @@ -252,16 +274,34 @@ pub(crate) async fn build_vector_index(
});
};

build_ivf_pq_index(
dataset,
column,
name,
uuid,
params.metric_type,
ivf_params,
pq_params,
)
.await?;
match params.version {
IndexFileVersion::Legacy => {
build_ivf_pq_index(
dataset,
column,
name,
uuid,
params.metric_type,
ivf_params,
pq_params,
)
.await?;
}
IndexFileVersion::V3 => {
IvfIndexBuilder::<FlatIndex, ProductQuantizer>::new(
dataset.clone(),
column.to_owned(),
dataset.indices_dir().child(uuid),
params.metric_type,
Box::new(shuffler),
Some(ivf_params.clone()),
Some(pq_params.clone()),
(),
)?
.build()
.await?;
}
}
} else if is_ivf_hnsw(stages) {
let len = stages.len();
let StageParams::Hnsw(hnsw_params) = &stages[1] else {
Expand Down
16 changes: 2 additions & 14 deletions rust/lance/src/index/vector/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + Clone + 'static> IvfIndexBuilde

info!("Start to train quantizer");
let start = std::time::Instant::now();
let quantizer = Q::build(&training_data, dt, quantizer_params)?;
let quantizer = Q::build(&training_data, DistanceType::L2, quantizer_params)?;
info!(
"Trained quantizer in {:02} seconds",
start.elapsed().as_secs_f32()
Expand Down Expand Up @@ -393,19 +393,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + Clone + 'static> IvfIndexBuilde
if num_rows == 0 {
continue;
}
let mut batch = arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?;
if self.distance_type == DistanceType::Cosine {
let vectors = batch
.column_by_name(&self.column)
.ok_or(Error::invalid_input(
format!("column {} not found", self.column).as_str(),
location!(),
))?
.as_fixed_size_list();
let vectors = lance_linalg::kernels::normalize_fsl(vectors)?;
batch = batch.replace_column_by_name(&self.column, Arc::new(vectors))?;
}

let batch = arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?;
let sizes = self.build_partition(partition, &batch).await?;
partition_sizes[partition] = sizes;
log::info!(
Expand Down
24 changes: 21 additions & 3 deletions rust/lance/src/index/vector/ivf/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ mod tests {
&vectors,
query.as_primitive::<Float32Type>().values(),
k,
DistanceType::L2,
params.metric_type,
);
let gt_set = gt.iter().map(|r| r.1).collect::<HashSet<_>>();

Expand Down Expand Up @@ -673,6 +673,24 @@ mod tests {
test_index(params, nlist, recall_requirement).await;
}

#[rstest]
#[case(4, DistanceType::L2, 0.9)]
#[case(4, DistanceType::Cosine, 0.9)]
#[case(4, DistanceType::Dot, 0.9)]
#[tokio::test]
async fn test_build_ivf_pq_v3(
#[case] nlist: usize,
#[case] distance_type: DistanceType,
#[case] recall_requirement: f32,
) {
let ivf_params = IvfBuildParams::new(nlist);
let pq_params = PQBuildParams::default();
let params = VectorIndexParams::with_ivf_pq_params(distance_type, ivf_params, pq_params)
.version(crate::index::vector::IndexFileVersion::V3)
.clone();
test_index(params, nlist, recall_requirement).await;
}

#[rstest]
#[case(4, DistanceType::L2, 0.9)]
#[case(4, DistanceType::Cosine, 0.9)]
Expand All @@ -697,8 +715,8 @@ mod tests {

#[rstest]
#[case(4, DistanceType::L2, 0.9)]
#[case(4, DistanceType::Cosine, 0.6)]
#[case(4, DistanceType::Dot, 0.2)]
#[case(4, DistanceType::Cosine, 0.9)]
#[case(4, DistanceType::Dot, 0.9)]
#[tokio::test]
async fn test_create_ivf_hnsw_pq(
#[case] nlist: usize,
Expand Down

0 comments on commit 71f323a

Please sign in to comment.