Skip to content

Commit

Permalink
perf: avoid re-alloc on assigning PQ (#3399)
Browse files Browse the repository at this point in the history
fix #2837
fix #2838

Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal authored Jan 21, 2025
1 parent 2b784b3 commit 7f60aa0
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 21 deletions.
3 changes: 3 additions & 0 deletions rust/lance-index/src/vector/ivf/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use arrow_array::{
cast::AsArray, types::UInt32Type, Array, FixedSizeListArray, RecordBatch, UInt32Array,
};
use arrow_schema::Field;
use lance_table::utils::LanceIteratorExtension;
use snafu::{location, Location};
use tracing::instrument;

Expand Down Expand Up @@ -122,6 +123,8 @@ impl PartitionFilter {
None
}
})
// in most cases, no partition will be filtered out.
.exact_size(partition_ids.len())
.collect()
}
}
Expand Down
3 changes: 3 additions & 0 deletions rust/lance-index/src/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use lance_arrow::*;
use lance_core::{Error, Result};
use lance_linalg::distance::{DistanceType, Dot, L2};
use lance_linalg::kmeans::compute_partition;
use lance_table::utils::LanceIteratorExtension;
use num_traits::Float;
use prost::Message;
use snafu::{location, Location};
Expand Down Expand Up @@ -143,6 +144,7 @@ impl ProductQuantizer {

let flatten_data = fsl.values().as_primitive::<T>();
let sub_dim = dim / num_sub_vectors;
let total_code_length = fsl.len() * num_sub_vectors / (8 / NUM_BITS as usize);
let values = flatten_data
.values()
.chunks_exact(dim)
Expand All @@ -169,6 +171,7 @@ impl ProductQuantizer {
sub_vec_code
}
})
.exact_size(total_code_length)
.collect::<Vec<_>>();

let num_sub_vectors_in_byte = if NUM_BITS == 4 {
Expand Down
8 changes: 6 additions & 2 deletions rust/lance-index/src/vector/residual.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use std::iter;
use std::ops::{AddAssign, DivAssign};
use std::sync::Arc;

Expand All @@ -15,6 +16,7 @@ use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt};
use lance_core::{Error, Result};
use lance_linalg::distance::{DistanceType, Dot, L2};
use lance_linalg::kmeans::{compute_partitions, KMeansAlgoFloat};
use lance_table::utils::LanceIteratorExtension;
use num_traits::{Float, FromPrimitive, Num};
use snafu::{location, Location};
use tracing::instrument;
Expand Down Expand Up @@ -77,17 +79,19 @@ where
)
.into()
});
let part_ids = part_ids.values();

let vectors_slice = vectors.values();
let centroids_slice = centroids.values();
let residuals = vectors_slice
.chunks_exact(dimension)
.enumerate()
.flat_map(|(idx, vector)| {
let part_id = part_ids.value(idx) as usize;
let part_id = part_ids[idx] as usize;
let c = &centroids_slice[part_id * dimension..(part_id + 1) * dimension];
vector.iter().zip(c.iter()).map(|(v, cent)| *v - *cent)
iter::zip(vector, c).map(|(v, cent)| *v - *cent)
})
.exact_size(vectors.len() * dimension)
.collect::<Vec<_>>();
let residual_arr = PrimitiveArray::<T>::from_iter_values(residuals);
Ok(FixedSizeListArray::try_new_from_values(
Expand Down
33 changes: 14 additions & 19 deletions rust/lance-index/src/vector/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,25 +132,20 @@ impl Transformer for KeepFiniteVectors {
}
};

let valid = data
.iter()
.enumerate()
.filter_map(|(idx, arr)| {
arr.and_then(|data| {
let is_valid = match data.data_type() {
DataType::Float16 => is_all_finite::<Float16Type>(&data),
DataType::Float32 => is_all_finite::<Float32Type>(&data),
DataType::Float64 => is_all_finite::<Float64Type>(&data),
_ => false,
};
if is_valid {
Some(idx as u32)
} else {
None
}
})
})
.collect::<Vec<_>>();
let mut valid = Vec::with_capacity(batch.num_rows());
data.iter().enumerate().for_each(|(idx, arr)| {
if let Some(data) = arr {
let is_valid = match data.data_type() {
DataType::Float16 => is_all_finite::<Float16Type>(&data),
DataType::Float32 => is_all_finite::<Float32Type>(&data),
DataType::Float64 => is_all_finite::<Float64Type>(&data),
_ => false,
};
if is_valid {
valid.push(idx as u32);
}
};
});
if valid.len() < batch.num_rows() {
let indices = UInt32Array::from(valid);
Ok(batch.take(&indices)?)
Expand Down

0 comments on commit 7f60aa0

Please sign in to comment.