diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index f1a870e9ad079..ff02ef8c7ef69 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -40,7 +40,7 @@ itertools = { workspace = true } object_store = { workspace = true } pbjson-types = "0.7" prost = "0.13" -substrait = { version = "0.38", features = ["serde"] } +substrait = { version = "0.41", features = ["serde"] } url = { workspace = true } [dev-dependencies] diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 51c3004948a18..648dd8f10f2f6 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -69,6 +69,7 @@ use datafusion::{ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; +use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; use substrait::proto::expression::literal::user_defined::Val; use substrait::proto::expression::literal::{ IntervalDayToSecond, IntervalYearToMonth, UserDefined, @@ -884,8 +885,8 @@ fn from_substrait_jointype(join_type: i32) -> Result { join_rel::JoinType::Left => Ok(JoinType::Left), join_rel::JoinType::Right => Ok(JoinType::Right), join_rel::JoinType::Outer => Ok(JoinType::Full), - join_rel::JoinType::Anti => Ok(JoinType::LeftAnti), - join_rel::JoinType::Semi => Ok(JoinType::LeftSemi), + join_rel::JoinType::LeftAnti => Ok(JoinType::LeftAnti), + join_rel::JoinType::LeftSemi => Ok(JoinType::LeftSemi), _ => plan_err!("unsupported join type {substrait_join_type:?}"), } } else { @@ -1500,10 +1501,10 @@ fn from_substrait_type( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, - r#type::Kind::IntervalYear(i) => { + r#type::Kind::IntervalYear(_) => { Ok(DataType::Interval(IntervalUnit::YearMonth)) } - r#type::Kind::IntervalDay(i) => Ok(DataType::Interval(IntervalUnit::DayTime)), + r#type::Kind::IntervalDay(_) => Ok(DataType::Interval(IntervalUnit::DayTime)), r#type::Kind::UserDefined(u) => { if let Some(name) = extensions.types.get(&u.type_reference) { match name.as_ref() { @@ -1942,10 +1943,24 @@ fn from_substrait_literal( Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { days, seconds, - microseconds, + subseconds, + precision_mode, })) => { - // DF only supports millisecond precision, so we lose the micros here - ScalarValue::new_interval_dt(*days, (seconds * 1000) + (microseconds / 1000)) + // DF only supports millisecond precision, so for any more granular type we lose precision + let milliseconds = match precision_mode { + Some(PrecisionMode::Microseconds(ms)) => ms / 1000, + Some(PrecisionMode::Precision(0)) => *subseconds as i32 * 1000, + Some(PrecisionMode::Precision(3)) => *subseconds as i32, + Some(PrecisionMode::Precision(6)) => (subseconds / 1000) as i32, + Some(PrecisionMode::Precision(9)) => (subseconds / 1000 / 1000) as i32, + _ => { + return not_impl_err!( + "Unsupported Substrait interval day to second precision mode" + ) + } + }; + + ScalarValue::new_interval_dt(*days, (seconds * 1000) + milliseconds) } Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => { ScalarValue::new_interval_ym(*years, *months) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 458894dc6448e..74b8c5046de13 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -53,10 +53,11 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera use datafusion::prelude::Expr; use pbjson_types::Any as ProtoAny; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; use substrait::proto::expression::literal::map::KeyValue; use substrait::proto::expression::literal::{ - user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map, PrecisionTimestamp, - Struct, UserDefined, + user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map, + PrecisionTimestamp, Struct, UserDefined, }; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; @@ -656,8 +657,8 @@ fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { JoinType::Left => join_rel::JoinType::Left, JoinType::Right => join_rel::JoinType::Right, JoinType::Full => join_rel::JoinType::Outer, - JoinType::LeftAnti => join_rel::JoinType::Anti, - JoinType::LeftSemi => join_rel::JoinType::Semi, + JoinType::LeftAnti => join_rel::JoinType::LeftAnti, + JoinType::LeftSemi => join_rel::JoinType::LeftSemi, JoinType::RightAnti | JoinType::RightSemi => unimplemented!(), } } @@ -1424,6 +1425,7 @@ fn to_substrait_type( kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability, + precision: Some(3), // DayTime precision is always milliseconds })), }), IntervalUnit::MonthDayNano => { @@ -1810,28 +1812,28 @@ fn to_substrait_literal( ScalarValue::TimestampSecond(Some(t), None) => ( LiteralType::PrecisionTimestamp(PrecisionTimestamp { precision: 0, - value: *t as u64, + value: *t, }), DEFAULT_TYPE_VARIATION_REF, ), ScalarValue::TimestampMillisecond(Some(t), None) => ( LiteralType::PrecisionTimestamp(PrecisionTimestamp { precision: 3, - value: *t as u64, + value: *t, }), DEFAULT_TYPE_VARIATION_REF, ), ScalarValue::TimestampMicrosecond(Some(t), None) => ( LiteralType::PrecisionTimestamp(PrecisionTimestamp { precision: 6, - value: *t as u64, + value: *t, }), DEFAULT_TYPE_VARIATION_REF, ), ScalarValue::TimestampNanosecond(Some(t), None) => ( LiteralType::PrecisionTimestamp(PrecisionTimestamp { precision: 9, - value: *t as u64, + value: *t, }), DEFAULT_TYPE_VARIATION_REF, ), @@ -1841,28 +1843,28 @@ fn to_substrait_literal( ScalarValue::TimestampSecond(Some(t), Some(_)) => ( LiteralType::PrecisionTimestampTz(PrecisionTimestamp { precision: 0, - value: *t as u64, + value: *t, }), DEFAULT_TYPE_VARIATION_REF, ), ScalarValue::TimestampMillisecond(Some(t), Some(_)) => ( LiteralType::PrecisionTimestampTz(PrecisionTimestamp { precision: 3, - value: *t as u64, + value: *t, }), DEFAULT_TYPE_VARIATION_REF, ), ScalarValue::TimestampMicrosecond(Some(t), Some(_)) => ( LiteralType::PrecisionTimestampTz(PrecisionTimestamp { precision: 6, - value: *t as u64, + value: *t, }), DEFAULT_TYPE_VARIATION_REF, ), ScalarValue::TimestampNanosecond(Some(t), Some(_)) => ( LiteralType::PrecisionTimestampTz(PrecisionTimestamp { precision: 9, - value: *t as u64, + value: *t, }), DEFAULT_TYPE_VARIATION_REF, ), @@ -1899,7 +1901,8 @@ fn to_substrait_literal( LiteralType::IntervalDayToSecond(IntervalDayToSecond { days: i.days, seconds: i.milliseconds / 1000, - microseconds: (i.milliseconds % 1000) * 1000, + subseconds: (i.milliseconds % 1000) as i64, + precision_mode: Some(PrecisionMode::Precision(3)), // 3 for milliseconds }), DEFAULT_TYPE_VARIATION_REF, ),