Skip to content

Commit

Permalink
fix(python): Convert date and datetime in literal construction (pola-…
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller authored Aug 15, 2024
1 parent a739825 commit 5e5506c
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 16 deletions.
50 changes: 36 additions & 14 deletions py-polars/polars/functions/lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import polars._reexport as pl
from polars._utils.convert import (
date_to_int,
datetime_to_int,
time_to_int,
timedelta_to_int,
)
Expand Down Expand Up @@ -78,25 +76,40 @@ def lit(
time_unit: TimeUnit

if isinstance(value, datetime):
if dtype == Date:
return wrap_expr(plr.lit(value.date(), allow_object=False))

# parse time unit
if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None:
time_unit = tu # type: ignore[assignment]
else:
time_unit = "us"

time_zone: str | None = getattr(dtype, "time_zone", None)
if (tzinfo := value.tzinfo) is not None:
tzinfo_str = str(tzinfo)
if time_zone is not None and time_zone != tzinfo_str:
msg = f"time zone of dtype ({time_zone!r}) differs from time zone of value ({tzinfo!r})"
# parse time zone
dtype_tz = getattr(dtype, "time_zone", None)
value_tz = value.tzinfo
if value_tz is None:
tz = dtype_tz
else:
if dtype_tz is None:
# value has time zone, but dtype does not: keep value time zone
tz = str(value_tz)
elif str(value_tz) == dtype_tz:
# dtype and value both have same time zone
tz = str(value_tz)
else:
# value has time zone that differs from dtype time zone
msg = (
f"time zone of dtype ({dtype_tz!r}) differs from time zone of "
f"value ({value_tz!r})"
)
raise TypeError(msg)
time_zone = tzinfo_str

dt_utc = value.replace(tzinfo=timezone.utc)
dt_int = datetime_to_int(dt_utc, time_unit)
expr = lit(dt_int).cast(Datetime(time_unit))
if time_zone is not None:
expr = wrap_expr(plr.lit(dt_utc, allow_object=False)).cast(Datetime(time_unit))
if tz is not None:
expr = expr.dt.replace_time_zone(
time_zone, ambiguous="earliest" if value.fold == 0 else "latest"
tz, ambiguous="earliest" if value.fold == 0 else "latest"
)
return expr

Expand All @@ -114,8 +127,17 @@ def lit(
return lit(time_int).cast(Time)

elif isinstance(value, date):
date_int = date_to_int(value)
return lit(date_int).cast(Date)
if dtype == Datetime:
time_unit = getattr(dtype, "time_unit", "us") or "us"
dt_utc = datetime(value.year, value.month, value.day)
expr = wrap_expr(plr.lit(dt_utc, allow_object=False)).cast(
Datetime(time_unit)
)
if (time_zone := getattr(dtype, "time_zone", None)) is not None:
expr = expr.dt.replace_time_zone(str(time_zone))
return expr
else:
return wrap_expr(plr.lit(value, allow_object=False))

elif isinstance(value, pl.Series):
value = value._s
Expand Down
5 changes: 4 additions & 1 deletion py-polars/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,10 @@ pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool) -> PyResult<PyExpr> {
Ok(dsl::lit(Null {}).into())
} else if let Ok(value) = value.downcast::<PyBytes>() {
Ok(dsl::lit(value.as_bytes()).into())
} else if value.get_type().qualname().unwrap() == "Decimal" {
} else if matches!(
value.get_type().qualname().unwrap().as_str(),
"date" | "datetime" | "Decimal"
) {
let av = py_object_to_any_value(value, true)?;
Ok(Expr::Literal(LiteralValue::try_from(av).unwrap()).into())
} else if allow_object {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections import OrderedDict
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING

Expand All @@ -15,7 +16,7 @@
if TYPE_CHECKING:
from zoneinfo import ZoneInfo

from polars._typing import TemporalLiteral, TimeUnit
from polars._typing import PolarsDataType, TemporalLiteral, TimeUnit
else:
from polars._utils.convert import string_to_zoneinfo as ZoneInfo

Expand Down Expand Up @@ -1350,3 +1351,79 @@ def test_dt_mean_deprecated() -> None:
with pytest.deprecated_call():
result = s.dt.mean()
assert result == s.mean()


@pytest.mark.parametrize(
"dtype",
[
pl.Date,
pl.Datetime("ms"),
pl.Datetime("ms", "EST"),
pl.Datetime("us"),
pl.Datetime("us", "EST"),
pl.Datetime("ns"),
pl.Datetime("ns", "EST"),
],
)
@pytest.mark.parametrize(
"value",
[
date(1677, 9, 22),
date(1970, 1, 1),
date(2024, 2, 29),
date(2262, 4, 11),
],
)
def test_literal_from_date(
value: date,
dtype: PolarsDataType,
) -> None:
out = pl.select(pl.lit(value, dtype=dtype))
assert out.schema == OrderedDict({"literal": dtype})
if dtype == pl.Datetime:
tz = ZoneInfo(dtype.time_zone) if dtype.time_zone is not None else None # type: ignore[union-attr]
value = datetime(value.year, value.month, value.day, tzinfo=tz)
assert out.item() == value


@pytest.mark.parametrize(
"dtype",
[
pl.Date,
pl.Datetime("ms"),
pl.Datetime("ms", "EST"),
pl.Datetime("us"),
pl.Datetime("us", "EST"),
pl.Datetime("ns"),
pl.Datetime("ns", "EST"),
],
)
@pytest.mark.parametrize(
"value",
[
datetime(1677, 9, 22),
datetime(1677, 9, 22, tzinfo=ZoneInfo("EST")),
datetime(1970, 1, 1),
datetime(1970, 1, 1, tzinfo=ZoneInfo("EST")),
datetime(2024, 2, 29),
datetime(2024, 2, 29, tzinfo=ZoneInfo("EST")),
datetime(2262, 4, 11),
datetime(2262, 4, 11, tzinfo=ZoneInfo("EST")),
],
)
def test_literal_from_datetime(
value: datetime,
dtype: pl.Date | pl.Datetime,
) -> None:
out = pl.select(pl.lit(value, dtype=dtype))
if dtype == pl.Date:
value = value.date() # type: ignore[assignment]
elif dtype.time_zone is None and value.tzinfo is not None: # type: ignore[union-attr]
# update the dtype with the supplied time zone in the value
dtype = pl.Datetime(dtype.time_unit, str(value.tzinfo)) # type: ignore[union-attr]
elif dtype.time_zone is not None and value.tzinfo is None: # type: ignore[union-attr]
# cast from dt without tz to dtype with tz
value = value.replace(tzinfo=ZoneInfo(dtype.time_zone)) # type: ignore[union-attr]

assert out.schema == OrderedDict({"literal": dtype})
assert out.item() == value

0 comments on commit 5e5506c

Please sign in to comment.