Skip to content

Commit

Permalink
Fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
westonpace committed Sep 20, 2024
1 parent dc96273 commit 8a7806b
Show file tree
Hide file tree
Showing 11 changed files with 169 additions and 98 deletions.
12 changes: 9 additions & 3 deletions python/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use arrow_schema::Schema as ArrowSchema;
use bytes::Bytes;
use futures::stream::StreamExt;
use lance::io::{ObjectStore, RecordBatchStream};
use lance_core::cache::FileMetadataCache;
use lance_encoding::decoder::{DecoderMiddlewareChain, FilterExpression};
use lance_file::{
v2::{
Expand Down Expand Up @@ -331,9 +332,14 @@ impl LanceFileReader {
},
);
let file = scheduler.open_file(&path).await.infer_error()?;
let inner = FileReader::try_open(file, None, Arc::<DecoderMiddlewareChain>::default())
.await
.infer_error()?;
let inner = FileReader::try_open(
file,
None,
Arc::<DecoderMiddlewareChain>::default(),
&FileMetadataCache::no_cache(),
)
.await
.infer_error()?;
Ok(Self {
inner: Arc::new(inner),
})
Expand Down
39 changes: 24 additions & 15 deletions rust/lance-datafusion/src/substrait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,22 +271,31 @@ pub async fn parse_substrait(expr: &[u8], input_schema: Arc<Schema>) -> Result<E
}),
}?;

let (substrait_schema, input_schema, index_mapping) =
remove_extension_types(envelope.base_schema.as_ref().unwrap(), input_schema.clone())?;
let (substrait_schema, input_schema) =
if envelope.base_schema.as_ref().unwrap().r#struct.is_some() {
let (substrait_schema, input_schema, index_mapping) = remove_extension_types(
envelope.base_schema.as_ref().unwrap(),
input_schema.clone(),
)?;

if substrait_schema.r#struct.as_ref().unwrap().types.len()
!= envelope
.base_schema
.as_ref()
.unwrap()
.r#struct
.as_ref()
.unwrap()
.types
.len()
{
remap_expr_references(&mut expr, &index_mapping)?;
}
if substrait_schema.r#struct.as_ref().unwrap().types.len()
!= envelope
.base_schema
.as_ref()
.unwrap()
.r#struct
.as_ref()
.unwrap()
.types
.len()
{
remap_expr_references(&mut expr, &index_mapping)?;
}

(substrait_schema, input_schema)
} else {
(envelope.base_schema.as_ref().unwrap().clone(), input_schema)
};

// Datafusion's substrait consumer only supports Plan (not ExtendedExpression) and so
// we need to create a dummy plan with a single project node
Expand Down
77 changes: 50 additions & 27 deletions rust/lance-encoding-datafusion/src/zone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use std::{
collections::{HashMap, VecDeque},
ops::Range,
sync::Arc,
sync::{Arc, Mutex},
};

use arrow_array::{cast::AsArray, types::UInt32Type, ArrayRef, RecordBatch, UInt32Array};
Expand Down Expand Up @@ -148,8 +148,6 @@ pub(crate) fn extract_zone_info(
);
let (position, size) =
col_info.buffer_offsets_and_sizes[zone_map_buffer.buffer_index as usize];
let mut new_col_info = col_info.as_ref().clone();
new_col_info.encoding = *inner;
let column = path_to_expr(cur_path);
let unloaded_pushdown = UnloadedPushdown {
data_type: data_type.clone(),
Expand All @@ -158,7 +156,10 @@ pub(crate) fn extract_zone_info(
size,
};
*result_ref = Some((rows_per_zone, unloaded_pushdown));
col_info

let mut col_info = col_info.as_ref().clone();
col_info.encoding = *inner;
Arc::new(col_info)
}
_ => col_info,
}
Expand All @@ -185,6 +186,13 @@ struct ZoneMap {
items: Vec<(Expr, NullableInterval)>,
}

#[derive(Debug)]
struct InitializedState {
zone_maps: Vec<ZoneMap>,
filter: Option<Expr>,
df_schema: Option<DFSchemaRef>,
}

/// A top level scheduler that refines the requested range based on
/// pushdown filtering with zone maps
#[derive(Debug)]
Expand All @@ -195,9 +203,7 @@ pub struct ZoneMapsFieldScheduler {
pushdown_buffers: HashMap<u32, UnloadedPushdown>,
rows_per_zone: u32,
num_rows: u64,
zone_maps: Vec<ZoneMap>,
filter: Option<Expr>,
df_schema: Option<DFSchemaRef>,
initialized_state: Mutex<Option<InitializedState>>,
}

impl ZoneMapsFieldScheduler {
Expand All @@ -215,9 +221,7 @@ impl ZoneMapsFieldScheduler {
rows_per_zone,
num_rows,
// These are set during initialization
zone_maps: Vec::new(),
filter: None,
df_schema: None,
initialized_state: Mutex::new(None),
}
}

Expand All @@ -233,13 +237,14 @@ impl ZoneMapsFieldScheduler {
.map(|pushdown| pushdown.position..pushdown.position + pushdown.size)
.collect();
let buffers = io.submit_request(ranges, 0).await?;
let maps = buffers
.into_iter()
.zip(pushdowns.iter())
.map(|(buffer, pushdown)| {
self.parse_zone(buffer, &pushdown.data_type, &pushdown.column)
})
.collect::<Result<Vec<_>>>()?;
let mut maps = Vec::new();
for (buffer, pushdown) in buffers.into_iter().zip(pushdowns.iter()) {
// There's no point in running this in parallel since it's actually synchronous
let map = self
.parse_zone(buffer, &pushdown.data_type, &pushdown.column)
.await?;
maps.push(map);
}
// A this point each item in `maps` is a vector of guarantees for a single field
// We need to transpose this so that each item is a vector of guarantees for a single zone
let zone_maps = transpose2(maps)
Expand Down Expand Up @@ -269,33 +274,47 @@ impl ZoneMapsFieldScheduler {
}

async fn do_initialize(
&mut self,
&self,
io: &dyn EncodingsIo,
cache: &FileMetadataCache,
filter: &FilterExpression,
) -> Result<()> {
if filter.is_noop() {
return Ok(());
}

let arrow_schema = ArrowSchema::from(self.schema.as_ref());
let df_schema = DFSchema::try_from(arrow_schema.clone())?;
let df_filter = filter.substrait_to_df(Arc::new(arrow_schema))?;

let columns = Planner::column_names_in_expr(&df_filter);
let referenced_schema = self.schema.project(&columns)?;

self.df_schema = Some(Arc::new(df_schema));
self.zone_maps = self.load_maps(io, cache, &referenced_schema).await?;
self.filter = Some(df_filter);
let df_schema = Some(Arc::new(df_schema));
let zone_maps = self.load_maps(io, cache, &referenced_schema).await?;
let filter = Some(df_filter);

let state = InitializedState {
zone_maps,
filter,
df_schema,
};
let mut initialized_state = self.initialized_state.lock().unwrap();
*initialized_state = Some(state);
Ok(())
}

fn create_filter(&self) -> Result<impl Fn(u64) -> bool + '_> {
Ok(move |zone_idx| {
let zone_map = &self.zone_maps[zone_idx as usize];
let state = self.initialized_state.lock().unwrap();
let state = state.as_ref().unwrap();
let zone_map = &state.zone_maps[zone_idx as usize];
let props = ExecutionProps::new();
let context =
SimplifyContext::new(&props).with_schema(self.df_schema.as_ref().unwrap().clone());
SimplifyContext::new(&props).with_schema(state.df_schema.as_ref().unwrap().clone());
let mut simplifier = ExprSimplifier::new(context);
simplifier = simplifier.with_guarantees(zone_map.items.clone());
match simplifier.simplify(self.filter.as_ref().unwrap().clone()) {
match simplifier.simplify(state.filter.as_ref().unwrap().clone()) {
Ok(expr) => match expr {
// Predicate, given guarantees, is always false, we can skip the zone
Expr::Literal(ScalarValue::Boolean(Some(false))) => false,
Expand Down Expand Up @@ -350,7 +369,7 @@ impl ZoneMapsFieldScheduler {
guarantees
}

fn parse_zone(
async fn parse_zone(
&self,
buffer: Bytes,
data_type: &DataType,
Expand All @@ -367,7 +386,8 @@ impl ZoneMapsFieldScheduler {
&zone_maps_batch,
&FilterExpression::no_filter(),
Arc::<DecoderMiddlewareChain>::default(),
)?;
)
.await?;

Ok(Self::extract_guarantees(
&zone_maps_batch,
Expand Down Expand Up @@ -419,7 +439,7 @@ impl SchedulingJob for EmptySchedulingJob {

impl FieldScheduler for ZoneMapsFieldScheduler {
fn initialize<'a>(
&'a mut self,
&'a self,
filter: &'a FilterExpression,
context: &'a SchedulerContext,
) -> BoxFuture<'a, Result<()>> {
Expand All @@ -435,6 +455,9 @@ impl FieldScheduler for ZoneMapsFieldScheduler {
ranges: &[std::ops::Range<u64>],
filter: &FilterExpression,
) -> Result<Box<dyn SchedulingJob + 'a>> {
if filter.is_noop() {
return self.inner.schedule_ranges(ranges, filter);
}
let zone_filter_fn = self.create_filter()?;
let zone_filter = ZoneMapsFilter::new(zone_filter_fn, self.rows_per_zone as u64);
let ranges = zone_filter.refine_ranges(ranges);
Expand Down
66 changes: 35 additions & 31 deletions rust/lance-encoding/benches/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ fn bench_decode(c: &mut Criterion) {
let func_name = format!("{:?}", data_type).to_lowercase();
group.bench_function(func_name, |b| {
b.iter(|| {
let batch = lance_encoding::decoder::decode_batch(
&encoded,
&FilterExpression::no_filter(),
Arc::<DecoderMiddlewareChain>::default(),
)
.unwrap();
let batch = rt
.block_on(lance_encoding::decoder::decode_batch(
&encoded,
&FilterExpression::no_filter(),
Arc::<DecoderMiddlewareChain>::default(),
))
.unwrap();
assert_eq!(data.num_rows(), batch.num_rows());
})
});
Expand Down Expand Up @@ -122,12 +123,13 @@ fn bench_decode_fsl(c: &mut Criterion) {
let func_name = format!("{:?}", data_type).to_lowercase();
group.bench_function(func_name, |b| {
b.iter(|| {
let batch = lance_encoding::decoder::decode_batch(
&encoded,
&FilterExpression::no_filter(),
Arc::<DecoderMiddlewareChain>::default(),
)
.unwrap();
let batch = rt
.block_on(lance_encoding::decoder::decode_batch(
&encoded,
&FilterExpression::no_filter(),
Arc::<DecoderMiddlewareChain>::default(),
))
.unwrap();
assert_eq!(data.num_rows(), batch.num_rows());
})
});
Expand Down Expand Up @@ -177,12 +179,13 @@ fn bench_decode_str_with_dict_encoding(c: &mut Criterion) {
let func_name = format!("{:?}", data_type).to_lowercase();
group.bench_function(func_name, |b| {
b.iter(|| {
let batch = lance_encoding::decoder::decode_batch(
&encoded,
&FilterExpression::no_filter(),
Arc::<DecoderMiddlewareChain>::default(),
)
.unwrap();
let batch = rt
.block_on(lance_encoding::decoder::decode_batch(
&encoded,
&FilterExpression::no_filter(),
Arc::<DecoderMiddlewareChain>::default(),
))
.unwrap();
assert_eq!(data.num_rows(), batch.num_rows());
})
});
Expand Down Expand Up @@ -215,7 +218,6 @@ fn bench_decode_packed_struct(c: &mut Criterion) {
.iter()
.map(|field| {
if matches!(field.data_type(), &DataType::Struct(_)) {
println!("Match");
let mut metadata = HashMap::new();
metadata.insert("packed".to_string(), "true".to_string());
let field =
Expand Down Expand Up @@ -246,12 +248,13 @@ fn bench_decode_packed_struct(c: &mut Criterion) {
let func_name = "struct";
group.bench_function(func_name, |b| {
b.iter(|| {
let batch = lance_encoding::decoder::decode_batch(
&encoded,
&FilterExpression::no_filter(),
Arc::<DecoderMiddlewareChain>::default(),
)
.unwrap();
let batch = rt
.block_on(lance_encoding::decoder::decode_batch(
&encoded,
&FilterExpression::no_filter(),
Arc::<DecoderMiddlewareChain>::default(),
))
.unwrap();
assert_eq!(data.num_rows(), batch.num_rows());
})
});
Expand Down Expand Up @@ -293,12 +296,13 @@ fn bench_decode_str_with_fixed_size_binary_encoding(c: &mut Criterion) {
let func_name = "fixed-utf8".to_string();
group.bench_function(func_name, |b| {
b.iter(|| {
let batch = lance_encoding::decoder::decode_batch(
&encoded,
&FilterExpression::no_filter(),
Arc::<DecoderMiddlewareChain>::default(),
)
.unwrap();
let batch = rt
.block_on(lance_encoding::decoder::decode_batch(
&encoded,
&FilterExpression::no_filter(),
Arc::<DecoderMiddlewareChain>::default(),
))
.unwrap();
assert_eq!(data.num_rows(), batch.num_rows());
})
});
Expand Down
Loading

0 comments on commit 8a7806b

Please sign in to comment.