diff --git a/rust/src/index/vector/ivf.rs b/rust/src/index/vector/ivf.rs index 8343ca2d17..5279845d2b 100644 --- a/rust/src/index/vector/ivf.rs +++ b/rust/src/index/vector/ivf.rs @@ -790,9 +790,11 @@ impl IndexBuilder for IvfPqIndexBuilder { ProductQuantizer::new(self.num_sub_vectors as usize, self.nbits, self.dimension); let batch = concat_batches(&partitioned_batches[0].schema(), &partitioned_batches)?; let residual_vector = batch.column_by_name(RESIDUAL_COLUMN).unwrap(); + let data: &Float32Array = as_primitive_array(residual_vector); + let resid_mat = MatrixView::new(Arc::new(data.clone()), self.dimension); let pq_code = pq - .fit_transform(as_fixed_size_list_array(residual_vector), self.metric_type) + .fit_transform(&resid_mat, self.metric_type) .await?; const PQ_CODE_COLUMN: &str = "__pq_code"; diff --git a/rust/src/index/vector/opq.rs b/rust/src/index/vector/opq.rs index 0bea944a2d..b1ff8ed0f3 100644 --- a/rust/src/index/vector/opq.rs +++ b/rust/src/index/vector/opq.rs @@ -76,11 +76,8 @@ impl OptimizedProductQuantizer { train: &MatrixView, metric_type: MetricType, ) -> Result<(MatrixView, ProductQuantizer)> { - let dim = train.num_columns(); - // TODO: make PQ::fit_transform work with MatrixView. - let fixed_list = FixedSizeListArray::try_new(train.data().as_ref(), dim as i32)?; let mut pq = ProductQuantizer::new(self.num_sub_vectors, self.num_bits, dim); - let pq_code = pq.fit_transform(&fixed_list, metric_type).await?; + let pq_code = pq.fit_transform(&train, metric_type).await?; // Reconstruct Y let mut builder = Float32Builder::with_capacity(train.num_columns() * train.num_rows()); diff --git a/rust/src/index/vector/pq.rs b/rust/src/index/vector/pq.rs index 6f7ac5b2d5..2438ecc4c8 100644 --- a/rust/src/index/vector/pq.rs +++ b/rust/src/index/vector/pq.rs @@ -26,6 +26,7 @@ use futures::stream::{self, StreamExt, TryStreamExt}; use rand::SeedableRng; use crate::arrow::*; +use crate::arrow::linalg::MatrixView; use crate::index::pb; use crate::index::vector::kmeans::train_kmeans; use crate::io::object_reader::{read_fixed_stride_array, ObjectReader}; @@ -400,12 +401,13 @@ impl ProductQuantizer { /// Train a [ProductQuantizer] using an array of vectors. pub async fn fit_transform( &mut self, - data: &FixedSizeListArray, + mat: &MatrixView, metric_type: MetricType, ) -> Result { - self.train(data, metric_type, 50).await?; + let data = FixedSizeListArray::try_new(mat.data().as_ref(), mat.num_columns() as i32)?; + self.train(&data, metric_type, 50).await?; - let sub_vectors = divide_to_subvectors(data, self.num_sub_vectors as i32); + let sub_vectors = divide_to_subvectors(&data, self.num_sub_vectors as i32); self.transform(&sub_vectors, metric_type).await } } @@ -430,7 +432,6 @@ impl From<&ProductQuantizer> for pb::Pq { num_sub_vectors: pq.num_sub_vectors as u32, dimension: pq.dimension as u32, codebook: pq.codebook.as_ref().unwrap().values().to_vec(), - opq: None, } } }