Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust] Lance File Writer #419

Merged
merged 20 commits into from
Jan 9, 2023
59 changes: 59 additions & 0 deletions rust/src/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,65 @@ use arrow_data::ArrayDataBuilder;
use arrow_schema::{DataType, Field};

use crate::error::Result;

pub trait DataTypeExt {
/// Returns true if the data type is binary-like, such as (Large)Utf8 and (Large)Binary.
///
/// ```
/// use lance::arrow::*;
/// use arrow_schema::DataType;
///
/// assert!(DataType::Utf8.is_binary_like());
/// assert!(DataType::Binary.is_binary_like());
/// assert!(DataType::LargeUtf8.is_binary_like());
/// assert!(DataType::LargeBinary.is_binary_like());
/// assert!(!DataType::Int32.is_binary_like());
/// ```
fn is_binary_like(&self) -> bool;

fn is_struct(&self) -> bool;

/// Check whether the given Arrow DataType is fixed stride.
/// A fixed stride type has the same byte width for all array elements
/// This includes all PrimitiveType's Boolean, FixedSizeList, FixedSizeBinary, and Decimals
fn is_fixed_stride(&self) -> bool;
}

impl DataTypeExt for DataType {
fn is_binary_like(&self) -> bool {
matches!(
self,
DataType::Utf8 | DataType::Binary | DataType::LargeUtf8 | DataType::LargeBinary
)
}

fn is_struct(&self) -> bool {
matches!(self, DataType::Struct(_))
}

fn is_fixed_stride(&self) -> bool {
match self {
DataType::Boolean
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _)
| DataType::FixedSizeList(_, _)
| DataType::FixedSizeBinary(_) => true,
_ => false,
}
}
}

pub trait ListArrayExt {
/// Create an [`ListArray`] from values and offsets.
///
Expand Down
51 changes: 7 additions & 44 deletions rust/src/datatypes.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Lance data types, [Schema] and [Field]

use std::cmp::min;
use std::cmp::max;
use std::collections::HashMap;
use std::fmt;
use std::fmt::Formatter;
Expand All @@ -9,36 +9,12 @@ use arrow_array::ArrayRef;
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema, TimeUnit};
use async_recursion::async_recursion;

use crate::arrow::DataTypeExt;
use crate::encodings::Encoding;
use crate::format::pb;
use crate::io::object_reader::ObjectReader;
use crate::{Error, Result};

/// Check whether the given Arrow DataType is fixed stride.
/// A fixed stride type has the same byte width for all array elements
/// This includes all PrimitiveType's Boolean, FixedSizeList, FixedSizeBinary, and Decimals
pub fn is_fixed_stride(arrow_type: &DataType) -> bool {
match arrow_type {
DataType::Boolean
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _)
| DataType::FixedSizeList(_, _)
| DataType::FixedSizeBinary(_) => true,
_ => false,
}
}

/// LogicalType is a string presentation of arrow type.
/// to be serialized into protobuf.
#[derive(Debug, Clone, PartialEq)]
Expand Down Expand Up @@ -201,19 +177,6 @@ impl TryFrom<&LogicalType> for DataType {
}
}

fn is_numeric(data_type: &DataType) -> bool {
use DataType::*;
matches!(
data_type,
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64
)
}

fn is_binary(data_type: &DataType) -> bool {
use DataType::*;
matches!(data_type, Binary | Utf8 | LargeBinary | LargeUtf8)
}

#[derive(Debug, Clone, PartialEq)]
pub struct Dictionary {
offset: usize,
Expand Down Expand Up @@ -242,7 +205,7 @@ pub struct Field {
parent_id: i32,
logical_type: LogicalType,
extension_name: String,
encoding: Option<Encoding>,
pub(crate) encoding: Option<Encoding>,
nullable: bool,

pub children: Vec<Field>,
Expand Down Expand Up @@ -313,9 +276,9 @@ impl Field {

// Get the max field id of itself and all children.
fn max_id(&self) -> i32 {
min(
max(
self.id,
self.children.iter().map(|c| c.id).min().unwrap_or(i32::MAX),
self.children.iter().map(|c| c.max_id()).max().unwrap_or(-1),
)
}

Expand Down Expand Up @@ -420,8 +383,8 @@ impl TryFrom<&ArrowField> for Field {
name: field.name().clone(),
logical_type: LogicalType::try_from(field.data_type())?,
encoding: match field.data_type() {
dt if is_numeric(dt) => Some(Encoding::Plain),
dt if is_binary(dt) => Some(Encoding::VarBinary),
dt if dt.is_fixed_stride() => Some(Encoding::Plain),
dt if dt.is_binary_like() => Some(Encoding::VarBinary),
DataType::Dictionary(_, _) => Some(Encoding::Dictionary),
_ => None,
},
Expand Down
12 changes: 4 additions & 8 deletions rust/src/encodings/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ use crate::io::object_writer::ObjectWriter;

/// Encoder for Var-binary encoding.
pub struct BinaryEncoder<'a> {
writer: &'a mut ObjectWriter<'a>,
writer: &'a mut ObjectWriter,
}

impl<'a> BinaryEncoder<'a> {
pub fn new(writer: &'a mut ObjectWriter<'a>) -> Self {
pub fn new(writer: &'a mut ObjectWriter) -> Self {
Self { writer }
}

Expand Down Expand Up @@ -190,16 +190,12 @@ mod tests {
async fn test_round_trips<O: OffsetSizeTrait>(arr: &GenericStringArray<O>) {
let store = ObjectStore::new(":memory:").unwrap();
let path = Path::from("/foo");
let (_, mut writer) = store.inner.put_multipart(&path).await.unwrap();

let mut object_writer = ObjectWriter::new(writer.as_mut());
let mut object_writer = ObjectWriter::new(&store, &path).await.unwrap();
// Write some gabage to reset "tell()".
object_writer.write_all(b"1234").await.unwrap();
let mut encoder = BinaryEncoder::new(&mut object_writer);

let pos = encoder.encode(&arr).await.unwrap();

writer.shutdown().await.unwrap();
object_writer.shutdown().await.unwrap();

let mut reader = store.open(&path).await.unwrap();
let decoder = BinaryDecoder::<GenericStringType<O>>::new(&mut reader, pos, arr.len());
Expand Down
6 changes: 2 additions & 4 deletions rust/src/encodings/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ mod tests {
};
use arrow_array::Array;
use object_store::path::Path;
use tokio::io::AsyncWriteExt;

async fn test_dict_decoder_for_type<T: ArrowDictionaryKeyType>() {
let values = vec!["a", "b", "b", "a", "c"];
Expand All @@ -137,11 +136,10 @@ mod tests {

let pos;
{
let (_, mut writer) = store.inner.put_multipart(&path).await.unwrap();
let mut object_writer = ObjectWriter::new(writer.as_mut());
let mut object_writer = ObjectWriter::new(&store, &path).await.unwrap();
let mut encoder = PlainEncoder::new(&mut object_writer, arr.keys().data_type());
pos = encoder.encode(arr.keys()).await.unwrap();
writer.shutdown().await.unwrap();
object_writer.shutdown().await.unwrap();
}

let reader = store.open(&path).await.unwrap();
Expand Down
26 changes: 10 additions & 16 deletions rust/src/encodings/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ use arrow_buffer::{bit_util, Buffer};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{DataType, Field};
use async_recursion::async_recursion;
use async_trait::async_trait;
use tokio::io::AsyncWriteExt;

use crate::arrow::FixedSizeBinaryArrayExt;
use crate::arrow::FixedSizeListArrayExt;
use crate::datatypes::is_fixed_stride;
use crate::arrow::*;
use crate::Error;
use async_trait::async_trait;
use tokio::io::AsyncWriteExt;

use super::Decoder;
use crate::error::Result;
Expand All @@ -49,12 +49,12 @@ use crate::io::object_writer::ObjectWriter;
/// Encoder for plain encoding.
///
pub struct PlainEncoder<'a> {
writer: &'a mut ObjectWriter<'a>,
writer: &'a mut ObjectWriter,
data_type: &'a DataType,
}

impl<'a> PlainEncoder<'a> {
pub fn new(writer: &'a mut ObjectWriter<'a>, data_type: &'a DataType) -> PlainEncoder<'a> {
pub fn new(writer: &'a mut ObjectWriter, data_type: &'a DataType) -> PlainEncoder<'a> {
PlainEncoder { writer, data_type }
}

Expand Down Expand Up @@ -153,7 +153,7 @@ impl<'a> PlainDecoder<'a> {
items: &Box<Field>,
list_size: &i32,
) -> Result<ArrayRef> {
if !is_fixed_stride(items.data_type()) {
if !items.data_type().is_fixed_stride() {
return Err(Error::Schema(format!(
"Items for fixed size list should be primitives but found {}",
items.data_type()
Expand Down Expand Up @@ -229,10 +229,8 @@ mod tests {
use rand::prelude::*;
use std::borrow::Borrow;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;

use super::*;
use crate::arrow::*;
use crate::io::object_writer::ObjectWriter;

#[tokio::test]
Expand Down Expand Up @@ -267,15 +265,11 @@ mod tests {
async fn test_round_trip(expected: ArrayRef, data_type: DataType) {
let store = ObjectStore::new(":memory:").unwrap();
let path = Path::from("/foo");
let (_, mut writer) = store.inner.put_multipart(&path).await.unwrap();

{
let mut object_writer = ObjectWriter::new(writer.as_mut());
let mut encoder = PlainEncoder::new(&mut object_writer, &data_type);
let mut object_writer = ObjectWriter::new(&store, &path).await.unwrap();
let mut encoder = PlainEncoder::new(&mut object_writer, &data_type);

assert_eq!(encoder.encode(expected.as_ref()).await.unwrap(), 0);
}
writer.shutdown().await.unwrap();
assert_eq!(encoder.encode(expected.as_ref()).await.unwrap(), 0);
object_writer.shutdown().await.unwrap();

let mut reader = store.open(&path).await.unwrap();
assert!(reader.size().await.unwrap() > 0);
Expand Down
7 changes: 7 additions & 0 deletions rust/src/format.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! On-disk format

mod fragment;
mod manifest;
mod metadata;
Expand All @@ -15,6 +17,11 @@ pub mod pb {
include!(concat!(env!("OUT_DIR"), "/lance.format.pb.rs"));
}

pub const MAJOR_VERSION: i16 = 0;
pub const MINOR_VERSION: i16 = 1;
pub const MAGIC: &[u8; 4] = b"LANC";
pub const INDEX_MAGIC: &[u8; 8] = b"LANC_IDX";

/// Annotation on a struct that can be converted a Protobuf message.
pub trait ProtoStruct {
type Proto: Message;
Expand Down
9 changes: 9 additions & 0 deletions rust/src/format/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,12 @@ impl From<&pb::DataFragment> for Fragment {
}
}
}

impl From<&Fragment> for pb::DataFragment {
fn from(f: &Fragment) -> Self {
Self {
id: f.id,
files: f.files.iter().map(pb::DataFile::from).collect(),
}
}
}
23 changes: 23 additions & 0 deletions rust/src/format/manifest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use super::Fragment;
use crate::datatypes::Schema;
use crate::format::{pb, ProtoStruct};
use std::collections::HashMap;

/// Manifest of a dataset
///
Expand All @@ -36,6 +37,16 @@ pub struct Manifest {
pub fragments: Vec<Fragment>,
}

impl Manifest {
pub fn new(schema: &Schema) -> Self {
Self {
schema: schema.clone(),
version: 1,
fragments: vec![],
}
}
}

impl ProtoStruct for Manifest {
type Proto = pb::Manifest;
}
Expand All @@ -49,3 +60,15 @@ impl From<pb::Manifest> for Manifest {
}
}
}

impl From<&Manifest> for pb::Manifest {
fn from(m: &Manifest) -> Self {
Self {
fields: (&m.schema).into(),
version: m.version,
fragments: m.fragments.iter().map(pb::DataFragment::from).collect(),
metadata: HashMap::default(),
version_aux_data: 0,
}
}
}
Loading