Skip to content

Commit

Permalink
Merge branch 'main' into lei/tf_data
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored Jul 26, 2023
2 parents 2936771 + d728044 commit 6964640
Show file tree
Hide file tree
Showing 13 changed files with 659 additions and 145 deletions.
11 changes: 10 additions & 1 deletion python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,13 @@ def __init__(
version: Optional[int] = None,
block_size: Optional[int] = None,
index_cache_size: Optional[int] = None,
metadata_cache_size: Optional[int] = None,
):
uri = os.fspath(uri) if isinstance(uri, Path) else uri
self._uri = uri
self._ds = _Dataset(uri, version, block_size, index_cache_size)
self._ds = _Dataset(
uri, version, block_size, index_cache_size, metadata_cache_size
)

def __reduce__(self):
return LanceDataset, (self.uri, self._ds.version())
Expand Down Expand Up @@ -196,6 +199,7 @@ def to_table(
limit: Optional[int] = None,
offset: Optional[int] = None,
nearest: Optional[dict] = None,
batch_size: Optional[int] = None,
batch_readahead: Optional[int] = None,
fragment_readahead: Optional[int] = None,
scan_in_order: bool = True,
Expand Down Expand Up @@ -229,6 +233,8 @@ def to_table(
"refine_factor": 1
}
batch_size: int, optional
The number of rows to read at a time.
batch_readahead: int, optional
The number of batches to read ahead.
fragment_readahead: int, optional
Expand All @@ -250,6 +256,7 @@ def to_table(
limit=limit,
offset=offset,
nearest=nearest,
batch_size=batch_size,
batch_readahead=batch_readahead,
fragment_readahead=fragment_readahead,
scan_in_order=scan_in_order,
Expand Down Expand Up @@ -291,6 +298,7 @@ def to_batches(
limit: Optional[int] = None,
offset: Optional[int] = None,
nearest: Optional[dict] = None,
batch_size: Optional[int] = None,
batch_readahead: Optional[int] = None,
fragment_readahead: Optional[int] = None,
scan_in_order: bool = True,
Expand All @@ -313,6 +321,7 @@ def to_batches(
limit=limit,
offset=offset,
nearest=nearest,
batch_size=batch_size,
batch_readahead=batch_readahead,
fragment_readahead=fragment_readahead,
scan_in_order=scan_in_order,
Expand Down
3 changes: 3 additions & 0 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use lance::index::{

const DEFAULT_NPROBS: usize = 1;
const DEFAULT_INDEX_CACHE_SIZE: usize = 256;
const DEFAULT_METADATA_CACHE_SIZE: usize = 256;

/// Lance Dataset that will be wrapped by another class in Python
#[pyclass(name = "_Dataset", module = "_lib")]
Expand All @@ -65,11 +66,13 @@ impl Dataset {
version: Option<u64>,
block_size: Option<usize>,
index_cache_size: Option<usize>,
metadata_cache_size: Option<usize>,
) -> PyResult<Self> {
let rt = Runtime::new()?;
let params = ReadParams {
block_size,
index_cache_size: index_cache_size.unwrap_or(DEFAULT_INDEX_CACHE_SIZE),
metadata_cache_size: metadata_cache_size.unwrap_or(DEFAULT_METADATA_CACHE_SIZE),
session: None,
store_options: None,
};
Expand Down
152 changes: 81 additions & 71 deletions rust/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ use chrono::prelude::*;
use futures::stream::{self, StreamExt, TryStreamExt};
use log::warn;
use object_store::path::Path;
use uuid::Uuid;

mod feature_flags;
pub mod fragment;
Expand All @@ -41,26 +40,28 @@ mod write;
use self::feature_flags::{apply_feature_flags, can_read_dataset, can_write_dataset};
use self::fragment::FileFragment;
use self::scanner::Scanner;
use self::write::{reader_to_stream, write_fragments};
use crate::datatypes::Schema;
use crate::error::box_error;
use crate::format::{pb, Fragment, Index, Manifest};
use crate::io::object_store::ObjectStoreParams;
use crate::io::{
object_reader::{read_message, read_struct},
read_manifest, read_metadata_offset, write_manifest, FileWriter, ObjectStore,
read_manifest, read_metadata_offset, write_manifest, ObjectStore,
};
use crate::session::Session;
use crate::{Error, Result};
use hash_joiner::HashJoiner;
pub use scanner::ROW_ID;
pub use write::*;
pub use write::{WriteMode, WriteParams};

const LATEST_MANIFEST_NAME: &str = "_latest.manifest";
const VERSIONS_DIR: &str = "_versions";
const INDICES_DIR: &str = "_indices";
pub(crate) const DELETION_DIRS: &str = "_deletions";
const DATA_DIR: &str = "data";
pub(crate) const DEFAULT_INDEX_CACHE_SIZE: usize = 256;
pub(crate) const DEFAULT_METADATA_CACHE_SIZE: usize = 256;

/// Lance Dataset
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -95,17 +96,6 @@ impl From<&Manifest> for Version {
}
}

/// Create a new [FileWriter] with the related `data_file_path` under `<DATA_DIR>`.
async fn new_file_writer(
object_store: &ObjectStore,
base_dir: &Path,
data_file_path: &str,
schema: &Schema,
) -> Result<FileWriter> {
let full_path = base_dir.child(DATA_DIR).child(data_file_path);
FileWriter::try_new(object_store, &full_path, schema.clone()).await
}

/// Get the manifest file path for a version.
fn manifest_path(base: &Path, version: u64) -> Path {
base.child(VERSIONS_DIR)
Expand All @@ -128,6 +118,10 @@ pub struct ReadParams {
///
pub index_cache_size: usize,

/// Metadata cache size for the fragment metadata. If it is zero, metadata
/// cache is disabled.
pub metadata_cache_size: usize,

/// If present, dataset will use this shared [`Session`] instead creating a new one.
///
/// This is useful for sharing the same session across multiple datasets.
Expand All @@ -143,6 +137,12 @@ impl ReadParams {
self
}

/// Set the cache size for the file metadata. Set to zero to disable this cache.
pub fn metadata_cache_size(&mut self, cache_size: usize) -> &mut Self {
self.metadata_cache_size = cache_size;
self
}

/// Set a shared session for the datasets.
pub fn session(&mut self, session: Arc<Session>) -> &mut Self {
self.session = Some(session);
Expand All @@ -155,6 +155,7 @@ impl Default for ReadParams {
Self {
block_size: None,
index_cache_size: DEFAULT_INDEX_CACHE_SIZE,
metadata_cache_size: DEFAULT_METADATA_CACHE_SIZE,
session: None,
store_options: None,
}
Expand Down Expand Up @@ -184,7 +185,10 @@ impl Dataset {
let session = if let Some(session) = params.session.as_ref() {
session.clone()
} else {
Arc::new(Session::new(params.index_cache_size))
Arc::new(Session::new(
params.index_cache_size,
params.metadata_cache_size,
))
};
Self::checkout_manifest(
Arc::new(object_store),
Expand Down Expand Up @@ -216,7 +220,10 @@ impl Dataset {
let session = if let Some(session) = params.session.as_ref() {
session.clone()
} else {
Arc::new(Session::new(params.index_cache_size))
Arc::new(Session::new(
params.index_cache_size,
params.metadata_cache_size,
))
};
Self::checkout_manifest(Arc::new(object_store), base_path, &manifest_file, session).await
}
Expand Down Expand Up @@ -304,16 +311,7 @@ impl Dataset {
let latest_manifest_path = latest_manifest_path(&base);
let flag_dataset_exists = object_store.exists(&latest_manifest_path).await?;

let mut schema: Schema = Schema::try_from(batches.schema().as_ref())?;
let mut peekable = batches.peekable();
if let Some(batch) = peekable.peek() {
if let Ok(b) = batch {
schema.set_dictionary(b)?;
} else {
return Err(Error::from(batch.as_ref().unwrap_err()));
}
}
schema.validate()?;
let (stream, schema) = reader_to_stream(batches)?;

// Running checks for the different write modes
// create + dataset already exists = error
Expand Down Expand Up @@ -393,52 +391,16 @@ impl Dataset {
vec![]
};

let mut writer = None;
let mut batches: Vec<RecordBatch> = Vec::new();
let mut num_rows: usize = 0;
for batch_result in peekable {
let batch: RecordBatch = batch_result?;
batches.push(batch.clone());
num_rows += batch.num_rows();
if num_rows >= params.max_rows_per_group {
// TODO: the max rows per group boundary is not accurately calculated yet.
if writer.is_none() {
writer = {
let file_path = format!("{}.lance", Uuid::new_v4());
let fragment = Fragment::with_file(fragment_id, &file_path, &schema);
fragments.push(fragment);
fragment_id += 1;
Some(new_file_writer(&object_store, &base, &file_path, &schema).await?)
}
};
let object_store = Arc::new(object_store);
let mut new_fragments =
write_fragments(object_store.clone(), &base, &schema, stream, params.clone()).await?;

writer.as_mut().unwrap().write(&batches).await?;
batches = Vec::new();
num_rows = 0;
}
if let Some(w) = writer.as_mut() {
if w.len() >= params.max_rows_per_file {
w.finish().await?;
writer = None;
}
}
// Assign IDs
for fragment in &mut new_fragments {
fragment.id = fragment_id;
fragment_id += 1;
}
if num_rows > 0 {
if writer.is_none() {
writer = {
let file_path = format!("{}.lance", Uuid::new_v4());
let fragment = Fragment::with_file(fragment_id, &file_path, &schema);
fragments.push(fragment);
Some(new_file_writer(&object_store, &base, &file_path, &schema).await?)
}
};
writer.as_mut().unwrap().write(&batches).await?;
}
if let Some(w) = writer.as_mut() {
// Drop the last writer.
w.finish().await?;
drop(writer);
};
fragments.extend(new_fragments);

let mut manifest = Manifest::new(&schema, Arc::new(fragments));
manifest.version = match dataset.as_ref() {
Expand All @@ -464,7 +426,7 @@ impl Dataset {
.await?;

Ok(Self {
object_store: Arc::new(object_store),
object_store,
base,
manifest: Arc::new(manifest.clone()),
session: Arc::new(Session::default()),
Expand Down Expand Up @@ -1071,6 +1033,12 @@ mod tests {
.try_collect::<Vec<_>>()
.await
.unwrap();

// The batch size batches the group size.
for batch in &actual_batches {
assert_eq!(batch.num_rows(), 10);
}

// sort
let actual_batch = concat_batches(&schema, &actual_batches).unwrap();
let idx_arr = actual_batch.column_by_name("i").unwrap();
Expand Down Expand Up @@ -1174,6 +1142,48 @@ mod tests {
assert_eq!(&expected_struct_arr, as_struct_array(sorted_arr.as_ref()));
}

#[tokio::test]
async fn test_write_params() {
let test_dir = tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();

let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"i",
DataType::Int32,
false,
)]));
let num_rows: usize = 1_000;
let batches = vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..num_rows as i32))],
)
.unwrap()];

let batches = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());

let write_params = WriteParams {
max_rows_per_file: 100,
max_rows_per_group: 10,
..Default::default()
};
let dataset = Dataset::write(batches, test_uri, Some(write_params))
.await
.unwrap();

assert_eq!(dataset.count_rows().await.unwrap(), num_rows);

let fragments = dataset.get_fragments();
assert_eq!(fragments.len(), 10);
for fragment in &fragments {
assert_eq!(fragment.count_rows().await.unwrap(), 100);
let reader = fragment.open(dataset.schema()).await.unwrap();
assert_eq!(reader.num_batches(), 10);
for i in 0..reader.num_batches() {
assert_eq!(reader.num_rows_in_batch(i), 10);
}
}
}

#[tokio::test]
async fn test_write_manifest() {
let test_dir = tempdir().unwrap();
Expand Down
Loading

0 comments on commit 6964640

Please sign in to comment.