From 072ef8eb887ec982b89c8ab99d40ddacdeb3ec63 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 13 Apr 2023 12:31:07 -0700 Subject: [PATCH 01/16] hash join interface --- rust/src/dataset.rs | 50 ++++++++++++++++++++++++++++++++++ rust/src/dataset/hash_join.rs | 51 +++++++++++++++++++++++++++++++++++ rust/src/lib.rs | 1 + 3 files changed, 102 insertions(+) create mode 100644 rust/src/dataset/hash_join.rs diff --git a/rust/src/dataset.rs b/rust/src/dataset.rs index a496b24e49..111c25fedd 100644 --- a/rust/src/dataset.rs +++ b/rust/src/dataset.rs @@ -33,6 +33,7 @@ pub mod fragment; pub mod scanner; pub mod updater; mod write; +mod hash_join; use self::fragment::FileFragment; use self::scanner::Scanner; @@ -479,6 +480,55 @@ impl Dataset { session: Arc::new(Session::default()), }) } + + /// Merge this dataset with another arrow Table / Dataset, and returns a new version of dataset. + /// + /// Parameters: + /// + /// - `stream`: the stream of [`RecordBatch`] to merge. + /// - `left_on`: the column name to join on the left side (self). + /// - `right_on`: the column name to join on the right side (stream). + /// + /// Returns: a new version of dataset. + /// + /// It performs a left-join on the two datasets. + pub fn merge( + &self, + stream: &dyn RecordBatchReader, + left_on: &str, + right_on: &str, + ) -> Result { + // Sanity check. + if self.schema().field(left_on).is_none() { + return Err(Error::IO(format!( + "Column {} does not exist in the left side dataset", + left_on + ))); + }; + let right_schema = stream.schema(); + if right_schema.field_with_name(right_on).is_err() { + return Err(Error::IO(format!( + "Column {} does not exist in the right side dataset", + right_on + ))); + }; + for field in right_schema.fields() { + if field.name() == right_on { + continue; + } + if self.schema().field(field.name()).is_some() { + return Err(Error::IO(format!( + "Column {} exists in both sides of the dataset", + field.name() + ))); + } + } + + // Hash join + + + todo!() + } /// Create a Scanner to scan the dataset. pub fn scan(&self) -> Scanner { diff --git a/rust/src/dataset/hash_join.rs b/rust/src/dataset/hash_join.rs new file mode 100644 index 0000000000..97df676faa --- /dev/null +++ b/rust/src/dataset/hash_join.rs @@ -0,0 +1,51 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! HashJoiner + +use std::collections::HashMap; + +use arrow_array::RecordBatch; + +/// `HashJoiner` does hash join on two datasets. +pub(super) struct HashJoiner { + /// Hash value to row index map. + index_map: HashMap>, +} + +impl HashJoiner { + /// Create a new `HashJoiner`. + pub fn new() -> Self { + Self { + index_map: HashMap::new(), + } + } + + /// Append a batch to the hash joiner. + pub fn append_batch(&mut self, batch: &RecordBatch, on: &str, start_idx: usize) { + let hash_column = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..hash_column.len() { + let hash = hash_column.value(i); + let index = batch.index(); + self.index_map + .entry(hash) + .and_modify(|v| v.push(index)) + .or_insert(vec![index]); + } + } +} \ No newline at end of file diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 9b8f54c1d9..745cc5767f 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -31,4 +31,5 @@ pub mod linalg; pub mod session; pub mod utils; +pub use dataset::Dataset; pub use error::{Error, Result}; From efaa4ffe51553bba330361e1df8fde88aedb2912 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 13 Apr 2023 12:58:03 -0700 Subject: [PATCH 02/16] build hash joiner --- rust/src/arrow/record_batch.rs | 35 +++++++++++-------- rust/src/dataset.rs | 3 +- rust/src/dataset/hash_join.rs | 62 +++++++++++++++++++++++----------- 3 files changed, 65 insertions(+), 35 deletions(-) diff --git a/rust/src/arrow/record_batch.rs b/rust/src/arrow/record_batch.rs index 683082515e..64d9f301fd 100644 --- a/rust/src/arrow/record_batch.rs +++ b/rust/src/arrow/record_batch.rs @@ -1,19 +1,16 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at +// Copyright 2023 Lance Developers. // -// http://www.apache.org/licenses/LICENSE-2.0 +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. //! Additional utility for [`RecordBatch`] //! @@ -23,6 +20,9 @@ use arrow_schema::{ArrowError, SchemaRef}; use crate::Result; +/// RecordBatchBuffer is a in-memory buffer for multiple [`RecordBatch`]s. +/// +/// #[derive(Debug)] pub struct RecordBatchBuffer { pub batches: Vec, @@ -69,3 +69,10 @@ impl Iterator for RecordBatchBuffer { } } } + +impl FromIterator for RecordBatchBuffer { + fn from_iter>(iter: T) -> Self { + let batches = iter.into_iter().collect::>(); + Self::new(batches) + } +} diff --git a/rust/src/dataset.rs b/rust/src/dataset.rs index 111c25fedd..e4ada4a6a1 100644 --- a/rust/src/dataset.rs +++ b/rust/src/dataset.rs @@ -30,10 +30,10 @@ use object_store::path::Path; use uuid::Uuid; pub mod fragment; +mod hash_join; pub mod scanner; pub mod updater; mod write; -mod hash_join; use self::fragment::FileFragment; use self::scanner::Scanner; @@ -526,7 +526,6 @@ impl Dataset { // Hash join - todo!() } diff --git a/rust/src/dataset/hash_join.rs b/rust/src/dataset/hash_join.rs index 97df676faa..c846e1275c 100644 --- a/rust/src/dataset/hash_join.rs +++ b/rust/src/dataset/hash_join.rs @@ -16,36 +16,60 @@ use std::collections::HashMap; -use arrow_array::RecordBatch; +use arrow_array::{Array, RecordBatchReader}; + +use crate::arrow::RecordBatchBuffer; +use crate::{Error, Result}; /// `HashJoiner` does hash join on two datasets. pub(super) struct HashJoiner { /// Hash value to row index map. index_map: HashMap>, + + data: RecordBatchBuffer, + + on_column: String, +} + +/// Hash the values of the array. +fn hash_values(array: &dyn Array) -> Result> { + todo!() } impl HashJoiner { /// Create a new `HashJoiner`. - pub fn new() -> Self { - Self { + pub fn try_new(reader: &mut dyn RecordBatchReader, on: &str) -> Result { + // Check column exist + reader.schema().field_with_name(on)?; + + Ok(Self { index_map: HashMap::new(), - } + + // Hold all data in memory for simple implementation. Can do external sort later. + data: reader.collect::>()?, + + on_column: on.to_string(), + }) } - /// Append a batch to the hash joiner. - pub fn append_batch(&mut self, batch: &RecordBatch, on: &str, start_idx: usize) { - let hash_column = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - for i in 0..hash_column.len() { - let hash = hash_column.value(i); - let index = batch.index(); - self.index_map - .entry(hash) - .and_modify(|v| v.push(index)) - .or_insert(vec![index]); + pub fn build(&mut self) -> Result<()> { + let mut start_idx = 0; + + for batch in &self.data.batches { + let key_column = batch.column_by_name(&self.on_column).ok_or_else(|| { + Error::IO(format!("HashJoiner: Column {} not found", self.on_column)) + })?; + + let hashes = hash_values(key_column)?; + for (i, hash_value) in hashes.iter().enumerate() { + let idx = start_idx + i; + self.index_map + .entry(*hash_value) + .or_insert_with(Vec::new) + .push(idx); + } + start_idx += batch.num_rows(); } + Ok(()) } -} \ No newline at end of file +} From 84fa0cadbb03d135e74e89f838321b60a0200f90 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 13 Apr 2023 15:24:08 -0700 Subject: [PATCH 03/16] use hash key --- rust/src/dataset.rs | 2 +- .../dataset/{hash_join.rs => hash_joiner.rs} | 18 ++++++++---------- 2 files changed, 9 insertions(+), 11 deletions(-) rename rust/src/dataset/{hash_join.rs => hash_joiner.rs} (85%) diff --git a/rust/src/dataset.rs b/rust/src/dataset.rs index e4ada4a6a1..929cd75bb9 100644 --- a/rust/src/dataset.rs +++ b/rust/src/dataset.rs @@ -30,7 +30,7 @@ use object_store::path::Path; use uuid::Uuid; pub mod fragment; -mod hash_join; +mod hash_joiner; pub mod scanner; pub mod updater; mod write; diff --git a/rust/src/dataset/hash_join.rs b/rust/src/dataset/hash_joiner.rs similarity index 85% rename from rust/src/dataset/hash_join.rs rename to rust/src/dataset/hash_joiner.rs index c846e1275c..8f2badf249 100644 --- a/rust/src/dataset/hash_join.rs +++ b/rust/src/dataset/hash_joiner.rs @@ -18,7 +18,7 @@ use std::collections::HashMap; use arrow_array::{Array, RecordBatchReader}; -use crate::arrow::RecordBatchBuffer; +use crate::arrow::{RecordBatchBuffer, hash}; use crate::{Error, Result}; /// `HashJoiner` does hash join on two datasets. @@ -31,10 +31,6 @@ pub(super) struct HashJoiner { on_column: String, } -/// Hash the values of the array. -fn hash_values(array: &dyn Array) -> Result> { - todo!() -} impl HashJoiner { /// Create a new `HashJoiner`. @@ -60,13 +56,15 @@ impl HashJoiner { Error::IO(format!("HashJoiner: Column {} not found", self.on_column)) })?; - let hashes = hash_values(key_column)?; + let hashes = hash(key_column.as_ref())?; for (i, hash_value) in hashes.iter().enumerate() { let idx = start_idx + i; - self.index_map - .entry(*hash_value) - .or_insert_with(Vec::new) - .push(idx); + if let Some(v) = hash_value { + self.index_map + .entry(v) + .or_insert_with(Vec::new) + .push(idx); + } } start_idx += batch.num_rows(); } From 3a86e9d9fec87d58b3fcc0db082c8b5930560e6a Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 13 Apr 2023 17:02:57 -0700 Subject: [PATCH 04/16] collect from hash joiner --- rust/src/arrow/record_batch.rs | 35 ++++++++++++++++++++++++++++- rust/src/dataset.rs | 5 ++++- rust/src/dataset/hash_joiner.rs | 40 ++++++++++++++++++++++++--------- 3 files changed, 68 insertions(+), 12 deletions(-) diff --git a/rust/src/arrow/record_batch.rs b/rust/src/arrow/record_batch.rs index 64d9f301fd..19ec747653 100644 --- a/rust/src/arrow/record_batch.rs +++ b/rust/src/arrow/record_batch.rs @@ -15,8 +15,10 @@ //! Additional utility for [`RecordBatch`] //! -use arrow_array::{RecordBatch, RecordBatchReader}; +use arrow::array::as_struct_array; +use arrow_array::{Array, RecordBatch, RecordBatchReader, StructArray}; use arrow_schema::{ArrowError, SchemaRef}; +use arrow_select::interleave::interleave; use crate::Result; @@ -48,6 +50,37 @@ impl RecordBatchBuffer { pub fn finish(&self) -> Result> { Ok(self.batches.clone()) } + + fn make_interleaving_indices(&self, indices: &[usize]) -> Vec<(usize, usize)> { + let mut lengths = vec![0_usize]; + for batch in self.batches.iter() { + lengths.push(lengths.last().unwrap() + batch.num_rows()); + } + + let mut idx = vec![]; + for i in indices { + let batch_id = match lengths.binary_search(&i) { + Ok(i) => i, + Err(i) => i - 1, + }; + idx.push((batch_id, i - lengths[batch_id])); + } + idx + } + + /// Take rows by indices. + pub fn take(&self, indices: &[usize]) -> Result { + let arrays = self + .batches + .iter() + .map(|batch| StructArray::from(batch.clone())) + .collect::>(); + let refs = arrays.iter().map(|a| a as &dyn Array).collect::>(); + + let interleaving_indices = self.make_interleaving_indices(indices); + let array = interleave(&refs, &interleaving_indices)?; + Ok(as_struct_array(&array).into()) + } } impl RecordBatchReader for RecordBatchBuffer { diff --git a/rust/src/dataset.rs b/rust/src/dataset.rs index 929cd75bb9..6654c96c38 100644 --- a/rust/src/dataset.rs +++ b/rust/src/dataset.rs @@ -46,6 +46,7 @@ use crate::io::{ }; use crate::session::Session; use crate::{Error, Result}; +use hash_joiner::HashJoiner; pub use scanner::ROW_ID; pub use write::*; @@ -494,7 +495,7 @@ impl Dataset { /// It performs a left-join on the two datasets. pub fn merge( &self, - stream: &dyn RecordBatchReader, + stream: &mut dyn RecordBatchReader, left_on: &str, right_on: &str, ) -> Result { @@ -525,6 +526,8 @@ impl Dataset { } // Hash join + let mut joiner = HashJoiner::try_new(stream, right_on)?; + joiner.build()?; todo!() } diff --git a/rust/src/dataset/hash_joiner.rs b/rust/src/dataset/hash_joiner.rs index 8f2badf249..701a23c0a5 100644 --- a/rust/src/dataset/hash_joiner.rs +++ b/rust/src/dataset/hash_joiner.rs @@ -16,22 +16,21 @@ use std::collections::HashMap; -use arrow_array::{Array, RecordBatchReader}; +use arrow_array::{Array, RecordBatch, RecordBatchReader}; -use crate::arrow::{RecordBatchBuffer, hash}; +use crate::arrow::{hash, RecordBatchBuffer}; use crate::{Error, Result}; /// `HashJoiner` does hash join on two datasets. pub(super) struct HashJoiner { /// Hash value to row index map. - index_map: HashMap>, + index_map: HashMap, data: RecordBatchBuffer, on_column: String, } - impl HashJoiner { /// Create a new `HashJoiner`. pub fn try_new(reader: &mut dyn RecordBatchReader, on: &str) -> Result { @@ -48,7 +47,8 @@ impl HashJoiner { }) } - pub fn build(&mut self) -> Result<()> { + /// Build the hash index. + pub(super) fn build(&mut self) -> Result<()> { let mut start_idx = 0; for batch in &self.data.batches { @@ -59,15 +59,35 @@ impl HashJoiner { let hashes = hash(key_column.as_ref())?; for (i, hash_value) in hashes.iter().enumerate() { let idx = start_idx + i; - if let Some(v) = hash_value { - self.index_map - .entry(v) - .or_insert_with(Vec::new) - .push(idx); + let Some(key) = hash_value else { + continue; + }; + + if self.index_map.contains_key(&key) { + return Err(Error::IO(format!("HashJoiner: Duplicate key {}", key))); } + // TODO: use [`HashMap::try_insert`] when it's stable. + self.index_map.insert(key, idx); } start_idx += batch.num_rows(); } Ok(()) } + + /// Collecting the data using the index column from left table. + pub(super) fn collect(&self, index_column: &dyn Array) -> Result { + let hashes = hash(index_column)?; + let mut indices: Vec = Vec::with_capacity(index_column.len()); + for hash_value in hashes.iter() { + let Some(key) = hash_value else { + continue; + }; + + if let Some(idx) = self.index_map.get(&key) { + indices.push(*idx); + } + } + + todo!() + } } From 3239b2a2a35124e60f70964b40e40cba6531fd8d Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 13 Apr 2023 17:24:46 -0700 Subject: [PATCH 05/16] take rows --- rust/src/arrow/record_batch.rs | 49 +++++++++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/rust/src/arrow/record_batch.rs b/rust/src/arrow/record_batch.rs index 19ec747653..d47544eb85 100644 --- a/rust/src/arrow/record_batch.rs +++ b/rust/src/arrow/record_batch.rs @@ -15,8 +15,7 @@ //! Additional utility for [`RecordBatch`] //! -use arrow::array::as_struct_array; -use arrow_array::{Array, RecordBatch, RecordBatchReader, StructArray}; +use arrow_array::{cast::as_struct_array, Array, RecordBatch, RecordBatchReader, StructArray}; use arrow_schema::{ArrowError, SchemaRef}; use arrow_select::interleave::interleave; @@ -51,6 +50,8 @@ impl RecordBatchBuffer { Ok(self.batches.clone()) } + /// Make interleaving indices to be used with [`arrow_select::interleave::interleave`]. + /// fn make_interleaving_indices(&self, indices: &[usize]) -> Vec<(usize, usize)> { let mut lengths = vec![0_usize]; for batch in self.batches.iter() { @@ -69,7 +70,7 @@ impl RecordBatchBuffer { } /// Take rows by indices. - pub fn take(&self, indices: &[usize]) -> Result { + pub fn take_rows(&self, indices: &[usize]) -> Result { let arrays = self .batches .iter() @@ -109,3 +110,45 @@ impl FromIterator for RecordBatchBuffer { Self::new(batches) } } + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + + use arrow_array::{Float64Array, Int32Array, StringArray}; + use arrow_schema::{DataType, Field, Schema}; + + #[test] + fn test_take() { + let schema = Arc::new(Schema::new(vec![ + Field::new("i", DataType::Int32, false), + Field::new("f", DataType::Float64, false), + Field::new("s", DataType::Utf8, false), + ])); + + let batch_buffer: RecordBatchBuffer = (0..5) + .map(|v| { + let values = (v * 10..v * 10 + 10).collect::>(); + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter(values.iter().copied())), + Arc::new(Float64Array::from_iter(values.iter().map(|v| *v as f64))), + Arc::new(StringArray::from_iter_values( + values.iter().map(|v| format!("str_{}", v)), + )), + ], + ) + .unwrap() + }) + .collect(); + let batch = batch_buffer.take_rows(&[10, 14, 30, 49, 0, 22]).unwrap(); + assert_eq!(batch.num_rows(), 6); + assert_eq!( + batch.column_by_name("i").unwrap().as_ref(), + &Int32Array::from(vec![10, 14, 30, 49, 0, 22]) + ); + } +} From 09e91bd12110f99d1f8f5d5b5e05e90f00db5e8d Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 13 Apr 2023 18:02:03 -0700 Subject: [PATCH 06/16] handle null --- rust/src/arrow/record_batch.rs | 78 +------------------- rust/src/dataset/hash_joiner.rs | 126 +++++++++++++++++++++++++------- 2 files changed, 101 insertions(+), 103 deletions(-) diff --git a/rust/src/arrow/record_batch.rs b/rust/src/arrow/record_batch.rs index d47544eb85..64d9f301fd 100644 --- a/rust/src/arrow/record_batch.rs +++ b/rust/src/arrow/record_batch.rs @@ -15,9 +15,8 @@ //! Additional utility for [`RecordBatch`] //! -use arrow_array::{cast::as_struct_array, Array, RecordBatch, RecordBatchReader, StructArray}; +use arrow_array::{RecordBatch, RecordBatchReader}; use arrow_schema::{ArrowError, SchemaRef}; -use arrow_select::interleave::interleave; use crate::Result; @@ -49,39 +48,6 @@ impl RecordBatchBuffer { pub fn finish(&self) -> Result> { Ok(self.batches.clone()) } - - /// Make interleaving indices to be used with [`arrow_select::interleave::interleave`]. - /// - fn make_interleaving_indices(&self, indices: &[usize]) -> Vec<(usize, usize)> { - let mut lengths = vec![0_usize]; - for batch in self.batches.iter() { - lengths.push(lengths.last().unwrap() + batch.num_rows()); - } - - let mut idx = vec![]; - for i in indices { - let batch_id = match lengths.binary_search(&i) { - Ok(i) => i, - Err(i) => i - 1, - }; - idx.push((batch_id, i - lengths[batch_id])); - } - idx - } - - /// Take rows by indices. - pub fn take_rows(&self, indices: &[usize]) -> Result { - let arrays = self - .batches - .iter() - .map(|batch| StructArray::from(batch.clone())) - .collect::>(); - let refs = arrays.iter().map(|a| a as &dyn Array).collect::>(); - - let interleaving_indices = self.make_interleaving_indices(indices); - let array = interleave(&refs, &interleaving_indices)?; - Ok(as_struct_array(&array).into()) - } } impl RecordBatchReader for RecordBatchBuffer { @@ -110,45 +76,3 @@ impl FromIterator for RecordBatchBuffer { Self::new(batches) } } - -#[cfg(test)] -mod tests { - use super::*; - - use std::sync::Arc; - - use arrow_array::{Float64Array, Int32Array, StringArray}; - use arrow_schema::{DataType, Field, Schema}; - - #[test] - fn test_take() { - let schema = Arc::new(Schema::new(vec![ - Field::new("i", DataType::Int32, false), - Field::new("f", DataType::Float64, false), - Field::new("s", DataType::Utf8, false), - ])); - - let batch_buffer: RecordBatchBuffer = (0..5) - .map(|v| { - let values = (v * 10..v * 10 + 10).collect::>(); - RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter(values.iter().copied())), - Arc::new(Float64Array::from_iter(values.iter().map(|v| *v as f64))), - Arc::new(StringArray::from_iter_values( - values.iter().map(|v| format!("str_{}", v)), - )), - ], - ) - .unwrap() - }) - .collect(); - let batch = batch_buffer.take_rows(&[10, 14, 30, 49, 0, 22]).unwrap(); - assert_eq!(batch.num_rows(), 6); - assert_eq!( - batch.column_by_name("i").unwrap().as_ref(), - &Int32Array::from(vec![10, 14, 30, 49, 0, 22]) - ); - } -} diff --git a/rust/src/dataset/hash_joiner.rs b/rust/src/dataset/hash_joiner.rs index 701a23c0a5..d0a90d7711 100644 --- a/rust/src/dataset/hash_joiner.rs +++ b/rust/src/dataset/hash_joiner.rs @@ -16,9 +16,13 @@ use std::collections::HashMap; -use arrow_array::{Array, RecordBatch, RecordBatchReader}; +use arrow_array::StructArray; +use arrow_array::{ + builder::UInt64Builder, cast::as_struct_array, Array, RecordBatch, RecordBatchReader, +}; +use arrow_select::{concat::concat_batches, take::take}; -use crate::arrow::{hash, RecordBatchBuffer}; +use crate::arrow::hash; use crate::{Error, Result}; /// `HashJoiner` does hash join on two datasets. @@ -26,7 +30,7 @@ pub(super) struct HashJoiner { /// Hash value to row index map. index_map: HashMap, - data: RecordBatchBuffer, + batch: RecordBatch, on_column: String, } @@ -37,57 +41,127 @@ impl HashJoiner { // Check column exist reader.schema().field_with_name(on)?; + // Hold all data in memory for simple implementation. Can do external sort later. + let batches = reader.collect::, _>>()?; + if batches.is_empty() { + return Err(Error::IO("HashJoiner: No data".to_string())); + }; + let batch = concat_batches(&batches[0].schema(), &batches)?; + Ok(Self { index_map: HashMap::new(), - - // Hold all data in memory for simple implementation. Can do external sort later. - data: reader.collect::>()?, - + batch, on_column: on.to_string(), }) } /// Build the hash index. pub(super) fn build(&mut self) -> Result<()> { - let mut start_idx = 0; + let key_column = self + .batch + .column_by_name(&self.on_column) + .ok_or_else(|| Error::IO(format!("HashJoiner: Column {} not found", self.on_column)))?; - for batch in &self.data.batches { - let key_column = batch.column_by_name(&self.on_column).ok_or_else(|| { - Error::IO(format!("HashJoiner: Column {} not found", self.on_column)) - })?; - - let hashes = hash(key_column.as_ref())?; - for (i, hash_value) in hashes.iter().enumerate() { - let idx = start_idx + i; - let Some(key) = hash_value else { + let hashes = hash(key_column.as_ref())?; + for (i, hash_value) in hashes.iter().enumerate() { + let Some(key) = hash_value else { continue; }; - if self.index_map.contains_key(&key) { - return Err(Error::IO(format!("HashJoiner: Duplicate key {}", key))); - } - // TODO: use [`HashMap::try_insert`] when it's stable. - self.index_map.insert(key, idx); + if self.index_map.contains_key(&key) { + return Err(Error::IO(format!("HashJoiner: Duplicate key {}", key))); } - start_idx += batch.num_rows(); + // TODO: use [`HashMap::try_insert`] when it's stable. + self.index_map.insert(key, i); } + Ok(()) } /// Collecting the data using the index column from left table. pub(super) fn collect(&self, index_column: &dyn Array) -> Result { let hashes = hash(index_column)?; - let mut indices: Vec = Vec::with_capacity(index_column.len()); + let mut builder = UInt64Builder::with_capacity(index_column.len()); for hash_value in hashes.iter() { let Some(key) = hash_value else { + builder.append_null(); continue; }; if let Some(idx) = self.index_map.get(&key) { - indices.push(*idx); + builder.append_value(*idx as u64); + } else { + builder.append_null(); } } + let indices = builder.finish(); + + let struct_arr = StructArray::from(self.batch.clone()); + let results = take(&struct_arr, &indices, None)?; + Ok(as_struct_array(&results).into()) + } +} - todo!() +#[cfg(test)] +mod tests { + + use super::*; + + use arrow_array::{Int32Array, StringArray, UInt32Array}; + use arrow_schema::{DataType, Field, Schema}; + use std::sync::Arc; + + use crate::arrow::RecordBatchBuffer; + + #[test] + fn test_joiner_collect() { + let schema = Arc::new(Schema::new(vec![ + Field::new("i", DataType::Int32, false), + Field::new("s", DataType::Utf8, false), + ])); + + let mut batch_buffer: RecordBatchBuffer = (0..5) + .map(|v| { + let values = (v * 10..v * 10 + 10).collect::>(); + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter(values.iter().copied())), + Arc::new(StringArray::from_iter_values( + values.iter().map(|v| format!("str_{}", v)), + )), + ], + ) + .unwrap() + }) + .collect(); + let mut joiner = HashJoiner::try_new(&mut batch_buffer, "i").unwrap(); + joiner.build().unwrap(); + + let indices = UInt32Array::from_iter(&[ + Some(15), + None, + Some(10), + Some(0), + None, + None, + Some(22), + Some(11111), // not found + ]); + let results = joiner.collect(&indices).unwrap(); + + assert_eq!( + results.column_by_name("s").unwrap().as_ref(), + &StringArray::from(vec![ + Some("str_15"), + None, + Some("str_10"), + Some("str_0"), + None, + None, + Some("str_22"), + None // 11111 not found + ]) + ); } } From 911d0aa955130f822f254a730d286906d0c6a57b Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 14 Apr 2023 09:47:00 -0700 Subject: [PATCH 07/16] use fragment --- rust/src/dataset.rs | 4 ++++ rust/src/dataset/hash_joiner.rs | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/rust/src/dataset.rs b/rust/src/dataset.rs index 6654c96c38..029f69ec68 100644 --- a/rust/src/dataset.rs +++ b/rust/src/dataset.rs @@ -529,6 +529,10 @@ impl Dataset { let mut joiner = HashJoiner::try_new(stream, right_on)?; joiner.build()?; + // Write new data file to each fragment. + let mut new_fragments: Vec = vec![]; + for fragment in self.get_fragments().iter() {} + todo!() } diff --git a/rust/src/dataset/hash_joiner.rs b/rust/src/dataset/hash_joiner.rs index d0a90d7711..4791a58232 100644 --- a/rust/src/dataset/hash_joiner.rs +++ b/rust/src/dataset/hash_joiner.rs @@ -107,9 +107,10 @@ mod tests { use super::*; + use std::sync::Arc; + use arrow_array::{Int32Array, StringArray, UInt32Array}; use arrow_schema::{DataType, Field, Schema}; - use std::sync::Arc; use crate::arrow::RecordBatchBuffer; @@ -160,7 +161,7 @@ mod tests { None, None, Some("str_22"), - None // 11111 not found + None // 11111 not found ]) ); } From 0c9953953afd70574de09da1fd3b74eea0ec2985 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 14 Apr 2023 10:39:10 -0700 Subject: [PATCH 08/16] merge on fragment level --- rust/src/dataset.rs | 12 ++++++------ rust/src/dataset/fragment.rs | 7 +++++++ rust/src/dataset/hash_joiner.rs | 6 +++--- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/rust/src/dataset.rs b/rust/src/dataset.rs index 029f69ec68..399b1af971 100644 --- a/rust/src/dataset.rs +++ b/rust/src/dataset.rs @@ -501,27 +501,27 @@ impl Dataset { ) -> Result { // Sanity check. if self.schema().field(left_on).is_none() { - return Err(Error::IO(format!( + return Err(Error::IO { message: format!( "Column {} does not exist in the left side dataset", left_on - ))); + )}); }; let right_schema = stream.schema(); if right_schema.field_with_name(right_on).is_err() { - return Err(Error::IO(format!( + return Err(Error::IO { message: format!( "Column {} does not exist in the right side dataset", right_on - ))); + )}); }; for field in right_schema.fields() { if field.name() == right_on { continue; } if self.schema().field(field.name()).is_some() { - return Err(Error::IO(format!( + return Err(Error::IO { message: format!( "Column {} exists in both sides of the dataset", field.name() - ))); + )}); } } diff --git a/rust/src/dataset/fragment.rs b/rust/src/dataset/fragment.rs index adba67f920..9e989ae068 100644 --- a/rust/src/dataset/fragment.rs +++ b/rust/src/dataset/fragment.rs @@ -27,6 +27,7 @@ use crate::format::Fragment; use crate::io::{FileReader, FileWriter, ObjectStore, ReadBatchParams}; use crate::{Error, Result}; +use super::hash_joiner::HashJoiner; use super::scanner::Scanner; use super::updater::Updater; use super::WriteParams; @@ -199,6 +200,11 @@ impl FileFragment { Ok(Updater::new(self.clone(), reader)) } + + /// Merge columns from joiner. + pub async fn merge(&self, join_column: &str, joiner: &HashJoiner) -> Result { + todo!() + } } impl From for Fragment { @@ -315,6 +321,7 @@ impl FragmentReader { merge_batches(&batches) } + } #[cfg(test)] diff --git a/rust/src/dataset/hash_joiner.rs b/rust/src/dataset/hash_joiner.rs index 4791a58232..407e7a6e67 100644 --- a/rust/src/dataset/hash_joiner.rs +++ b/rust/src/dataset/hash_joiner.rs @@ -44,7 +44,7 @@ impl HashJoiner { // Hold all data in memory for simple implementation. Can do external sort later. let batches = reader.collect::, _>>()?; if batches.is_empty() { - return Err(Error::IO("HashJoiner: No data".to_string())); + return Err(Error::IO { message: "HashJoiner: No data".to_string()}); }; let batch = concat_batches(&batches[0].schema(), &batches)?; @@ -60,7 +60,7 @@ impl HashJoiner { let key_column = self .batch .column_by_name(&self.on_column) - .ok_or_else(|| Error::IO(format!("HashJoiner: Column {} not found", self.on_column)))?; + .ok_or_else(|| Error::IO { message: format!("HashJoiner: Column {} not found", self.on_column) })?; let hashes = hash(key_column.as_ref())?; for (i, hash_value) in hashes.iter().enumerate() { @@ -69,7 +69,7 @@ impl HashJoiner { }; if self.index_map.contains_key(&key) { - return Err(Error::IO(format!("HashJoiner: Duplicate key {}", key))); + return Err(Error::IO { message:format!("HashJoiner: Duplicate key {}", key) } ); } // TODO: use [`HashMap::try_insert`] when it's stable. self.index_map.insert(key, i); From b1ad06ef447e5b0edd14f0722fbb64cdd4289510 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Sat, 15 Apr 2023 10:21:36 -0700 Subject: [PATCH 09/16] change interface --- rust/src/dataset/fragment.rs | 4 ++-- rust/src/dataset/hash_joiner.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/rust/src/dataset/fragment.rs b/rust/src/dataset/fragment.rs index 9e989ae068..1332033747 100644 --- a/rust/src/dataset/fragment.rs +++ b/rust/src/dataset/fragment.rs @@ -34,7 +34,7 @@ use super::WriteParams; /// A Fragment of a Lance [`Dataset`]. /// -/// The interface is similar to `pyarrow.dataset.Fragment`. +/// The interface is modeled after `pyarrow.dataset.Fragment`. #[derive(Debug, Clone)] pub struct FileFragment { dataset: Arc, @@ -202,7 +202,7 @@ impl FileFragment { } /// Merge columns from joiner. - pub async fn merge(&self, join_column: &str, joiner: &HashJoiner) -> Result { + pub(crate) async fn merge(&self, join_column: &str, joiner: &HashJoiner) -> Result { todo!() } } diff --git a/rust/src/dataset/hash_joiner.rs b/rust/src/dataset/hash_joiner.rs index 407e7a6e67..81f146f450 100644 --- a/rust/src/dataset/hash_joiner.rs +++ b/rust/src/dataset/hash_joiner.rs @@ -26,7 +26,7 @@ use crate::arrow::hash; use crate::{Error, Result}; /// `HashJoiner` does hash join on two datasets. -pub(super) struct HashJoiner { +pub(crate) struct HashJoiner { /// Hash value to row index map. index_map: HashMap, From d79fe49a3ea6417e27c8857128e185fbcf8bdef4 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 4 May 2023 11:43:06 -0700 Subject: [PATCH 10/16] cargo fmat --- rust/src/dataset.rs | 2 +- rust/src/dataset/fragment.rs | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/rust/src/dataset.rs b/rust/src/dataset.rs index 399b1af971..c64db5c455 100644 --- a/rust/src/dataset.rs +++ b/rust/src/dataset.rs @@ -481,7 +481,7 @@ impl Dataset { session: Arc::new(Session::default()), }) } - + /// Merge this dataset with another arrow Table / Dataset, and returns a new version of dataset. /// /// Parameters: diff --git a/rust/src/dataset/fragment.rs b/rust/src/dataset/fragment.rs index 1332033747..f7b61858c6 100644 --- a/rust/src/dataset/fragment.rs +++ b/rust/src/dataset/fragment.rs @@ -20,6 +20,9 @@ use std::sync::Arc; use arrow_array::{RecordBatch, RecordBatchReader}; use uuid::Uuid; +use super::hash_joiner::HashJoiner; +use super::scanner::Scanner; +use super::updater::Updater; use crate::arrow::*; use crate::dataset::{Dataset, DATA_DIR}; use crate::datatypes::Schema; @@ -27,9 +30,6 @@ use crate::format::Fragment; use crate::io::{FileReader, FileWriter, ObjectStore, ReadBatchParams}; use crate::{Error, Result}; -use super::hash_joiner::HashJoiner; -use super::scanner::Scanner; -use super::updater::Updater; use super::WriteParams; /// A Fragment of a Lance [`Dataset`]. @@ -321,7 +321,6 @@ impl FragmentReader { merge_batches(&batches) } - } #[cfg(test)] From a6c10dba70eaed4b200a1664bc9d5e181dd7922b Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 2 Jun 2023 16:49:44 -0700 Subject: [PATCH 11/16] draft remainder of implementation --- rust/Cargo.toml | 3 + rust/src/dataset.rs | 65 ++++++---- rust/src/dataset/fragment.rs | 37 +++++- rust/src/dataset/hash_joiner.rs | 215 ++++++++++++++++++++++---------- rust/src/format/manifest.rs | 11 ++ 5 files changed, 242 insertions(+), 89 deletions(-) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index f795141065..5f2a7f7201 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -34,6 +34,7 @@ arrow-cast = "37.0.0" arrow-data = "37.0" arrow-ipc = { version = "37.0", features = ["zstd"] } arrow-ord = "37.0" +arrow-row = "37.0" arrow-schema = "37.0" arrow-select = "37.0" async-recursion = "1.0" @@ -41,6 +42,8 @@ async-trait = "0.1.60" byteorder = "1.4.3" chrono = "0.4.23" clap = { version = "4.1.1", features = ["derive"], optional = true } +# This is already used by datafusion +dashmap = "5" object_store = { version = "0.5.6", features = ["aws_profile", "gcp"] } reqwest = { version = "0.11.16" } aws-config = "0.54" diff --git a/rust/src/dataset.rs b/rust/src/dataset.rs index c64db5c455..e3f6d8989b 100644 --- a/rust/src/dataset.rs +++ b/rust/src/dataset.rs @@ -493,47 +493,70 @@ impl Dataset { /// Returns: a new version of dataset. /// /// It performs a left-join on the two datasets. - pub fn merge( - &self, + pub async fn merge( + &mut self, stream: &mut dyn RecordBatchReader, left_on: &str, right_on: &str, - ) -> Result { + ) -> Result<()> { // Sanity check. if self.schema().field(left_on).is_none() { - return Err(Error::IO { message: format!( - "Column {} does not exist in the left side dataset", - left_on - )}); + return Err(Error::IO { + message: format!("Column {} does not exist in the left side dataset", left_on), + }); }; let right_schema = stream.schema(); if right_schema.field_with_name(right_on).is_err() { - return Err(Error::IO { message: format!( - "Column {} does not exist in the right side dataset", - right_on - )}); + return Err(Error::IO { + message: format!( + "Column {} does not exist in the right side dataset", + right_on + ), + }); }; for field in right_schema.fields() { if field.name() == right_on { continue; } if self.schema().field(field.name()).is_some() { - return Err(Error::IO { message: format!( - "Column {} exists in both sides of the dataset", - field.name() - )}); + return Err(Error::IO { + message: format!( + "Column {} exists in both sides of the dataset", + field.name() + ), + }); } } // Hash join - let mut joiner = HashJoiner::try_new(stream, right_on)?; - joiner.build()?; + let joiner = Arc::new(HashJoiner::try_new(stream, right_on).await?); + + // Write new data file to each fragment. Parallism is done over columns, + // so no parallelism done at this level. + let updated_fragments: Vec = stream::iter(self.get_fragments()) + .then(|f| { + let joiner = joiner.clone(); + async move { f.merge(left_on, &joiner).await.map(|f| f.metadata) } + }) + .try_collect::>() + .await?; + + // Inherit the index, since we are just adding columns. + let indices = self.load_indices().await?; + + let mut manifest = Manifest::new(&self.schema(), Arc::new(updated_fragments)); + manifest.version = self + .latest_manifest() + .await + .map(|m| m.version + 1) + .unwrap_or(1); + manifest.set_timestamp(None); + + write_manifest_file(&self.object_store, &mut manifest, Some(indices)).await?; - // Write new data file to each fragment. - let mut new_fragments: Vec = vec![]; - for fragment in self.get_fragments().iter() {} + self.manifest = Arc::new(manifest); - todo!() + Ok(()) } /// Create a Scanner to scan the dataset. diff --git a/rust/src/dataset/fragment.rs b/rust/src/dataset/fragment.rs index f7b61858c6..e2d01fcc4c 100644 --- a/rust/src/dataset/fragment.rs +++ b/rust/src/dataset/fragment.rs @@ -18,6 +18,7 @@ use std::ops::Range; use std::sync::Arc; use arrow_array::{RecordBatch, RecordBatchReader}; +use futures::{StreamExt, TryStreamExt}; use uuid::Uuid; use super::hash_joiner::HashJoiner; @@ -202,8 +203,40 @@ impl FileFragment { } /// Merge columns from joiner. - pub(crate) async fn merge(&self, join_column: &str, joiner: &HashJoiner) -> Result { - todo!() + pub(crate) async fn merge(mut self, join_column: &str, joiner: &HashJoiner) -> Result { + let mut scanner = self.scan(); + scanner.project(&[join_column])?; + + let mut batch_stream = scanner + .try_into_stream() + .await? + .and_then(|batch| joiner.collect(batch.column(0).clone())) + .boxed(); + + let filename = format!("{}.lance", Uuid::new_v4()); + let full_path = self + .dataset + .object_store + .base_path() + .child(DATA_DIR) + .child(filename.clone()); + let mut writer = FileWriter::try_new( + &self.dataset.object_store, + &full_path, + self.dataset.schema().clone(), + ) + .await?; + + while let Some(batch) = batch_stream.try_next().await? { + writer.write(&[batch]).await?; + } + + writer.finish().await?; + + let schema = crate::datatypes::Schema::try_from(joiner.out_schema().as_ref())?; + self.metadata.add_file(full_path.as_ref(), &schema); + + Ok(self) } } diff --git a/rust/src/dataset/hash_joiner.rs b/rust/src/dataset/hash_joiner.rs index 81f146f450..bc9194a845 100644 --- a/rust/src/dataset/hash_joiner.rs +++ b/rust/src/dataset/hash_joiner.rs @@ -14,91 +14,175 @@ //! HashJoiner -use std::collections::HashMap; +use std::sync::Arc; -use arrow_array::StructArray; -use arrow_array::{ - builder::UInt64Builder, cast::as_struct_array, Array, RecordBatch, RecordBatchReader, -}; -use arrow_select::{concat::concat_batches, take::take}; +use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchReader}; +use arrow_array::ArrayRef; +use arrow_row::{OwnedRow, RowConverter, Rows, SortField}; +use arrow_schema::SchemaRef; +use arrow_select::interleave::interleave; +use dashmap::{DashMap, ReadOnlyView}; +use futures::{StreamExt, TryStreamExt}; +use tokio::task; -use crate::arrow::hash; use crate::{Error, Result}; /// `HashJoiner` does hash join on two datasets. pub(crate) struct HashJoiner { - /// Hash value to row index map. - index_map: HashMap, + index_map: ReadOnlyView, - batch: RecordBatch, + batches: Vec, - on_column: String, + out_schema: SchemaRef, +} + +fn column_to_rows(column: ArrayRef) -> Result { + let mut row_converter = RowConverter::new(vec![SortField::new(column.data_type().clone())])?; + let rows = row_converter.convert_columns(&[column])?; + Ok(rows) } impl HashJoiner { - /// Create a new `HashJoiner`. - pub fn try_new(reader: &mut dyn RecordBatchReader, on: &str) -> Result { + /// Create a new `HashJoiner`, building the hash index. + /// + /// Will run in parallel over batches using all available cores. + pub async fn try_new(reader: &mut dyn RecordBatchReader, on: &str) -> Result { // Check column exist reader.schema().field_with_name(on)?; // Hold all data in memory for simple implementation. Can do external sort later. let batches = reader.collect::, _>>()?; if batches.is_empty() { - return Err(Error::IO { message: "HashJoiner: No data".to_string()}); + return Err(Error::IO { + message: "HashJoiner: No data".to_string(), + }); }; - let batch = concat_batches(&batches[0].schema(), &batches)?; + + let map = DashMap::new(); + + let schema = reader.schema(); + + let keep_indices: Vec = schema + .fields() + .iter() + .enumerate() + .filter_map(|(i, field)| if field.name() == on { None } else { Some(i) }) + .collect(); + let out_schema: Arc = Arc::new(schema.project(&keep_indices)?); + let right_batches = batches + .iter() + .map(|batch| { + let mut columns = Vec::with_capacity(keep_indices.len()); + for i in &keep_indices { + columns.push(batch.column(*i).clone()); + } + RecordBatch::try_new(out_schema.clone(), columns).unwrap() + }) + .collect::>(); + + let map_ref = ↦ + + futures::stream::iter(batches.iter().enumerate().map(Ok::<_, Error>)) + // Not sure if this can actually run in parallel though + // TODO: use spawn_blocking instead + .try_for_each_concurrent(num_cpus::get(), |(batch_i, batch)| async move { + let column = batch[on].clone(); + let map_ref = map_ref.clone(); + let task_result = task::spawn(async move { + let rows = column_to_rows(column)?; + for (row_i, row) in rows.iter().enumerate() { + map_ref.insert(row.owned(), (batch_i, row_i)); + } + Ok(()) + }) + .await; + match task_result { + Ok(Ok(_)) => Ok(()), + Ok(Err(err)) => Err(err), + Err(err) => Err(Error::IO { + message: format!("HashJoiner: {}", err), + }), + } + }) + .await?; Ok(Self { - index_map: HashMap::new(), - batch, - on_column: on.to_string(), + index_map: map.into_read_only(), + batches: right_batches, + out_schema, }) } - /// Build the hash index. - pub(super) fn build(&mut self) -> Result<()> { - let key_column = self - .batch - .column_by_name(&self.on_column) - .ok_or_else(|| Error::IO { message: format!("HashJoiner: Column {} not found", self.on_column) })?; - - let hashes = hash(key_column.as_ref())?; - for (i, hash_value) in hashes.iter().enumerate() { - let Some(key) = hash_value else { - continue; - }; - - if self.index_map.contains_key(&key) { - return Err(Error::IO { message:format!("HashJoiner: Duplicate key {}", key) } ); - } - // TODO: use [`HashMap::try_insert`] when it's stable. - self.index_map.insert(key, i); - } - - Ok(()) + pub fn out_schema(&self) -> &SchemaRef { + &self.out_schema } /// Collecting the data using the index column from left table. - pub(super) fn collect(&self, index_column: &dyn Array) -> Result { - let hashes = hash(index_column)?; - let mut builder = UInt64Builder::with_capacity(index_column.len()); - for hash_value in hashes.iter() { - let Some(key) = hash_value else { - builder.append_null(); - continue; - }; - - if let Some(idx) = self.index_map.get(&key) { - builder.append_value(*idx as u64); - } else { - builder.append_null(); - } - } - let indices = builder.finish(); - - let struct_arr = StructArray::from(self.batch.clone()); - let results = take(&struct_arr, &indices, None)?; - Ok(as_struct_array(&results).into()) + /// + /// Will run in parallel over columns using all available cores. + pub(super) async fn collect(&self, index_column: ArrayRef) -> Result { + // Index to use for null values + let null_index = self.batches.len(); + + let index_type = index_column.data_type().clone(); + + // collect indices + let indices = column_to_rows(index_column)? + .into_iter() + .map(|row| { + self.index_map + .get(&row.owned()) + .map(|(batch_i, row_i)| (*batch_i, *row_i)) + .unwrap_or((null_index, 0)) + }) + .collect::>(); + let indices = Arc::new(indices); + + // Use interleave to get the data + // https://docs.rs/arrow/40.0.0/arrow/compute/kernels/interleave/fn.interleave.html + // To handle nulls, we'll add an extra null array at the end + let null_array = Arc::new(new_null_array(&index_type, 1)); + + // Do this in parallel over the columns + let columns = futures::stream::iter(0..self.batches[0].num_columns()) + .map(|column_i| { + // TODO: we need to drop the join_on column + + let mut arrays = Vec::with_capacity(self.batches.len() + 1); + for batch in &self.batches { + arrays.push(batch.column(column_i).clone()); + } + arrays.push(null_array.clone()); + + let indices = indices.clone(); + + async move { + let task_result = task::spawn(async move { + let array_refs = arrays.iter().map(|x| x.as_ref()).collect::>(); + interleave(array_refs.as_ref(), indices.as_ref()) + .map_err(|err| Error::IO { + message: format!("HashJoiner: {}", err), + }) + .map(|x| x.clone()) + }) + .await; + match task_result { + Ok(Ok(array)) => Ok(array), + Ok(Err(err)) => Err(err), + Err(err) => Err(Error::IO { + message: format!("HashJoiner: {}", err), + }), + } + } + }) + .buffered(num_cpus::get()) + .try_collect::>() + .await?; + + Ok(RecordBatch::try_new( + self.batches[0].schema().clone(), + columns, + )?) } } @@ -114,8 +198,8 @@ mod tests { use crate::arrow::RecordBatchBuffer; - #[test] - fn test_joiner_collect() { + #[tokio::test] + async fn test_joiner_collect() { let schema = Arc::new(Schema::new(vec![ Field::new("i", DataType::Int32, false), Field::new("s", DataType::Utf8, false), @@ -136,10 +220,9 @@ mod tests { .unwrap() }) .collect(); - let mut joiner = HashJoiner::try_new(&mut batch_buffer, "i").unwrap(); - joiner.build().unwrap(); + let joiner = HashJoiner::try_new(&mut batch_buffer, "i").await.unwrap(); - let indices = UInt32Array::from_iter(&[ + let indices = Arc::new(UInt32Array::from_iter(&[ Some(15), None, Some(10), @@ -148,8 +231,8 @@ mod tests { None, Some(22), Some(11111), // not found - ]); - let results = joiner.collect(&indices).unwrap(); + ])); + let results = joiner.collect(indices).await.unwrap(); assert_eq!( results.column_by_name("s").unwrap().as_ref(), diff --git a/rust/src/format/manifest.rs b/rust/src/format/manifest.rs index 2924dade02..3a8ac07fc3 100644 --- a/rust/src/format/manifest.rs +++ b/rust/src/format/manifest.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use chrono::prelude::*; use prost_types::Timestamp; +use std::time::SystemTime; use super::Fragment; use crate::datatypes::Schema; @@ -79,6 +80,16 @@ impl Manifest { ) } + /// Set the `timestamp_nanos` value from a Utc DateTime + pub fn set_timestamp(&mut self, timestamp: Option) { + let timestamp = timestamp.unwrap_or_else(SystemTime::now); + let nanos = timestamp + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_nanos(); + self.timestamp_nanos = nanos as u128; + } + /// Return the max fragment id. /// Note this does not support recycling of fragment ids. pub fn max_fragment_id(&self) -> Option { From 003f22fd42b5dc68187eb261656f98cd98b54cc7 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 5 Jun 2023 14:25:13 -0700 Subject: [PATCH 12/16] get rust working --- rust/src/dataset.rs | 139 +++++++++++++++++++++++++++----- rust/src/dataset/fragment.rs | 21 ++--- rust/src/dataset/hash_joiner.rs | 119 ++++++++++++++++++--------- rust/src/error.rs | 9 +++ 4 files changed, 225 insertions(+), 63 deletions(-) diff --git a/rust/src/dataset.rs b/rust/src/dataset.rs index e3f6d8989b..c2e64ba56b 100644 --- a/rust/src/dataset.rs +++ b/rust/src/dataset.rs @@ -495,48 +495,51 @@ impl Dataset { /// It performs a left-join on the two datasets. pub async fn merge( &mut self, - stream: &mut dyn RecordBatchReader, + stream: &mut Box, left_on: &str, right_on: &str, ) -> Result<()> { // Sanity check. if self.schema().field(left_on).is_none() { - return Err(Error::IO { - message: format!("Column {} does not exist in the left side dataset", left_on), - }); + return Err(Error::invalid_input(format!( + "Column {} does not exist in the left side dataset", + left_on + ))); }; let right_schema = stream.schema(); if right_schema.field_with_name(right_on).is_err() { - return Err(Error::IO { - message: format!( - "Column {} does not exist in the right side dataset", - right_on - ), - }); + return Err(Error::invalid_input(format!( + "Column {} does not exist in the right side dataset", + right_on + ))); }; for field in right_schema.fields() { if field.name() == right_on { continue; } if self.schema().field(field.name()).is_some() { - return Err(Error::IO { - message: format!( - "Column {} exists in both sides of the dataset", - field.name() - ), - }); + return Err(Error::invalid_input(format!( + "Column {} exists in both sides of the dataset", + field.name() + ))); } } // Hash join let joiner = Arc::new(HashJoiner::try_new(stream, right_on).await?); + let new_schema: Schema = self.schema().merge(joiner.out_schema().as_ref())?; - // Write new data file to each fragment. Parallism is done over columns, + // Write new data file to each fragment. Parallelism is done over columns, // so no parallelism done at this level. let updated_fragments: Vec = stream::iter(self.get_fragments()) .then(|f| { let joiner = joiner.clone(); - async move { f.merge(left_on, &joiner).await.map(|f| f.metadata) } + let full_schema = new_schema.clone(); + async move { + f.merge(left_on, &joiner, &full_schema) + .await + .map(|f| f.metadata) + } }) .try_collect::>() .await?; @@ -551,6 +554,7 @@ impl Dataset { .map(|m| m.version + 1) .unwrap_or(1); manifest.set_timestamp(None); + manifest.schema = new_schema; write_manifest_file(&self.object_store, &mut manifest, Some(indices)).await?; @@ -826,6 +830,7 @@ mod tests { use crate::{datatypes::Schema, utils::testing::generate_random_array}; use crate::dataset::WriteMode::Overwrite; + use arrow_array::Float32Array; use arrow_array::{ cast::{as_string_array, as_struct_array}, DictionaryArray, FixedSizeListArray, Int32Array, RecordBatch, StringArray, UInt16Array, @@ -1349,4 +1354,102 @@ mod tests { let result = Dataset::open(".").await; assert!(matches!(result.unwrap_err(), Error::DatasetNotFound { .. })); } + + #[tokio::test] + async fn test_merge() { + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("i", DataType::Int32, false), + Field::new("x", DataType::Float32, false), + ])); + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(Float32Array::from(vec![1.0, 2.0])), + ], + ) + .unwrap(); + let batch2 = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![3, 2])), + Arc::new(Float32Array::from(vec![3.0, 4.0])), + ], + ) + .unwrap(); + + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + let mut write_params = WriteParams::default(); + write_params.mode = WriteMode::Append; + + let mut batches: Box = + Box::new(RecordBatchBuffer::from_iter(vec![batch1])); + Dataset::write(&mut batches, test_uri, Some(write_params)) + .await + .unwrap(); + + let mut batches: Box = + Box::new(RecordBatchBuffer::from_iter(vec![batch2])); + Dataset::write(&mut batches, test_uri, Some(write_params)) + .await + .unwrap(); + + let dataset = Dataset::open(test_uri).await.unwrap(); + assert_eq!(dataset.fragments().len(), 2); + + let right_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("i2", DataType::Int32, false), + Field::new("y", DataType::Utf8, true), + ])); + let right_batch1 = RecordBatch::try_new( + right_schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["a", "b"])), + ], + ) + .unwrap(); + + let mut batches: Box = + Box::new(RecordBatchBuffer::from_iter(vec![right_batch1])); + let mut dataset = Dataset::open(test_uri).await.unwrap(); + dataset.merge(&mut batches, "i", "i2").await.unwrap(); + + assert_eq!(dataset.version().version, 3); + assert_eq!(dataset.fragments().len(), 2); + assert_eq!(dataset.fragments()[0].files.len(), 2); + assert_eq!(dataset.fragments()[1].files.len(), 2); + + let actual_batches = dataset + .scan() + .try_into_stream() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + let actual = concat_batches(&actual_batches[0].schema(), &actual_batches).unwrap(); + let expected = RecordBatch::try_new( + Arc::new(ArrowSchema::new(vec![ + Field::new("i", DataType::Int32, false), + Field::new("x", DataType::Float32, false), + Field::new("y", DataType::Utf8, true), + ])), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 2])), + Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0])), + Arc::new(StringArray::from(vec![ + Some("a"), + Some("b"), + None, + Some("b"), + ])), + ], + ) + .unwrap(); + + assert_eq!(actual, expected); + } } diff --git a/rust/src/dataset/fragment.rs b/rust/src/dataset/fragment.rs index e2d01fcc4c..62c2f64845 100644 --- a/rust/src/dataset/fragment.rs +++ b/rust/src/dataset/fragment.rs @@ -203,7 +203,12 @@ impl FileFragment { } /// Merge columns from joiner. - pub(crate) async fn merge(mut self, join_column: &str, joiner: &HashJoiner) -> Result { + pub(crate) async fn merge( + mut self, + join_column: &str, + joiner: &HashJoiner, + full_schema: &Schema, + ) -> Result { let mut scanner = self.scan(); scanner.project(&[join_column])?; @@ -213,6 +218,8 @@ impl FileFragment { .and_then(|batch| joiner.collect(batch.column(0).clone())) .boxed(); + let file_schema = full_schema.project_by_schema(joiner.out_schema().as_ref())?; + let filename = format!("{}.lance", Uuid::new_v4()); let full_path = self .dataset @@ -220,12 +227,9 @@ impl FileFragment { .base_path() .child(DATA_DIR) .child(filename.clone()); - let mut writer = FileWriter::try_new( - &self.dataset.object_store, - &full_path, - self.dataset.schema().clone(), - ) - .await?; + let mut writer = + FileWriter::try_new(&self.dataset.object_store, &full_path, file_schema.clone()) + .await?; while let Some(batch) = batch_stream.try_next().await? { writer.write(&[batch]).await?; @@ -233,8 +237,7 @@ impl FileFragment { writer.finish().await?; - let schema = crate::datatypes::Schema::try_from(joiner.out_schema().as_ref())?; - self.metadata.add_file(full_path.as_ref(), &schema); + self.metadata.add_file(&filename, &file_schema); Ok(self) } diff --git a/rust/src/dataset/hash_joiner.rs b/rust/src/dataset/hash_joiner.rs index bc9194a845..d04ea42874 100644 --- a/rust/src/dataset/hash_joiner.rs +++ b/rust/src/dataset/hash_joiner.rs @@ -16,10 +16,10 @@ use std::sync::Arc; -use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchReader}; use arrow_array::ArrayRef; +use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchReader}; use arrow_row::{OwnedRow, RowConverter, Rows, SortField}; -use arrow_schema::SchemaRef; +use arrow_schema::{DataType as ArrowDataType, SchemaRef}; use arrow_select::interleave::interleave; use dashmap::{DashMap, ReadOnlyView}; use futures::{StreamExt, TryStreamExt}; @@ -31,6 +31,8 @@ use crate::{Error, Result}; pub(crate) struct HashJoiner { index_map: ReadOnlyView, + index_type: ArrowDataType, + batches: Vec, out_schema: SchemaRef, @@ -46,7 +48,7 @@ impl HashJoiner { /// Create a new `HashJoiner`, building the hash index. /// /// Will run in parallel over batches using all available cores. - pub async fn try_new(reader: &mut dyn RecordBatchReader, on: &str) -> Result { + pub async fn try_new(reader: &mut Box, on: &str) -> Result { // Check column exist reader.schema().field_with_name(on)?; @@ -80,34 +82,44 @@ impl HashJoiner { }) .collect::>(); - let map_ref = ↦ + let map = Arc::new(map); futures::stream::iter(batches.iter().enumerate().map(Ok::<_, Error>)) // Not sure if this can actually run in parallel though - // TODO: use spawn_blocking instead - .try_for_each_concurrent(num_cpus::get(), |(batch_i, batch)| async move { - let column = batch[on].clone(); - let map_ref = map_ref.clone(); - let task_result = task::spawn(async move { - let rows = column_to_rows(column)?; - for (row_i, row) in rows.iter().enumerate() { - map_ref.insert(row.owned(), (batch_i, row_i)); + .try_for_each_concurrent(num_cpus::get(), |(batch_i, batch)| { + let map = map.clone(); + async move { + let column = batch[on].clone(); + let task_result = task::spawn_blocking(move || { + let rows = column_to_rows(column)?; + for (row_i, row) in rows.iter().enumerate() { + map.insert(row.owned(), (batch_i, row_i)); + } + Ok(()) + }) + .await; + match task_result { + Ok(Ok(_)) => Ok(()), + Ok(Err(err)) => Err(err), + Err(err) => Err(Error::IO { + message: format!("HashJoiner: {}", err), + }), } - Ok(()) - }) - .await; - match task_result { - Ok(Ok(_)) => Ok(()), - Ok(Err(err)) => Err(err), - Err(err) => Err(Error::IO { - message: format!("HashJoiner: {}", err), - }), } }) .await?; + let map = Arc::try_unwrap(map) + .expect("HashJoiner: No remaining tasks should still be referencing map."); + let index_type = batches[0] + .schema() + .field_with_name(on) + .unwrap() + .data_type() + .clone(); Ok(Self { index_map: map.into_read_only(), + index_type, batches: right_batches, out_schema, }) @@ -121,11 +133,17 @@ impl HashJoiner { /// /// Will run in parallel over columns using all available cores. pub(super) async fn collect(&self, index_column: ArrayRef) -> Result { + if index_column.data_type() != &self.index_type { + return Err(Error::invalid_input(format!( + "Index column type mismatch: expected {}, got {}", + self.index_type, + index_column.data_type() + ))); + } + // Index to use for null values let null_index = self.batches.len(); - let index_type = index_column.data_type().clone(); - // collect indices let indices = column_to_rows(index_column)? .into_iter() @@ -138,26 +156,22 @@ impl HashJoiner { .collect::>(); let indices = Arc::new(indices); - // Use interleave to get the data - // https://docs.rs/arrow/40.0.0/arrow/compute/kernels/interleave/fn.interleave.html - // To handle nulls, we'll add an extra null array at the end - let null_array = Arc::new(new_null_array(&index_type, 1)); - // Do this in parallel over the columns let columns = futures::stream::iter(0..self.batches[0].num_columns()) .map(|column_i| { - // TODO: we need to drop the join_on column - + // Use interleave to get the data + // https://docs.rs/arrow/40.0.0/arrow/compute/kernels/interleave/fn.interleave.html + // To handle nulls, we'll add an extra null array at the end let mut arrays = Vec::with_capacity(self.batches.len() + 1); for batch in &self.batches { arrays.push(batch.column(column_i).clone()); } - arrays.push(null_array.clone()); + arrays.push(Arc::new(new_null_array(&arrays[0].data_type(), 1))); let indices = indices.clone(); async move { - let task_result = task::spawn(async move { + let task_result = task::spawn_blocking(move || { let array_refs = arrays.iter().map(|x| x.as_ref()).collect::>(); interleave(array_refs.as_ref(), indices.as_ref()) .map_err(|err| Error::IO { @@ -201,11 +215,11 @@ mod tests { #[tokio::test] async fn test_joiner_collect() { let schema = Arc::new(Schema::new(vec![ - Field::new("i", DataType::Int32, false), - Field::new("s", DataType::Utf8, false), + Field::new("i", DataType::Int32, true), + Field::new("s", DataType::Utf8, true), ])); - let mut batch_buffer: RecordBatchBuffer = (0..5) + let batch_buffer: RecordBatchBuffer = (0..5) .map(|v| { let values = (v * 10..v * 10 + 10).collect::>(); RecordBatch::try_new( @@ -220,9 +234,10 @@ mod tests { .unwrap() }) .collect(); + let mut batch_buffer: Box = Box::new(batch_buffer); let joiner = HashJoiner::try_new(&mut batch_buffer, "i").await.unwrap(); - let indices = Arc::new(UInt32Array::from_iter(&[ + let indices = Arc::new(Int32Array::from_iter(&[ Some(15), None, Some(10), @@ -247,5 +262,37 @@ mod tests { None // 11111 not found ]) ); + + assert_eq!(results.num_columns(), 1); + } + + #[tokio::test] + async fn test_reject_invalid() { + let schema = Arc::new(Schema::new(vec![ + Field::new("i", DataType::Int32, true), + Field::new("s", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["a", "b"])), + ], + ) + .unwrap(); + let mut batch_buffer: Box = + Box::new(RecordBatchBuffer::from_iter(vec![batch])); + + let joiner = HashJoiner::try_new(&mut batch_buffer, "i").await.unwrap(); + + // Wrong type: was Int32, passing UInt32. + let indices = Arc::new(UInt32Array::from_iter(&[Some(15)])); + let result = joiner.collect(indices).await; + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Index column type mismatch: expected Int32, got UInt32")); } } diff --git a/rust/src/error.rs b/rust/src/error.rs index 4cd3a65669..7a07ad40be 100644 --- a/rust/src/error.rs +++ b/rust/src/error.rs @@ -29,6 +29,8 @@ pub(crate) fn box_error(e: impl std::error::Error + Send + Sync + 'static) -> Bo #[derive(Debug, Snafu)] #[snafu(visibility(pub(crate)))] pub enum Error { + #[snafu(display("Invalid user input: {source}"))] + InvalidInput { source: BoxedError }, #[snafu(display("Attempt to write empty record batches"))] EmptyDataset, #[snafu(display("Dataset already exists: {uri}"))] @@ -63,6 +65,13 @@ impl Error { source: message.into(), } } + + pub fn invalid_input(message: impl Into) -> Self { + let message: String = message.into(); + Self::InvalidInput { + source: message.into(), + } + } } pub type Result = std::result::Result; From f5be5612aa55d4c5ff21a4cc4336cdadfb30a099 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 5 Jun 2023 15:39:43 -0700 Subject: [PATCH 13/16] Add Python bindings --- python/python/lance/dataset.py | 61 +++++++++++++++++++++++------ python/python/tests/test_dataset.py | 30 ++++++++++++++ python/src/dataset.rs | 16 ++++++++ rust/src/dataset/hash_joiner.rs | 12 +++++- rust/src/datatypes.rs | 12 ++++++ 5 files changed, 118 insertions(+), 13 deletions(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index e01125b8e6..ceccfe9ea1 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -356,6 +356,39 @@ def join( Not implemented (just override pyarrow dataset to prevent segfault) """ raise NotImplementedError("Versioning not yet supported in Rust") + + def merge( + self, + data_obj: ReaderLike, + left_on: str, + right_on: Optional[str] = None, + ): + """ + Merge another dataset into this one. + + Performs a left join, where the dataset is the left side and data_obj + is the right side. Rows existing in the dataset but not on the left will + be filled with null values, unless Lance doesn't support null values for + some types, in which case an error will be raised. + + Parameters + ---------- + data_obj: Reader-like + The data to be merged. Acceptable types are: + - Pandas DataFrame, Pyarrow Table, Dataset, Scanner, or RecordBatchReader + left_on: str + The name of the column in the dataset to join on. + right_on: str or None + The name of the column in data_obj to join on. If None, defaults to + left_on. + """ + if right_on is None: + right_on = left_on + + reader = _coerce_reader(data_obj) + + self._ds.merge(reader, left_on, right_on) + def versions(self): """ @@ -808,18 +841,7 @@ def write_dataset( The max number of rows before starting a new group (in the same file) """ - if isinstance(data_obj, pd.DataFrame): - reader = pa.Table.from_pandas(data_obj, schema=schema).to_reader() - elif isinstance(data_obj, pa.Table): - reader = data_obj.to_reader() - elif isinstance(data_obj, pa.dataset.Dataset): - reader = pa.dataset.Scanner.from_dataset(data_obj).to_reader() - elif isinstance(data_obj, pa.dataset.Scanner): - reader = data_obj.to_reader() - elif isinstance(data_obj, pa.RecordBatchReader): - reader = data_obj - else: - raise TypeError(f"Unknown data_obj type {type(data_obj)}") + reader = _coerce_reader(data_obj) # TODO add support for passing in LanceDataset and LanceScanner here params = { @@ -831,3 +853,18 @@ def write_dataset( uri = os.fspath(uri) if isinstance(uri, Path) else uri _write_dataset(reader, uri, params) return LanceDataset(uri) + + +def _coerce_reader(data_obj: ReaderLike, schema: Optional[pa.Schema] = None) -> pa.RecordBatchReader: + if isinstance(data_obj, pd.DataFrame): + return pa.Table.from_pandas(data_obj, schema=schema).to_reader() + elif isinstance(data_obj, pa.Table): + return data_obj.to_reader() + elif isinstance(data_obj, pa.dataset.Dataset): + return pa.dataset.Scanner.from_dataset(data_obj).to_reader() + elif isinstance(data_obj, pa.dataset.Scanner): + return data_obj.to_reader() + elif isinstance(data_obj, pa.RecordBatchReader): + return data_obj + else: + raise TypeError(f"Unknown data_obj type {type(data_obj)}") diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 88a5dfd68d..70d2442e31 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -357,3 +357,33 @@ def test_load_scanner_from_fragments(tmp_path: Path): # Accepts an iterator scanner = dataset.scanner(fragments=iter(fragments[0:2]), scan_in_order=False) assert scanner.to_table().num_rows == 2 * 100 + + +def test_merge_data(tmp_path: Path): + tab = pa.table({"a": range(100), "b": range(100)}) + lance.write_dataset(tab, tmp_path / "dataset", mode="append") + + dataset = lance.dataset(tmp_path / "dataset") + + # rejects partial data for non-nullable types + new_tab = pa.table({"a": range(40), "c": range(40)}) + # TODO: this should be ValueError + with pytest.raises(OSError, match=".+Lance does not yet support nulls for type Int64."): + dataset.merge(new_tab, "a") + + # accepts a full merge + new_tab = pa.table({"a": range(100), "c": range(100)}) + dataset.merge(new_tab, "a") + assert dataset.version == 2 + assert dataset.to_table() == pa.table({"a": range(100), "b": range(100), "c": range(100)}) + + # accepts a partial for string + new_tab = pa.table({"a2": range(5), "d": ["a", "b", "c", "d", "e"]}) + dataset.merge(new_tab, left_on="a", right_on="a2") + assert dataset.version == 3 + assert dataset.to_table() == pa.table({ + "a": range(100), + "b": range(100), + "c": range(100), + "d": ["a", "b", "c", "d", "e"] + [None] * 95 + }) diff --git a/python/src/dataset.rs b/python/src/dataset.rs index aa0d5962cc..09c37a0414 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -281,6 +281,22 @@ impl Dataset { batch.to_pyarrow(self_.py()) } + fn merge( + &mut self, + reader: PyArrowType, + left_on: &str, + right_on: &str, + ) -> PyResult<()> { + let mut reader: Box = Box::new(reader.0); + let mut new_self = self.ds.as_ref().clone(); + let fut = new_self.merge(&mut reader, left_on, right_on); + self.rt.block_on( + async move { fut.await.map_err(|err| PyIOError::new_err(err.to_string())) }, + )?; + self.ds = Arc::new(new_self); + Ok(()) + } + fn versions(self_: PyRef<'_, Self>) -> PyResult> { let versions = self_ .list_versions() diff --git a/rust/src/dataset/hash_joiner.rs b/rust/src/dataset/hash_joiner.rs index d04ea42874..4fa7e15027 100644 --- a/rust/src/dataset/hash_joiner.rs +++ b/rust/src/dataset/hash_joiner.rs @@ -25,6 +25,7 @@ use dashmap::{DashMap, ReadOnlyView}; use futures::{StreamExt, TryStreamExt}; use tokio::task; +use crate::datatypes::lance_supports_nulls; use crate::{Error, Result}; /// `HashJoiner` does hash join on two datasets. @@ -181,7 +182,16 @@ impl HashJoiner { }) .await; match task_result { - Ok(Ok(array)) => Ok(array), + Ok(Ok(array)) => { + if array.null_count() > 0 && !lance_supports_nulls(array.data_type()) { + return Err(Error::invalid_input(format!( + "Found rows on LHS that do not match any rows on RHS. Lance would need to write \ + nulls on the RHS, but Lance does not yet support nulls for type {:?}.", + array.data_type() + ))); + } + Ok(array) + }, Ok(Err(err)) => Err(err), Err(err) => Err(Error::IO { message: format!("HashJoiner: {}", err), diff --git a/rust/src/datatypes.rs b/rust/src/datatypes.rs index bfb5e794f3..c4b5a7b40f 100644 --- a/rust/src/datatypes.rs +++ b/rust/src/datatypes.rs @@ -310,3 +310,15 @@ impl From<&Dictionary> for pb::Dictionary { } } } + +/// Returns true if Lance supports writing this datatype with nulls. +pub(crate) fn lance_supports_nulls(datatype: &DataType) -> bool { + match datatype { + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::FixedSizeBinary(_) + | DataType::FixedSizeList(_, _) => true, + _ => false, + } +} From 6b2e77043d96d2aec84f3d51d0144b4c595972bf Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 5 Jun 2023 15:58:46 -0700 Subject: [PATCH 14/16] cleanup --- rust/src/dataset/hash_joiner.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/rust/src/dataset/hash_joiner.rs b/rust/src/dataset/hash_joiner.rs index 4fa7e15027..c08da8fce3 100644 --- a/rust/src/dataset/hash_joiner.rs +++ b/rust/src/dataset/hash_joiner.rs @@ -86,8 +86,8 @@ impl HashJoiner { let map = Arc::new(map); futures::stream::iter(batches.iter().enumerate().map(Ok::<_, Error>)) - // Not sure if this can actually run in parallel though .try_for_each_concurrent(num_cpus::get(), |(batch_i, batch)| { + // A clone of map we can send to a new thread let map = map.clone(); async move { let column = batch[on].clone(); @@ -145,7 +145,9 @@ impl HashJoiner { // Index to use for null values let null_index = self.batches.len(); - // collect indices + // Indices are a pair of (batch_i, row_i). We'll add a null batch at the + // end with one null element, and that's what we resolve when no match is + // found. let indices = column_to_rows(index_column)? .into_iter() .map(|row| { @@ -169,6 +171,7 @@ impl HashJoiner { } arrays.push(Arc::new(new_null_array(&arrays[0].data_type(), 1))); + // Clone of indices we can send to a new thread let indices = indices.clone(); async move { From bd903c95eaa31d28386bc013421dd5a5fcdf6a78 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 6 Jun 2023 09:39:36 -0700 Subject: [PATCH 15/16] Add more comments --- rust/src/dataset.rs | 4 ++++ rust/src/dataset/hash_joiner.rs | 3 +++ 2 files changed, 7 insertions(+) diff --git a/rust/src/dataset.rs b/rust/src/dataset.rs index c2e64ba56b..707ddcac58 100644 --- a/rust/src/dataset.rs +++ b/rust/src/dataset.rs @@ -515,6 +515,8 @@ impl Dataset { }; for field in right_schema.fields() { if field.name() == right_on { + // right_on is allowed to exist in the dataset, since it may be + // the same as left_on. continue; } if self.schema().field(field.name()).is_some() { @@ -527,6 +529,8 @@ impl Dataset { // Hash join let joiner = Arc::new(HashJoiner::try_new(stream, right_on).await?); + // Final schema is union of current schema, plus the RHS schema without + // the right_on key. let new_schema: Schema = self.schema().merge(joiner.out_schema().as_ref())?; // Write new data file to each fragment. Parallelism is done over columns, diff --git a/rust/src/dataset/hash_joiner.rs b/rust/src/dataset/hash_joiner.rs index c08da8fce3..1cbe6e70b0 100644 --- a/rust/src/dataset/hash_joiner.rs +++ b/rust/src/dataset/hash_joiner.rs @@ -126,6 +126,9 @@ impl HashJoiner { }) } + /// Returns the schema of data yielded by `collect()`. + /// + /// This excludes the index column on the right-hand side. pub fn out_schema(&self) -> &SchemaRef { &self.out_schema } From 18184729a137b3614bdd106d464ead8130290bb7 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 6 Jun 2023 10:41:58 -0700 Subject: [PATCH 16/16] add list to supported nulls --- rust/src/datatypes.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rust/src/datatypes.rs b/rust/src/datatypes.rs index c4b5a7b40f..3fa531171c 100644 --- a/rust/src/datatypes.rs +++ b/rust/src/datatypes.rs @@ -317,6 +317,7 @@ pub(crate) fn lance_supports_nulls(datatype: &DataType) -> bool { DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary + | DataType::List(_) | DataType::FixedSizeBinary(_) | DataType::FixedSizeList(_, _) => true, _ => false,