diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index e01125b8e6..ceccfe9ea1 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -356,6 +356,39 @@ def join( Not implemented (just override pyarrow dataset to prevent segfault) """ raise NotImplementedError("Versioning not yet supported in Rust") + + def merge( + self, + data_obj: ReaderLike, + left_on: str, + right_on: Optional[str] = None, + ): + """ + Merge another dataset into this one. + + Performs a left join, where the dataset is the left side and data_obj + is the right side. Rows existing in the dataset but not on the left will + be filled with null values, unless Lance doesn't support null values for + some types, in which case an error will be raised. + + Parameters + ---------- + data_obj: Reader-like + The data to be merged. Acceptable types are: + - Pandas DataFrame, Pyarrow Table, Dataset, Scanner, or RecordBatchReader + left_on: str + The name of the column in the dataset to join on. + right_on: str or None + The name of the column in data_obj to join on. If None, defaults to + left_on. + """ + if right_on is None: + right_on = left_on + + reader = _coerce_reader(data_obj) + + self._ds.merge(reader, left_on, right_on) + def versions(self): """ @@ -808,18 +841,7 @@ def write_dataset( The max number of rows before starting a new group (in the same file) """ - if isinstance(data_obj, pd.DataFrame): - reader = pa.Table.from_pandas(data_obj, schema=schema).to_reader() - elif isinstance(data_obj, pa.Table): - reader = data_obj.to_reader() - elif isinstance(data_obj, pa.dataset.Dataset): - reader = pa.dataset.Scanner.from_dataset(data_obj).to_reader() - elif isinstance(data_obj, pa.dataset.Scanner): - reader = data_obj.to_reader() - elif isinstance(data_obj, pa.RecordBatchReader): - reader = data_obj - else: - raise TypeError(f"Unknown data_obj type {type(data_obj)}") + reader = _coerce_reader(data_obj) # TODO add support for passing in LanceDataset and LanceScanner here params = { @@ -831,3 +853,18 @@ def write_dataset( uri = os.fspath(uri) if isinstance(uri, Path) else uri _write_dataset(reader, uri, params) return LanceDataset(uri) + + +def _coerce_reader(data_obj: ReaderLike, schema: Optional[pa.Schema] = None) -> pa.RecordBatchReader: + if isinstance(data_obj, pd.DataFrame): + return pa.Table.from_pandas(data_obj, schema=schema).to_reader() + elif isinstance(data_obj, pa.Table): + return data_obj.to_reader() + elif isinstance(data_obj, pa.dataset.Dataset): + return pa.dataset.Scanner.from_dataset(data_obj).to_reader() + elif isinstance(data_obj, pa.dataset.Scanner): + return data_obj.to_reader() + elif isinstance(data_obj, pa.RecordBatchReader): + return data_obj + else: + raise TypeError(f"Unknown data_obj type {type(data_obj)}") diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 88a5dfd68d..70d2442e31 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -357,3 +357,33 @@ def test_load_scanner_from_fragments(tmp_path: Path): # Accepts an iterator scanner = dataset.scanner(fragments=iter(fragments[0:2]), scan_in_order=False) assert scanner.to_table().num_rows == 2 * 100 + + +def test_merge_data(tmp_path: Path): + tab = pa.table({"a": range(100), "b": range(100)}) + lance.write_dataset(tab, tmp_path / "dataset", mode="append") + + dataset = lance.dataset(tmp_path / "dataset") + + # rejects partial data for non-nullable types + new_tab = pa.table({"a": range(40), "c": range(40)}) + # TODO: this should be ValueError + with pytest.raises(OSError, match=".+Lance does not yet support nulls for type Int64."): + dataset.merge(new_tab, "a") + + # accepts a full merge + new_tab = pa.table({"a": range(100), "c": range(100)}) + dataset.merge(new_tab, "a") + assert dataset.version == 2 + assert dataset.to_table() == pa.table({"a": range(100), "b": range(100), "c": range(100)}) + + # accepts a partial for string + new_tab = pa.table({"a2": range(5), "d": ["a", "b", "c", "d", "e"]}) + dataset.merge(new_tab, left_on="a", right_on="a2") + assert dataset.version == 3 + assert dataset.to_table() == pa.table({ + "a": range(100), + "b": range(100), + "c": range(100), + "d": ["a", "b", "c", "d", "e"] + [None] * 95 + }) diff --git a/python/src/dataset.rs b/python/src/dataset.rs index aa0d5962cc..09c37a0414 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -281,6 +281,22 @@ impl Dataset { batch.to_pyarrow(self_.py()) } + fn merge( + &mut self, + reader: PyArrowType, + left_on: &str, + right_on: &str, + ) -> PyResult<()> { + let mut reader: Box = Box::new(reader.0); + let mut new_self = self.ds.as_ref().clone(); + let fut = new_self.merge(&mut reader, left_on, right_on); + self.rt.block_on( + async move { fut.await.map_err(|err| PyIOError::new_err(err.to_string())) }, + )?; + self.ds = Arc::new(new_self); + Ok(()) + } + fn versions(self_: PyRef<'_, Self>) -> PyResult> { let versions = self_ .list_versions() diff --git a/rust/Cargo.toml b/rust/Cargo.toml index f795141065..5f2a7f7201 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -34,6 +34,7 @@ arrow-cast = "37.0.0" arrow-data = "37.0" arrow-ipc = { version = "37.0", features = ["zstd"] } arrow-ord = "37.0" +arrow-row = "37.0" arrow-schema = "37.0" arrow-select = "37.0" async-recursion = "1.0" @@ -41,6 +42,8 @@ async-trait = "0.1.60" byteorder = "1.4.3" chrono = "0.4.23" clap = { version = "4.1.1", features = ["derive"], optional = true } +# This is already used by datafusion +dashmap = "5" object_store = { version = "0.5.6", features = ["aws_profile", "gcp"] } reqwest = { version = "0.11.16" } aws-config = "0.54" diff --git a/rust/src/arrow/record_batch.rs b/rust/src/arrow/record_batch.rs index 683082515e..64d9f301fd 100644 --- a/rust/src/arrow/record_batch.rs +++ b/rust/src/arrow/record_batch.rs @@ -1,19 +1,16 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at +// Copyright 2023 Lance Developers. // -// http://www.apache.org/licenses/LICENSE-2.0 +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. //! Additional utility for [`RecordBatch`] //! @@ -23,6 +20,9 @@ use arrow_schema::{ArrowError, SchemaRef}; use crate::Result; +/// RecordBatchBuffer is a in-memory buffer for multiple [`RecordBatch`]s. +/// +/// #[derive(Debug)] pub struct RecordBatchBuffer { pub batches: Vec, @@ -69,3 +69,10 @@ impl Iterator for RecordBatchBuffer { } } } + +impl FromIterator for RecordBatchBuffer { + fn from_iter>(iter: T) -> Self { + let batches = iter.into_iter().collect::>(); + Self::new(batches) + } +} diff --git a/rust/src/dataset.rs b/rust/src/dataset.rs index a496b24e49..707ddcac58 100644 --- a/rust/src/dataset.rs +++ b/rust/src/dataset.rs @@ -30,6 +30,7 @@ use object_store::path::Path; use uuid::Uuid; pub mod fragment; +mod hash_joiner; pub mod scanner; pub mod updater; mod write; @@ -45,6 +46,7 @@ use crate::io::{ }; use crate::session::Session; use crate::{Error, Result}; +use hash_joiner::HashJoiner; pub use scanner::ROW_ID; pub use write::*; @@ -480,6 +482,91 @@ impl Dataset { }) } + /// Merge this dataset with another arrow Table / Dataset, and returns a new version of dataset. + /// + /// Parameters: + /// + /// - `stream`: the stream of [`RecordBatch`] to merge. + /// - `left_on`: the column name to join on the left side (self). + /// - `right_on`: the column name to join on the right side (stream). + /// + /// Returns: a new version of dataset. + /// + /// It performs a left-join on the two datasets. + pub async fn merge( + &mut self, + stream: &mut Box, + left_on: &str, + right_on: &str, + ) -> Result<()> { + // Sanity check. + if self.schema().field(left_on).is_none() { + return Err(Error::invalid_input(format!( + "Column {} does not exist in the left side dataset", + left_on + ))); + }; + let right_schema = stream.schema(); + if right_schema.field_with_name(right_on).is_err() { + return Err(Error::invalid_input(format!( + "Column {} does not exist in the right side dataset", + right_on + ))); + }; + for field in right_schema.fields() { + if field.name() == right_on { + // right_on is allowed to exist in the dataset, since it may be + // the same as left_on. + continue; + } + if self.schema().field(field.name()).is_some() { + return Err(Error::invalid_input(format!( + "Column {} exists in both sides of the dataset", + field.name() + ))); + } + } + + // Hash join + let joiner = Arc::new(HashJoiner::try_new(stream, right_on).await?); + // Final schema is union of current schema, plus the RHS schema without + // the right_on key. + let new_schema: Schema = self.schema().merge(joiner.out_schema().as_ref())?; + + // Write new data file to each fragment. Parallelism is done over columns, + // so no parallelism done at this level. + let updated_fragments: Vec = stream::iter(self.get_fragments()) + .then(|f| { + let joiner = joiner.clone(); + let full_schema = new_schema.clone(); + async move { + f.merge(left_on, &joiner, &full_schema) + .await + .map(|f| f.metadata) + } + }) + .try_collect::>() + .await?; + + // Inherit the index, since we are just adding columns. + let indices = self.load_indices().await?; + + let mut manifest = Manifest::new(&self.schema(), Arc::new(updated_fragments)); + manifest.version = self + .latest_manifest() + .await + .map(|m| m.version + 1) + .unwrap_or(1); + manifest.set_timestamp(None); + manifest.schema = new_schema; + + write_manifest_file(&self.object_store, &mut manifest, Some(indices)).await?; + + self.manifest = Arc::new(manifest); + + Ok(()) + } + /// Create a Scanner to scan the dataset. pub fn scan(&self) -> Scanner { Scanner::new(Arc::new(self.clone())) @@ -747,6 +834,7 @@ mod tests { use crate::{datatypes::Schema, utils::testing::generate_random_array}; use crate::dataset::WriteMode::Overwrite; + use arrow_array::Float32Array; use arrow_array::{ cast::{as_string_array, as_struct_array}, DictionaryArray, FixedSizeListArray, Int32Array, RecordBatch, StringArray, UInt16Array, @@ -1270,4 +1358,102 @@ mod tests { let result = Dataset::open(".").await; assert!(matches!(result.unwrap_err(), Error::DatasetNotFound { .. })); } + + #[tokio::test] + async fn test_merge() { + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("i", DataType::Int32, false), + Field::new("x", DataType::Float32, false), + ])); + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(Float32Array::from(vec![1.0, 2.0])), + ], + ) + .unwrap(); + let batch2 = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![3, 2])), + Arc::new(Float32Array::from(vec![3.0, 4.0])), + ], + ) + .unwrap(); + + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + let mut write_params = WriteParams::default(); + write_params.mode = WriteMode::Append; + + let mut batches: Box = + Box::new(RecordBatchBuffer::from_iter(vec![batch1])); + Dataset::write(&mut batches, test_uri, Some(write_params)) + .await + .unwrap(); + + let mut batches: Box = + Box::new(RecordBatchBuffer::from_iter(vec![batch2])); + Dataset::write(&mut batches, test_uri, Some(write_params)) + .await + .unwrap(); + + let dataset = Dataset::open(test_uri).await.unwrap(); + assert_eq!(dataset.fragments().len(), 2); + + let right_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("i2", DataType::Int32, false), + Field::new("y", DataType::Utf8, true), + ])); + let right_batch1 = RecordBatch::try_new( + right_schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["a", "b"])), + ], + ) + .unwrap(); + + let mut batches: Box = + Box::new(RecordBatchBuffer::from_iter(vec![right_batch1])); + let mut dataset = Dataset::open(test_uri).await.unwrap(); + dataset.merge(&mut batches, "i", "i2").await.unwrap(); + + assert_eq!(dataset.version().version, 3); + assert_eq!(dataset.fragments().len(), 2); + assert_eq!(dataset.fragments()[0].files.len(), 2); + assert_eq!(dataset.fragments()[1].files.len(), 2); + + let actual_batches = dataset + .scan() + .try_into_stream() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + let actual = concat_batches(&actual_batches[0].schema(), &actual_batches).unwrap(); + let expected = RecordBatch::try_new( + Arc::new(ArrowSchema::new(vec![ + Field::new("i", DataType::Int32, false), + Field::new("x", DataType::Float32, false), + Field::new("y", DataType::Utf8, true), + ])), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 2])), + Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0])), + Arc::new(StringArray::from(vec![ + Some("a"), + Some("b"), + None, + Some("b"), + ])), + ], + ) + .unwrap(); + + assert_eq!(actual, expected); + } } diff --git a/rust/src/dataset/fragment.rs b/rust/src/dataset/fragment.rs index adba67f920..62c2f64845 100644 --- a/rust/src/dataset/fragment.rs +++ b/rust/src/dataset/fragment.rs @@ -18,8 +18,12 @@ use std::ops::Range; use std::sync::Arc; use arrow_array::{RecordBatch, RecordBatchReader}; +use futures::{StreamExt, TryStreamExt}; use uuid::Uuid; +use super::hash_joiner::HashJoiner; +use super::scanner::Scanner; +use super::updater::Updater; use crate::arrow::*; use crate::dataset::{Dataset, DATA_DIR}; use crate::datatypes::Schema; @@ -27,13 +31,11 @@ use crate::format::Fragment; use crate::io::{FileReader, FileWriter, ObjectStore, ReadBatchParams}; use crate::{Error, Result}; -use super::scanner::Scanner; -use super::updater::Updater; use super::WriteParams; /// A Fragment of a Lance [`Dataset`]. /// -/// The interface is similar to `pyarrow.dataset.Fragment`. +/// The interface is modeled after `pyarrow.dataset.Fragment`. #[derive(Debug, Clone)] pub struct FileFragment { dataset: Arc, @@ -199,6 +201,46 @@ impl FileFragment { Ok(Updater::new(self.clone(), reader)) } + + /// Merge columns from joiner. + pub(crate) async fn merge( + mut self, + join_column: &str, + joiner: &HashJoiner, + full_schema: &Schema, + ) -> Result { + let mut scanner = self.scan(); + scanner.project(&[join_column])?; + + let mut batch_stream = scanner + .try_into_stream() + .await? + .and_then(|batch| joiner.collect(batch.column(0).clone())) + .boxed(); + + let file_schema = full_schema.project_by_schema(joiner.out_schema().as_ref())?; + + let filename = format!("{}.lance", Uuid::new_v4()); + let full_path = self + .dataset + .object_store + .base_path() + .child(DATA_DIR) + .child(filename.clone()); + let mut writer = + FileWriter::try_new(&self.dataset.object_store, &full_path, file_schema.clone()) + .await?; + + while let Some(batch) = batch_stream.try_next().await? { + writer.write(&[batch]).await?; + } + + writer.finish().await?; + + self.metadata.add_file(&filename, &file_schema); + + Ok(self) + } } impl From for Fragment { diff --git a/rust/src/dataset/hash_joiner.rs b/rust/src/dataset/hash_joiner.rs new file mode 100644 index 0000000000..1cbe6e70b0 --- /dev/null +++ b/rust/src/dataset/hash_joiner.rs @@ -0,0 +1,314 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! HashJoiner + +use std::sync::Arc; + +use arrow_array::ArrayRef; +use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchReader}; +use arrow_row::{OwnedRow, RowConverter, Rows, SortField}; +use arrow_schema::{DataType as ArrowDataType, SchemaRef}; +use arrow_select::interleave::interleave; +use dashmap::{DashMap, ReadOnlyView}; +use futures::{StreamExt, TryStreamExt}; +use tokio::task; + +use crate::datatypes::lance_supports_nulls; +use crate::{Error, Result}; + +/// `HashJoiner` does hash join on two datasets. +pub(crate) struct HashJoiner { + index_map: ReadOnlyView, + + index_type: ArrowDataType, + + batches: Vec, + + out_schema: SchemaRef, +} + +fn column_to_rows(column: ArrayRef) -> Result { + let mut row_converter = RowConverter::new(vec![SortField::new(column.data_type().clone())])?; + let rows = row_converter.convert_columns(&[column])?; + Ok(rows) +} + +impl HashJoiner { + /// Create a new `HashJoiner`, building the hash index. + /// + /// Will run in parallel over batches using all available cores. + pub async fn try_new(reader: &mut Box, on: &str) -> Result { + // Check column exist + reader.schema().field_with_name(on)?; + + // Hold all data in memory for simple implementation. Can do external sort later. + let batches = reader.collect::, _>>()?; + if batches.is_empty() { + return Err(Error::IO { + message: "HashJoiner: No data".to_string(), + }); + }; + + let map = DashMap::new(); + + let schema = reader.schema(); + + let keep_indices: Vec = schema + .fields() + .iter() + .enumerate() + .filter_map(|(i, field)| if field.name() == on { None } else { Some(i) }) + .collect(); + let out_schema: Arc = Arc::new(schema.project(&keep_indices)?); + let right_batches = batches + .iter() + .map(|batch| { + let mut columns = Vec::with_capacity(keep_indices.len()); + for i in &keep_indices { + columns.push(batch.column(*i).clone()); + } + RecordBatch::try_new(out_schema.clone(), columns).unwrap() + }) + .collect::>(); + + let map = Arc::new(map); + + futures::stream::iter(batches.iter().enumerate().map(Ok::<_, Error>)) + .try_for_each_concurrent(num_cpus::get(), |(batch_i, batch)| { + // A clone of map we can send to a new thread + let map = map.clone(); + async move { + let column = batch[on].clone(); + let task_result = task::spawn_blocking(move || { + let rows = column_to_rows(column)?; + for (row_i, row) in rows.iter().enumerate() { + map.insert(row.owned(), (batch_i, row_i)); + } + Ok(()) + }) + .await; + match task_result { + Ok(Ok(_)) => Ok(()), + Ok(Err(err)) => Err(err), + Err(err) => Err(Error::IO { + message: format!("HashJoiner: {}", err), + }), + } + } + }) + .await?; + + let map = Arc::try_unwrap(map) + .expect("HashJoiner: No remaining tasks should still be referencing map."); + let index_type = batches[0] + .schema() + .field_with_name(on) + .unwrap() + .data_type() + .clone(); + Ok(Self { + index_map: map.into_read_only(), + index_type, + batches: right_batches, + out_schema, + }) + } + + /// Returns the schema of data yielded by `collect()`. + /// + /// This excludes the index column on the right-hand side. + pub fn out_schema(&self) -> &SchemaRef { + &self.out_schema + } + + /// Collecting the data using the index column from left table. + /// + /// Will run in parallel over columns using all available cores. + pub(super) async fn collect(&self, index_column: ArrayRef) -> Result { + if index_column.data_type() != &self.index_type { + return Err(Error::invalid_input(format!( + "Index column type mismatch: expected {}, got {}", + self.index_type, + index_column.data_type() + ))); + } + + // Index to use for null values + let null_index = self.batches.len(); + + // Indices are a pair of (batch_i, row_i). We'll add a null batch at the + // end with one null element, and that's what we resolve when no match is + // found. + let indices = column_to_rows(index_column)? + .into_iter() + .map(|row| { + self.index_map + .get(&row.owned()) + .map(|(batch_i, row_i)| (*batch_i, *row_i)) + .unwrap_or((null_index, 0)) + }) + .collect::>(); + let indices = Arc::new(indices); + + // Do this in parallel over the columns + let columns = futures::stream::iter(0..self.batches[0].num_columns()) + .map(|column_i| { + // Use interleave to get the data + // https://docs.rs/arrow/40.0.0/arrow/compute/kernels/interleave/fn.interleave.html + // To handle nulls, we'll add an extra null array at the end + let mut arrays = Vec::with_capacity(self.batches.len() + 1); + for batch in &self.batches { + arrays.push(batch.column(column_i).clone()); + } + arrays.push(Arc::new(new_null_array(&arrays[0].data_type(), 1))); + + // Clone of indices we can send to a new thread + let indices = indices.clone(); + + async move { + let task_result = task::spawn_blocking(move || { + let array_refs = arrays.iter().map(|x| x.as_ref()).collect::>(); + interleave(array_refs.as_ref(), indices.as_ref()) + .map_err(|err| Error::IO { + message: format!("HashJoiner: {}", err), + }) + .map(|x| x.clone()) + }) + .await; + match task_result { + Ok(Ok(array)) => { + if array.null_count() > 0 && !lance_supports_nulls(array.data_type()) { + return Err(Error::invalid_input(format!( + "Found rows on LHS that do not match any rows on RHS. Lance would need to write \ + nulls on the RHS, but Lance does not yet support nulls for type {:?}.", + array.data_type() + ))); + } + Ok(array) + }, + Ok(Err(err)) => Err(err), + Err(err) => Err(Error::IO { + message: format!("HashJoiner: {}", err), + }), + } + } + }) + .buffered(num_cpus::get()) + .try_collect::>() + .await?; + + Ok(RecordBatch::try_new( + self.batches[0].schema().clone(), + columns, + )?) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + use std::sync::Arc; + + use arrow_array::{Int32Array, StringArray, UInt32Array}; + use arrow_schema::{DataType, Field, Schema}; + + use crate::arrow::RecordBatchBuffer; + + #[tokio::test] + async fn test_joiner_collect() { + let schema = Arc::new(Schema::new(vec![ + Field::new("i", DataType::Int32, true), + Field::new("s", DataType::Utf8, true), + ])); + + let batch_buffer: RecordBatchBuffer = (0..5) + .map(|v| { + let values = (v * 10..v * 10 + 10).collect::>(); + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter(values.iter().copied())), + Arc::new(StringArray::from_iter_values( + values.iter().map(|v| format!("str_{}", v)), + )), + ], + ) + .unwrap() + }) + .collect(); + let mut batch_buffer: Box = Box::new(batch_buffer); + let joiner = HashJoiner::try_new(&mut batch_buffer, "i").await.unwrap(); + + let indices = Arc::new(Int32Array::from_iter(&[ + Some(15), + None, + Some(10), + Some(0), + None, + None, + Some(22), + Some(11111), // not found + ])); + let results = joiner.collect(indices).await.unwrap(); + + assert_eq!( + results.column_by_name("s").unwrap().as_ref(), + &StringArray::from(vec![ + Some("str_15"), + None, + Some("str_10"), + Some("str_0"), + None, + None, + Some("str_22"), + None // 11111 not found + ]) + ); + + assert_eq!(results.num_columns(), 1); + } + + #[tokio::test] + async fn test_reject_invalid() { + let schema = Arc::new(Schema::new(vec![ + Field::new("i", DataType::Int32, true), + Field::new("s", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["a", "b"])), + ], + ) + .unwrap(); + let mut batch_buffer: Box = + Box::new(RecordBatchBuffer::from_iter(vec![batch])); + + let joiner = HashJoiner::try_new(&mut batch_buffer, "i").await.unwrap(); + + // Wrong type: was Int32, passing UInt32. + let indices = Arc::new(UInt32Array::from_iter(&[Some(15)])); + let result = joiner.collect(indices).await; + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Index column type mismatch: expected Int32, got UInt32")); + } +} diff --git a/rust/src/datatypes.rs b/rust/src/datatypes.rs index bfb5e794f3..3fa531171c 100644 --- a/rust/src/datatypes.rs +++ b/rust/src/datatypes.rs @@ -310,3 +310,16 @@ impl From<&Dictionary> for pb::Dictionary { } } } + +/// Returns true if Lance supports writing this datatype with nulls. +pub(crate) fn lance_supports_nulls(datatype: &DataType) -> bool { + match datatype { + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::List(_) + | DataType::FixedSizeBinary(_) + | DataType::FixedSizeList(_, _) => true, + _ => false, + } +} diff --git a/rust/src/error.rs b/rust/src/error.rs index 4cd3a65669..7a07ad40be 100644 --- a/rust/src/error.rs +++ b/rust/src/error.rs @@ -29,6 +29,8 @@ pub(crate) fn box_error(e: impl std::error::Error + Send + Sync + 'static) -> Bo #[derive(Debug, Snafu)] #[snafu(visibility(pub(crate)))] pub enum Error { + #[snafu(display("Invalid user input: {source}"))] + InvalidInput { source: BoxedError }, #[snafu(display("Attempt to write empty record batches"))] EmptyDataset, #[snafu(display("Dataset already exists: {uri}"))] @@ -63,6 +65,13 @@ impl Error { source: message.into(), } } + + pub fn invalid_input(message: impl Into) -> Self { + let message: String = message.into(); + Self::InvalidInput { + source: message.into(), + } + } } pub type Result = std::result::Result; diff --git a/rust/src/format/manifest.rs b/rust/src/format/manifest.rs index 2924dade02..3a8ac07fc3 100644 --- a/rust/src/format/manifest.rs +++ b/rust/src/format/manifest.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use chrono::prelude::*; use prost_types::Timestamp; +use std::time::SystemTime; use super::Fragment; use crate::datatypes::Schema; @@ -79,6 +80,16 @@ impl Manifest { ) } + /// Set the `timestamp_nanos` value from a Utc DateTime + pub fn set_timestamp(&mut self, timestamp: Option) { + let timestamp = timestamp.unwrap_or_else(SystemTime::now); + let nanos = timestamp + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_nanos(); + self.timestamp_nanos = nanos as u128; + } + /// Return the max fragment id. /// Note this does not support recycling of fragment ids. pub fn max_fragment_id(&self) -> Option { diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 9b8f54c1d9..745cc5767f 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -31,4 +31,5 @@ pub mod linalg; pub mod session; pub mod utils; +pub use dataset::Dataset; pub use error::{Error, Result};