From 593073f7af1718f8773a5e90c86ccabe1aabd82b Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Sun, 11 Jun 2023 15:53:58 -0700 Subject: [PATCH 1/5] allow to pass pq and ivf centroids --- rust/src/index/vector/ivf.rs | 98 ++++++++++++++++++++++++++++-------- rust/src/index/vector/pq.rs | 18 +++++++ 2 files changed, 94 insertions(+), 22 deletions(-) diff --git a/rust/src/index/vector/ivf.rs b/rust/src/index/vector/ivf.rs index bab388579e..c924406976 100644 --- a/rust/src/index/vector/ivf.rs +++ b/rust/src/index/vector/ivf.rs @@ -43,6 +43,7 @@ use super::{ use crate::{ arrow::{linalg::MatrixView, *}, dataset::{Dataset, ROW_ID}, + datatypes::Field, index::{pb, vector::Transformer, Index}, }; use crate::{io::object_reader::ObjectReader, session::Session}; @@ -413,7 +414,7 @@ impl TryFrom<&pb::Ivf> for Ivf { } } -fn sanity_check(dataset: &Dataset, column: &str) -> Result<()> { +fn sanity_check<'a>(dataset: &'a Dataset, column: &str) -> Result<&'a Field> { let Some(field) = dataset.schema().field(column) else { return Err(Error::IO{message:format!( "Building index: column {} does not exist in dataset: {:?}", @@ -435,7 +436,7 @@ fn sanity_check(dataset: &Dataset, column: &str) -> Result<()> { ), }); } - Ok(()) + Ok(field) } /// Parameters to build IVF partitions @@ -447,6 +448,9 @@ pub struct IvfBuildParams { // ---- kmeans parameters /// Max number of iterations to train kmeans. pub max_iters: usize, + + /// Use provided IVF centroids. + pub centroids: Option>, } impl Default for IvfBuildParams { @@ -454,6 +458,7 @@ impl Default for IvfBuildParams { Self { num_partitions: 32, max_iters: 50, + centroids: None, } } } @@ -522,7 +527,17 @@ pub async fn build_ivf_pq_index( metric_type, ); - sanity_check(dataset, column)?; + let field = sanity_check(dataset, column)?; + let dim = if let DataType::FixedSizeList(elem_type, d) = field.data_type() { + d as usize + } else { + return Err(Error::Index { + message: format!( + "VectorIndex requires the column data type to be fixed size list of floats, got {}", + field.data_type() + ), + }); + }; // Maximum to train 256 vectors per centroids, see Faiss. let sample_size_hint = std::cmp::max( @@ -530,30 +545,69 @@ pub async fn build_ivf_pq_index( ProductQuantizer::num_centroids(pq_params.num_bits as u32), ) * 256; - let mut training_data = maybe_sample_training_data(dataset, column, sample_size_hint).await?; - - // Pre-transforms let mut transforms: Vec> = vec![]; - if pq_params.use_opq { - let opq = train_opq(&training_data, pq_params).await?; - transforms.push(Box::new(opq)); - } + // Train IVF partitions. + let ivf_model = if let Some(centroids) = &ivf_params.centroids { + if centroids.len() != ivf_params.num_partitions * dim { + return Err(Error::Index { + message: format!( + "IVF centroids length mismatch: {} != {}", + centroids.len(), + ivf_params.num_partitions + ), + }); + } + Ivf::new(centroids.clone()) + } else { + let mut training_data = + maybe_sample_training_data(dataset, column, sample_size_hint).await?; - // Transform training data if necessary. - for transform in transforms.iter() { - training_data = transform.transform(&training_data).await?; - } + // Pre-transforms + if pq_params.use_opq { + let opq = train_opq(&training_data, pq_params).await?; + transforms.push(Box::new(opq)); + } - // Train IVF partitions. - let ivf_model = train_ivf_model(&training_data, metric_type, ivf_params).await?; + // Transform training data if necessary. + for transform in transforms.iter() { + training_data = transform.transform(&training_data).await?; + } + + train_ivf_model(&training_data, metric_type, ivf_params).await? + }; + + let pq = if let Some(codebook) = &pq_params.codebook { + if codebook.len() != pq_params.num_sub_vectors * dim { + return Err(Error::Index { + message: format!( + "PQ codebook length mismatch: {} != {}", + codebook.len(), + pq_params.num_sub_vectors * dim + ), + }); + } + ProductQuantizer::new_with_codebook( + pq_params.num_sub_vectors, + pq_params.num_bits as u32, + dim, + codebook.clone(), + ) + } else { + let mut training_data = + maybe_sample_training_data(dataset, column, sample_size_hint).await?; + + // Transform training data if necessary. + for transform in transforms.iter() { + training_data = transform.transform(&training_data).await?; + } - // Compute the residual vector for training PQ - let ivf_centroids = ivf_model.centroids.as_ref().try_into()?; - let residual_data = compute_residual_matrix(&training_data, &ivf_centroids, metric_type)?; - let pq_training_data = MatrixView::new(residual_data, training_data.num_columns()); + // Compute the residual vector for training PQ + let ivf_centroids = ivf_model.centroids.as_ref().try_into()?; + let residual_data = compute_residual_matrix(&training_data, &ivf_centroids, metric_type)?; + let pq_training_data = MatrixView::new(residual_data, training_data.num_columns()); - // The final train of PQ sub-vectors - let pq = train_pq(&pq_training_data, pq_params).await?; + train_pq(&pq_training_data, pq_params).await? + }; // Transform data, compute residuals and sort by partition ids. let mut scanner = dataset.scan(); diff --git a/rust/src/index/vector/pq.rs b/rust/src/index/vector/pq.rs index d85fd013a3..59ec62df30 100644 --- a/rust/src/index/vector/pq.rs +++ b/rust/src/index/vector/pq.rs @@ -316,6 +316,21 @@ impl ProductQuantizer { } } + pub fn new_with_codebook( + m: usize, + nbits: u32, + dimension: usize, + codebook: Arc, + ) -> Self { + assert!(nbits == 8, "nbits can only be 8"); + Self { + num_bits: nbits, + num_sub_vectors: m, + dimension, + codebook: Some(codebook), + } + } + pub fn num_centroids(num_bits: u32) -> usize { 2_usize.pow(num_bits) } @@ -609,6 +624,8 @@ pub struct PQBuildParams { /// Max number of iterations to train OPQ, if `use_opq` is true. pub max_opq_iters: usize, + + pub codebook: Option>, } impl Default for PQBuildParams { @@ -620,6 +637,7 @@ impl Default for PQBuildParams { use_opq: false, max_iters: 50, max_opq_iters: 50, + codebook: None, } } } From a1b0f82287a9bd2fbbeac6797e2192d24c0353fe Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Sun, 11 Jun 2023 16:09:26 -0700 Subject: [PATCH 2/5] support pre-trained centroids --- rust/src/index/vector/ivf.rs | 17 ++++------------- rust/src/index/vector/pq.rs | 1 + 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/rust/src/index/vector/ivf.rs b/rust/src/index/vector/ivf.rs index c924406976..fab7fc4f1b 100644 --- a/rust/src/index/vector/ivf.rs +++ b/rust/src/index/vector/ivf.rs @@ -544,8 +544,10 @@ pub async fn build_ivf_pq_index( ivf_params.num_partitions, ProductQuantizer::num_centroids(pq_params.num_bits as u32), ) * 256; - + // TODO: only sample data if training is necessary. + let mut training_data = maybe_sample_training_data(dataset, column, sample_size_hint).await?; let mut transforms: Vec> = vec![]; + // Train IVF partitions. let ivf_model = if let Some(centroids) = &ivf_params.centroids { if centroids.len() != ivf_params.num_partitions * dim { @@ -559,9 +561,6 @@ pub async fn build_ivf_pq_index( } Ivf::new(centroids.clone()) } else { - let mut training_data = - maybe_sample_training_data(dataset, column, sample_size_hint).await?; - // Pre-transforms if pq_params.use_opq { let opq = train_opq(&training_data, pq_params).await?; @@ -577,7 +576,7 @@ pub async fn build_ivf_pq_index( }; let pq = if let Some(codebook) = &pq_params.codebook { - if codebook.len() != pq_params.num_sub_vectors * dim { + if codebook.len() != pq_params. * dim { return Err(Error::Index { message: format!( "PQ codebook length mismatch: {} != {}", @@ -593,14 +592,6 @@ pub async fn build_ivf_pq_index( codebook.clone(), ) } else { - let mut training_data = - maybe_sample_training_data(dataset, column, sample_size_hint).await?; - - // Transform training data if necessary. - for transform in transforms.iter() { - training_data = transform.transform(&training_data).await?; - } - // Compute the residual vector for training PQ let ivf_centroids = ivf_model.centroids.as_ref().try_into()?; let residual_data = compute_residual_matrix(&training_data, &ivf_centroids, metric_type)?; diff --git a/rust/src/index/vector/pq.rs b/rust/src/index/vector/pq.rs index 59ec62df30..0c243516ca 100644 --- a/rust/src/index/vector/pq.rs +++ b/rust/src/index/vector/pq.rs @@ -316,6 +316,7 @@ impl ProductQuantizer { } } + /// Create a [`ProductQuantizer`] with pre-trained codebook. pub fn new_with_codebook( m: usize, nbits: u32, From f3ad58bded76df692d0d1c23c0a1cd6ecfa4d612 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Sun, 11 Jun 2023 16:19:08 -0700 Subject: [PATCH 3/5] fix build --- rust/src/index/vector/ivf.rs | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/rust/src/index/vector/ivf.rs b/rust/src/index/vector/ivf.rs index fab7fc4f1b..7ca34c5b23 100644 --- a/rust/src/index/vector/ivf.rs +++ b/rust/src/index/vector/ivf.rs @@ -528,7 +528,7 @@ pub async fn build_ivf_pq_index( ); let field = sanity_check(dataset, column)?; - let dim = if let DataType::FixedSizeList(elem_type, d) = field.data_type() { + let dim = if let DataType::FixedSizeList(_, d) = field.data_type() { d as usize } else { return Err(Error::Index { @@ -576,15 +576,6 @@ pub async fn build_ivf_pq_index( }; let pq = if let Some(codebook) = &pq_params.codebook { - if codebook.len() != pq_params. * dim { - return Err(Error::Index { - message: format!( - "PQ codebook length mismatch: {} != {}", - codebook.len(), - pq_params.num_sub_vectors * dim - ), - }); - } ProductQuantizer::new_with_codebook( pq_params.num_sub_vectors, pq_params.num_bits as u32, From d0229eda465f9679f62d3e461a0cc97caf7b94f4 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Sun, 11 Jun 2023 16:55:20 -0700 Subject: [PATCH 4/5] factory method to create params with centrods --- rust/src/index/vector.rs | 1 + rust/src/index/vector/ivf.rs | 21 +++++++++++++++++++++ rust/src/index/vector/pq.rs | 24 ++++++++++++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/rust/src/index/vector.rs b/rust/src/index/vector.rs index 9db978d18f..12aedf1879 100644 --- a/rust/src/index/vector.rs +++ b/rust/src/index/vector.rs @@ -214,6 +214,7 @@ impl VectorIndexParams { } } + /// Create index parameters with `IVF` and `PQ` parameters, respectively. pub fn with_ivf_pq_params( metric_type: MetricType, ivf: IvfBuildParams, diff --git a/rust/src/index/vector/ivf.rs b/rust/src/index/vector/ivf.rs index 7ca34c5b23..b1d14e587e 100644 --- a/rust/src/index/vector/ivf.rs +++ b/rust/src/index/vector/ivf.rs @@ -471,6 +471,27 @@ impl IvfBuildParams { ..Default::default() } } + + /// Create a new instance of [`IvfBuildParams`] with centroids. + pub fn try_with_centroids( + num_partitions: usize, + centroids: Arc, + ) -> Result { + if num_partitions != centroids.len() { + return Err(Error::Index { + message: format!( + "IvfBuildParams::try_with_centroids: num_partitions {} != centroids.len() {}", + num_partitions, + centroids.len() + ), + }); + } + Ok(Self { + num_partitions, + centroids: Some(centroids), + ..Default::default() + }) + } } /// Compute residual matrix. diff --git a/rust/src/index/vector/pq.rs b/rust/src/index/vector/pq.rs index 0c243516ca..995090d9ea 100644 --- a/rust/src/index/vector/pq.rs +++ b/rust/src/index/vector/pq.rs @@ -626,6 +626,7 @@ pub struct PQBuildParams { /// Max number of iterations to train OPQ, if `use_opq` is true. pub max_opq_iters: usize, + /// User provided codebook. pub codebook: Option>, } @@ -643,6 +644,29 @@ impl Default for PQBuildParams { } } +impl PQBuildParams { + pub fn new(num_sub_vectors: usize, num_bits: usize) -> Self { + Self { + num_sub_vectors, + num_bits, + ..Default::default() + } + } + + pub fn with_codebook( + num_sub_vectors: usize, + num_bits: usize, + codebook: Arc, + ) -> Self { + Self { + num_sub_vectors, + num_bits, + codebook: Some(codebook), + ..Default::default() + } + } +} + /// Train product quantization over (OPQ-rotated) residual vectors. pub(crate) async fn train_pq( data: &MatrixView, From 21c5fe2b5a7da04dbfdfc642327209a75fd72d30 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Sun, 11 Jun 2023 21:10:44 -0700 Subject: [PATCH 5/5] add tests --- rust/src/index/vector/ivf.rs | 71 +++++++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 2 deletions(-) diff --git a/rust/src/index/vector/ivf.rs b/rust/src/index/vector/ivf.rs index b1d14e587e..fe3f9d33cc 100644 --- a/rust/src/index/vector/ivf.rs +++ b/rust/src/index/vector/ivf.rs @@ -571,12 +571,12 @@ pub async fn build_ivf_pq_index( // Train IVF partitions. let ivf_model = if let Some(centroids) = &ivf_params.centroids { - if centroids.len() != ivf_params.num_partitions * dim { + if centroids.values().len() != ivf_params.num_partitions * dim { return Err(Error::Index { message: format!( "IVF centroids length mismatch: {} != {}", centroids.len(), - ivf_params.num_partitions + ivf_params.num_partitions * dim, ), }); } @@ -787,3 +787,70 @@ async fn train_ivf_model( data.num_columns() as i32, )?))) } + +#[cfg(test)] +mod tests { + use super::*; + + use arrow_array::cast::AsArray; + use arrow_schema::{DataType, Field, Schema}; + use tempfile::tempdir; + + use crate::{ + index::{vector::VectorIndexParams, DatasetIndexExt, IndexType}, + utils::testing::generate_random_array, + }; + + #[tokio::test] + async fn test_create_ivf_pq_with_centroids() { + const DIM: usize = 32; + let vectors = generate_random_array(1000 * DIM); + + let schema = Arc::new(Schema::new(vec![Field::new( + "vector", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + DIM as i32, + ), + true, + )])); + let array = Arc::new(FixedSizeListArray::try_new(&vectors, DIM as i32).unwrap()); + let batch = RecordBatch::try_new(schema, vec![array.clone()]).unwrap(); + + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + let mut batches: Box = + Box::new(RecordBatchBuffer::new(vec![batch])); + let dataset = Dataset::write(&mut batches, test_uri, None).await.unwrap(); + + let centroids = generate_random_array(2 * DIM); + let ivf_centroids = FixedSizeListArray::try_new(¢roids, DIM as i32).unwrap(); + let ivf_params = IvfBuildParams::try_with_centroids(2, Arc::new(ivf_centroids)).unwrap(); + + let codebook = Arc::new(generate_random_array(256 * DIM)); + let pq_params = PQBuildParams::with_codebook(4, 8, codebook); + + let params = VectorIndexParams::with_ivf_pq_params(MetricType::L2, ivf_params, pq_params); + + let dataset = dataset + .create_index(&["vector"], IndexType::Vector, None, ¶ms, false) + .await + .unwrap(); + + let elem = array.value(10); + let query = elem.as_primitive::(); + let results = dataset + .scan() + .nearest("vector", query, 5) + .unwrap() + .try_into_stream() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + assert_eq!(1, results.len()); + assert_eq!(5, results[0].num_rows()); + } +}