Skip to content

Commit

Permalink
Updated variable names and fixed python tests
Browse files Browse the repository at this point in the history
  • Loading branch information
trueutkarsh committed Jul 28, 2023
1 parent db612bd commit 7b43a62
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 44 deletions.
12 changes: 7 additions & 5 deletions python/python/tests/test_lance.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def test_nearest(tmp_path):
top10 = dataset.to_table(
nearest={"column": "emb", "q": arr[0].values, "k": 10, "nprobes": 10}
)
scores = l2sq(arr[0].values, npvals.reshape((100, 32)))
indices = np.argsort(scores)
distances = l2sq(arr[0].values, npvals.reshape((100, 32)))
indices = np.argsort(distances)
assert tbl.take(indices[:10]).to_pandas().equals(top10.to_pandas()[["emb"]])
assert np.allclose(scores[indices[:10]], top10.to_pandas().score.values)
assert np.allclose(distances[indices[:10]], top10.to_pandas()["_distance"].values)


def l2sq(vec, mat):
Expand All @@ -114,8 +114,10 @@ def test_nearest_cosine(tmp_path):
nearest={"column": "vector", "q": q, "k": 10, "metric": "cosine"}
).to_pandas()
for i in range(len(rs)):
assert rs.score[i] == pytest.approx(cosine_distance(rs.vector[i], q), abs=1e-6)
assert 0 <= rs.score[i] <= 1
assert rs["_distance"][i] == pytest.approx(
cosine_distance(rs.vector[i], q), abs=1e-6
)
assert 0 <= rs["_distance"][i] <= 1


def cosine_distance(vec1, vec2):
Expand Down
6 changes: 3 additions & 3 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def run(ds, q=None, assert_func=None):
expected_columns.extend(ds.schema.names)
else:
expected_columns.extend(columns)
for c in ["vector", "score"]:
for c in ["vector", "_distance"]:
if c not in expected_columns:
expected_columns.append(c)

Expand All @@ -96,8 +96,8 @@ def run(ds, q=None, assert_func=None):
assert len(inmem.to_table(filter=filter_)) == len(rs)
else:
assert len(rs) == 15
scores = rs["score"].to_numpy()
assert (scores.max() - scores.min()) > 1e-6
distances = rs["_distance"].to_numpy()
assert (distances.max() - distances.min()) > 1e-6
if assert_func is not None:
assert_func(rs)
return times
Expand Down
10 changes: 8 additions & 2 deletions rust/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1425,7 +1425,10 @@ mod test {

let take = &filter.children()[0];
let take = take.as_any().downcast_ref::<TakeExec>().unwrap();
assert_eq!(take.schema().field_names(), ["_distance", "_rowid", "vec", "i"]);
assert_eq!(
take.schema().field_names(),
["_distance", "_rowid", "vec", "i"]
);
assert_eq!(
take.extra_schema
.fields
Expand Down Expand Up @@ -1509,7 +1512,10 @@ mod test {

let take = &filter.children()[0];
let take = take.as_any().downcast_ref::<TakeExec>().unwrap();
assert_eq!(take.schema().field_names(), ["_distance", "_rowid", "vec", "i"]);
assert_eq!(
take.schema().field_names(),
["_distance", "_rowid", "vec", "i"]
);
assert_eq!(
take.extra_schema
.fields
Expand Down
11 changes: 7 additions & 4 deletions rust/src/index/vector/diskann/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,12 @@ impl VectorIndex for DiskANNIndex {
]));

let mut candidates = Vec::with_capacity(query.k);
for (score, row) in state.candidates {
for (distance, row) in state.candidates {
if candidates.len() == query.k {
break;
}
if !self.deletion_cache.as_ref().is_deleted(row as u64).await? {
candidates.push((score, row));
candidates.push((distance, row));
}
}

Expand All @@ -236,11 +236,14 @@ impl VectorIndex for DiskANNIndex {
.take(query.k)
.map(|(_, id)| *id as u64)
.collect();
let scores: Float32Array = candidates.iter().take(query.k).map(|(d, _)| **d).collect();
let distances: Float32Array = candidates.iter().take(query.k).map(|(d, _)| **d).collect();

let batch = RecordBatch::try_new(
schema,
vec![Arc::new(row_ids) as ArrayRef, Arc::new(scores) as ArrayRef],
vec![
Arc::new(row_ids) as ArrayRef,
Arc::new(distances) as ArrayRef,
],
)?;
Ok(batch)
}
Expand Down
18 changes: 10 additions & 8 deletions rust/src/index/vector/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub async fn flat_search(
let k = query.key.clone();
let mut batch = batch?;
if batch.column_by_name(DIST_COL).is_some() {
// Ignore the score calculated from inner vector index.
// Ignore the distance calculated from inner vector index.
batch = batch.drop_column(DIST_COL)?;
}
let vectors = batch
Expand All @@ -52,7 +52,7 @@ pub async fn flat_search(
})?
.clone();
let flatten_vectors = as_fixed_size_list_array(vectors.as_ref()).values().clone();
let scores = tokio::task::spawn_blocking(move || {
let distances = tokio::task::spawn_blocking(move || {
mt.batch_func()(
k.values(),
as_primitive_array::<Float32Type>(flatten_vectors.as_ref()).values(),
Expand All @@ -62,19 +62,21 @@ pub async fn flat_search(
.await? as ArrayRef;

// TODO: use heap
let indices = sort_to_indices(&scores, None, Some(query.k))?;
let batch_with_score = batch
.try_with_column(ArrowField::new(DIST_COL, DataType::Float32, false), scores)?;
let struct_arr = StructArray::from(batch_with_score);
let indices = sort_to_indices(&distances, None, Some(query.k))?;
let batch_with_distance = batch.try_with_column(
ArrowField::new(DIST_COL, DataType::Float32, false),
distances,
)?;
let struct_arr = StructArray::from(batch_with_distance);
let selected_arr = take(&struct_arr, &indices, None)?;
Ok::<RecordBatch, Error>(as_struct_array(&selected_arr).into())
})
.buffer_unordered(16)
.try_collect::<Vec<_>>()
.await?;
let batch = concat_batches(&batches[0].schema(), &batches)?;
let scores = batch.column_by_name(DIST_COL).unwrap();
let indices = sort_to_indices(scores, None, Some(query.k))?;
let distances = batch.column_by_name(DIST_COL).unwrap();
let indices = sort_to_indices(distances, None, Some(query.k))?;

let struct_arr = StructArray::from(batch);
let selected_arr = take(&struct_arr, &indices, None)?;
Expand Down
9 changes: 6 additions & 3 deletions rust/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,18 @@ impl VectorIndex for IVFIndex {
let batch = concat_batches(&batches[0].schema(), &batches)?;

let dist_col = batch.column_by_name("_distance").ok_or_else(|| Error::IO {
message: format!("_distance column does not exist in batch: {}", batch.schema()),
message: format!(
"_distance column does not exist in batch: {}",
batch.schema()
),
})?;

// TODO: Use a heap sort to get the top-k.
let limit = query.k * query.refine_factor.unwrap_or(1) as usize;
let selection = sort_to_indices(dist_col, None, Some(limit))?;
let struct_arr = StructArray::from(batch);
let taken_scores = take(&struct_arr, &selection, None)?;
Ok(as_struct_array(&taken_scores).into())
let taken_distances = take(&struct_arr, &selection, None)?;
Ok(as_struct_array(&taken_distances).into())
}

fn is_loadable(&self) -> bool {
Expand Down
20 changes: 10 additions & 10 deletions rust/src/index/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl PQIndex {
}
}

fn fast_l2_scores(&self, key: &Float32Array) -> Result<ArrayRef> {
fn fast_l2_distances(&self, key: &Float32Array) -> Result<ArrayRef> {
// Build distance table for each sub-centroid to the query key.
//
// Distance table: `[f32: num_sub_vectors(row) * num_centroids(column)]`.
Expand All @@ -110,7 +110,7 @@ impl PQIndex {
for i in 0..self.num_sub_vectors {
let from = key.slice(i * sub_vector_length, sub_vector_length);
let subvec_centroids = self.pq.centroids(i).ok_or_else(|| Error::Index {
message: "PQIndex::l2_scores: PQ is not initialized".to_string(),
message: "PQIndex::l2_distances: PQ is not initialized".to_string(),
})?;
let distances = l2_distance_batch(
as_primitive_array::<Float32Type>(&from).values(),
Expand Down Expand Up @@ -141,7 +141,7 @@ impl PQIndex {
}))
}

fn cosine_scores(&self, key: &Float32Array) -> Result<ArrayRef> {
fn cosine_distances(&self, key: &Float32Array) -> Result<ArrayRef> {
// Build two tables for cosine distance.
//
// xy table: `[f32: num_sub_vectors(row) * num_centroids(column)]`.
Expand All @@ -155,7 +155,7 @@ impl PQIndex {
for i in 0..self.num_sub_vectors {
let key_sub_vector: Float32Array = key.slice(i * sub_vector_length, sub_vector_length);
let sub_vector_centroids = self.pq.centroids(i).ok_or_else(|| Error::Index {
message: "PQIndex::cosine_scores: PQ is not initialized".to_string(),
message: "PQIndex::cosine_distances: PQ is not initialized".to_string(),
})?;
let xy = sub_vector_centroids
.as_ref()
Expand Down Expand Up @@ -231,22 +231,22 @@ impl VectorIndex for PQIndex {
let row_ids = self.row_ids.as_ref().unwrap();
assert_eq!(code.len() % self.num_sub_vectors, 0);

let scores = if self.metric_type == MetricType::L2 {
self.fast_l2_scores(&query.key)?
let distances = if self.metric_type == MetricType::L2 {
self.fast_l2_distances(&query.key)?
} else {
self.cosine_scores(&query.key)?
self.cosine_distances(&query.key)?
};

let limit = query.k * query.refine_factor.unwrap_or(1) as usize;
let indices = sort_to_indices(&scores, None, Some(limit))?;
let scores = take(&scores, &indices, None)?;
let indices = sort_to_indices(&distances, None, Some(limit))?;
let distances = take(&distances, &indices, None)?;
let row_ids = take(row_ids.as_ref(), &indices, None)?;

let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new(DIST_COL, DataType::Float32, false),
ArrowField::new(ROW_ID, DataType::UInt64, false),
]));
Ok(RecordBatch::try_new(schema, vec![scores, row_ids])?)
Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?)
}

fn is_loadable(&self) -> bool {
Expand Down
6 changes: 3 additions & 3 deletions rust/src/io/exec/knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl KNNFlatStream {
Ok(b) => b,
Err(e) => {
tx.send(Err(DataFusionError::Execution(format!(
"Failed to compute scores: {e}"
"Failed to compute distances: {e}"
))))
.await
.expect("KNNFlat failed to send message");
Expand Down Expand Up @@ -180,7 +180,7 @@ impl ExecutionPlan for KNNFlatExec {
self
}

/// Flat KNN inherits the schema from input node, and add one score column.
/// Flat KNN inherits the schema from input node, and add one distance column.
fn schema(&self) -> arrow_schema::SchemaRef {
let input_schema = self.input.schema();
let mut fields = input_schema.fields().to_vec();
Expand Down Expand Up @@ -261,7 +261,7 @@ impl KNNIndexStream {
Ok(b) => b,
Err(e) => {
tx.send(Err(datafusion::error::DataFusionError::Execution(format!(
"Failed to compute scores: {e}"
"Failed to compute distances: {e}"
))))
.await
.expect("KNNIndex failed to send message");
Expand Down
12 changes: 6 additions & 6 deletions rust/src/linalg/l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,14 @@ mod tests {
8,
);
let point = Float32Array::from((2..10).map(|v| Some(v as f32)).collect::<Vec<_>>());
let scores = l2_distance_batch(
let distances = l2_distance_batch(
point.values(),
as_primitive_array::<Float32Type>(mat.values().as_ref()).values(),
8,
);

assert_eq!(
scores.as_ref(),
distances.as_ref(),
&Float32Array::from(vec![32.0, 8.0, 0.0, 8.0])
);
}
Expand All @@ -254,20 +254,20 @@ mod tests {
.map(|v| v as f32)
.collect::<Vec<_>>();
let point = Float32Array::from((0..10).map(|v| Some(v as f32)).collect::<Vec<_>>());
let scores = l2_distance_batch(&point.values()[2..], &mat[6..], 8);
let distances = l2_distance_batch(&point.values()[2..], &mat[6..], 8);

assert_eq!(
scores.as_ref(),
distances.as_ref(),
&Float32Array::from(vec![32.0, 8.0, 0.0, 8.0])
);
}
#[test]
fn test_odd_length_vector() {
let mat = Float32Array::from_iter((0..5).map(|v| Some(v as f32)));
let point = Float32Array::from((2..7).map(|v| Some(v as f32)).collect::<Vec<_>>());
let scores = l2_distance_batch(point.values(), mat.values(), 5);
let distances = l2_distance_batch(point.values(), mat.values(), 5);

assert_eq!(scores.as_ref(), &Float32Array::from(vec![20.0]));
assert_eq!(distances.as_ref(), &Float32Array::from(vec![20.0]));
}

#[test]
Expand Down

0 comments on commit 7b43a62

Please sign in to comment.