diff --git a/rust/lance-index/src/vector/ivf/transform.rs b/rust/lance-index/src/vector/ivf/transform.rs index e3294b5f54..ccd615c4e8 100644 --- a/rust/lance-index/src/vector/ivf/transform.rs +++ b/rust/lance-index/src/vector/ivf/transform.rs @@ -10,6 +10,7 @@ use arrow_array::{ cast::AsArray, types::UInt32Type, Array, FixedSizeListArray, RecordBatch, UInt32Array, }; use arrow_schema::Field; +use lance_table::utils::LanceIteratorExtension; use snafu::{location, Location}; use tracing::instrument; @@ -122,6 +123,8 @@ impl PartitionFilter { None } }) + // in most cases, no partition will be filtered out. + .exact_size(partition_ids.len()) .collect() } } diff --git a/rust/lance-index/src/vector/pq.rs b/rust/lance-index/src/vector/pq.rs index 7e325c1397..9b3a50cd67 100644 --- a/rust/lance-index/src/vector/pq.rs +++ b/rust/lance-index/src/vector/pq.rs @@ -16,6 +16,7 @@ use lance_arrow::*; use lance_core::{Error, Result}; use lance_linalg::distance::{DistanceType, Dot, L2}; use lance_linalg::kmeans::compute_partition; +use lance_table::utils::LanceIteratorExtension; use num_traits::Float; use prost::Message; use snafu::{location, Location}; @@ -143,6 +144,7 @@ impl ProductQuantizer { let flatten_data = fsl.values().as_primitive::(); let sub_dim = dim / num_sub_vectors; + let total_code_length = fsl.len() * num_sub_vectors / (8 / NUM_BITS as usize); let values = flatten_data .values() .chunks_exact(dim) @@ -169,6 +171,7 @@ impl ProductQuantizer { sub_vec_code } }) + .exact_size(total_code_length) .collect::>(); let num_sub_vectors_in_byte = if NUM_BITS == 4 { diff --git a/rust/lance-index/src/vector/residual.rs b/rust/lance-index/src/vector/residual.rs index 90730529b4..5afa168fbb 100644 --- a/rust/lance-index/src/vector/residual.rs +++ b/rust/lance-index/src/vector/residual.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::iter; use std::ops::{AddAssign, DivAssign}; use std::sync::Arc; @@ -15,6 +16,7 @@ use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt}; use lance_core::{Error, Result}; use lance_linalg::distance::{DistanceType, Dot, L2}; use lance_linalg::kmeans::{compute_partitions, KMeansAlgoFloat}; +use lance_table::utils::LanceIteratorExtension; use num_traits::{Float, FromPrimitive, Num}; use snafu::{location, Location}; use tracing::instrument; @@ -77,6 +79,7 @@ where ) .into() }); + let part_ids = part_ids.values(); let vectors_slice = vectors.values(); let centroids_slice = centroids.values(); @@ -84,10 +87,11 @@ where .chunks_exact(dimension) .enumerate() .flat_map(|(idx, vector)| { - let part_id = part_ids.value(idx) as usize; + let part_id = part_ids[idx] as usize; let c = ¢roids_slice[part_id * dimension..(part_id + 1) * dimension]; - vector.iter().zip(c.iter()).map(|(v, cent)| *v - *cent) + iter::zip(vector, c).map(|(v, cent)| *v - *cent) }) + .exact_size(vectors.len() * dimension) .collect::>(); let residual_arr = PrimitiveArray::::from_iter_values(residuals); Ok(FixedSizeListArray::try_new_from_values( diff --git a/rust/lance-index/src/vector/transform.rs b/rust/lance-index/src/vector/transform.rs index c3f5dd46fc..01e1fa4f81 100644 --- a/rust/lance-index/src/vector/transform.rs +++ b/rust/lance-index/src/vector/transform.rs @@ -132,25 +132,20 @@ impl Transformer for KeepFiniteVectors { } }; - let valid = data - .iter() - .enumerate() - .filter_map(|(idx, arr)| { - arr.and_then(|data| { - let is_valid = match data.data_type() { - DataType::Float16 => is_all_finite::(&data), - DataType::Float32 => is_all_finite::(&data), - DataType::Float64 => is_all_finite::(&data), - _ => false, - }; - if is_valid { - Some(idx as u32) - } else { - None - } - }) - }) - .collect::>(); + let mut valid = Vec::with_capacity(batch.num_rows()); + data.iter().enumerate().for_each(|(idx, arr)| { + if let Some(data) = arr { + let is_valid = match data.data_type() { + DataType::Float16 => is_all_finite::(&data), + DataType::Float32 => is_all_finite::(&data), + DataType::Float64 => is_all_finite::(&data), + _ => false, + }; + if is_valid { + valid.push(idx as u32); + } + }; + }); if valid.len() < batch.num_rows() { let indices = UInt32Array::from(valid); Ok(batch.take(&indices)?)