From 33b7347735ffcf4a63e7a3c2ec7ff04e12b5883a Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 2 Jun 2024 18:40:11 +0100 Subject: [PATCH 1/2] perf: improve `truncate` performance when `every` is just a single duration (and not an expression) --- crates/polars-time/src/truncate.rs | 90 +++++++++++-------- .../namespaces/temporal/test_truncate.py | 53 +++++++++-- 2 files changed, 102 insertions(+), 41 deletions(-) diff --git a/crates/polars-time/src/truncate.rs b/crates/polars-time/src/truncate.rs index a095ff0628e4..7940f4e5a706 100644 --- a/crates/polars-time/src/truncate.rs +++ b/crates/polars-time/src/truncate.rs @@ -1,5 +1,5 @@ use arrow::legacy::time_zone::Tz; -use arrow::temporal_conversions::{MILLISECONDS, SECONDS_IN_DAY}; +use arrow::temporal_conversions::MILLISECONDS_IN_DAY; use polars_core::prelude::arity::broadcast_try_binary_elementwise; use polars_core::prelude::*; use polars_utils::cache::FastFixedCache; @@ -16,7 +16,6 @@ impl PolarsTruncate for DatetimeChunked { fn truncate(&self, tz: Option<&Tz>, every: &StringChunked, offset: &str) -> PolarsResult { let offset: Duration = Duration::parse(offset); let time_zone = self.time_zone(); - let mut duration_cache_opt: Option> = None; // Let's check if we can use a fastpath... if every.len() == 1 { @@ -42,19 +41,28 @@ impl PolarsTruncate for DatetimeChunked { }) .into_datetime(self.time_unit(), time_zone.clone())); } else { - // A sqrt(n) cache is not too small, not too large. - duration_cache_opt = - Some(FastFixedCache::new((every.len() as f64).sqrt() as usize)); - duration_cache_opt - .as_mut() - .map(|cache| *cache.insert(every.to_string(), every_parsed)); + let w = Window::new(every_parsed, every_parsed, offset); + let out = match self.time_unit() { + TimeUnit::Milliseconds => { + self.try_apply_nonnull_values_generic(|t| w.truncate_ms(t, tz)) + }, + TimeUnit::Microseconds => { + self.try_apply_nonnull_values_generic(|t| w.truncate_us(t, tz)) + }, + TimeUnit::Nanoseconds => { + self.try_apply_nonnull_values_generic(|t| w.truncate_ns(t, tz)) + }, + }; + return Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone())); } + } else { + return Ok(Int64Chunked::full_null(self.name(), self.len()) + .into_datetime(self.time_unit(), self.time_zone().clone())); } } - let mut duration_cache = match duration_cache_opt { - Some(cache) => cache, - None => FastFixedCache::new((every.len() as f64).sqrt() as usize), - }; + + // A sqrt(n) cache is not too small, not too large. + let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize); let func = match self.time_unit() { TimeUnit::Nanoseconds => Window::truncate_ns, @@ -62,14 +70,6 @@ impl PolarsTruncate for DatetimeChunked { TimeUnit::Milliseconds => Window::truncate_ms, }; - // TODO: optimize the code below, so it does the following: - // - convert to naive - // - truncate all naively - // - localize, preserving the fold of the original datetime. - // The last step is the non-trivial one. But it should be worth it, - // and faster than the current approach of truncating everything - // as tz-aware. - let out = broadcast_try_binary_elementwise(self, every, |opt_timestamp, opt_every| match ( opt_timestamp, opt_every, @@ -99,26 +99,44 @@ impl PolarsTruncate for DateChunked { offset: &str, ) -> PolarsResult { let offset = Duration::parse(offset); - // A sqrt(n) cache is not too small, not too large. - let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize); - let out = broadcast_try_binary_elementwise(&self.0, every, |opt_t, opt_every| { - match (opt_t, opt_every) { - (Some(t), Some(every)) => { - const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY; - let every = - *duration_cache.get_or_insert_with(every, |every| Duration::parse(every)); + let out = match every.len() { + 1 => { + if let Some(every) = every.get(0) { + let every = Duration::parse(every); if every.negative { polars_bail!(ComputeError: "cannot truncate a Date to a negative duration") } - let w = Window::new(every, every, offset); - Ok(Some( - (w.truncate_ms(MSECS_IN_DAY * t as i64, None)? / MSECS_IN_DAY) as i32, - )) - }, - _ => Ok(None), - } - }); + self.try_apply_nonnull_values_generic(|t| { + Ok((w.truncate_ms(MILLISECONDS_IN_DAY * t as i64, None)? + / MILLISECONDS_IN_DAY) as i32) + }) + } else { + Ok(Int32Chunked::full_null(self.name(), self.len())) + } + }, + _ => broadcast_try_binary_elementwise(self, every, |opt_t, opt_every| { + // A sqrt(n) cache is not too small, not too large. + let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize); + match (opt_t, opt_every) { + (Some(t), Some(every)) => { + let every = *duration_cache + .get_or_insert_with(every, |every| Duration::parse(every)); + + if every.negative { + polars_bail!(ComputeError: "cannot truncate a Date to a negative duration") + } + + let w = Window::new(every, every, offset); + Ok(Some( + (w.truncate_ms(MILLISECONDS_IN_DAY * t as i64, None)? + / MILLISECONDS_IN_DAY) as i32, + )) + }, + _ => Ok(None), + } + }), + }; Ok(out?.into_date()) } } diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py b/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py index 6e684ce130ad..c750fa537d18 100644 --- a/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py @@ -1,24 +1,67 @@ -import datetime as dt +from datetime import date, datetime import hypothesis.strategies as st from hypothesis import given import polars as pl +from polars.testing import assert_series_equal @given( value=st.datetimes( - min_value=dt.datetime(1000, 1, 1), - max_value=dt.datetime(3000, 1, 1), + min_value=datetime(1000, 1, 1), + max_value=datetime(3000, 1, 1), ), n=st.integers(min_value=1, max_value=100), ) -def test_truncate_monthly(value: dt.date, n: int) -> None: +def test_truncate_monthly(value: date, n: int) -> None: result = pl.Series([value]).dt.truncate(f"{n}mo").item() # manual calculation total = value.year * 12 + value.month - 1 remainder = total % n total -= remainder year, month = (total // 12), ((total % 12) + 1) - expected = dt.datetime(year, month, 1) + expected = datetime(year, month, 1) assert result == expected + + +def test_truncate_date() -> None: + # n vs n + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(pl.col("a").dt.truncate(pl.col("b")))["a"] + expected = pl.Series("a", [None, None, date(2020, 1, 1)]) + assert_series_equal(result, expected) + + # n vs 1 + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(pl.col("a").dt.truncate("1mo"))["a"] + expected = pl.Series("a", [date(2020, 1, 1), None, date(2020, 1, 1)]) + assert_series_equal(result, expected) + + # n vs missing + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(pl.col("a").dt.truncate(pl.lit(None, dtype=pl.String)))["a"] + expected = pl.Series("a", [None, None, None], dtype=pl.Date) + assert_series_equal(result, expected) + + # 1 vs n + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(a=pl.date(2020, 1, 1).dt.truncate(pl.col("b")))["a"] + expected = pl.Series("a", [None, date(2020, 1, 1), date(2020, 1, 1)]) + assert_series_equal(result, expected) + + # missing vs n + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(a=pl.lit(None, dtype=pl.Date).dt.truncate(pl.col("b")))["a"] + expected = pl.Series("a", [None, None, None], dtype=pl.Date) + assert_series_equal(result, expected) From 61f5d18518f6050dd8af4ceeba8cc84d7382241b Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 3 Jun 2024 09:30:44 +0100 Subject: [PATCH 2/2] extra test coverage for good measure --- .../namespaces/temporal/test_truncate.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py b/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py index c750fa537d18..03e19b453048 100644 --- a/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py @@ -1,11 +1,18 @@ +from __future__ import annotations + from datetime import date, datetime +from typing import TYPE_CHECKING import hypothesis.strategies as st +import pytest from hypothesis import given import polars as pl from polars.testing import assert_series_equal +if TYPE_CHECKING: + from polars.type_aliases import TimeUnit + @given( value=st.datetimes( @@ -65,3 +72,23 @@ def test_truncate_date() -> None: result = df.select(a=pl.lit(None, dtype=pl.Date).dt.truncate(pl.col("b")))["a"] expected = pl.Series("a", [None, None, None], dtype=pl.Date) assert_series_equal(result, expected) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_truncate_datetime_simple(time_unit: TimeUnit) -> None: + s = pl.Series([datetime(2020, 1, 2, 6)], dtype=pl.Datetime(time_unit)) + result = s.dt.truncate("1mo").item() + assert result == datetime(2020, 1, 1) + result = s.dt.truncate("1d").item() + assert result == datetime(2020, 1, 2) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_truncate_datetime_w_expression(time_unit: TimeUnit) -> None: + df = pl.DataFrame( + {"a": [datetime(2020, 1, 2, 6), datetime(2020, 1, 3, 7)], "b": ["1mo", "1d"]}, + schema_overrides={"a": pl.Datetime(time_unit)}, + ) + result = df.select(pl.col("a").dt.truncate(pl.col("b")))["a"] + assert result[0] == datetime(2020, 1, 1) + assert result[1] == datetime(2020, 1, 3)