Skip to content

Commit

Permalink
perf: Add fastpath for when rounding by single constant durations (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jul 13, 2024
1 parent ddc3f46 commit f2fed3b
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 127 deletions.
143 changes: 107 additions & 36 deletions crates/polars-time/src/round.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use arrow::legacy::time_zone::Tz;
use arrow::temporal_conversions::{MILLISECONDS, SECONDS_IN_DAY};
use arrow::temporal_conversions::MILLISECONDS_IN_DAY;
use polars_core::prelude::arity::broadcast_try_binary_elementwise;
use polars_core::prelude::*;
use polars_utils::cache::FastFixedCache;
Expand All @@ -14,56 +14,127 @@ pub trait PolarsRound {

impl PolarsRound for DatetimeChunked {
fn round(&self, every: &StringChunked, tz: Option<&Tz>) -> PolarsResult<Self> {
let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize);
let time_zone = self.time_zone();
let offset = Duration::new(0);
let out = broadcast_try_binary_elementwise(self, every, |opt_t, opt_every| {
match (opt_t, opt_every) {
(Some(timestamp), Some(every)) => {
let every =
*duration_cache.get_or_insert_with(every, |every| Duration::parse(every));

if every.negative {
polars_bail!(ComputeError: "Cannot round a Datetime to a negative duration")
}

let w = Window::new(every, every, offset);

let func = match self.time_unit() {
TimeUnit::Nanoseconds => Window::round_ns,
TimeUnit::Microseconds => Window::round_us,
TimeUnit::Milliseconds => Window::round_ms,
// Let's check if we can use a fastpath...
if every.len() == 1 {
if let Some(every) = every.get(0) {
let every_parsed = Duration::parse(every);
if every_parsed.negative {
polars_bail!(ComputeError: "cannot round a Datetime to a negative duration")
}
if (time_zone.is_none() || time_zone.as_deref() == Some("UTC"))
&& (every_parsed.months() == 0 && every_parsed.weeks() == 0)
{
// ... yes we can! Weeks, months, and time zones require extra logic.
// But in this simple case, it's just simple integer arithmetic.
let every = match self.time_unit() {
TimeUnit::Milliseconds => every_parsed.duration_ms(),
TimeUnit::Microseconds => every_parsed.duration_us(),
TimeUnit::Nanoseconds => every_parsed.duration_ns(),
};
return Ok(self
.apply_values(|t| {
// Round half-way values away from zero
let half_away = t.signum() * every / 2;
t + half_away - (t + half_away) % every
})
.into_datetime(self.time_unit(), time_zone.clone()));
} else {
let w = Window::new(every_parsed, every_parsed, offset);
let out = match self.time_unit() {
TimeUnit::Milliseconds => {
self.try_apply_nonnull_values_generic(|t| w.round_ms(t, tz))
},
TimeUnit::Microseconds => {
self.try_apply_nonnull_values_generic(|t| w.round_us(t, tz))
},
TimeUnit::Nanoseconds => {
self.try_apply_nonnull_values_generic(|t| w.round_ns(t, tz))
},
};
func(&w, timestamp, tz).map(Some)
},
_ => Ok(None),
return Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone()));
}
} else {
return Ok(Int64Chunked::full_null(self.name(), self.len())
.into_datetime(self.time_unit(), self.time_zone().clone()));
}
}

// A sqrt(n) cache is not too small, not too large.
let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize);

let func = match self.time_unit() {
TimeUnit::Nanoseconds => Window::round_ns,
TimeUnit::Microseconds => Window::round_us,
TimeUnit::Milliseconds => Window::round_ms,
};

let out = broadcast_try_binary_elementwise(self, every, |opt_timestamp, opt_every| match (
opt_timestamp,
opt_every,
) {
(Some(timestamp), Some(every)) => {
let every =
*duration_cache.get_or_insert_with(every, |every| Duration::parse(every));

if every.negative {
polars_bail!(ComputeError: "cannot round a Datetime to a negative duration")
}

let w = Window::new(every, every, offset);
func(&w, timestamp, tz).map(Some)
},
_ => Ok(None),
});
Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone()))
}
}

impl PolarsRound for DateChunked {
fn round(&self, every: &StringChunked, _tz: Option<&Tz>) -> PolarsResult<Self> {
let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize);
let offset = Duration::new(0);
const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY;
let out = broadcast_try_binary_elementwise(&self.0, every, |opt_t, opt_every| {
match (opt_t, opt_every) {
(Some(t), Some(every)) => {
let every =
*duration_cache.get_or_insert_with(every, |every| Duration::parse(every));
let out = match every.len() {
1 => {
if let Some(every) = every.get(0) {
let every = Duration::parse(every);
if every.negative {
polars_bail!(ComputeError: "Cannot round a Date to a negative duration")
polars_bail!(ComputeError: "cannot round a Date to a negative duration")
}

let w = Window::new(every, every, offset);
Ok(Some(
(w.round_ms(MSECS_IN_DAY * t as i64, None)? / MSECS_IN_DAY) as i32,
))
},
_ => Ok(None),
}
});
self.try_apply_nonnull_values_generic(|t| {
Ok(
(w.round_ms(MILLISECONDS_IN_DAY * t as i64, None)?
/ MILLISECONDS_IN_DAY) as i32,
)
})
} else {
Ok(Int32Chunked::full_null(self.name(), self.len()))
}
},
_ => broadcast_try_binary_elementwise(self, every, |opt_t, opt_every| {
// A sqrt(n) cache is not too small, not too large.
let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize);
match (opt_t, opt_every) {
(Some(t), Some(every)) => {
let every = *duration_cache
.get_or_insert_with(every, |every| Duration::parse(every));

if every.negative {
polars_bail!(ComputeError: "cannot round a Date to a negative duration")
}

let w = Window::new(every, every, offset);
Ok(Some(
(w.round_ms(MILLISECONDS_IN_DAY * t as i64, None)?
/ MILLISECONDS_IN_DAY) as i32,
))
},
_ => Ok(None),
}
}),
};
Ok(out?.into_date())
}
}
89 changes: 0 additions & 89 deletions py-polars/tests/unit/datatypes/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2042,95 +2042,6 @@ def test_truncate_non_existent_14957() -> None:
).dt.truncate("46m")


def test_round_ambiguous() -> None:
t = (
pl.datetime_range(
date(2020, 10, 25),
datetime(2020, 10, 25, 2),
"30m",
eager=True,
time_zone="Europe/London",
)
.alias("datetime")
.dt.offset_by("15m")
)
result = t.dt.round("30m")
expected = (
pl.Series(
[
"2020-10-25T00:30:00+0100",
"2020-10-25T01:00:00+0100",
"2020-10-25T01:30:00+0100",
"2020-10-25T01:00:00+0000",
"2020-10-25T01:30:00+0000",
"2020-10-25T02:00:00+0000",
"2020-10-25T02:30:00+0000",
]
)
.str.to_datetime()
.dt.convert_time_zone("Europe/London")
.rename("datetime")
)
assert_series_equal(result, expected)

df = pl.DataFrame(
{
"date": pl.datetime_range(
date(2020, 10, 25),
datetime(2020, 10, 25, 2),
"30m",
eager=True,
time_zone="Europe/London",
).dt.offset_by("15m")
}
)

df = df.select(pl.col("date").dt.round("30m"))
assert df.to_dict(as_series=False) == {
"date": [
datetime(2020, 10, 25, 0, 30, tzinfo=ZoneInfo("Europe/London")),
datetime(2020, 10, 25, 1, tzinfo=ZoneInfo("Europe/London")),
datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo("Europe/London")),
datetime(2020, 10, 25, 1, tzinfo=ZoneInfo("Europe/London")),
datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo("Europe/London")),
datetime(2020, 10, 25, 2, tzinfo=ZoneInfo("Europe/London")),
datetime(2020, 10, 25, 2, 30, tzinfo=ZoneInfo("Europe/London")),
]
}


def test_round_by_week() -> None:
df = pl.DataFrame(
{
"date": pl.Series(
[
# Sunday and Monday
"1998-04-12",
"2022-11-28",
]
).str.strptime(pl.Date, "%Y-%m-%d")
}
)

assert (
df.select(
pl.col("date").dt.round("7d").alias("7d"),
pl.col("date").dt.round("1w").alias("1w"),
)
).to_dict(as_series=False) == {
"7d": [date(1998, 4, 9), date(2022, 12, 1)],
"1w": [date(1998, 4, 13), date(2022, 11, 28)],
}


@pytest.mark.parametrize("time_zone", [None, "Asia/Kathmandu"])
def test_round_by_day_datetime(time_zone: str | None) -> None:
ser = pl.Series([datetime(2021, 11, 7, 3)]).dt.replace_time_zone(time_zone)
result = ser.dt.round("1d")
expected = pl.Series([datetime(2021, 11, 7)]).dt.replace_time_zone(time_zone)
assert_series_equal(result, expected)


def test_cast_time_to_duration() -> None:
assert pl.Series([time(hour=0, minute=0, second=2)]).cast(
pl.Duration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,12 +612,12 @@ def test_round_expr() -> None:
def test_round_negative() -> None:
"""Test that rounding to a negative duration gives a helpful error message."""
with pytest.raises(
ComputeError, match="Cannot round a Date to a negative duration"
ComputeError, match="cannot round a Date to a negative duration"
):
pl.Series([date(1895, 5, 7)]).dt.round("-1m")

with pytest.raises(
ComputeError, match="Cannot round a Datetime to a negative duration"
ComputeError, match="cannot round a Datetime to a negative duration"
):
pl.Series([datetime(1895, 5, 7)]).dt.round("-1m")

Expand Down
Loading

0 comments on commit f2fed3b

Please sign in to comment.