Skip to content

Commit

Permalink
Temporal datatype support for interval arithmetic (#5971)
Browse files Browse the repository at this point in the history
* first implementation and tests of timestamp subtraction

* improvement after review

* postgre interval format option

* random tests extended

* corrections after review

* operator check

* flag is removed

* clippy fix

* toml conflict

* minor changes

* deterministic matches

* simplifications (clippy error)

* test format changed

* minor test fix

* Update scalar.rs

* Refactoring and simplifications

* Make ScalarValue support interval comparison

* naming tests

* macro renaming

* renaming macro

* ok till arrow kernel ops

* macro will replace matches inside evaluate

add tests

macro will replace matches inside evaluate

ready for review

* Code refactor

* retract changes in scalar and datetime

* ts op interval with chrono functions

* bug fix and refactor

* test refactor

* Enhance commenting

* new binary operation logic, handling the inside errors

* slt and minor changes

* tz parsing excluded

* replace try_binary and as_datetime, and keep timezone for ts+interval op

* fix after merge

* delete unused functions

* ready to review

* correction after merge

* change match order

* minor changes

* simplifications

* update lock file

* Refactoring tests

You can add a millisecond array as well, but I used Nano.

* bug detected

* bug fixed

* update cargo

* tests added

* minor changes after merge

* fix after merge

* code simplification

* Some simplifications

* Update min_max.rs

* arithmetics moved into macros

* fix cargo.lock

* remove unwraps from tests

* Remove run-time string comparison from the interval min/max macro

* adapt upstream changes of timezone signature

---------

Co-authored-by: Mehmet Ozan Kabak <[email protected]>
Co-authored-by: metesynnada <[email protected]>
Co-authored-by: Mustafa Akur <[email protected]>
  • Loading branch information
4 people authored Apr 13, 2023
1 parent ca60ff1 commit bd705fe
Show file tree
Hide file tree
Showing 9 changed files with 983 additions and 97 deletions.
68 changes: 39 additions & 29 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1192,9 +1192,9 @@ pub fn seconds_add_array<const INTERVAL_MODE: i8>(

#[inline]
pub fn milliseconds_add(ts_ms: i64, scalar: &ScalarValue, sign: i32) -> Result<i64> {
let secs = ts_ms / 1000;
let nsecs = ((ts_ms % 1000) * 1_000_000) as u32;
do_date_time_math(secs, nsecs, scalar, sign).map(|dt| dt.timestamp_millis())
let secs = ts_ms.div_euclid(1000);
let nsecs = ts_ms.rem_euclid(1000) * 1_000_000;
do_date_time_math(secs, nsecs as u32, scalar, sign).map(|dt| dt.timestamp_millis())
}

#[inline]
Expand All @@ -1203,21 +1203,18 @@ pub fn milliseconds_add_array<const INTERVAL_MODE: i8>(
interval: i128,
sign: i32,
) -> Result<i64> {
let mut secs = ts_ms / 1000;
let mut nsecs = ((ts_ms % 1000) * 1_000_000) as i32;
if nsecs < 0 {
secs -= 1;
nsecs += 1_000_000_000;
}
let secs = ts_ms.div_euclid(1000);
let nsecs = ts_ms.rem_euclid(1000) * 1_000_000;
do_date_time_math_array::<INTERVAL_MODE>(secs, nsecs as u32, interval, sign)
.map(|dt| dt.timestamp_millis())
}

#[inline]
pub fn microseconds_add(ts_us: i64, scalar: &ScalarValue, sign: i32) -> Result<i64> {
let secs = ts_us / 1_000_000;
let nsecs = ((ts_us % 1_000_000) * 1000) as u32;
do_date_time_math(secs, nsecs, scalar, sign).map(|dt| dt.timestamp_nanos() / 1000)
let secs = ts_us.div_euclid(1_000_000);
let nsecs = ts_us.rem_euclid(1_000_000) * 1_000;
do_date_time_math(secs, nsecs as u32, scalar, sign)
.map(|dt| dt.timestamp_nanos() / 1000)
}

#[inline]
Expand All @@ -1226,21 +1223,17 @@ pub fn microseconds_add_array<const INTERVAL_MODE: i8>(
interval: i128,
sign: i32,
) -> Result<i64> {
let mut secs = ts_us / 1_000_000;
let mut nsecs = ((ts_us % 1_000_000) * 1000) as i32;
if nsecs < 0 {
secs -= 1;
nsecs += 1_000_000_000;
}
let secs = ts_us.div_euclid(1_000_000);
let nsecs = ts_us.rem_euclid(1_000_000) * 1_000;
do_date_time_math_array::<INTERVAL_MODE>(secs, nsecs as u32, interval, sign)
.map(|dt| dt.timestamp_nanos() / 1000)
}

#[inline]
pub fn nanoseconds_add(ts_ns: i64, scalar: &ScalarValue, sign: i32) -> Result<i64> {
let secs = ts_ns / 1_000_000_000;
let nsecs = (ts_ns % 1_000_000_000) as u32;
do_date_time_math(secs, nsecs, scalar, sign).map(|dt| dt.timestamp_nanos())
let secs = ts_ns.div_euclid(1_000_000_000);
let nsecs = ts_ns.rem_euclid(1_000_000_000);
do_date_time_math(secs, nsecs as u32, scalar, sign).map(|dt| dt.timestamp_nanos())
}

#[inline]
Expand All @@ -1249,12 +1242,8 @@ pub fn nanoseconds_add_array<const INTERVAL_MODE: i8>(
interval: i128,
sign: i32,
) -> Result<i64> {
let mut secs = ts_ns / 1_000_000_000;
let mut nsecs = (ts_ns % 1_000_000_000) as i32;
if nsecs < 0 {
secs -= 1;
nsecs += 1_000_000_000;
}
let secs = ts_ns.div_euclid(1_000_000_000);
let nsecs = ts_ns.rem_euclid(1_000_000_000);
do_date_time_math_array::<INTERVAL_MODE>(secs, nsecs as u32, interval, sign)
.map(|dt| dt.timestamp_nanos())
}
Expand Down Expand Up @@ -1297,7 +1286,7 @@ fn do_date_time_math(
) -> Result<NaiveDateTime> {
let prior = NaiveDateTime::from_timestamp_opt(secs, nsecs).ok_or_else(|| {
DataFusionError::Internal(format!(
"Could not conert to NaiveDateTime: secs {secs} nsecs {nsecs} scalar {scalar:?} sign {sign}"
"Could not convert to NaiveDateTime: secs {secs} nsecs {nsecs} scalar {scalar:?} sign {sign}"
))
})?;
do_date_math(prior, scalar, sign)
Expand All @@ -1312,7 +1301,7 @@ fn do_date_time_math_array<const INTERVAL_MODE: i8>(
) -> Result<NaiveDateTime> {
let prior = NaiveDateTime::from_timestamp_opt(secs, nsecs).ok_or_else(|| {
DataFusionError::Internal(format!(
"Could not conert to NaiveDateTime: secs {secs} nsecs {nsecs}"
"Could not convert to NaiveDateTime: secs {secs} nsecs {nsecs}"
))
})?;
do_date_math_array::<_, INTERVAL_MODE>(prior, interval, sign)
Expand Down Expand Up @@ -1768,6 +1757,27 @@ impl ScalarValue {
DataType::UInt64 => ScalarValue::UInt64(Some(0)),
DataType::Float32 => ScalarValue::Float32(Some(0.0)),
DataType::Float64 => ScalarValue::Float64(Some(0.0)),
DataType::Timestamp(TimeUnit::Second, tz) => {
ScalarValue::TimestampSecond(Some(0), tz.clone())
}
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
ScalarValue::TimestampMillisecond(Some(0), tz.clone())
}
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
ScalarValue::TimestampMicrosecond(Some(0), tz.clone())
}
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
ScalarValue::TimestampNanosecond(Some(0), tz.clone())
}
DataType::Interval(IntervalUnit::YearMonth) => {
ScalarValue::IntervalYearMonth(Some(0))
}
DataType::Interval(IntervalUnit::DayTime) => {
ScalarValue::IntervalDayTime(Some(0))
}
DataType::Interval(IntervalUnit::MonthDayNano) => {
ScalarValue::IntervalMonthDayNano(Some(0))
}
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a zero scalar from data_type \"{datatype:?}\""
Expand Down
222 changes: 216 additions & 6 deletions datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1554,17 +1554,19 @@ impl SymmetricHashJoinStream {
mod tests {
use std::fs::File;

use arrow::array::ArrayRef;
use arrow::array::{Int32Array, TimestampNanosecondArray};
use arrow::array::{ArrayRef, IntervalDayTimeArray};
use arrow::array::{Int32Array, TimestampMillisecondArray};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
use arrow::util::pretty::pretty_format_batches;
use rstest::*;
use tempfile::TempDir;

use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{binary, col, Column};
use datafusion_physical_expr::intervals::test_utils::gen_conjunctive_numeric_expr;
use datafusion_physical_expr::intervals::test_utils::{
gen_conjunctive_numeric_expr, gen_conjunctive_temporal_expr,
};
use datafusion_physical_expr::PhysicalExpr;

use crate::physical_plan::joins::{
Expand Down Expand Up @@ -1789,6 +1791,44 @@ mod tests {
_ => unreachable!(),
}
}
fn join_expr_tests_fixture_temporal(
expr_id: usize,
left_col: Arc<dyn PhysicalExpr>,
right_col: Arc<dyn PhysicalExpr>,
schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
match expr_id {
// constructs ((left_col - INTERVAL '100ms') > (right_col - INTERVAL '200ms')) AND ((left_col - INTERVAL '450ms') < (right_col - INTERVAL '300ms'))
0 => gen_conjunctive_temporal_expr(
left_col,
right_col,
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Minus,
ScalarValue::new_interval_dt(0, 100), // 100 ms
ScalarValue::new_interval_dt(0, 200), // 200 ms
ScalarValue::new_interval_dt(0, 450), // 450 ms
ScalarValue::new_interval_dt(0, 300), // 300 ms
schema,
),
// constructs ((left_col - TIMESTAMP '2023-01-01:12.00.03') > (right_col - TIMESTAMP '2023-01-01:12.00.01')) AND ((left_col - TIMESTAMP '2023-01-01:12.00.00') < (right_col - TIMESTAMP '2023-01-01:12.00.02'))
1 => gen_conjunctive_temporal_expr(
left_col,
right_col,
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Minus,
ScalarValue::TimestampMillisecond(Some(1672574403000), None), // 2023-01-01:12.00.03
ScalarValue::TimestampMillisecond(Some(1672574401000), None), // 2023-01-01:12.00.01
ScalarValue::TimestampMillisecond(Some(1672574400000), None), // 2023-01-01:12.00.00
ScalarValue::TimestampMillisecond(Some(1672574402000), None), // 2023-01-01:12.00.02
schema,
),
_ => unreachable!(),
}
}
fn build_sides_record_batches(
table_size: i32,
key_cardinality: (i32, i32),
Expand Down Expand Up @@ -1833,9 +1873,15 @@ mod tests {
.collect::<Vec<Option<i32>>>()
}));

let time = Arc::new(TimestampNanosecondArray::from(
let time = Arc::new(TimestampMillisecondArray::from(
initial_range
.clone()
.map(|x| x as i64 + 1672531200000) // x + 2023-01-01:00.00.00
.collect::<Vec<i64>>(),
));
let interval_time: ArrayRef = Arc::new(IntervalDayTimeArray::from(
initial_range
.map(|x| 1664264591000000000 + (5000000000 * (x as i64)))
.map(|x| x as i64 * 100) // x * 100ms
.collect::<Vec<i64>>(),
));

Expand All @@ -1849,6 +1895,7 @@ mod tests {
("l_asc_null_first", ordered_asc_null_first.clone()),
("l_asc_null_last", ordered_asc_null_last.clone()),
("l_desc_null_first", ordered_desc_null_first.clone()),
("li1", interval_time.clone()),
])?;
let right = RecordBatch::try_from_iter(vec![
("ra1", ordered.clone()),
Expand All @@ -1860,6 +1907,7 @@ mod tests {
("r_asc_null_first", ordered_asc_null_first),
("r_asc_null_last", ordered_asc_null_last),
("r_desc_null_first", ordered_desc_null_first),
("ri1", interval_time),
])?;
Ok((left, right))
}
Expand Down Expand Up @@ -2781,4 +2829,166 @@ mod tests {
assert_eq!(left_side_joiner.visited_rows.is_empty(), should_be_empty);
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn testing_with_temporal_columns(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
JoinType::Full
)]
join_type: JoinType,
#[values(
(4, 5),
(99, 12),
)]
cardinality: (i32, i32),
#[values(0, 1)] case_expr: usize,
) -> Result<()> {
let config = SessionConfig::new().with_repartition_joins(false);
let session_ctx = SessionContext::with_config(config);
let task_ctx = session_ctx.task_ctx();
let (left_batch, right_batch) =
build_sides_record_batches(TABLE_SIZE, cardinality)?;
let left_schema = &left_batch.schema();
let right_schema = &right_batch.schema();
let on = vec![(
Column::new_with_schema("lc1", left_schema)?,
Column::new_with_schema("rc1", right_schema)?,
)];
let left_sorted = vec![PhysicalSortExpr {
expr: col("lt1", left_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("rt1", right_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}];
let (left, right) = create_memory_table(
left_batch,
right_batch,
Some(left_sorted),
Some(right_sorted),
13,
)?;
let intermediate_schema = Schema::new(vec![
Field::new(
"left",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new(
"right",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
]);
let filter_expr = join_expr_tests_fixture_temporal(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
&intermediate_schema,
)?;
let column_indices = vec![
ColumnIndex {
index: 3,
side: JoinSide::Left,
},
ColumnIndex {
index: 3,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn test_with_interval_columns(
#[values(
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
JoinType::Full
)]
join_type: JoinType,
#[values(
(4, 5),
(99, 12),
)]
cardinality: (i32, i32),
) -> Result<()> {
let config = SessionConfig::new().with_repartition_joins(false);
let session_ctx = SessionContext::with_config(config);
let task_ctx = session_ctx.task_ctx();
let (left_batch, right_batch) =
build_sides_record_batches(TABLE_SIZE, cardinality)?;
let left_schema = &left_batch.schema();
let right_schema = &right_batch.schema();
let on = vec![(
Column::new_with_schema("lc1", left_schema)?,
Column::new_with_schema("rc1", right_schema)?,
)];
let left_sorted = vec![PhysicalSortExpr {
expr: col("li1", left_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}];
let right_sorted = vec![PhysicalSortExpr {
expr: col("ri1", right_schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
},
}];
let (left, right) = create_memory_table(
left_batch,
right_batch,
Some(left_sorted),
Some(right_sorted),
13,
)?;
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Interval(IntervalUnit::DayTime), false),
Field::new("right", DataType::Interval(IntervalUnit::DayTime), false),
]);
let filter_expr = join_expr_tests_fixture_temporal(
0,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
&intermediate_schema,
)?;
let column_indices = vec![
ColumnIndex {
index: 9,
side: JoinSide::Left,
},
ColumnIndex {
index: 9,
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;

Ok(())
}
}
Loading

0 comments on commit bd705fe

Please sign in to comment.