diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index 6d4f0b45..b94e39b4 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -30,6 +30,9 @@ pub enum CoreError { #[error("Config error: {0}")] Config(#[from] ConfigError), + #[error("Data type error: {0}")] + DataType(String), + #[error("File group error: {0}")] FileGroup(String), diff --git a/crates/core/src/expr/filter.rs b/crates/core/src/expr/filter.rs new file mode 100644 index 00000000..cf7ee5d8 --- /dev/null +++ b/crates/core/src/expr/filter.rs @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::error::CoreError; +use crate::expr::ExprOperator; +use crate::Result; +use std::str::FromStr; + +#[derive(Debug, Clone)] +pub struct Filter { + pub field_name: String, + pub operator: ExprOperator, + pub field_value: String, +} + +impl Filter {} + +impl TryFrom<(&str, &str, &str)> for Filter { + type Error = CoreError; + + fn try_from(binary_expr_tuple: (&str, &str, &str)) -> Result { + let (field_name, operator_str, field_value) = binary_expr_tuple; + + let field_name = field_name.to_string(); + + let operator = ExprOperator::from_str(operator_str)?; + + let field_value = field_value.to_string(); + + Ok(Filter { + field_name, + operator, + field_value, + }) + } +} diff --git a/crates/core/src/expr/mod.rs b/crates/core/src/expr/mod.rs new file mode 100644 index 00000000..d592c3f6 --- /dev/null +++ b/crates/core/src/expr/mod.rs @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +pub mod filter; + +use crate::error::CoreError; +use crate::error::CoreError::Unsupported; + +use std::cmp::PartialEq; +use std::fmt::{Display, Formatter, Result as FmtResult}; +use std::str::FromStr; + +/// An operator that represents a comparison operation used in a partition filter expression. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ExprOperator { + Eq, + Ne, + Lt, + Lte, + Gt, + Gte, +} + +impl Display for ExprOperator { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + // Binary Operators + ExprOperator::Eq => write!(f, "="), + ExprOperator::Ne => write!(f, "!="), + ExprOperator::Lt => write!(f, "<"), + ExprOperator::Lte => write!(f, "<="), + ExprOperator::Gt => write!(f, ">"), + ExprOperator::Gte => write!(f, ">="), + } + } +} + +impl ExprOperator { + pub const TOKEN_OP_PAIRS: [(&'static str, ExprOperator); 6] = [ + ("=", ExprOperator::Eq), + ("!=", ExprOperator::Ne), + ("<", ExprOperator::Lt), + ("<=", ExprOperator::Lte), + (">", ExprOperator::Gt), + (">=", ExprOperator::Gte), + ]; + + /// Negates the operator. + pub fn negate(&self) -> Option { + match self { + ExprOperator::Eq => Some(ExprOperator::Ne), + ExprOperator::Ne => Some(ExprOperator::Eq), + ExprOperator::Lt => Some(ExprOperator::Gte), + ExprOperator::Lte => Some(ExprOperator::Gt), + ExprOperator::Gt => Some(ExprOperator::Lte), + ExprOperator::Gte => Some(ExprOperator::Lt), + } + } +} + +impl FromStr for ExprOperator { + type Err = CoreError; + + fn from_str(s: &str) -> Result { + ExprOperator::TOKEN_OP_PAIRS + .iter() + .find_map(|&(token, op)| { + if token.eq_ignore_ascii_case(s) { + Some(op) + } else { + None + } + }) + .ok_or_else(|| Unsupported(format!("Unsupported operator: {}", s))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_operator_from_str() { + assert_eq!(ExprOperator::from_str("=").unwrap(), ExprOperator::Eq); + assert_eq!(ExprOperator::from_str("!=").unwrap(), ExprOperator::Ne); + assert_eq!(ExprOperator::from_str("<").unwrap(), ExprOperator::Lt); + assert_eq!(ExprOperator::from_str("<=").unwrap(), ExprOperator::Lte); + assert_eq!(ExprOperator::from_str(">").unwrap(), ExprOperator::Gt); + assert_eq!(ExprOperator::from_str(">=").unwrap(), ExprOperator::Gte); + assert!(ExprOperator::from_str("??").is_err()); + } + + #[test] + fn test_operator_display() { + assert_eq!(ExprOperator::Eq.to_string(), "="); + assert_eq!(ExprOperator::Ne.to_string(), "!="); + assert_eq!(ExprOperator::Lt.to_string(), "<"); + assert_eq!(ExprOperator::Lte.to_string(), "<="); + assert_eq!(ExprOperator::Gt.to_string(), ">"); + assert_eq!(ExprOperator::Gte.to_string(), ">="); + } +} diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 42079a17..c9ce8154 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -45,6 +45,7 @@ pub mod config; pub mod error; +pub mod expr; pub mod file_group; pub mod storage; pub mod table; diff --git a/crates/core/src/table/fs_view.rs b/crates/core/src/table/fs_view.rs index 092b352a..6ca5d9b1 100644 --- a/crates/core/src/table/fs_view.rs +++ b/crates/core/src/table/fs_view.rs @@ -180,10 +180,12 @@ impl FileSystemView { mod tests { use crate::config::table::HudiTableConfig; use crate::config::HudiConfigs; + use crate::expr::filter::Filter; use crate::storage::Storage; use crate::table::fs_view::FileSystemView; use crate::table::partition::PartitionPruner; use crate::table::Table; + use hudi_tests::TestTable; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -298,12 +300,16 @@ mod tests { .await .unwrap(); let partition_schema = hudi_table.get_partition_schema().await.unwrap(); + + let filter_lt_20 = Filter::try_from(("byteField", "<", "20")).unwrap(); + let filter_eq_300 = Filter::try_from(("shortField", "=", "300")).unwrap(); let partition_pruner = PartitionPruner::new( - &[("byteField", "<", "20"), ("shortField", "=", "300")], + &[filter_lt_20, filter_eq_300], &partition_schema, hudi_table.hudi_configs.as_ref(), ) .unwrap(); + let file_slices = fs_view .get_file_slices_as_of("20240418173235694", &partition_pruner, excludes) .await diff --git a/crates/core/src/table/mod.rs b/crates/core/src/table/mod.rs index 9e7a495c..f83aa91a 100644 --- a/crates/core/src/table/mod.rs +++ b/crates/core/src/table/mod.rs @@ -84,18 +84,17 @@ //! } //! ``` -use std::collections::{HashMap, HashSet}; -use std::sync::Arc; - -use arrow::record_batch::RecordBatch; -use arrow_schema::{Field, Schema}; -use url::Url; +pub mod builder; +mod fs_view; +pub mod partition; +mod timeline; use crate::config::read::HudiReadConfig::AsOfTimestamp; use crate::config::table::HudiTableConfig; use crate::config::table::HudiTableConfig::PartitionFields; use crate::config::HudiConfigs; use crate::error::CoreError; +use crate::expr::filter::Filter; use crate::file_group::reader::FileGroupReader; use crate::file_group::FileSlice; use crate::table::builder::TableBuilder; @@ -104,10 +103,11 @@ use crate::table::partition::PartitionPruner; use crate::table::timeline::Timeline; use crate::Result; -pub mod builder; -mod fs_view; -mod partition; -mod timeline; +use arrow::record_batch::RecordBatch; +use arrow_schema::{Field, Schema}; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; +use url::Url; /// Hudi Table in-memory #[derive(Clone, Debug)] @@ -195,10 +195,11 @@ impl Table { /// The file slices are split into `n` chunks. /// /// If the [AsOfTimestamp] configuration is set, the file slices at the specified timestamp will be returned. + /// pub async fn get_file_slices_splits( &self, n: usize, - filters: &[(&str, &str, &str)], + filters: &[Filter], ) -> Result>> { let file_slices = self.get_file_slices(filters).await?; if file_slices.is_empty() { @@ -217,7 +218,7 @@ impl Table { /// Get all the [FileSlice]s in the table. /// /// If the [AsOfTimestamp] configuration is set, the file slices at the specified timestamp will be returned. - pub async fn get_file_slices(&self, filters: &[(&str, &str, &str)]) -> Result> { + pub async fn get_file_slices(&self, filters: &[Filter]) -> Result> { if let Some(timestamp) = self.hudi_configs.try_get(AsOfTimestamp) { self.get_file_slices_as_of(timestamp.to::().as_str(), filters) .await @@ -232,7 +233,7 @@ impl Table { async fn get_file_slices_as_of( &self, timestamp: &str, - filters: &[(&str, &str, &str)], + filters: &[Filter], ) -> Result> { let excludes = self.timeline.get_replaced_file_groups().await?; let partition_schema = self.get_partition_schema().await?; @@ -250,7 +251,7 @@ impl Table { /// Get all the latest records in the table. /// /// If the [AsOfTimestamp] configuration is set, the records at the specified timestamp will be returned. - pub async fn read_snapshot(&self, filters: &[(&str, &str, &str)]) -> Result> { + pub async fn read_snapshot(&self, filters: &[Filter]) -> Result> { if let Some(timestamp) = self.hudi_configs.try_get(AsOfTimestamp) { self.read_snapshot_as_of(timestamp.to::().as_str(), filters) .await @@ -265,7 +266,7 @@ impl Table { async fn read_snapshot_as_of( &self, timestamp: &str, - filters: &[(&str, &str, &str)], + filters: &[Filter], ) -> Result> { let file_slices = self.get_file_slices_as_of(timestamp, filters).await?; let fg_reader = self.create_file_group_reader(); @@ -298,12 +299,13 @@ impl Table { mod tests { use super::*; use arrow_array::StringArray; + use hudi_tests::{assert_not, TestTable}; + use std::collections::HashSet; use std::fs::canonicalize; use std::path::PathBuf; use std::{env, panic}; - use hudi_tests::{assert_not, TestTable}; - + use crate::config::read::HudiReadConfig::AsOfTimestamp; use crate::config::table::HudiTableConfig::{ BaseFileFormat, Checksum, DatabaseName, DropsPartitionFields, IsHiveStylePartitioning, IsPartitionPathUrlencoded, KeyGeneratorClass, PartitionFields, PopulatesMetaFields, @@ -313,6 +315,7 @@ mod tests { use crate::config::HUDI_CONF_DIR; use crate::storage::util::join_url_segments; use crate::storage::Storage; + use crate::table::Filter; /// Test helper to create a new `Table` instance without validating the configuration. /// @@ -333,10 +336,7 @@ mod tests { } /// Test helper to get relative file paths from the table with filters. - async fn get_file_paths_with_filters( - table: &Table, - filters: &[(&str, &str, &str)], - ) -> Result> { + async fn get_file_paths_with_filters(table: &Table, filters: &[Filter]) -> Result> { let mut file_paths = Vec::new(); for f in table.get_file_slices(filters).await? { file_paths.push(f.base_file_path().to_string()); @@ -733,8 +733,11 @@ mod tests { .collect::>(); assert_eq!(actual, expected); - let partition_filters = &[("byteField", ">=", "10"), ("byteField", "<", "30")]; - let actual = get_file_paths_with_filters(&hudi_table, partition_filters) + let filter_ge_10 = Filter::try_from(("byteField", ">=", "10")).unwrap(); + + let filter_lt_30 = Filter::try_from(("byteField", "<", "30")).unwrap(); + + let actual = get_file_paths_with_filters(&hudi_table, &[filter_ge_10, filter_lt_30]) .await .unwrap() .into_iter() @@ -748,8 +751,8 @@ mod tests { .collect::>(); assert_eq!(actual, expected); - let partition_filters = &[("byteField", ">", "30")]; - let actual = get_file_paths_with_filters(&hudi_table, partition_filters) + let filter_gt_30 = Filter::try_from(("byteField", ">", "30")).unwrap(); + let actual = get_file_paths_with_filters(&hudi_table, &[filter_gt_30]) .await .unwrap() .into_iter() @@ -780,16 +783,16 @@ mod tests { .collect::>(); assert_eq!(actual, expected); - let partition_filters = &[ - ("byteField", ">=", "10"), - ("byteField", "<", "20"), - ("shortField", "!=", "100"), - ]; - let actual = get_file_paths_with_filters(&hudi_table, partition_filters) - .await - .unwrap() - .into_iter() - .collect::>(); + let filter_gte_10 = Filter::try_from(("byteField", ">=", "10")).unwrap(); + let filter_lt_20 = Filter::try_from(("byteField", "<", "20")).unwrap(); + let filter_ne_100 = Filter::try_from(("shortField", "!=", "100")).unwrap(); + + let actual = + get_file_paths_with_filters(&hudi_table, &[filter_gte_10, filter_lt_20, filter_ne_100]) + .await + .unwrap() + .into_iter() + .collect::>(); let expected = [ "byteField=10/shortField=300/a22e8257-e249-45e9-ba46-115bc85adcba-0_0-161-223_20240418173235694.parquet", ] @@ -797,9 +800,10 @@ mod tests { .into_iter() .collect::>(); assert_eq!(actual, expected); + let filter_lt_20 = Filter::try_from(("byteField", ">", "20")).unwrap(); + let filter_eq_300 = Filter::try_from(("shortField", "=", "300")).unwrap(); - let partition_filters = &[("byteField", ">", "20"), ("shortField", "=", "300")]; - let actual = get_file_paths_with_filters(&hudi_table, partition_filters) + let actual = get_file_paths_with_filters(&hudi_table, &[filter_lt_20, filter_eq_300]) .await .unwrap() .into_iter() @@ -812,12 +816,15 @@ mod tests { async fn hudi_table_read_snapshot_for_complex_keygen_hive_style() { let base_url = TestTable::V6ComplexkeygenHivestyle.url(); let hudi_table = Table::new(base_url.path()).await.unwrap(); - let partition_filters = &[ - ("byteField", ">=", "10"), - ("byteField", "<", "20"), - ("shortField", "!=", "100"), - ]; - let records = hudi_table.read_snapshot(partition_filters).await.unwrap(); + + let filter_gte_10 = Filter::try_from(("byteField", ">=", "10")).unwrap(); + let filter_lt_20 = Filter::try_from(("byteField", "<", "20")).unwrap(); + let filter_ne_100 = Filter::try_from(("shortField", "!=", "100")).unwrap(); + + let records = hudi_table + .read_snapshot(&[filter_gte_10, filter_lt_20, filter_ne_100]) + .await + .unwrap(); assert_eq!(records.len(), 1); assert_eq!(records[0].num_rows(), 2); let actual_partition_paths: HashSet<&str> = HashSet::from_iter( diff --git a/crates/core/src/table/partition.rs b/crates/core/src/table/partition.rs index ca41f731..14f6b30f 100644 --- a/crates/core/src/table/partition.rs +++ b/crates/core/src/table/partition.rs @@ -18,16 +18,19 @@ */ use crate::config::table::HudiTableConfig; use crate::config::HudiConfigs; -use crate::error::CoreError; -use crate::error::CoreError::{InvalidPartitionPath, Unsupported}; +use crate::error::CoreError::InvalidPartitionPath; +use crate::expr::filter::Filter; +use crate::expr::ExprOperator; use crate::Result; + use arrow_array::{ArrayRef, Scalar, StringArray}; use arrow_cast::{cast_with_options, CastOptions}; use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq}; -use arrow_schema::{DataType, Field, Schema}; -use std::cmp::PartialEq; +use arrow_schema::Schema; +use arrow_schema::{DataType, Field}; + +use crate::table::CoreError; use std::collections::HashMap; -use std::str::FromStr; use std::sync::Arc; /// A partition pruner that filters partitions based on the partition path and its filters. @@ -41,13 +44,13 @@ pub struct PartitionPruner { impl PartitionPruner { pub fn new( - and_filters: &[(&str, &str, &str)], + and_filters: &[Filter], partition_schema: &Schema, hudi_configs: &HudiConfigs, ) -> Result { let and_filters = and_filters .iter() - .map(|filter| PartitionFilter::try_from((*filter, partition_schema))) + .map(|filter| PartitionFilter::try_from((filter.clone(), partition_schema))) .collect::>>()?; let schema = Arc::new(partition_schema.clone()); @@ -91,12 +94,12 @@ impl PartitionPruner { match segments.get(filter.field.name()) { Some(segment_value) => { let comparison_result = match filter.operator { - Operator::Eq => eq(segment_value, &filter.value), - Operator::Ne => neq(segment_value, &filter.value), - Operator::Lt => lt(segment_value, &filter.value), - Operator::Lte => lt_eq(segment_value, &filter.value), - Operator::Gt => gt(segment_value, &filter.value), - Operator::Gte => gt_eq(segment_value, &filter.value), + ExprOperator::Eq => eq(segment_value, &filter.value), + ExprOperator::Ne => neq(segment_value, &filter.value), + ExprOperator::Lt => lt(segment_value, &filter.value), + ExprOperator::Lte => lt_eq(segment_value, &filter.value), + ExprOperator::Gt => gt(segment_value, &filter.value), + ExprOperator::Gte => gt_eq(segment_value, &filter.value), }; match comparison_result { @@ -155,60 +158,24 @@ impl PartitionPruner { } } -/// An operator that represents a comparison operation used in a partition filter expression. -#[derive(Debug, Clone, Copy, PartialEq)] -enum Operator { - Eq, - Ne, - Lt, - Lte, - Gt, - Gte, -} - -impl Operator { - const TOKEN_OP_PAIRS: [(&'static str, Operator); 6] = [ - ("=", Operator::Eq), - ("!=", Operator::Ne), - ("<", Operator::Lt), - ("<=", Operator::Lte), - (">", Operator::Gt), - (">=", Operator::Gte), - ]; -} - -impl FromStr for Operator { - type Err = CoreError; - - fn from_str(s: &str) -> Result { - Operator::TOKEN_OP_PAIRS - .iter() - .find_map(|&(token, op)| if token == s { Some(op) } else { None }) - .ok_or(Unsupported(format!("Unsupported operator: {}", s))) - } -} - /// A partition filter that represents a filter expression for partition pruning. #[derive(Debug, Clone)] pub struct PartitionFilter { - field: Field, - operator: Operator, - value: Scalar, + pub field: Field, + pub operator: ExprOperator, + pub value: Scalar, } -impl TryFrom<((&str, &str, &str), &Schema)> for PartitionFilter { +impl TryFrom<(Filter, &Schema)> for PartitionFilter { type Error = CoreError; - fn try_from( - (filter, partition_schema): ((&str, &str, &str), &Schema), - ) -> Result { - let (field_name, operator_str, value_str) = filter; + fn try_from((filter, partition_schema): (Filter, &Schema)) -> Result { + let field: &Field = partition_schema + .field_with_name(&filter.field_name) + .map_err(|_| InvalidPartitionPath("Partition path should be in schema.".to_string()))?; - let field: &Field = partition_schema.field_with_name(field_name)?; - - let operator = Operator::from_str(operator_str)?; - - let value = &[value_str]; + let operator = filter.operator; + let value = &[filter.field_value.as_str()]; let value = Self::cast_value(value, field.data_type())?; let field = field.clone(); @@ -221,7 +188,7 @@ impl TryFrom<((&str, &str, &str), &Schema)> for PartitionFilter { } impl PartitionFilter { - fn cast_value(value: &[&str; 1], data_type: &DataType) -> Result> { + pub fn cast_value(value: &[&str; 1], data_type: &DataType) -> Result> { let cast_options = CastOptions { safe: false, format_options: Default::default(), @@ -229,11 +196,11 @@ impl PartitionFilter { let value = StringArray::from(Vec::from(value)); - Ok(Scalar::new(cast_with_options( - &value, - data_type, - &cast_options, - )?)) + Ok(Scalar::new( + cast_with_options(&value, data_type, &cast_options).map_err(|e| { + CoreError::DataType(format!("Unable to cast {:?}: {:?}", data_type, e)) + })?, + )) } } @@ -243,8 +210,9 @@ mod tests { use crate::config::table::HudiTableConfig::{ IsHiveStylePartitioning, IsPartitionPathUrlencoded, }; + use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::{Array, Datum}; + use arrow_array::Date32Array; use hudi_tests::assert_not; use std::str::FromStr; @@ -256,90 +224,6 @@ mod tests { ]) } - #[test] - fn test_partition_filter_try_from_valid() { - let schema = create_test_schema(); - let filter_tuple = ("date", "=", "2023-01-01"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_ok()); - let filter = filter.unwrap(); - assert_eq!(filter.field.name(), "date"); - assert_eq!(filter.operator, Operator::Eq); - assert_eq!(filter.value.get().0.len(), 1); - - let filter_tuple = ("category", "!=", "foo"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_ok()); - let filter = filter.unwrap(); - assert_eq!(filter.field.name(), "category"); - assert_eq!(filter.operator, Operator::Ne); - assert_eq!(filter.value.get().0.len(), 1); - assert_eq!( - StringArray::from(filter.value.into_inner().to_data()).value(0), - "foo" - ) - } - - #[test] - fn test_partition_filter_try_from_invalid_field() { - let schema = create_test_schema(); - let filter_tuple = ("invalid_field", "=", "2023-01-01"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_err()); - assert!(filter - .unwrap_err() - .to_string() - .contains("Unable to get field named")); - } - - #[test] - fn test_partition_filter_try_from_invalid_operator() { - let schema = create_test_schema(); - let filter_tuple = ("date", "??", "2023-01-01"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_err()); - assert!(filter - .unwrap_err() - .to_string() - .contains("Unsupported operator: ??")); - } - - #[test] - fn test_partition_filter_try_from_invalid_value() { - let schema = create_test_schema(); - let filter_tuple = ("count", "=", "not_a_number"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_err()); - assert!(filter - .unwrap_err() - .to_string() - .contains("Cannot cast string")); - } - - #[test] - fn test_partition_filter_try_from_all_operators() { - let schema = create_test_schema(); - for (op, _) in Operator::TOKEN_OP_PAIRS { - let filter_tuple = ("count", op, "10"); - let filter = PartitionFilter::try_from((filter_tuple, &schema)); - assert!(filter.is_ok(), "Failed for operator: {}", op); - let filter = filter.unwrap(); - assert_eq!(filter.field.name(), "count"); - assert_eq!(filter.operator, Operator::from_str(op).unwrap()); - } - } - - #[test] - fn test_operator_from_str() { - assert_eq!(Operator::from_str("=").unwrap(), Operator::Eq); - assert_eq!(Operator::from_str("!=").unwrap(), Operator::Ne); - assert_eq!(Operator::from_str("<").unwrap(), Operator::Lt); - assert_eq!(Operator::from_str("<=").unwrap(), Operator::Lte); - assert_eq!(Operator::from_str(">").unwrap(), Operator::Gt); - assert_eq!(Operator::from_str(">=").unwrap(), Operator::Gte); - assert!(Operator::from_str("??").is_err()); - } - fn create_hudi_configs(is_hive_style: bool, is_url_encoded: bool) -> HudiConfigs { HudiConfigs::new([ (IsHiveStylePartitioning, is_hive_style.to_string()), @@ -350,9 +234,11 @@ mod tests { fn test_partition_pruner_new() { let schema = create_test_schema(); let configs = create_hudi_configs(true, false); - let filters = vec![("date", ">", "2023-01-01"), ("category", "=", "A")]; - let pruner = PartitionPruner::new(&filters, &schema, &configs); + let filter_gt_date = Filter::try_from(("date", ">", "2023-01-01")).unwrap(); + let filter_eq_a = Filter::try_from(("category", "=", "A")).unwrap(); + + let pruner = PartitionPruner::new(&[filter_gt_date, filter_eq_a], &schema, &configs); assert!(pruner.is_ok()); let pruner = pruner.unwrap(); @@ -377,8 +263,8 @@ mod tests { let pruner_empty = PartitionPruner::new(&[], &schema, &configs).unwrap(); assert!(pruner_empty.is_empty()); - let pruner_non_empty = - PartitionPruner::new(&[("date", ">", "2023-01-01")], &schema, &configs).unwrap(); + let filter_gt_date = Filter::try_from(("date", ">", "2023-01-01")).unwrap(); + let pruner_non_empty = PartitionPruner::new(&[filter_gt_date], &schema, &configs).unwrap(); assert_not!(pruner_non_empty.is_empty()); } @@ -386,13 +272,17 @@ mod tests { fn test_partition_pruner_should_include() { let schema = create_test_schema(); let configs = create_hudi_configs(true, false); - let filters = vec![ - ("date", ">", "2023-01-01"), - ("category", "=", "A"), - ("count", "<=", "100"), - ]; - let pruner = PartitionPruner::new(&filters, &schema, &configs).unwrap(); + let filter_gt_date = Filter::try_from(("date", ">", "2023-01-01")).unwrap(); + let filter_eq_a = Filter::try_from(("category", "=", "A")).unwrap(); + let filter_lte_100 = Filter::try_from(("count", "<=", "100")).unwrap(); + + let pruner = PartitionPruner::new( + &[filter_gt_date, filter_eq_a, filter_lte_100], + &schema, + &configs, + ) + .unwrap(); assert!(pruner.should_include("date=2023-02-01/category=A/count=10")); assert!(pruner.should_include("date=2023-02-01/category=A/count=100")); @@ -445,4 +335,69 @@ mod tests { let result = pruner.parse_segments("date=2023-02-01/category=A/non_exist_field=10"); assert!(matches!(result.unwrap_err(), InvalidPartitionPath(_))); } + + #[test] + fn test_partition_filter_try_from_valid() { + let schema = create_test_schema(); + let filter = Filter { + field_name: "date".to_string(), + operator: ExprOperator::Eq, + field_value: "2023-01-01".to_string(), + }; + + let partition_filter = PartitionFilter::try_from((filter, &schema)).unwrap(); + assert_eq!(partition_filter.field.name(), "date"); + assert_eq!(partition_filter.operator, ExprOperator::Eq); + + let value_inner = partition_filter.value.into_inner(); + + let date_array = value_inner.as_any().downcast_ref::().unwrap(); + + let date_value = date_array.value_as_date(0).unwrap(); + assert_eq!(date_value.to_string(), "2023-01-01"); + } + + #[test] + fn test_partition_filter_try_from_invalid_field() { + let schema = create_test_schema(); + let filter = Filter { + field_name: "invalid_field".to_string(), + operator: ExprOperator::Eq, + field_value: "2023-01-01".to_string(), + }; + let result = PartitionFilter::try_from((filter, &schema)); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Partition path should be in schema.")); + } + + #[test] + fn test_partition_filter_try_from_invalid_value() { + let schema = create_test_schema(); + let filter = Filter { + field_name: "count".to_string(), + operator: ExprOperator::Eq, + field_value: "not_a_number".to_string(), + }; + let result = PartitionFilter::try_from((filter, &schema)); + assert!(result.is_err()); + } + + #[test] + fn test_partition_filter_try_from_all_operators() { + let schema = create_test_schema(); + for (op, _) in ExprOperator::TOKEN_OP_PAIRS { + let filter = Filter { + field_name: "count".to_string(), + operator: ExprOperator::from_str(op).unwrap(), + field_value: "5".to_string(), + }; + let partition_filter = PartitionFilter::try_from((filter, &schema)); + let filter = partition_filter.unwrap(); + assert_eq!(filter.field.name(), "count"); + assert_eq!(filter.operator, ExprOperator::from_str(op).unwrap()); + } + } } diff --git a/crates/datafusion/Cargo.toml b/crates/datafusion/Cargo.toml index 120aa8a1..53bc2e45 100644 --- a/crates/datafusion/Cargo.toml +++ b/crates/datafusion/Cargo.toml @@ -30,6 +30,9 @@ repository.workspace = true [dependencies] hudi-core = { version = "0.3.0", path = "../core", features = ["datafusion"] } # arrow +arrow = { workspace = true } +arrow-array = { workspace = true } +arrow-cast = { workspace = true } arrow-schema = { workspace = true } # datafusion diff --git a/crates/datafusion/src/lib.rs b/crates/datafusion/src/lib.rs index 33f39870..a976a0f7 100644 --- a/crates/datafusion/src/lib.rs +++ b/crates/datafusion/src/lib.rs @@ -17,6 +17,8 @@ * under the License. */ +pub(crate) mod util; + use std::any::Any; use std::collections::HashMap; use std::fmt::Debug; @@ -31,14 +33,16 @@ use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::parquet::ParquetExecBuilder; use datafusion::datasource::physical_plan::FileScanConfig; use datafusion::datasource::TableProvider; +use datafusion::logical_expr::Operator; use datafusion::physical_plan::ExecutionPlan; use datafusion_common::config::TableParquetOptions; use datafusion_common::DFSchema; use datafusion_common::DataFusionError::Execution; use datafusion_common::Result; -use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_expr::{CreateExternalTable, Expr, TableProviderFilterPushDown, TableType}; use datafusion_physical_expr::create_physical_expr; +use crate::util::expr::exprs_to_filters; use hudi_core::config::read::HudiReadConfig::InputPartitions; use hudi_core::config::util::empty_options; use hudi_core::storage::util::{get_scheme_authority, parse_uri}; @@ -54,7 +58,7 @@ use hudi_core::table::Table as HudiTable; /// /// use datafusion::error::Result; /// use datafusion::prelude::{DataFrame, SessionContext}; -/// use hudi::HudiDataSource; +/// use hudi_datafusion::HudiDataSource; /// /// // Initialize a new DataFusion session context /// let ctx = SessionContext::new(); @@ -62,7 +66,7 @@ use hudi_core::table::Table as HudiTable; /// // Create a new HudiDataSource with specific read options /// let hudi = HudiDataSource::new_with_options( /// "/tmp/trips_table", -/// [("hoodie.read.as.of.timestamp", "20241122010827898")]).await?; +/// [("hoodie.read.as.of.timestamp", "20241122010827898")]).await?; /// /// // Register the Hudi table with the session context /// ctx.register_table("trips_table", Arc::new(hudi))?; @@ -98,6 +102,42 @@ impl HudiDataSource { .get_or_default(InputPartitions) .to::() } + + /// Check if the given expression can be pushed down to the Hudi table. + /// + /// The expression can be pushed down if it is a binary expression with a supported operator and operands. + fn can_push_down(&self, expr: &Expr) -> bool { + match expr { + Expr::BinaryExpr(binary_expr) => { + let left = &binary_expr.left; + let op = &binary_expr.op; + let right = &binary_expr.right; + self.is_supported_operator(op) + && self.is_supported_operand(left) + && self.is_supported_operand(right) + } + Expr::Not(inner_expr) => { + // Recursively check if the inner expression can be pushed down + self.can_push_down(inner_expr) + } + _ => false, + } + } + + fn is_supported_operator(&self, op: &Operator) -> bool { + matches!( + op, + Operator::Eq | Operator::Gt | Operator::Lt | Operator::GtEq | Operator::LtEq + ) + } + + fn is_supported_operand(&self, expr: &Expr) -> bool { + match expr { + Expr::Column(col) => self.schema().field_with_name(&col.name).is_ok(), + Expr::Literal(_) => true, + _ => false, + } + } } #[async_trait] @@ -129,10 +169,11 @@ impl TableProvider for HudiDataSource { ) -> Result> { self.table.register_storage(state.runtime_env().clone()); + // Convert Datafusion `Expr` to `Filter` + let pushdown_filters = exprs_to_filters(filters); let file_slices = self .table - // TODO: implement supports_filters_pushdown() to pass filters to Hudi table API - .get_file_slices_splits(self.get_input_partitions(), &[]) + .get_file_slices_splits(self.get_input_partitions(), pushdown_filters.as_slice()) .await .map_err(|e| Execution(format!("Failed to get file slices from Hudi table: {}", e)))?; let mut parquet_file_groups: Vec> = Vec::new(); @@ -176,6 +217,22 @@ impl TableProvider for HudiDataSource { Ok(exec_builder.build_arc()) } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + filters + .iter() + .map(|expr| { + if self.can_push_down(expr) { + Ok(TableProviderFilterPushDown::Inexact) + } else { + Ok(TableProviderFilterPushDown::Unsupported) + } + }) + .collect() + } } /// `HudiTableFactory` is responsible for creating and configuring Hudi tables. @@ -188,7 +245,8 @@ impl TableProvider for HudiDataSource { /// Creating a new `HudiTableFactory` instance: /// /// ```rust -/// use hudi::HudiTableFactory; +/// use datafusion::prelude::SessionContext; +/// use hudi_datafusion::HudiTableFactory; /// /// // Initialize a new HudiTableFactory /// let factory = HudiTableFactory::new(); @@ -265,12 +323,13 @@ mod tests { use super::*; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::{SessionConfig, SessionContext}; - use datafusion_common::{DataFusionError, ScalarValue}; + use datafusion_common::{Column, DataFusionError, ScalarValue}; use std::fs::canonicalize; use std::path::Path; use std::sync::Arc; use url::Url; + use datafusion::logical_expr::BinaryExpr; use hudi_core::config::read::HudiReadConfig::InputPartitions; use hudi_tests::TestTable::{ V6ComplexkeygenHivestyle, V6Empty, V6Nonpartitioned, V6SimplekeygenHivestyleNoMetafields, @@ -479,4 +538,57 @@ mod tests { verify_data_with_replacecommits(&ctx, &sql, test_table.as_ref()).await } } + + #[tokio::test] + async fn test_supports_filters_pushdown() { + let table_provider = + HudiDataSource::new_with_options(V6Nonpartitioned.path().as_str(), empty_options()) + .await + .unwrap(); + + let expr1 = Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::Column(Column::from_name("name".to_string()))), + op: Operator::Eq, + right: Box::new(Expr::Literal(ScalarValue::Utf8(Some("Alice".to_string())))), + }); + + let expr2 = Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::Column(Column::from_name("intField".to_string()))), + op: Operator::Gt, + right: Box::new(Expr::Literal(ScalarValue::Int32(Some(20000)))), + }); + + let expr3 = Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::Column(Column::from_name( + "nonexistent_column".to_string(), + ))), + op: Operator::Eq, + right: Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), + }); + + let expr4 = Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::Column(Column::from_name("name".to_string()))), + op: Operator::NotEq, + right: Box::new(Expr::Literal(ScalarValue::Utf8(Some("Diana".to_string())))), + }); + + let expr5 = Expr::Literal(ScalarValue::Int32(Some(10))); + + let expr6 = Expr::Not(Box::new(Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::Column(Column::from_name("intField".to_string()))), + op: Operator::Gt, + right: Box::new(Expr::Literal(ScalarValue::Int32(Some(20000)))), + }))); + + let filters = vec![&expr1, &expr2, &expr3, &expr4, &expr5, &expr6]; + let result = table_provider.supports_filters_pushdown(&filters).unwrap(); + + assert_eq!(result.len(), 6); + assert_eq!(result[0], TableProviderFilterPushDown::Inexact); + assert_eq!(result[1], TableProviderFilterPushDown::Inexact); + assert_eq!(result[2], TableProviderFilterPushDown::Unsupported); + assert_eq!(result[3], TableProviderFilterPushDown::Unsupported); + assert_eq!(result[4], TableProviderFilterPushDown::Unsupported); + assert_eq!(result[5], TableProviderFilterPushDown::Inexact); + } } diff --git a/crates/datafusion/src/util/expr.rs b/crates/datafusion/src/util/expr.rs new file mode 100644 index 00000000..daa21b61 --- /dev/null +++ b/crates/datafusion/src/util/expr.rs @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use datafusion::logical_expr::Operator; +use datafusion_expr::{BinaryExpr, Expr}; +use hudi_core::expr::filter::Filter as HudiFilter; +use hudi_core::expr::ExprOperator; + +/// Converts DataFusion expressions into Hudi filters. +/// +/// Takes a slice of DataFusion [`Expr`] and attempts to convert each expression +/// into a [`HudiFilter`]. Only binary expressions and NOT expressions are currently supported. +/// +/// # Arguments +/// * `exprs` - A slice of DataFusion expressions to convert +/// +/// # Returns +/// Returns `Some(Vec)` if at least one filter is successfully converted, +/// otherwise returns `None`. +/// +/// TODO: Handle other DataFusion [`Expr`] +pub fn exprs_to_filters(exprs: &[Expr]) -> Vec { + let mut filters: Vec = Vec::new(); + + for expr in exprs { + match expr { + Expr::BinaryExpr(binary_expr) => { + if let Some(filter) = binary_expr_to_filter(binary_expr) { + filters.push(filter); + } + } + Expr::Not(not_expr) => { + if let Some(filter) = not_expr_to_filter(not_expr) { + filters.push(filter); + } + } + _ => {} + } + } + + filters +} + +/// Converts a binary expression [`Expr::BinaryExpr`] into a [`HudiFilter`]. +fn binary_expr_to_filter(binary_expr: &BinaryExpr) -> Option { + // extract the column and literal from the binary expression + let (column, literal) = match (&*binary_expr.left, &*binary_expr.right) { + (Expr::Column(col), Expr::Literal(lit)) => (col, lit), + (Expr::Literal(lit), Expr::Column(col)) => (col, lit), + _ => return None, + }; + + let field_name = column.name().to_string(); + + let operator = match binary_expr.op { + Operator::Eq => ExprOperator::Eq, + Operator::NotEq => ExprOperator::Ne, + Operator::Lt => ExprOperator::Lt, + Operator::LtEq => ExprOperator::Lte, + Operator::Gt => ExprOperator::Gt, + Operator::GtEq => ExprOperator::Gte, + _ => return None, + }; + + let value = literal.to_string(); + + Some(HudiFilter { + field_name, + operator, + field_value: value, + }) +} + +/// Converts a NOT expression (`Expr::Not`) into a `PartitionFilter`. +fn not_expr_to_filter(not_expr: &Expr) -> Option { + match not_expr { + Expr::BinaryExpr(ref binary_expr) => { + let mut filter = binary_expr_to_filter(binary_expr)?; + filter.operator = filter.operator.negate()?; + Some(filter) + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::{DataType, Field, Schema}; + use datafusion::logical_expr::{col, lit}; + use datafusion_expr::{BinaryExpr, Expr}; + use hudi_core::expr::ExprOperator; + use std::str::FromStr; + use std::sync::Arc; + + #[test] + fn test_convert_simple_binary_expr() { + let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Int32, false)])); + + let expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col("col")), + Operator::Eq, + Box::new(lit(42i32)), + )); + + let filters = vec![expr]; + + let result = exprs_to_filters(&filters); + + assert_eq!(result.len(), 1); + + let expected_filter = HudiFilter { + field_name: schema.field(0).name().to_string(), + operator: ExprOperator::Eq, + field_value: "42".to_string(), + }; + + assert_eq!(result[0].field_name, expected_filter.field_name); + assert_eq!(result[0].operator, expected_filter.operator); + assert_eq!(*result[0].field_value.clone(), expected_filter.field_value); + } + + #[test] + fn test_convert_not_expr() { + let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Int32, false)])); + + let inner_expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col("col")), + Operator::Eq, + Box::new(lit(42i32)), + )); + let expr = Expr::Not(Box::new(inner_expr)); + + let filters = vec![expr]; + + let result = exprs_to_filters(&filters); + + assert_eq!(result.len(), 1); + + let expected_filter = HudiFilter { + field_name: schema.field(0).name().to_string(), + operator: ExprOperator::Ne, + field_value: "42".to_string(), + }; + + assert_eq!(result[0].field_name, expected_filter.field_name); + assert_eq!(result[0].operator, expected_filter.operator); + assert_eq!(*result[0].field_value.clone(), expected_filter.field_value); + } + + #[test] + fn test_convert_binary_expr_extensive() { + // list of test cases with different operators and data types + let test_cases = vec![ + ( + col("int32_col").eq(lit(42i32)), + Some(HudiFilter { + field_name: String::from("int32_col"), + operator: ExprOperator::Eq, + field_value: String::from("42"), + }), + ), + ( + col("int64_col").gt_eq(lit(100i64)), + Some(HudiFilter { + field_name: String::from("int64_col"), + operator: ExprOperator::Gte, + field_value: String::from("100"), + }), + ), + ( + col("float64_col").lt(lit(32.666)), + Some(HudiFilter { + field_name: String::from("float64_col"), + operator: ExprOperator::Lt, + field_value: "32.666".to_string(), + }), + ), + ( + col("string_col").not_eq(lit("test")), + Some(HudiFilter { + field_name: String::from("string_col"), + operator: ExprOperator::Ne, + field_value: String::from("test"), + }), + ), + ]; + + let filters: Vec = test_cases.iter().map(|(expr, _)| expr.clone()).collect(); + let result = exprs_to_filters(&filters); + let expected_filters: Vec<&HudiFilter> = test_cases + .iter() + .filter_map(|(_, opt_filter)| opt_filter.as_ref()) + .collect(); + + assert_eq!(result.len(), expected_filters.len()); + + for (result, expected_filter) in result.iter().zip(expected_filters.iter()) { + assert_eq!(result.field_name, expected_filter.field_name); + assert_eq!(result.operator, expected_filter.operator); + assert_eq!(*result.field_value.clone(), expected_filter.field_value); + } + } + + // Tests conversion with different operators (e.g., <, <=, >, >=) + #[test] + fn test_convert_various_operators() { + let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Int32, false)])); + + let operators = vec![ + (Operator::Lt, ExprOperator::Lt), + (Operator::LtEq, ExprOperator::Lte), + (Operator::Gt, ExprOperator::Gt), + (Operator::GtEq, ExprOperator::Gte), + ]; + + for (op, expected_op) in operators { + let expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col("col")), + op, + Box::new(lit(42i32)), + )); + + let filters = vec![expr]; + + let result = exprs_to_filters(&filters); + + assert_eq!(result.len(), 1); + + let expected_filter = HudiFilter { + field_name: schema.field(0).name().to_string(), + operator: expected_op, + field_value: String::from("42"), + }; + + assert_eq!(result[0].field_name, expected_filter.field_name); + assert_eq!(result[0].operator, expected_filter.operator); + assert_eq!(*result[0].field_value.clone(), expected_filter.field_value); + } + } + + #[test] + fn test_convert_expr_with_unsupported_operator() { + let expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col("col")), + Operator::And, + Box::new(lit("value")), + )); + + let filters = vec![expr]; + let result = exprs_to_filters(&filters); + assert!(result.is_empty()); + } + + #[test] + fn test_negate_operator_for_all_ops() { + for (op, _) in ExprOperator::TOKEN_OP_PAIRS { + if let Some(negated_op) = ExprOperator::from_str(op).unwrap().negate() { + let double_negated_op = negated_op + .negate() + .expect("Negation should be defined for all operators"); + + assert_eq!(double_negated_op, ExprOperator::from_str(op).unwrap()); + } + } + } +} diff --git a/crates/datafusion/src/util/mod.rs b/crates/datafusion/src/util/mod.rs new file mode 100644 index 00000000..ff2dcf36 --- /dev/null +++ b/crates/datafusion/src/util/mod.rs @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +pub mod expr; diff --git a/python/src/internal.rs b/python/src/internal.rs index a6a35524..4a76f7d2 100644 --- a/python/src/internal.rs +++ b/python/src/internal.rs @@ -23,19 +23,17 @@ use std::path::PathBuf; use std::sync::OnceLock; use arrow::pyarrow::ToPyArrow; -use pyo3::{pyclass, pyfunction, pymethods, PyErr, PyObject, PyResult, Python}; use tokio::runtime::Runtime; use hudi::error::CoreError; +use hudi::expr::filter::Filter; use hudi::file_group::reader::FileGroupReader; use hudi::file_group::FileSlice; use hudi::storage::error::StorageError; use hudi::table::builder::TableBuilder; use hudi::table::Table; -use hudi::util::convert_vec_to_slice; -use hudi::util::vec_to_slice; -use pyo3::create_exception; -use pyo3::exceptions::PyException; +use pyo3::exceptions::{PyException, PyValueError}; +use pyo3::{create_exception, pyclass, pyfunction, pymethods, PyErr, PyObject, PyResult, Python}; create_exception!(_internal, HudiCoreError, PyException); @@ -197,12 +195,11 @@ impl HudiTable { filters: Option>, py: Python, ) -> PyResult>> { + let filters = convert_filters(filters)?; + py.allow_threads(|| { let file_slices = rt() - .block_on( - self.inner - .get_file_slices_splits(n, vec_to_slice!(filters.unwrap_or_default())), - ) + .block_on(self.inner.get_file_slices_splits(n, &filters)) .map_err(PythonError::from)?; Ok(file_slices .iter() @@ -217,12 +214,11 @@ impl HudiTable { filters: Option>, py: Python, ) -> PyResult> { + let filters = convert_filters(filters)?; + py.allow_threads(|| { let file_slices = rt() - .block_on( - self.inner - .get_file_slices(vec_to_slice!(filters.unwrap_or_default())), - ) + .block_on(self.inner.get_file_slices(&filters)) .map_err(PythonError::from)?; Ok(file_slices.iter().map(convert_file_slice).collect()) }) @@ -239,15 +235,29 @@ impl HudiTable { filters: Option>, py: Python, ) -> PyResult { - rt().block_on( - self.inner - .read_snapshot(vec_to_slice!(filters.unwrap_or_default())), - ) - .map_err(PythonError::from)? - .to_pyarrow(py) + let filters = convert_filters(filters)?; + + rt().block_on(self.inner.read_snapshot(&filters)) + .map_err(PythonError::from)? + .to_pyarrow(py) } } +fn convert_filters(filters: Option>) -> PyResult> { + filters + .unwrap_or_default() + .into_iter() + .map(|(field, op, value)| { + Filter::try_from((field.as_str(), op.as_str(), value.as_str())).map_err(|e| { + PyValueError::new_err(format!( + "Invalid filter ({}, {}, {}): {}", + field, op, value, e + )) + }) + }) + .collect() +} + #[cfg(not(tarpaulin))] #[pyfunction] #[pyo3(signature = (base_uri, hudi_options=None, storage_options=None, options=None))] diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index baebdffd..f8986a8e 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -193,3 +193,23 @@ def test_read_table_as_of_timestamp(get_sample_table): "fare": 34.15, }, ] + + +def test_convert_filters_valid(get_sample_table): + table_path = get_sample_table + table = HudiTable(table_path) + + filters = [ + ("city", "=", "san_francisco"), + ("city", ">", "san_francisco"), + ("city", "<", "san_francisco"), + ("city", "<=", "san_francisco"), + ("city", ">=", "san_francisco"), + ] + + result = [3, 1, 1, 4, 4] + + for i in range(len(filters)): + filter_list = [filters[i]] + file_slices = table.get_file_slices(filters=filter_list) + assert len(file_slices) == result[i]