Skip to content

Commit

Permalink
perf: speed up offset_by 2x for constant durations (#16728)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jun 6, 2024
1 parent b0cba6e commit c6ab549
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 31 deletions.
72 changes: 41 additions & 31 deletions crates/polars-plan/src/dsl/function_expr/temporal.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[cfg(feature = "date_offset")]
use arrow::legacy::time_zone::Tz;
#[cfg(feature = "date_offset")]
use polars_core::chunked_array::ops::arity::try_binary_elementwise;
use polars_core::chunked_array::ops::arity::broadcast_try_binary_elementwise;
#[cfg(feature = "date_offset")]
use polars_time::prelude::*;

Expand Down Expand Up @@ -189,33 +189,52 @@ pub(super) fn datetime(
fn apply_offsets_to_datetime(
datetime: &Logical<DatetimeType, Int64Type>,
offsets: &StringChunked,
offset_fn: fn(&Duration, i64, Option<&Tz>) -> PolarsResult<i64>,
time_zone: Option<&Tz>,
) -> PolarsResult<Int64Chunked> {
match (datetime.len(), offsets.len()) {
(1, _) => match datetime.0.get(0) {
Some(dt) => offsets.try_apply_nonnull_values_generic(|offset| {
offset_fn(&Duration::parse(offset), dt, time_zone)
}),
_ => Ok(Int64Chunked::full_null(datetime.0.name(), offsets.len())),
},
(_, 1) => match offsets.get(0) {
match offsets.len() {
1 => match offsets.get(0) {
Some(offset) => {
let offset = &Duration::parse(offset);
datetime
.0
.try_apply_nonnull_values_generic(|v| offset_fn(offset, v, time_zone))
if offset.is_constant_duration(datetime.time_zone().as_deref()) {
// fastpath!
let mut duration = match datetime.time_unit() {
TimeUnit::Milliseconds => offset.duration_ms(),
TimeUnit::Microseconds => offset.duration_us(),
TimeUnit::Nanoseconds => offset.duration_ns(),
};
if offset.negative() {
duration = -duration;
}
Ok(datetime.0.clone().wrapping_add_scalar(duration))
} else {
let offset_fn = match datetime.time_unit() {
TimeUnit::Milliseconds => Duration::add_ms,
TimeUnit::Microseconds => Duration::add_us,
TimeUnit::Nanoseconds => Duration::add_ns,
};
datetime
.0
.try_apply_nonnull_values_generic(|v| offset_fn(offset, v, time_zone))
}
},
_ => Ok(datetime.0.apply(|_| None)),
},
_ => try_binary_elementwise(datetime, offsets, |timestamp_opt, offset_opt| {
match (timestamp_opt, offset_opt) {
_ => {
let offset_fn = match datetime.time_unit() {
TimeUnit::Milliseconds => Duration::add_ms,
TimeUnit::Microseconds => Duration::add_us,
TimeUnit::Nanoseconds => Duration::add_ns,
};
broadcast_try_binary_elementwise(datetime, offsets, |timestamp_opt, offset_opt| match (
timestamp_opt,
offset_opt,
) {
(Some(timestamp), Some(offset)) => {
offset_fn(&Duration::parse(offset), timestamp, time_zone).map(Some)
},
_ => Ok(None),
}
}),
})
},
}
}

Expand All @@ -231,7 +250,7 @@ pub(super) fn date_offset(s: &[Series]) -> PolarsResult<Series> {
.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))
.unwrap();
let datetime = ts.datetime().unwrap();
let out = apply_offsets_to_datetime(datetime, offsets, Duration::add_ms, None)?;
let out = apply_offsets_to_datetime(datetime, offsets, None)?;
// sortedness is only guaranteed to be preserved if a constant offset is being added to every datetime
preserve_sortedness = match offsets.len() {
1 => offsets.get(0).is_some(),
Expand All @@ -244,21 +263,12 @@ pub(super) fn date_offset(s: &[Series]) -> PolarsResult<Series> {
DataType::Datetime(tu, tz) => {
let datetime = ts.datetime().unwrap();

let offset_fn = match tu {
TimeUnit::Nanoseconds => Duration::add_ns,
TimeUnit::Microseconds => Duration::add_us,
TimeUnit::Milliseconds => Duration::add_ms,
};

let out = match tz {
#[cfg(feature = "timezones")]
Some(ref tz) => apply_offsets_to_datetime(
datetime,
offsets,
offset_fn,
tz.parse::<Tz>().ok().as_ref(),
)?,
_ => apply_offsets_to_datetime(datetime, offsets, offset_fn, None)?,
Some(ref tz) => {
apply_offsets_to_datetime(datetime, offsets, tz.parse::<Tz>().ok().as_ref())?
},
_ => apply_offsets_to_datetime(datetime, offsets, None)?,
};
// Sortedness may not be preserved when crossing daylight savings time boundaries
// for calendar-aware durations.
Expand Down
114 changes: 114 additions & 0 deletions py-polars/tests/unit/operations/namespaces/temporal/test_offset_by.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import annotations

from datetime import date, datetime
from typing import TYPE_CHECKING

import pytest

import polars as pl
from polars.testing import assert_series_equal

if TYPE_CHECKING:
from polars.type_aliases import TimeUnit


@pytest.mark.parametrize(
("inputs", "offset", "outputs"),
[
(
[date(2020, 1, 1), date(2020, 1, 2)],
"1d",
[date(2020, 1, 2), date(2020, 1, 3)],
),
(
[date(2020, 1, 1), date(2020, 1, 2)],
"-1d",
[date(2019, 12, 31), date(2020, 1, 1)],
),
(
[date(2020, 1, 1), date(2020, 1, 2)],
"3d",
[date(2020, 1, 4), date(2020, 1, 5)],
),
(
[date(2020, 1, 1), date(2020, 1, 2)],
"72h",
[date(2020, 1, 4), date(2020, 1, 5)],
),
(
[date(2020, 1, 1), date(2020, 1, 2)],
"2d24h",
[date(2020, 1, 4), date(2020, 1, 5)],
),
(
[date(2020, 1, 1), date(2020, 1, 2)],
"-2mo",
[date(2019, 11, 1), date(2019, 11, 2)],
),
],
)
def test_date_offset_by(inputs: list[date], offset: str, outputs: list[date]) -> None:
result = pl.Series(inputs).dt.offset_by(offset)
expected = pl.Series(outputs)
assert_series_equal(result, expected)


@pytest.mark.parametrize(
("inputs", "offset", "outputs"),
[
(
[date(2020, 1, 1), date(2020, 1, 2)],
"1d",
[date(2020, 1, 2), date(2020, 1, 3)],
),
(
[date(2020, 1, 1), date(2020, 1, 2)],
"-1d",
[date(2019, 12, 31), date(2020, 1, 1)],
),
(
[date(2020, 1, 1), date(2020, 1, 2)],
"3d",
[date(2020, 1, 4), date(2020, 1, 5)],
),
(
[date(2020, 1, 1), date(2020, 1, 2)],
"72h",
[date(2020, 1, 4), date(2020, 1, 5)],
),
(
[date(2020, 1, 1), date(2020, 1, 2)],
"2d24h",
[date(2020, 1, 4), date(2020, 1, 5)],
),
(
[date(2020, 1, 1), date(2020, 1, 2)],
"7m",
[datetime(2020, 1, 1, 0, 7), datetime(2020, 1, 2, 0, 7)],
),
(
[date(2020, 1, 1), date(2020, 1, 2)],
"-3m",
[datetime(2019, 12, 31, 23, 57), datetime(2020, 1, 1, 23, 57)],
),
(
[date(2020, 1, 1), date(2020, 1, 2)],
"2mo",
[datetime(2020, 3, 1), datetime(2020, 3, 2)],
),
],
)
@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])
@pytest.mark.parametrize("time_zone", ["Europe/London", "Asia/Kathmandu", None])
def test_datetime_offset_by(
inputs: list[date],
offset: str,
outputs: list[datetime],
time_unit: TimeUnit,
time_zone: str | None,
) -> None:
result = pl.Series(inputs, dtype=pl.Datetime(time_unit, time_zone)).dt.offset_by(
offset
)
expected = pl.Series(outputs, dtype=pl.Datetime(time_unit, time_zone))
assert_series_equal(result, expected)

0 comments on commit c6ab549

Please sign in to comment.