Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust] Allow to provide pre-existing IVF_PQ centroids during index creation #963

Merged
merged 5 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/src/index/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ impl VectorIndexParams {
}
}

/// Create index parameters with `IVF` and `PQ` parameters, respectively.
pub fn with_ivf_pq_params(
metric_type: MetricType,
ivf: IvfBuildParams,
Expand Down
168 changes: 146 additions & 22 deletions rust/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ use super::{
use crate::{
arrow::{linalg::MatrixView, *},
dataset::{Dataset, ROW_ID},
datatypes::Field,
index::{pb, vector::Transformer, Index},
};
use crate::{io::object_reader::ObjectReader, session::Session};
Expand Down Expand Up @@ -413,7 +414,7 @@ impl TryFrom<&pb::Ivf> for Ivf {
}
}

fn sanity_check(dataset: &Dataset, column: &str) -> Result<()> {
fn sanity_check<'a>(dataset: &'a Dataset, column: &str) -> Result<&'a Field> {
let Some(field) = dataset.schema().field(column) else {
return Err(Error::IO{message:format!(
"Building index: column {} does not exist in dataset: {:?}",
Expand All @@ -435,7 +436,7 @@ fn sanity_check(dataset: &Dataset, column: &str) -> Result<()> {
),
});
}
Ok(())
Ok(field)
}

/// Parameters to build IVF partitions
Expand All @@ -447,13 +448,17 @@ pub struct IvfBuildParams {
// ---- kmeans parameters
/// Max number of iterations to train kmeans.
pub max_iters: usize,

/// Use provided IVF centroids.
pub centroids: Option<Arc<FixedSizeListArray>>,
}

impl Default for IvfBuildParams {
fn default() -> Self {
Self {
num_partitions: 32,
max_iters: 50,
centroids: None,
}
}
}
Expand All @@ -466,6 +471,27 @@ impl IvfBuildParams {
..Default::default()
}
}

/// Create a new instance of [`IvfBuildParams`] with centroids.
pub fn try_with_centroids(
num_partitions: usize,
centroids: Arc<FixedSizeListArray>,
) -> Result<Self> {
if num_partitions != centroids.len() {
return Err(Error::Index {
message: format!(
"IvfBuildParams::try_with_centroids: num_partitions {} != centroids.len() {}",
num_partitions,
centroids.len()
),
});
}
Ok(Self {
num_partitions,
centroids: Some(centroids),
..Default::default()
})
}
}

/// Compute residual matrix.
Expand Down Expand Up @@ -522,38 +548,69 @@ pub async fn build_ivf_pq_index(
metric_type,
);

sanity_check(dataset, column)?;
let field = sanity_check(dataset, column)?;
let dim = if let DataType::FixedSizeList(_, d) = field.data_type() {
d as usize
} else {
return Err(Error::Index {
message: format!(
"VectorIndex requires the column data type to be fixed size list of floats, got {}",
field.data_type()
),
});
};

// Maximum to train 256 vectors per centroids, see Faiss.
let sample_size_hint = std::cmp::max(
ivf_params.num_partitions,
ProductQuantizer::num_centroids(pq_params.num_bits as u32),
) * 256;

// TODO: only sample data if training is necessary.
let mut training_data = maybe_sample_training_data(dataset, column, sample_size_hint).await?;

// Pre-transforms
let mut transforms: Vec<Box<dyn Transformer>> = vec![];
if pq_params.use_opq {
let opq = train_opq(&training_data, pq_params).await?;
transforms.push(Box::new(opq));
}

// Transform training data if necessary.
for transform in transforms.iter() {
training_data = transform.transform(&training_data).await?;
}

// Train IVF partitions.
let ivf_model = train_ivf_model(&training_data, metric_type, ivf_params).await?;
let ivf_model = if let Some(centroids) = &ivf_params.centroids {
if centroids.values().len() != ivf_params.num_partitions * dim {
return Err(Error::Index {
message: format!(
"IVF centroids length mismatch: {} != {}",
centroids.len(),
ivf_params.num_partitions * dim,
),
});
}
Ivf::new(centroids.clone())
} else {
// Pre-transforms
if pq_params.use_opq {
let opq = train_opq(&training_data, pq_params).await?;
transforms.push(Box::new(opq));
}

// Transform training data if necessary.
for transform in transforms.iter() {
training_data = transform.transform(&training_data).await?;
}

train_ivf_model(&training_data, metric_type, ivf_params).await?
};

// Compute the residual vector for training PQ
let ivf_centroids = ivf_model.centroids.as_ref().try_into()?;
let residual_data = compute_residual_matrix(&training_data, &ivf_centroids, metric_type)?;
let pq_training_data = MatrixView::new(residual_data, training_data.num_columns());
let pq = if let Some(codebook) = &pq_params.codebook {
ProductQuantizer::new_with_codebook(
pq_params.num_sub_vectors,
pq_params.num_bits as u32,
dim,
codebook.clone(),
)
} else {
// Compute the residual vector for training PQ
let ivf_centroids = ivf_model.centroids.as_ref().try_into()?;
let residual_data = compute_residual_matrix(&training_data, &ivf_centroids, metric_type)?;
let pq_training_data = MatrixView::new(residual_data, training_data.num_columns());

// The final train of PQ sub-vectors
let pq = train_pq(&pq_training_data, pq_params).await?;
train_pq(&pq_training_data, pq_params).await?
};

// Transform data, compute residuals and sort by partition ids.
let mut scanner = dataset.scan();
Expand Down Expand Up @@ -730,3 +787,70 @@ async fn train_ivf_model(
data.num_columns() as i32,
)?)))
}

#[cfg(test)]
mod tests {
use super::*;

use arrow_array::cast::AsArray;
use arrow_schema::{DataType, Field, Schema};
use tempfile::tempdir;

use crate::{
index::{vector::VectorIndexParams, DatasetIndexExt, IndexType},
utils::testing::generate_random_array,
};

#[tokio::test]
async fn test_create_ivf_pq_with_centroids() {
const DIM: usize = 32;
let vectors = generate_random_array(1000 * DIM);

let schema = Arc::new(Schema::new(vec![Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
DIM as i32,
),
true,
)]));
let array = Arc::new(FixedSizeListArray::try_new(&vectors, DIM as i32).unwrap());
let batch = RecordBatch::try_new(schema, vec![array.clone()]).unwrap();

let test_dir = tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();

let mut batches: Box<dyn arrow_array::RecordBatchReader> =
Box::new(RecordBatchBuffer::new(vec![batch]));
let dataset = Dataset::write(&mut batches, test_uri, None).await.unwrap();

let centroids = generate_random_array(2 * DIM);
let ivf_centroids = FixedSizeListArray::try_new(&centroids, DIM as i32).unwrap();
let ivf_params = IvfBuildParams::try_with_centroids(2, Arc::new(ivf_centroids)).unwrap();

let codebook = Arc::new(generate_random_array(256 * DIM));
let pq_params = PQBuildParams::with_codebook(4, 8, codebook);

let params = VectorIndexParams::with_ivf_pq_params(MetricType::L2, ivf_params, pq_params);

let dataset = dataset
.create_index(&["vector"], IndexType::Vector, None, &params, false)
.await
.unwrap();

let elem = array.value(10);
let query = elem.as_primitive::<Float32Type>();
let results = dataset
.scan()
.nearest("vector", query, 5)
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(1, results.len());
assert_eq!(5, results[0].num_rows());
}
}
43 changes: 43 additions & 0 deletions rust/src/index/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,22 @@ impl ProductQuantizer {
}
}

/// Create a [`ProductQuantizer`] with pre-trained codebook.
pub fn new_with_codebook(
m: usize,
nbits: u32,
dimension: usize,
codebook: Arc<Float32Array>,
) -> Self {
assert!(nbits == 8, "nbits can only be 8");
Self {
num_bits: nbits,
num_sub_vectors: m,
dimension,
codebook: Some(codebook),
}
}

pub fn num_centroids(num_bits: u32) -> usize {
2_usize.pow(num_bits)
}
Expand Down Expand Up @@ -609,6 +625,9 @@ pub struct PQBuildParams {

/// Max number of iterations to train OPQ, if `use_opq` is true.
pub max_opq_iters: usize,

/// User provided codebook.
pub codebook: Option<Arc<Float32Array>>,
}

impl Default for PQBuildParams {
Expand All @@ -620,6 +639,30 @@ impl Default for PQBuildParams {
use_opq: false,
max_iters: 50,
max_opq_iters: 50,
codebook: None,
}
}
}

impl PQBuildParams {
pub fn new(num_sub_vectors: usize, num_bits: usize) -> Self {
Self {
num_sub_vectors,
num_bits,
..Default::default()
}
}

pub fn with_codebook(
num_sub_vectors: usize,
num_bits: usize,
codebook: Arc<Float32Array>,
) -> Self {
Self {
num_sub_vectors,
num_bits,
codebook: Some(codebook),
..Default::default()
}
}
}
Expand Down