diff --git a/rust/src/arrow.rs b/rust/src/arrow.rs index 24c0fbf80b..da958f7d60 100644 --- a/rust/src/arrow.rs +++ b/rust/src/arrow.rs @@ -26,6 +26,65 @@ use arrow_data::ArrayDataBuilder; use arrow_schema::{DataType, Field}; use crate::error::Result; + +pub trait DataTypeExt { + /// Returns true if the data type is binary-like, such as (Large)Utf8 and (Large)Binary. + /// + /// ``` + /// use lance::arrow::*; + /// use arrow_schema::DataType; + /// + /// assert!(DataType::Utf8.is_binary_like()); + /// assert!(DataType::Binary.is_binary_like()); + /// assert!(DataType::LargeUtf8.is_binary_like()); + /// assert!(DataType::LargeBinary.is_binary_like()); + /// assert!(!DataType::Int32.is_binary_like()); + /// ``` + fn is_binary_like(&self) -> bool; + + fn is_struct(&self) -> bool; + + /// Check whether the given Arrow DataType is fixed stride. + /// A fixed stride type has the same byte width for all array elements + /// This includes all PrimitiveType's Boolean, FixedSizeList, FixedSizeBinary, and Decimals + fn is_fixed_stride(&self) -> bool; +} + +impl DataTypeExt for DataType { + fn is_binary_like(&self) -> bool { + matches!( + self, + DataType::Utf8 | DataType::Binary | DataType::LargeUtf8 | DataType::LargeBinary + ) + } + + fn is_struct(&self) -> bool { + matches!(self, DataType::Struct(_)) + } + + fn is_fixed_stride(&self) -> bool { + match self { + DataType::Boolean + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::FixedSizeList(_, _) + | DataType::FixedSizeBinary(_) => true, + _ => false, + } + } +} + pub trait ListArrayExt { /// Create an [`ListArray`] from values and offsets. /// diff --git a/rust/src/datatypes.rs b/rust/src/datatypes.rs index 440d57476f..02c0a177db 100644 --- a/rust/src/datatypes.rs +++ b/rust/src/datatypes.rs @@ -1,6 +1,6 @@ //! Lance data types, [Schema] and [Field] -use std::cmp::min; +use std::cmp::max; use std::collections::HashMap; use std::fmt; use std::fmt::Formatter; @@ -9,36 +9,12 @@ use arrow_array::ArrayRef; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema, TimeUnit}; use async_recursion::async_recursion; +use crate::arrow::DataTypeExt; use crate::encodings::Encoding; use crate::format::pb; use crate::io::object_reader::ObjectReader; use crate::{Error, Result}; -/// Check whether the given Arrow DataType is fixed stride. -/// A fixed stride type has the same byte width for all array elements -/// This includes all PrimitiveType's Boolean, FixedSizeList, FixedSizeBinary, and Decimals -pub fn is_fixed_stride(arrow_type: &DataType) -> bool { - match arrow_type { - DataType::Boolean - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float16 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) - | DataType::FixedSizeList(_, _) - | DataType::FixedSizeBinary(_) => true, - _ => false, - } -} - /// LogicalType is a string presentation of arrow type. /// to be serialized into protobuf. #[derive(Debug, Clone, PartialEq)] @@ -201,19 +177,6 @@ impl TryFrom<&LogicalType> for DataType { } } -fn is_numeric(data_type: &DataType) -> bool { - use DataType::*; - matches!( - data_type, - UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 - ) -} - -fn is_binary(data_type: &DataType) -> bool { - use DataType::*; - matches!(data_type, Binary | Utf8 | LargeBinary | LargeUtf8) -} - #[derive(Debug, Clone, PartialEq)] pub struct Dictionary { offset: usize, @@ -242,7 +205,7 @@ pub struct Field { parent_id: i32, logical_type: LogicalType, extension_name: String, - encoding: Option, + pub(crate) encoding: Option, nullable: bool, pub children: Vec, @@ -313,9 +276,9 @@ impl Field { // Get the max field id of itself and all children. fn max_id(&self) -> i32 { - min( + max( self.id, - self.children.iter().map(|c| c.id).min().unwrap_or(i32::MAX), + self.children.iter().map(|c| c.max_id()).max().unwrap_or(-1), ) } @@ -420,8 +383,8 @@ impl TryFrom<&ArrowField> for Field { name: field.name().clone(), logical_type: LogicalType::try_from(field.data_type())?, encoding: match field.data_type() { - dt if is_numeric(dt) => Some(Encoding::Plain), - dt if is_binary(dt) => Some(Encoding::VarBinary), + dt if dt.is_fixed_stride() => Some(Encoding::Plain), + dt if dt.is_binary_like() => Some(Encoding::VarBinary), DataType::Dictionary(_, _) => Some(Encoding::Dictionary), _ => None, }, diff --git a/rust/src/encodings/binary.rs b/rust/src/encodings/binary.rs index 7a317ce21b..5244244f11 100644 --- a/rust/src/encodings/binary.rs +++ b/rust/src/encodings/binary.rs @@ -25,11 +25,11 @@ use crate::io::object_writer::ObjectWriter; /// Encoder for Var-binary encoding. pub struct BinaryEncoder<'a> { - writer: &'a mut ObjectWriter<'a>, + writer: &'a mut ObjectWriter, } impl<'a> BinaryEncoder<'a> { - pub fn new(writer: &'a mut ObjectWriter<'a>) -> Self { + pub fn new(writer: &'a mut ObjectWriter) -> Self { Self { writer } } @@ -190,16 +190,12 @@ mod tests { async fn test_round_trips(arr: &GenericStringArray) { let store = ObjectStore::new(":memory:").unwrap(); let path = Path::from("/foo"); - let (_, mut writer) = store.inner.put_multipart(&path).await.unwrap(); - - let mut object_writer = ObjectWriter::new(writer.as_mut()); + let mut object_writer = ObjectWriter::new(&store, &path).await.unwrap(); // Write some gabage to reset "tell()". object_writer.write_all(b"1234").await.unwrap(); let mut encoder = BinaryEncoder::new(&mut object_writer); - let pos = encoder.encode(&arr).await.unwrap(); - - writer.shutdown().await.unwrap(); + object_writer.shutdown().await.unwrap(); let mut reader = store.open(&path).await.unwrap(); let decoder = BinaryDecoder::>::new(&mut reader, pos, arr.len()); diff --git a/rust/src/encodings/dictionary.rs b/rust/src/encodings/dictionary.rs index bc252091be..3e8888102e 100644 --- a/rust/src/encodings/dictionary.rs +++ b/rust/src/encodings/dictionary.rs @@ -126,7 +126,6 @@ mod tests { }; use arrow_array::Array; use object_store::path::Path; - use tokio::io::AsyncWriteExt; async fn test_dict_decoder_for_type() { let values = vec!["a", "b", "b", "a", "c"]; @@ -137,11 +136,10 @@ mod tests { let pos; { - let (_, mut writer) = store.inner.put_multipart(&path).await.unwrap(); - let mut object_writer = ObjectWriter::new(writer.as_mut()); + let mut object_writer = ObjectWriter::new(&store, &path).await.unwrap(); let mut encoder = PlainEncoder::new(&mut object_writer, arr.keys().data_type()); pos = encoder.encode(arr.keys()).await.unwrap(); - writer.shutdown().await.unwrap(); + object_writer.shutdown().await.unwrap(); } let reader = store.open(&path).await.unwrap(); diff --git a/rust/src/encodings/plain.rs b/rust/src/encodings/plain.rs index 00a0cc3707..d0d8f464e2 100644 --- a/rust/src/encodings/plain.rs +++ b/rust/src/encodings/plain.rs @@ -33,13 +33,13 @@ use arrow_buffer::{bit_util, Buffer}; use arrow_data::ArrayDataBuilder; use arrow_schema::{DataType, Field}; use async_recursion::async_recursion; +use async_trait::async_trait; +use tokio::io::AsyncWriteExt; use crate::arrow::FixedSizeBinaryArrayExt; use crate::arrow::FixedSizeListArrayExt; -use crate::datatypes::is_fixed_stride; +use crate::arrow::*; use crate::Error; -use async_trait::async_trait; -use tokio::io::AsyncWriteExt; use super::Decoder; use crate::error::Result; @@ -49,12 +49,12 @@ use crate::io::object_writer::ObjectWriter; /// Encoder for plain encoding. /// pub struct PlainEncoder<'a> { - writer: &'a mut ObjectWriter<'a>, + writer: &'a mut ObjectWriter, data_type: &'a DataType, } impl<'a> PlainEncoder<'a> { - pub fn new(writer: &'a mut ObjectWriter<'a>, data_type: &'a DataType) -> PlainEncoder<'a> { + pub fn new(writer: &'a mut ObjectWriter, data_type: &'a DataType) -> PlainEncoder<'a> { PlainEncoder { writer, data_type } } @@ -153,7 +153,7 @@ impl<'a> PlainDecoder<'a> { items: &Box, list_size: &i32, ) -> Result { - if !is_fixed_stride(items.data_type()) { + if !items.data_type().is_fixed_stride() { return Err(Error::Schema(format!( "Items for fixed size list should be primitives but found {}", items.data_type() @@ -229,10 +229,8 @@ mod tests { use rand::prelude::*; use std::borrow::Borrow; use std::sync::Arc; - use tokio::io::AsyncWriteExt; use super::*; - use crate::arrow::*; use crate::io::object_writer::ObjectWriter; #[tokio::test] @@ -267,15 +265,11 @@ mod tests { async fn test_round_trip(expected: ArrayRef, data_type: DataType) { let store = ObjectStore::new(":memory:").unwrap(); let path = Path::from("/foo"); - let (_, mut writer) = store.inner.put_multipart(&path).await.unwrap(); - - { - let mut object_writer = ObjectWriter::new(writer.as_mut()); - let mut encoder = PlainEncoder::new(&mut object_writer, &data_type); + let mut object_writer = ObjectWriter::new(&store, &path).await.unwrap(); + let mut encoder = PlainEncoder::new(&mut object_writer, &data_type); - assert_eq!(encoder.encode(expected.as_ref()).await.unwrap(), 0); - } - writer.shutdown().await.unwrap(); + assert_eq!(encoder.encode(expected.as_ref()).await.unwrap(), 0); + object_writer.shutdown().await.unwrap(); let mut reader = store.open(&path).await.unwrap(); assert!(reader.size().await.unwrap() > 0); diff --git a/rust/src/format.rs b/rust/src/format.rs index 24698f0e99..af0aa3a907 100644 --- a/rust/src/format.rs +++ b/rust/src/format.rs @@ -1,3 +1,5 @@ +//! On-disk format + mod fragment; mod manifest; mod metadata; @@ -15,6 +17,11 @@ pub mod pb { include!(concat!(env!("OUT_DIR"), "/lance.format.pb.rs")); } +pub const MAJOR_VERSION: i16 = 0; +pub const MINOR_VERSION: i16 = 1; +pub const MAGIC: &[u8; 4] = b"LANC"; +pub const INDEX_MAGIC: &[u8; 8] = b"LANC_IDX"; + /// Annotation on a struct that can be converted a Protobuf message. pub trait ProtoStruct { type Proto: Message; diff --git a/rust/src/format/fragment.rs b/rust/src/format/fragment.rs index 901713e40b..807e94c14d 100644 --- a/rust/src/format/fragment.rs +++ b/rust/src/format/fragment.rs @@ -79,3 +79,12 @@ impl From<&pb::DataFragment> for Fragment { } } } + +impl From<&Fragment> for pb::DataFragment { + fn from(f: &Fragment) -> Self { + Self { + id: f.id, + files: f.files.iter().map(pb::DataFile::from).collect(), + } + } +} diff --git a/rust/src/format/manifest.rs b/rust/src/format/manifest.rs index 750c50c2b2..6ce1d6f2db 100644 --- a/rust/src/format/manifest.rs +++ b/rust/src/format/manifest.rs @@ -18,6 +18,7 @@ use super::Fragment; use crate::datatypes::Schema; use crate::format::{pb, ProtoStruct}; +use std::collections::HashMap; /// Manifest of a dataset /// @@ -36,6 +37,16 @@ pub struct Manifest { pub fragments: Vec, } +impl Manifest { + pub fn new(schema: &Schema) -> Self { + Self { + schema: schema.clone(), + version: 1, + fragments: vec![], + } + } +} + impl ProtoStruct for Manifest { type Proto = pb::Manifest; } @@ -49,3 +60,15 @@ impl From for Manifest { } } } + +impl From<&Manifest> for pb::Manifest { + fn from(m: &Manifest) -> Self { + Self { + fields: (&m.schema).into(), + version: m.version, + fragments: m.fragments.iter().map(pb::DataFragment::from).collect(), + metadata: HashMap::default(), + version_aux_data: 0, + } + } +} diff --git a/rust/src/format/page_table.rs b/rust/src/format/page_table.rs index 7c47733e73..850be85e07 100644 --- a/rust/src/format/page_table.rs +++ b/rust/src/format/page_table.rs @@ -15,42 +15,58 @@ // specific language governing permissions and limitations // under the License. -use arrow_array::Int64Array; +use arrow_array::builder::Int64Builder; +use arrow_array::{Array, Int64Array}; use arrow_schema::DataType; -use std::collections::HashMap; +use std::collections::BTreeMap; +use tokio::io::AsyncWriteExt; use crate::encodings::plain::PlainDecoder; use crate::encodings::Decoder; use crate::error::Result; use crate::io::object_reader::ObjectReader; +use crate::io::object_writer::ObjectWriter; -#[derive(Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct PageInfo { pub position: usize, pub length: usize, } +impl PageInfo { + pub fn new(position: usize, length: usize) -> Self { + Self { position, length } + } +} + +/// Page lookup table. +/// #[derive(Debug, Default)] pub struct PageTable { - pages: HashMap>, + /// map[field-id, map[batch-id, PageInfo]] + pages: BTreeMap>, } impl PageTable { - /// Create page table from disk. - pub async fn new<'a>( + /// Load [PageTable] from disk. + pub async fn load<'a>( reader: &'a ObjectReader<'_>, position: usize, num_columns: i32, num_batches: i32, ) -> Result { + println!( + "Loading page table: columns={} batches={}", + num_columns, num_batches + ); let length = num_columns * num_batches * 2; let decoder = PlainDecoder::new(reader, &DataType::Int64, position, length as usize)?; let raw_arr = decoder.decode().await?; let arr = raw_arr.as_any().downcast_ref::().unwrap(); - let mut pages = HashMap::default(); + let mut pages = BTreeMap::default(); for col in 0..num_columns { - pages.insert(col, HashMap::default()); + pages.insert(col, BTreeMap::default()); for batch in 0..num_batches { let idx = col * num_batches + batch; let batch_position = &arr.value((idx * 2) as usize); @@ -68,7 +84,47 @@ impl PageTable { Ok(Self { pages }) } - pub fn set_page_info(&mut self) {} + pub async fn write(&self, writer: &mut ObjectWriter) -> Result { + let pos = writer.tell(); + assert!(!self.pages.is_empty()); + let num_columns = self.pages.keys().max().unwrap() + 1; + let num_batches = self + .pages + .values() + .map(|c_map| c_map.keys().max()) + .flatten() + .max() + .unwrap() + + 1; + + let mut builder = Int64Builder::with_capacity((num_columns * num_batches) as usize); + for col in 0..num_columns { + for batch in 0..num_batches { + if let Some(page_info) = self.get(col, batch) { + builder.append_value(page_info.position as i64); + builder.append_value(page_info.length as i64); + } else { + builder.append_slice(&[0, 0]); + } + } + } + let arr = builder.finish(); + writer + .write_all(arr.into_data().buffers()[0].as_slice()) + .await?; + + Ok(pos) + } + + /// Set page lookup info for a page identified by `(column, batch)` pair. + pub fn set(&mut self, column: i32, batch: i32, page_info: PageInfo) { + if !self.pages.contains_key(&column) { + self.pages.insert(column, BTreeMap::default()); + } + self.pages + .get_mut(&column) + .map(|c_map| c_map.insert(batch, page_info)); + } pub fn get(&self, column: i32, batch: i32) -> Option<&PageInfo> { self.pages @@ -79,4 +135,16 @@ impl PageTable { } #[cfg(test)] -mod tests {} +mod tests { + use super::*; + + #[test] + fn test_set_page_info() { + let mut page_table = PageTable::default(); + let page_info = PageInfo::new(1, 2); + page_table.set(10, 20, page_info.clone()); + + let actual = page_table.get(10, 20).unwrap(); + assert_eq!(actual, &page_info); + } +} diff --git a/rust/src/io.rs b/rust/src/io.rs index d1bdad17b8..ec7cec1092 100644 --- a/rust/src/io.rs +++ b/rust/src/io.rs @@ -29,13 +29,13 @@ pub mod object_reader; pub mod object_store; pub mod object_writer; pub mod reader; +mod writer; -use crate::format::ProtoStruct; +use crate::format::{ProtoStruct, INDEX_MAGIC, MAGIC}; pub use self::object_store::ObjectStore; - -const MAGIC: &[u8; 4] = b"LANC"; -const INDEX_MAGIC: &[u8; 8] = b"LANC_IDX"; +pub use reader::FileReader; +pub use writer::FileWriter; #[async_trait] pub trait AsyncWriteProtoExt { diff --git a/rust/src/io/object_reader.rs b/rust/src/io/object_reader.rs index dffc266e3f..575c85f23c 100644 --- a/rust/src/io/object_reader.rs +++ b/rust/src/io/object_reader.rs @@ -19,7 +19,6 @@ use std::cmp::min; use std::ops::Range; -use crate::datatypes::is_fixed_stride; use arrow_array::{ types::{BinaryType, LargeBinaryType, LargeUtf8Type, Utf8Type}, ArrayRef, @@ -30,6 +29,7 @@ use bytes::Bytes; use object_store::{path::Path, ObjectMeta}; use prost::Message; +use crate::arrow::*; use crate::encodings::{binary::BinaryDecoder, plain::PlainDecoder, Decoder}; use crate::error::{Error, Result}; use crate::format::ProtoStruct; @@ -109,7 +109,7 @@ impl<'a> ObjectReader<'a> { position: usize, length: usize, ) -> Result { - if !is_fixed_stride(data_type) { + if !data_type.is_fixed_stride() { return Err(Error::Schema(format!( "{} is not a fixed stride type", data_type diff --git a/rust/src/io/object_store.rs b/rust/src/io/object_store.rs index 4d153f71dd..1f593abd30 100644 --- a/rust/src/io/object_store.rs +++ b/rust/src/io/object_store.rs @@ -16,6 +16,7 @@ // under the License. //! Wraps [ObjectStore](object_store::ObjectStore) + use std::sync::Arc; use ::object_store::{ @@ -24,11 +25,12 @@ use ::object_store::{ use object_store::local::LocalFileSystem; use url::{ParseError, Url}; -use super::object_reader::ObjectReader; use crate::error::{Error, Result}; +use crate::io::object_reader::ObjectReader; +use crate::io::object_writer::ObjectWriter; /// Wraps [ObjectStore](object_store::ObjectStore) -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ObjectStore { // Inner object store pub inner: Arc, @@ -47,12 +49,7 @@ impl ObjectStore { /// Create a ObjectStore instance from a given URL. pub fn new(uri: &str) -> Result { if uri == ":memory:" { - return Ok(Self { - inner: Arc::new(InMemory::new()), - scheme: String::from("memory"), - base_path: Path::from("/"), - prefetch_size: 64 * 1024, - }); + return Ok(ObjectStore::memory()); }; let parsed = match Url::parse(uri) { @@ -98,6 +95,16 @@ impl ObjectStore { }) } + /// Create a in-memory object store directly. + pub(crate) fn memory() -> Self { + Self { + inner: Arc::new(InMemory::new()), + scheme: String::from("memory"), + base_path: Path::from("/"), + prefetch_size: 64 * 1024, + } + } + pub fn prefetch_size(&self) -> usize { self.prefetch_size } @@ -116,4 +123,9 @@ impl ObjectStore { Err(e) => Err(e), } } + + /// Create a new file. + pub async fn create(&self, path: &Path) -> Result { + ObjectWriter::new(self, path).await + } } diff --git a/rust/src/io/object_writer.rs b/rust/src/io/object_writer.rs index 4d570c9b2f..af631f2792 100644 --- a/rust/src/io/object_writer.rs +++ b/rust/src/io/object_writer.rs @@ -15,29 +15,41 @@ // specific language governing permissions and limitations // under the License. -use std::io::Error; use std::pin::Pin; use std::task::{Context, Poll}; +use object_store::{path::Path, MultipartId}; use pin_project::pin_project; use prost::Message; use tokio::io::{AsyncWrite, AsyncWriteExt}; use crate::format::ProtoStruct; +use crate::io::ObjectStore; +use crate::Result; /// AsyncWrite with the capability to tell the position the data is written. /// #[pin_project] -pub struct ObjectWriter<'a> { +pub struct ObjectWriter { + store: ObjectStore, + // TODO: wrap writer with a BufWriter. #[pin] - writer: &'a mut (dyn AsyncWrite + Unpin + Send), + writer: Box, + multipart_id: MultipartId, cursor: usize, } -impl<'a> ObjectWriter<'a> { - pub fn new(writer: &'a mut (dyn AsyncWrite + Unpin + Send)) -> ObjectWriter<'a> { - ObjectWriter { writer, cursor: 0 } +impl ObjectWriter { + pub async fn new(object_store: &ObjectStore, path: &Path) -> Result { + let (multipart_id, writer) = object_store.inner.put_multipart(path).await?; + + Ok(Self { + store: object_store.clone(), + writer, + multipart_id, + cursor: 0, + }) } /// Tell the current position (file size). @@ -46,7 +58,7 @@ impl<'a> ObjectWriter<'a> { } /// Write a protobuf message to the object, and returns the file position of the protobuf. - pub async fn write_protobuf(&mut self, msg: &impl Message) -> Result { + pub async fn write_protobuf(&mut self, msg: &impl Message) -> Result { let offset = self.tell(); let len = msg.encoded_len(); @@ -60,35 +72,38 @@ impl<'a> ObjectWriter<'a> { pub async fn write_struct<'b, M: Message + From<&'b T>, T: ProtoStruct + 'b>( &mut self, obj: &'b T, - ) -> Result { + ) -> Result { let msg: M = M::from(obj); self.write_protobuf(&msg).await } + + pub async fn shutdown(&mut self) -> Result<()> { + Ok(self.writer.shutdown().await?) + } } -impl AsyncWrite for ObjectWriter<'_> { +impl AsyncWrite for ObjectWriter { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], - ) -> Poll> { + ) -> Poll> { let mut this = self.project(); *this.cursor += buf.len(); this.writer.as_mut().poll_write(cx, buf) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().writer.as_mut().poll_flush(cx) } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().writer.as_mut().poll_shutdown(cx) } } #[cfg(test)] mod tests { - use object_store::path::Path; use tokio::io::AsyncWriteExt; @@ -101,13 +116,10 @@ mod tests { #[tokio::test] async fn test_write() { let store = ObjectStore::new(":memory:").unwrap(); - let (_, mut writer) = store - .inner - .put_multipart(&Path::from("/foo")) + + let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo")) .await .unwrap(); - - let mut object_writer = ObjectWriter::new(writer.as_mut()); assert_eq!(object_writer.tell(), 0); let mut buf = Vec::::new(); @@ -122,20 +134,14 @@ mod tests { assert_eq!(object_writer.tell(), 256 * 3); object_writer.shutdown().await.unwrap(); - - assert_eq!( - store.inner.head(&Path::from("/foo")).await.unwrap().size, - 256 * 3 - ); } #[tokio::test] async fn test_write_proto_structs() { let store = ObjectStore::new(":memory:").unwrap(); let path = Path::from("/foo"); - let (_, mut writer) = store.inner.put_multipart(&path).await.unwrap(); - let mut object_writer = ObjectWriter::new(writer.as_mut()); + let mut object_writer = ObjectWriter::new(&store, &path).await.unwrap(); assert_eq!(object_writer.tell(), 0); let mut metadata = Metadata::default(); diff --git a/rust/src/io/reader.rs b/rust/src/io/reader.rs index 6d5ae6f60b..0fd066b498 100644 --- a/rust/src/io/reader.rs +++ b/rust/src/io/reader.rs @@ -18,7 +18,6 @@ //! Lance Data File Reader // Standard -use std::cmp::max; use std::ops::Range; use std::sync::Arc; @@ -32,7 +31,6 @@ use object_store::path::Path; use prost::Message; use crate::arrow::*; -use crate::datatypes::is_fixed_stride; use crate::encodings::{dictionary::DictionaryDecoder, Decoder}; use crate::error::{Error, Result}; use crate::format::Manifest; @@ -93,13 +91,15 @@ impl<'a> FileReader<'a> { ObjectReader::new(object_store, path.clone(), object_store.prefetch_size())?; let file_size = object_reader.size().await?; + let begin = if file_size < object_store.prefetch_size() { + 0 + } else { + file_size - object_store.prefetch_size() + }; let tail_bytes = object_reader .object_store .inner - .get_range( - path, - max(0, file_size - object_store.prefetch_size())..file_size, - ) + .get_range(path, begin..file_size) .await?; let metadata_pos = read_metadata_offset(&tail_bytes)?; @@ -112,14 +112,14 @@ impl<'a> FileReader<'a> { }; let (projection, num_columns) = if let Some(m) = manifest { - (m.schema.clone(), m.schema.max_field_id().unwrap()) + (m.schema.clone(), m.schema.max_field_id().unwrap() + 1) } else { let m: Manifest = object_reader .read_struct(metadata.manifest_position.unwrap()) .await?; - (m.schema.clone(), m.schema.max_field_id().unwrap()) + (m.schema.clone(), m.schema.max_field_id().unwrap() + 1) }; - let page_table = PageTable::new( + let page_table = PageTable::load( &object_reader, metadata.page_table_position, num_columns, @@ -167,8 +167,8 @@ impl<'a> FileReader<'a> { let column = field.id; self.page_table.get(column, batch_id).ok_or_else(|| { Error::IO(format!( - "No page info found for field: {}, batch={}", - field.name, batch_id + "No page info found for field: {}, field_id={} batch={}", + field.name, field.id, batch_id )) }) } @@ -243,7 +243,7 @@ impl<'a> FileReader<'a> { use DataType::*; - if is_fixed_stride(&data_type) { + if data_type.is_fixed_stride() { self.read_fixed_stride_array(field, batch_id).await } else { match data_type { diff --git a/rust/src/io/writer.rs b/rust/src/io/writer.rs new file mode 100644 index 0000000000..f5168ce6cb --- /dev/null +++ b/rust/src/io/writer.rs @@ -0,0 +1,224 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::cast::as_struct_array; +use arrow_array::{Array, ArrayRef, RecordBatch, StructArray}; +use async_recursion::async_recursion; +use tokio::io::AsyncWriteExt; + +use crate::arrow::*; +use crate::datatypes::{Field, Schema}; +use crate::encodings::binary::BinaryEncoder; +use crate::encodings::Encoder; +use crate::encodings::{plain::PlainEncoder, Encoding}; +use crate::format::{Manifest, Metadata, PageInfo, PageTable, MAGIC, MAJOR_VERSION, MINOR_VERSION}; +use crate::io::object_writer::ObjectWriter; +use crate::{Error, Result}; + +/// FileWriter writes Arrow Table to a file. +pub struct FileWriter<'a> { + object_writer: ObjectWriter, + schema: &'a Schema, + batch_id: i32, + page_table: PageTable, + metadata: Metadata, +} + +impl<'a> FileWriter<'a> { + pub fn new(object_writer: ObjectWriter, schema: &'a Schema) -> Self { + Self { + object_writer, + schema, + batch_id: 0, + page_table: PageTable::default(), + metadata: Metadata::default(), + } + } + + /// Write a [RecordBatch] to the open file. + /// + /// Returns [Err] if the schema does not match with the batch. + pub async fn write(&mut self, batch: &RecordBatch) -> Result<()> { + for field in &self.schema.fields { + let column_id = batch.schema().index_of(&field.name)?; + let array = batch.column(column_id); + self.write_array(field, array).await?; + } + self.metadata.push_batch_length(batch.num_rows() as i32); + self.batch_id += 1; + Ok(()) + } + + // pub async fn abort(&self) -> Result<()> { + // // self.object_writer. + // Ok(()) + // } + + pub async fn finish(&mut self) -> Result<()> { + self.write_footer().await?; + self.object_writer.shutdown().await + } + + #[async_recursion] + async fn write_array(&mut self, field: &Field, array: &ArrayRef) -> Result<()> { + let data_type = array.data_type(); + if data_type.is_fixed_stride() { + self.write_fixed_stride_array(field, array).await?; + } else if data_type.is_struct() { + let struct_arr = as_struct_array(array); + self.write_struct_array(field, struct_arr).await?; + } else if data_type.is_binary_like() { + self.write_binary_array(field, array).await?; + }; + + Ok(()) + } + + /// Write fixed size array, including, primtiives, fixed size binary, and fixed size list. + async fn write_fixed_stride_array(&mut self, field: &Field, array: &ArrayRef) -> Result<()> { + assert_eq!(field.encoding, Some(Encoding::Plain)); + let mut encoder = PlainEncoder::new(&mut self.object_writer, array.data_type()); + let pos = encoder.encode(array).await?; + let page_info = PageInfo::new(pos, array.len()); + self.page_table.set(field.id, self.batch_id, page_info); + Ok(()) + } + + /// Write var-length binary arrays. + async fn write_binary_array(&mut self, field: &Field, array: &ArrayRef) -> Result<()> { + assert_eq!(field.encoding, Some(Encoding::VarBinary)); + let mut encoder = BinaryEncoder::new(&mut self.object_writer); + let pos = encoder.encode(array).await?; + let page_info = PageInfo::new(pos, array.len()); + self.page_table.set(field.id, self.batch_id, page_info); + Ok(()) + } + + #[async_recursion] + async fn write_struct_array(&mut self, field: &Field, array: &StructArray) -> Result<()> { + assert_eq!(array.num_columns(), field.children.len()); + for child in &field.children { + if let Some(arr) = array.column_by_name(&child.name) { + self.write_array(child, arr).await?; + } else { + return Err(Error::Schema(format!( + "FileWriter: schema mismatch: column {} does not exist in array: {:?}", + child.name, + array.data_type() + ))); + } + } + Ok(()) + } + + async fn write_footer(&mut self) -> Result<()> { + // Step 1. write dictionary values. + + // Step 2. Write page table. + let pos = self.page_table.write(&mut self.object_writer).await?; + self.metadata.page_table_position = pos; + + // Step 3. Write manifest. + let manifest = Manifest::new(self.schema); + let pos = self.object_writer.write_struct(&manifest).await?; + + // Step 4. Write metadata. + self.metadata.manifest_position = Some(pos); + let pos = self.object_writer.write_struct(&self.metadata).await?; + + // Step 5. Write magics. + self.write_magics(pos).await + } + + async fn write_magics(&mut self, pos: usize) -> Result<()> { + self.object_writer.write_i64_le(pos as i64).await?; + self.object_writer.write_i16_le(MAJOR_VERSION).await?; + self.object_writer.write_i16_le(MINOR_VERSION).await?; + self.object_writer.write_all(MAGIC).await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + + use arrow_array::{BooleanArray, Float32Array, Int64Array, StringArray}; + use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; + use object_store::path::Path; + + use crate::io::{FileReader, ObjectStore}; + + #[tokio::test] + async fn test_write_file() { + let arrow_schema = ArrowSchema::new(vec![ + ArrowField::new("bool", DataType::Boolean, true), + ArrowField::new("i", DataType::Int64, true), + ArrowField::new("f", DataType::Float32, false), + ArrowField::new("b", DataType::Utf8, true), + ArrowField::new( + "s", + DataType::Struct(vec![ + ArrowField::new("si", DataType::Int64, true), + ArrowField::new("sb", DataType::Utf8, true), + ]), + true, + ), + ]); + let schema = Schema::try_from(&arrow_schema).unwrap(); + + let store = ObjectStore::memory(); + let path = Path::from("/foo"); + let writer = store.create(&path).await.unwrap(); + + let mut file_writer = FileWriter::new(writer, &schema); + + let columns: Vec = vec![ + Arc::new(BooleanArray::from_iter( + (0..100).map(|f| Some(f % 3 == 0)).collect::>(), + )), + Arc::new(Int64Array::from_iter((0..100).collect::>())), + Arc::new(Float32Array::from_iter( + (0..100).map(|n| n as f32).collect::>(), + )), + Arc::new(StringArray::from( + (0..100).map(|n| n.to_string()).collect::>(), + )), + Arc::new(StructArray::from(vec![ + ( + ArrowField::new("si", DataType::Int64, true), + Arc::new(Int64Array::from_iter((100..200).collect::>())) as ArrayRef, + ), + ( + ArrowField::new("sb", DataType::Utf8, true), + Arc::new(StringArray::from( + (0..100).map(|n| n.to_string()).collect::>(), + )) as ArrayRef, + ), + ])), + ]; + let batch = RecordBatch::try_new(Arc::new(arrow_schema), columns).unwrap(); + file_writer.write(&batch).await.unwrap(); + file_writer.finish().await.unwrap(); + + let reader = FileReader::new(&store, &path, None).await.unwrap(); + let actual = reader.read_batch(0).await.unwrap(); + assert_eq!(actual, batch); + } +}