Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: vector search with distance range #3326

Merged
merged 5 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions java/core/lance-jni/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ pub fn get_query(env: &mut JNIEnv, query_obj: JObject) -> Result<Option<Query>>
column,
key,
k,
lower_bound: None,
upper_bound: None,
nprobes,
ef,
refine_factor,
Expand Down
6 changes: 6 additions & 0 deletions rust/lance-index/src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ pub struct Query {
/// Top k results to return.
pub k: usize,

/// The lower bound (inclusive) of the distance to be searched.
pub lower_bound: Option<f32>,

/// The upper bound (exclusive) of the distance to be searched.
pub upper_bound: Option<f32>,

/// The number of probes to load and search.
pub nprobes: usize,

Expand Down
80 changes: 46 additions & 34 deletions rust/lance-index/src/vector/flat/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use arrow::array::AsArray;
use arrow_array::{Array, ArrayRef, Float32Array, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use deepsize::DeepSizeOf;
use itertools::Itertools;
use lance_core::{Error, Result, ROW_ID_FIELD};
use lance_file::reader::FileReader;
use lance_linalg::distance::DistanceType;
Expand Down Expand Up @@ -44,11 +43,17 @@ lazy_static::lazy_static! {
}

#[derive(Default)]
pub struct FlatQueryParams {}
pub struct FlatQueryParams {
lower_bound: Option<f32>,
upper_bound: Option<f32>,
}

impl From<&Query> for FlatQueryParams {
fn from(_: &Query) -> Self {
Self {}
fn from(q: &Query) -> Self {
Self {
lower_bound: q.lower_bound,
upper_bound: q.upper_bound,
}
}
}

Expand All @@ -72,50 +77,57 @@ impl IvfSubIndex for FlatIndex {
&self,
query: ArrayRef,
k: usize,
_params: Self::QueryParams,
params: Self::QueryParams,
storage: &impl VectorStore,
prefilter: Arc<dyn PreFilter>,
) -> Result<RecordBatch> {
let dist_calc = storage.dist_calculator(query);

let (row_ids, dists): (Vec<u64>, Vec<f32>) = match prefilter.is_empty() {
true => dist_calc
.distance_all()
.into_iter()
.zip(0..storage.len() as u32)
.map(|(dist, id)| OrderedNode {
id,
dist: OrderedFloat(dist),
})
.sorted_unstable()
.take(k)
.map(
|OrderedNode {
id,
dist: OrderedFloat(dist),
}| (storage.row_id(id), dist),
)
.unzip(),
let mut res: Vec<_> = match prefilter.is_empty() {
true => {
let iter = dist_calc
.distance_all()
.into_iter()
.zip(0..storage.len() as u32)
.map(|(dist, id)| OrderedNode {
id,
dist: OrderedFloat(dist),
});

if params.lower_bound.is_some() || params.upper_bound.is_some() {
let lower_bound = params.lower_bound.unwrap_or(f32::MIN);
let upper_bound = params.upper_bound.unwrap_or(f32::MAX);
iter.filter(|r| lower_bound <= r.dist.0 && r.dist.0 < upper_bound)
.collect()
} else {
iter.collect()
}
}
false => {
let row_id_mask = prefilter.mask();
(0..storage.len())
let iter = (0..storage.len())
.filter(|&id| row_id_mask.selected(storage.row_id(id as u32)))
.map(|id| OrderedNode {
id: id as u32,
dist: OrderedFloat(dist_calc.distance(id as u32)),
})
.sorted_unstable()
.take(k)
.map(
|OrderedNode {
id,
dist: OrderedFloat(dist),
}| (storage.row_id(id), dist),
)
.unzip()
});
if params.lower_bound.is_some() || params.upper_bound.is_some() {
let lower_bound = params.lower_bound.unwrap_or(f32::MIN);
let upper_bound = params.upper_bound.unwrap_or(f32::MAX);
iter.filter(|r| lower_bound <= r.dist.0 && r.dist.0 < upper_bound)
.collect()
} else {
iter.collect()
}
}
};
res.sort_unstable();

let (row_ids, dists): (Vec<_>, Vec<_>) = res
.into_iter()
.take(k)
.map(|r| (storage.row_id(r.id), r.dist.0))
.unzip();
let (row_ids, dists) = (UInt64Array::from(row_ids), Float32Array::from(dists));

Ok(RecordBatch::try_new(
Expand Down
63 changes: 61 additions & 2 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use datafusion::physical_plan::{
ExecutionPlan, SendableRecordBatchStream,
};
use datafusion::scalar::ScalarValue;
use datafusion_expr::Operator;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::{Partitioning, PhysicalExpr};
use futures::future::BoxFuture;
Expand Down Expand Up @@ -705,6 +706,8 @@ impl Scanner {
column: column.to_string(),
key: key.into(),
k,
lower_bound: None,
upper_bound: None,
nprobes: 1,
ef: None,
refine_factor: None,
Expand All @@ -714,6 +717,19 @@ impl Scanner {
Ok(self)
}

/// Set the distance thresholds for the nearest neighbor search.
pub fn distance_range(
&mut self,
lower_bound: Option<f32>,
upper_bound: Option<f32>,
) -> &mut Self {
if let Some(q) = self.nearest.as_mut() {
q.lower_bound = lower_bound;
q.upper_bound = upper_bound;
}
self
}

pub fn nprobs(&mut self, n: usize) -> &mut Self {
if let Some(q) = self.nearest.as_mut() {
q.nprobes = n;
Expand Down Expand Up @@ -1994,16 +2010,59 @@ impl Scanner {
q.metric_type,
)?);

// filter out elements out of distance range
let lower_bound_expr = q
.lower_bound
.map(|v| {
let lower_bound = expressions::lit(v);
expressions::binary(
expressions::col(DIST_COL, flat_dist.schema().as_ref())?,
Operator::GtEq,
lower_bound,
flat_dist.schema().as_ref(),
)
})
.transpose()?;
let upper_bound_expr = q
.upper_bound
.map(|v| {
let upper_bound = expressions::lit(v);
expressions::binary(
expressions::col(DIST_COL, flat_dist.schema().as_ref())?,
Operator::Lt,
upper_bound,
flat_dist.schema().as_ref(),
)
})
.transpose()?;
let filter_expr = match (lower_bound_expr, upper_bound_expr) {
(Some(lower), Some(upper)) => Some(expressions::binary(
lower,
Operator::And,
upper,
flat_dist.schema().as_ref(),
)?),
(Some(lower), None) => Some(lower),
(None, Some(upper)) => Some(upper),
(None, None) => None,
};

let knn_plan: Arc<dyn ExecutionPlan> = if let Some(filter_expr) = filter_expr {
Arc::new(FilterExec::try_new(filter_expr, flat_dist)?)
} else {
flat_dist
};

// Use DataFusion's [SortExec] for Top-K search
let sort = SortExec::new(
vec![PhysicalSortExpr {
expr: expressions::col(DIST_COL, flat_dist.schema().as_ref())?,
expr: expressions::col(DIST_COL, knn_plan.schema().as_ref())?,
options: SortOptions {
descending: false,
nulls_first: false,
},
}],
flat_dist,
knn_plan,
)
.with_fetch(Some(q.k));

Expand Down
2 changes: 2 additions & 0 deletions rust/lance/src/index/vector/fixture_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ mod test {
column: "test".to_string(),
key: Arc::new(Float32Array::from(query)),
k: 1,
lower_bound: None,
upper_bound: None,
nprobes: 1,
ef: None,
refine_factor: None,
Expand Down
2 changes: 2 additions & 0 deletions rust/lance/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1978,6 +1978,8 @@ mod tests {
column: Self::COLUMN.to_string(),
key: Arc::new(row),
k: 5,
lower_bound: None,
upper_bound: None,
nprobes: 1,
ef: None,
refine_factor: None,
Expand Down
Loading
Loading