Skip to content

Commit

Permalink
Improve indexing performance (#699)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored Mar 18, 2023
1 parent 4a688f6 commit 32d343e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 47 deletions.
18 changes: 14 additions & 4 deletions rust/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ impl TryFrom<&pb::Index> for IvfPQIndexMetadata {
}

/// Ivf Model
#[derive(Debug)]
#[derive(Debug, Clone)]
struct Ivf {
/// Centroids of each partition.
///
Expand Down Expand Up @@ -649,7 +649,11 @@ pub async fn build_ivf_pq_index(
.column_by_name(column)
.ok_or_else(|| Error::IO(format!("Dataset does not have column {column}")))?;
let vectors: MatrixView = as_fixed_size_list_array(arr).try_into()?;
let part_id_and_residual = ivf.compute_partition_and_residual(&vectors, metric_type)?;
let i = ivf.clone();
let part_id_and_residual = tokio::task::spawn_blocking(move || {
i.compute_partition_and_residual(&vectors, metric_type)
})
.await??;

let residual_col = part_id_and_residual
.column_by_name(RESIDUAL_COLUMN)
Expand All @@ -659,7 +663,10 @@ pub async fn build_ivf_pq_index(
.transform(&residual_data.try_into()?, metric_type)
.await?;

let row_ids = batch.column_by_name(ROW_ID).expect("Expect row id").clone();
let row_ids = batch
.column_by_name(ROW_ID)
.expect("Expect row id column")
.clone();
let part_ids = part_id_and_residual
.column_by_name(PARTITION_ID_COLUMN)
.expect("Expect partition ids column")
Expand All @@ -677,7 +684,10 @@ pub async fn build_ivf_pq_index(
false,
),
]));
RecordBatch::try_new(schema.clone(), vec![row_ids, part_ids, Arc::new(pq_code)])
Ok::<RecordBatch, Error>(RecordBatch::try_new(
schema.clone(),
vec![row_ids, part_ids, Arc::new(pq_code)],
)?)
})
.buffered(num_cpus::get())
.try_collect::<Vec<_>>()
Expand Down
72 changes: 29 additions & 43 deletions rust/src/index/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use arrow_array::{
use arrow_ord::sort::sort_to_indices;
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use arrow_select::take::take;
use futures::stream::{self, StreamExt, TryStreamExt};
use rand::SeedableRng;

use super::MetricType;
Expand All @@ -31,7 +30,7 @@ use crate::index::{pb, vector::kmeans::train_kmeans};
use crate::io::object_reader::{read_fixed_stride_array, ObjectReader};
use crate::utils::distance::compute::normalize;
use crate::utils::distance::l2::l2_distance;
use crate::Result;
use crate::{Error, Result};

/// Product Quantization Index.
///
Expand Down Expand Up @@ -305,51 +304,38 @@ impl ProductQuantizer {
data: &MatrixView,
metric_type: MetricType,
) -> Result<FixedSizeListArray> {
let sub_vectors = divide_to_subvectors(&data, self.num_sub_vectors);

assert_eq!(sub_vectors.len(), self.num_sub_vectors);

let vectors = sub_vectors.to_vec();
let all_centroids = (0..sub_vectors.len())
let all_centroids = (0..self.num_sub_vectors)
.map(|idx| self.centroids(idx))
.collect::<Vec<_>>();
let pq_code = stream::iter(vectors)
.zip(stream::iter(all_centroids))
.map(|(vec, centroid)| async move {
tokio::task::spawn_blocking(move || {
let dist_func = metric_type.func();
// TODO Use tiling to improve cache efficiency.
(0..vec.len())
.map(|i| {
let value = vec.value(i);
let vector: &Float32Array = as_primitive_array(value.as_ref());
let id = argmin(
dist_func(vector, centroid.as_ref(), vector.len())
.unwrap()
.as_ref(),
)
.unwrap() as u8;
id
})
.collect::<Vec<_>>()
})
.await
})
.buffered(num_cpus::get())
.try_collect::<Vec<_>>()
.await?;

// Need to transpose pq_code to column oriented.
let capacity = sub_vectors.len() * sub_vectors[0].len();
let mut pq_codebook_builder: Vec<u8> = vec![0; capacity];
for i in 0..pq_code.len() {
let vec = pq_code[i].as_slice();
for j in 0..vec.len() {
pq_codebook_builder[j * self.num_sub_vectors + i] = vec[j];
let dist_func = metric_type.func();

let flatten_data = data.data();
let num_sub_vectors = self.num_sub_vectors;
let dim = self.dimension;
let num_rows = data.num_rows();
let values = tokio::task::spawn_blocking(move || {
let capacity = num_sub_vectors * num_rows;
let mut builder: Vec<u8> = vec![0; capacity];
// Dimension of each sub-vector.
let sub_dim = dim / num_sub_vectors;
for i in 0..num_rows {
let row_offset = i * dim;
for sub_idx in 0..num_sub_vectors {
let offset = row_offset + sub_idx * sub_dim;
let sub_vector = flatten_data.slice(offset, sub_dim);
let centroids = all_centroids[sub_idx].as_ref();
let code = argmin(
dist_func(as_primitive_array(sub_vector.as_ref()), centroids, sub_dim)?
.as_ref(),
)
.unwrap();
builder[i * num_sub_vectors + sub_idx] = code as u8;
}
}
}
Ok::<UInt8Array, Error>(UInt8Array::from_iter_values(builder))
})
.await??;

let values = UInt8Array::from_iter_values(pq_codebook_builder);
FixedSizeListArray::try_new(values, self.num_sub_vectors as i32)
}

Expand Down

0 comments on commit 32d343e

Please sign in to comment.