Skip to content

Commit

Permalink
bump substrait to latest
Browse files Browse the repository at this point in the history
  • Loading branch information
Blizzara committed Aug 20, 2024
1 parent 6c9dff3 commit e8753f6
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 21 deletions.
2 changes: 1 addition & 1 deletion datafusion/substrait/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ itertools = { workspace = true }
object_store = { workspace = true }
pbjson-types = "0.7"
prost = "0.13"
substrait = { version = "0.38", features = ["serde"] }
substrait = { version = "0.41", features = ["serde"] }
url = { workspace = true }

[dev-dependencies]
Expand Down
29 changes: 22 additions & 7 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ use datafusion::{
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use substrait::proto::exchange_rel::ExchangeKind;
use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode;
use substrait::proto::expression::literal::user_defined::Val;
use substrait::proto::expression::literal::{
IntervalDayToSecond, IntervalYearToMonth, UserDefined,
Expand Down Expand Up @@ -884,8 +885,8 @@ fn from_substrait_jointype(join_type: i32) -> Result<JoinType> {
join_rel::JoinType::Left => Ok(JoinType::Left),
join_rel::JoinType::Right => Ok(JoinType::Right),
join_rel::JoinType::Outer => Ok(JoinType::Full),
join_rel::JoinType::Anti => Ok(JoinType::LeftAnti),
join_rel::JoinType::Semi => Ok(JoinType::LeftSemi),
join_rel::JoinType::LeftAnti => Ok(JoinType::LeftAnti),
join_rel::JoinType::LeftSemi => Ok(JoinType::LeftSemi),
_ => plan_err!("unsupported join type {substrait_join_type:?}"),
}
} else {
Expand Down Expand Up @@ -1500,10 +1501,10 @@ fn from_substrait_type(
"Unsupported Substrait type variation {v} of type {s_kind:?}"
),
},
r#type::Kind::IntervalYear(i) => {
r#type::Kind::IntervalYear(_) => {
Ok(DataType::Interval(IntervalUnit::YearMonth))
}
r#type::Kind::IntervalDay(i) => Ok(DataType::Interval(IntervalUnit::DayTime)),
r#type::Kind::IntervalDay(_) => Ok(DataType::Interval(IntervalUnit::DayTime)),
r#type::Kind::UserDefined(u) => {
if let Some(name) = extensions.types.get(&u.type_reference) {
match name.as_ref() {
Expand Down Expand Up @@ -1942,10 +1943,24 @@ fn from_substrait_literal(
Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond {
days,
seconds,
microseconds,
subseconds,
precision_mode,
})) => {
// DF only supports millisecond precision, so we lose the micros here
ScalarValue::new_interval_dt(*days, (seconds * 1000) + (microseconds / 1000))
// DF only supports millisecond precision, so for any more granular type we lose precision
let milliseconds = match precision_mode {
Some(PrecisionMode::Microseconds(ms)) => ms / 1000,
Some(PrecisionMode::Precision(0)) => *subseconds as i32 * 1000,
Some(PrecisionMode::Precision(3)) => *subseconds as i32,
Some(PrecisionMode::Precision(6)) => (subseconds / 1000) as i32,
Some(PrecisionMode::Precision(9)) => (subseconds / 1000 / 1000) as i32,
_ => {
return not_impl_err!(
"Unsupported Substrait interval day to second precision mode"
)
}
};

ScalarValue::new_interval_dt(*days, (seconds * 1000) + milliseconds)
}
Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => {
ScalarValue::new_interval_ym(*years, *months)
Expand Down
29 changes: 16 additions & 13 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera
use datafusion::prelude::Expr;
use pbjson_types::Any as ProtoAny;
use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode;
use substrait::proto::expression::literal::map::KeyValue;
use substrait::proto::expression::literal::{
user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map, PrecisionTimestamp,
Struct, UserDefined,
user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map,
PrecisionTimestamp, Struct, UserDefined,
};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
Expand Down Expand Up @@ -656,8 +657,8 @@ fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType {
JoinType::Left => join_rel::JoinType::Left,
JoinType::Right => join_rel::JoinType::Right,
JoinType::Full => join_rel::JoinType::Outer,
JoinType::LeftAnti => join_rel::JoinType::Anti,
JoinType::LeftSemi => join_rel::JoinType::Semi,
JoinType::LeftAnti => join_rel::JoinType::LeftAnti,
JoinType::LeftSemi => join_rel::JoinType::LeftSemi,
JoinType::RightAnti | JoinType::RightSemi => unimplemented!(),
}
}
Expand Down Expand Up @@ -1424,6 +1425,7 @@ fn to_substrait_type(
kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay {
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
nullability,
precision: Some(3), // DayTime precision is always milliseconds
})),
}),
IntervalUnit::MonthDayNano => {
Expand Down Expand Up @@ -1810,28 +1812,28 @@ fn to_substrait_literal(
ScalarValue::TimestampSecond(Some(t), None) => (
LiteralType::PrecisionTimestamp(PrecisionTimestamp {
precision: 0,
value: *t as u64,
value: *t,
}),
DEFAULT_TYPE_VARIATION_REF,
),
ScalarValue::TimestampMillisecond(Some(t), None) => (
LiteralType::PrecisionTimestamp(PrecisionTimestamp {
precision: 3,
value: *t as u64,
value: *t,
}),
DEFAULT_TYPE_VARIATION_REF,
),
ScalarValue::TimestampMicrosecond(Some(t), None) => (
LiteralType::PrecisionTimestamp(PrecisionTimestamp {
precision: 6,
value: *t as u64,
value: *t,
}),
DEFAULT_TYPE_VARIATION_REF,
),
ScalarValue::TimestampNanosecond(Some(t), None) => (
LiteralType::PrecisionTimestamp(PrecisionTimestamp {
precision: 9,
value: *t as u64,
value: *t,
}),
DEFAULT_TYPE_VARIATION_REF,
),
Expand All @@ -1841,28 +1843,28 @@ fn to_substrait_literal(
ScalarValue::TimestampSecond(Some(t), Some(_)) => (
LiteralType::PrecisionTimestampTz(PrecisionTimestamp {
precision: 0,
value: *t as u64,
value: *t,
}),
DEFAULT_TYPE_VARIATION_REF,
),
ScalarValue::TimestampMillisecond(Some(t), Some(_)) => (
LiteralType::PrecisionTimestampTz(PrecisionTimestamp {
precision: 3,
value: *t as u64,
value: *t,
}),
DEFAULT_TYPE_VARIATION_REF,
),
ScalarValue::TimestampMicrosecond(Some(t), Some(_)) => (
LiteralType::PrecisionTimestampTz(PrecisionTimestamp {
precision: 6,
value: *t as u64,
value: *t,
}),
DEFAULT_TYPE_VARIATION_REF,
),
ScalarValue::TimestampNanosecond(Some(t), Some(_)) => (
LiteralType::PrecisionTimestampTz(PrecisionTimestamp {
precision: 9,
value: *t as u64,
value: *t,
}),
DEFAULT_TYPE_VARIATION_REF,
),
Expand Down Expand Up @@ -1899,7 +1901,8 @@ fn to_substrait_literal(
LiteralType::IntervalDayToSecond(IntervalDayToSecond {
days: i.days,
seconds: i.milliseconds / 1000,
microseconds: (i.milliseconds % 1000) * 1000,
subseconds: (i.milliseconds % 1000) as i64,
precision_mode: Some(PrecisionMode::Precision(3)), // 3 for milliseconds
}),
DEFAULT_TYPE_VARIATION_REF,
),
Expand Down

0 comments on commit e8753f6

Please sign in to comment.