diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 9b43efd7f9..fcd5959d71 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -15,6 +15,7 @@ use itertools::Itertools; use lance_core::traits::DatasetTakeRows; use lance_core::utils::address::RowAddress; use lance_core::utils::tokio::get_num_compute_intensive_cpus; +use lance_core::ROW_ADDR; use lance_datafusion::projection::ProjectionPlan; use lance_file::datatypes::populate_schema_dictionary; use lance_file::version::LanceFileVersion; @@ -1395,7 +1396,7 @@ impl Dataset { right_on: &str, ) -> Result<()> { // Sanity check. - if self.schema().field(left_on).is_none() { + if self.schema().field(left_on).is_none() && left_on != ROW_ID && left_on != ROW_ADDR { return Err(Error::invalid_input( format!("Column {} does not exist in the left side dataset", left_on), location!(), @@ -1661,7 +1662,7 @@ mod tests { use crate::index::vector::VectorIndexParams; use crate::utils::test::TestDatasetGenerator; - use arrow::array::as_struct_array; + use arrow::array::{as_struct_array, AsArray}; use arrow::compute::concat_batches; use arrow_array::{ builder::StringDictionaryBuilder, @@ -1691,6 +1692,7 @@ mod tests { use lance_table::io::deletion::read_deletion_file; use lance_testing::datagen::generate_random_array; use pretty_assertions::assert_eq; + use rand::seq::SliceRandom; use rstest::rstest; use tempfile::{tempdir, TempDir}; use url::Url; @@ -3131,6 +3133,137 @@ mod tests { dataset.validate().await.unwrap(); } + #[rstest] + #[tokio::test] + async fn test_merge_on_row_id( + #[values(LanceFileVersion::Stable)] data_storage_version: LanceFileVersion, + #[values(false, true)] use_stable_row_id: bool, + ) { + // Tests a merge on _rowid + + let data = lance_datagen::gen() + .col("key", array::step::()) + .col("value", array::fill_utf8("value".to_string())) + .into_reader_rows(RowCount::from(1_000), BatchCount::from(10)); + + let write_params = WriteParams { + mode: WriteMode::Append, + data_storage_version: Some(data_storage_version), + max_rows_per_file: 1024, + max_rows_per_group: 150, + enable_move_stable_row_ids: use_stable_row_id, + ..Default::default() + }; + let mut dataset = Dataset::write(data, "memory://", Some(write_params.clone())) + .await + .unwrap(); + assert_eq!(dataset.fragments().len(), 10); + assert_eq!(dataset.manifest.max_fragment_id(), Some(9)); + + let data = dataset.scan().with_row_id().try_into_batch().await.unwrap(); + let row_ids: Arc = data[ROW_ID].clone(); + let key = data["key"].as_primitive::(); + let new_schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("rowid", DataType::UInt64, false), + ArrowField::new("new_value", DataType::Int32, false), + ])); + let new_value = Arc::new( + key.into_iter() + .map(|v| v.unwrap() + 1) + .collect::(), + ); + let len = new_value.len() as u32; + let new_batch = RecordBatch::try_new(new_schema.clone(), vec![row_ids, new_value]).unwrap(); + // shuffle new_batch + let mut rng = rand::thread_rng(); + let mut indices: Vec = (0..len).collect(); + indices.shuffle(&mut rng); + let indices = arrow_array::UInt32Array::from_iter_values(indices); + let new_batch = arrow::compute::take_record_batch(&new_batch, &indices).unwrap(); + let new_data = RecordBatchIterator::new(vec![Ok(new_batch)], new_schema.clone()); + dataset.merge(new_data, ROW_ID, "rowid").await.unwrap(); + dataset.validate().await.unwrap(); + assert_eq!(dataset.schema().fields.len(), 3); + assert!(dataset.schema().field("key").is_some()); + assert!(dataset.schema().field("value").is_some()); + assert!(dataset.schema().field("new_value").is_some()); + let batch = dataset.scan().try_into_batch().await.unwrap(); + let key = batch["key"].as_primitive::(); + let new_value = batch["new_value"].as_primitive::(); + for i in 0..key.len() { + assert_eq!(key.value(i) + 1, new_value.value(i)); + } + } + + #[rstest] + #[tokio::test] + async fn test_merge_on_row_addr( + #[values(LanceFileVersion::Stable)] data_storage_version: LanceFileVersion, + #[values(false, true)] use_stable_row_id: bool, + ) { + // Tests a merge on _rowaddr + + let data = lance_datagen::gen() + .col("key", array::step::()) + .col("value", array::fill_utf8("value".to_string())) + .into_reader_rows(RowCount::from(1_000), BatchCount::from(10)); + + let write_params = WriteParams { + mode: WriteMode::Append, + data_storage_version: Some(data_storage_version), + max_rows_per_file: 1024, + max_rows_per_group: 150, + enable_move_stable_row_ids: use_stable_row_id, + ..Default::default() + }; + let mut dataset = Dataset::write(data, "memory://", Some(write_params.clone())) + .await + .unwrap(); + + assert_eq!(dataset.fragments().len(), 10); + assert_eq!(dataset.manifest.max_fragment_id(), Some(9)); + + let data = dataset + .scan() + .with_row_address() + .try_into_batch() + .await + .unwrap(); + let row_addrs = data[ROW_ADDR].clone(); + let key = data["key"].as_primitive::(); + let new_schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("rowaddr", DataType::UInt64, false), + ArrowField::new("new_value", DataType::Int32, false), + ])); + let new_value = Arc::new( + key.into_iter() + .map(|v| v.unwrap() + 1) + .collect::(), + ); + let len = new_value.len() as u32; + let new_batch = + RecordBatch::try_new(new_schema.clone(), vec![row_addrs, new_value]).unwrap(); + // shuffle new_batch + let mut rng = rand::thread_rng(); + let mut indices: Vec = (0..len).collect(); + indices.shuffle(&mut rng); + let indices = arrow_array::UInt32Array::from_iter_values(indices); + let new_batch = arrow::compute::take_record_batch(&new_batch, &indices).unwrap(); + let new_data = RecordBatchIterator::new(vec![Ok(new_batch)], new_schema.clone()); + dataset.merge(new_data, ROW_ADDR, "rowaddr").await.unwrap(); + dataset.validate().await.unwrap(); + assert_eq!(dataset.schema().fields.len(), 3); + assert!(dataset.schema().field("key").is_some()); + assert!(dataset.schema().field("value").is_some()); + assert!(dataset.schema().field("new_value").is_some()); + let batch = dataset.scan().try_into_batch().await.unwrap(); + let key = batch["key"].as_primitive::(); + let new_value = batch["new_value"].as_primitive::(); + for i in 0..key.len() { + assert_eq!(key.value(i) + 1, new_value.value(i)); + } + } + #[rstest] #[tokio::test] async fn test_delete( diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index d1d1d790ad..938ff646ab 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -22,7 +22,7 @@ use lance_core::datatypes::SchemaCompareOptions; use lance_core::utils::deletion::DeletionVector; use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_core::{datatypes::Schema, Error, Result}; -use lance_core::{ROW_ADDR, ROW_ADDR_FIELD, ROW_ID_FIELD}; +use lance_core::{ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD}; use lance_datafusion::utils::StreamingWriteSource; use lance_encoding::decoder::DecoderPlugins; use lance_file::reader::{read_batch, FileReader}; @@ -1285,11 +1285,14 @@ impl FileFragment { let mut schema = self.dataset.schema().clone(); let mut with_row_addr = false; + let mut with_row_id = false; if let Some(columns) = columns { let mut projection = Vec::new(); for column in columns { if column.as_ref() == ROW_ADDR { with_row_addr = true; + } else if column.as_ref() == ROW_ID { + with_row_id = true; } else { projection.push(column.as_ref()); } @@ -1305,11 +1308,13 @@ impl FileFragment { } // If there is no projection, we at least need to read the row addresses - with_row_addr |= schema.fields.is_empty(); + with_row_addr |= !with_row_id && schema.fields.is_empty(); let reader = self.open( &schema, - FragReadConfig::default().with_row_address(with_row_addr), + FragReadConfig::default() + .with_row_address(with_row_addr) + .with_row_id(with_row_id), None, ); let deletion_vector = read_deletion_file(