Skip to content

Commit

Permalink
feat: support merge by row_id, row_addr (#3254)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenkovsky authored Dec 18, 2024
1 parent ae36abe commit 95f98b3
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 5 deletions.
137 changes: 135 additions & 2 deletions rust/lance/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use itertools::Itertools;
use lance_core::traits::DatasetTakeRows;
use lance_core::utils::address::RowAddress;
use lance_core::utils::tokio::get_num_compute_intensive_cpus;
use lance_core::ROW_ADDR;
use lance_datafusion::projection::ProjectionPlan;
use lance_file::datatypes::populate_schema_dictionary;
use lance_file::version::LanceFileVersion;
Expand Down Expand Up @@ -1395,7 +1396,7 @@ impl Dataset {
right_on: &str,
) -> Result<()> {
// Sanity check.
if self.schema().field(left_on).is_none() {
if self.schema().field(left_on).is_none() && left_on != ROW_ID && left_on != ROW_ADDR {
return Err(Error::invalid_input(
format!("Column {} does not exist in the left side dataset", left_on),
location!(),
Expand Down Expand Up @@ -1661,7 +1662,7 @@ mod tests {
use crate::index::vector::VectorIndexParams;
use crate::utils::test::TestDatasetGenerator;

use arrow::array::as_struct_array;
use arrow::array::{as_struct_array, AsArray};
use arrow::compute::concat_batches;
use arrow_array::{
builder::StringDictionaryBuilder,
Expand Down Expand Up @@ -1691,6 +1692,7 @@ mod tests {
use lance_table::io::deletion::read_deletion_file;
use lance_testing::datagen::generate_random_array;
use pretty_assertions::assert_eq;
use rand::seq::SliceRandom;
use rstest::rstest;
use tempfile::{tempdir, TempDir};
use url::Url;
Expand Down Expand Up @@ -3131,6 +3133,137 @@ mod tests {
dataset.validate().await.unwrap();
}

#[rstest]
#[tokio::test]
async fn test_merge_on_row_id(
#[values(LanceFileVersion::Stable)] data_storage_version: LanceFileVersion,
#[values(false, true)] use_stable_row_id: bool,
) {
// Tests a merge on _rowid

let data = lance_datagen::gen()
.col("key", array::step::<Int32Type>())
.col("value", array::fill_utf8("value".to_string()))
.into_reader_rows(RowCount::from(1_000), BatchCount::from(10));

let write_params = WriteParams {
mode: WriteMode::Append,
data_storage_version: Some(data_storage_version),
max_rows_per_file: 1024,
max_rows_per_group: 150,
enable_move_stable_row_ids: use_stable_row_id,
..Default::default()
};
let mut dataset = Dataset::write(data, "memory://", Some(write_params.clone()))
.await
.unwrap();
assert_eq!(dataset.fragments().len(), 10);
assert_eq!(dataset.manifest.max_fragment_id(), Some(9));

let data = dataset.scan().with_row_id().try_into_batch().await.unwrap();
let row_ids: Arc<dyn Array> = data[ROW_ID].clone();
let key = data["key"].as_primitive::<Int32Type>();
let new_schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("rowid", DataType::UInt64, false),
ArrowField::new("new_value", DataType::Int32, false),
]));
let new_value = Arc::new(
key.into_iter()
.map(|v| v.unwrap() + 1)
.collect::<arrow_array::Int32Array>(),
);
let len = new_value.len() as u32;
let new_batch = RecordBatch::try_new(new_schema.clone(), vec![row_ids, new_value]).unwrap();
// shuffle new_batch
let mut rng = rand::thread_rng();
let mut indices: Vec<u32> = (0..len).collect();
indices.shuffle(&mut rng);
let indices = arrow_array::UInt32Array::from_iter_values(indices);
let new_batch = arrow::compute::take_record_batch(&new_batch, &indices).unwrap();
let new_data = RecordBatchIterator::new(vec![Ok(new_batch)], new_schema.clone());
dataset.merge(new_data, ROW_ID, "rowid").await.unwrap();
dataset.validate().await.unwrap();
assert_eq!(dataset.schema().fields.len(), 3);
assert!(dataset.schema().field("key").is_some());
assert!(dataset.schema().field("value").is_some());
assert!(dataset.schema().field("new_value").is_some());
let batch = dataset.scan().try_into_batch().await.unwrap();
let key = batch["key"].as_primitive::<Int32Type>();
let new_value = batch["new_value"].as_primitive::<Int32Type>();
for i in 0..key.len() {
assert_eq!(key.value(i) + 1, new_value.value(i));
}
}

#[rstest]
#[tokio::test]
async fn test_merge_on_row_addr(
#[values(LanceFileVersion::Stable)] data_storage_version: LanceFileVersion,
#[values(false, true)] use_stable_row_id: bool,
) {
// Tests a merge on _rowaddr

let data = lance_datagen::gen()
.col("key", array::step::<Int32Type>())
.col("value", array::fill_utf8("value".to_string()))
.into_reader_rows(RowCount::from(1_000), BatchCount::from(10));

let write_params = WriteParams {
mode: WriteMode::Append,
data_storage_version: Some(data_storage_version),
max_rows_per_file: 1024,
max_rows_per_group: 150,
enable_move_stable_row_ids: use_stable_row_id,
..Default::default()
};
let mut dataset = Dataset::write(data, "memory://", Some(write_params.clone()))
.await
.unwrap();

assert_eq!(dataset.fragments().len(), 10);
assert_eq!(dataset.manifest.max_fragment_id(), Some(9));

let data = dataset
.scan()
.with_row_address()
.try_into_batch()
.await
.unwrap();
let row_addrs = data[ROW_ADDR].clone();
let key = data["key"].as_primitive::<Int32Type>();
let new_schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("rowaddr", DataType::UInt64, false),
ArrowField::new("new_value", DataType::Int32, false),
]));
let new_value = Arc::new(
key.into_iter()
.map(|v| v.unwrap() + 1)
.collect::<arrow_array::Int32Array>(),
);
let len = new_value.len() as u32;
let new_batch =
RecordBatch::try_new(new_schema.clone(), vec![row_addrs, new_value]).unwrap();
// shuffle new_batch
let mut rng = rand::thread_rng();
let mut indices: Vec<u32> = (0..len).collect();
indices.shuffle(&mut rng);
let indices = arrow_array::UInt32Array::from_iter_values(indices);
let new_batch = arrow::compute::take_record_batch(&new_batch, &indices).unwrap();
let new_data = RecordBatchIterator::new(vec![Ok(new_batch)], new_schema.clone());
dataset.merge(new_data, ROW_ADDR, "rowaddr").await.unwrap();
dataset.validate().await.unwrap();
assert_eq!(dataset.schema().fields.len(), 3);
assert!(dataset.schema().field("key").is_some());
assert!(dataset.schema().field("value").is_some());
assert!(dataset.schema().field("new_value").is_some());
let batch = dataset.scan().try_into_batch().await.unwrap();
let key = batch["key"].as_primitive::<Int32Type>();
let new_value = batch["new_value"].as_primitive::<Int32Type>();
for i in 0..key.len() {
assert_eq!(key.value(i) + 1, new_value.value(i));
}
}

#[rstest]
#[tokio::test]
async fn test_delete(
Expand Down
11 changes: 8 additions & 3 deletions rust/lance/src/dataset/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use lance_core::datatypes::SchemaCompareOptions;
use lance_core::utils::deletion::DeletionVector;
use lance_core::utils::tokio::get_num_compute_intensive_cpus;
use lance_core::{datatypes::Schema, Error, Result};
use lance_core::{ROW_ADDR, ROW_ADDR_FIELD, ROW_ID_FIELD};
use lance_core::{ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD};
use lance_datafusion::utils::StreamingWriteSource;
use lance_encoding::decoder::DecoderPlugins;
use lance_file::reader::{read_batch, FileReader};
Expand Down Expand Up @@ -1285,11 +1285,14 @@ impl FileFragment {
let mut schema = self.dataset.schema().clone();

let mut with_row_addr = false;
let mut with_row_id = false;
if let Some(columns) = columns {
let mut projection = Vec::new();
for column in columns {
if column.as_ref() == ROW_ADDR {
with_row_addr = true;
} else if column.as_ref() == ROW_ID {
with_row_id = true;
} else {
projection.push(column.as_ref());
}
Expand All @@ -1305,11 +1308,13 @@ impl FileFragment {
}

// If there is no projection, we at least need to read the row addresses
with_row_addr |= schema.fields.is_empty();
with_row_addr |= !with_row_id && schema.fields.is_empty();

let reader = self.open(
&schema,
FragReadConfig::default().with_row_address(with_row_addr),
FragReadConfig::default()
.with_row_address(with_row_addr)
.with_row_id(with_row_id),
None,
);
let deletion_vector = read_deletion_file(
Expand Down

0 comments on commit 95f98b3

Please sign in to comment.