From 7ebcfabddf3c8b40838533ad5bbe62c5273b991c Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Fri, 14 Jun 2024 11:44:54 +0200 Subject: [PATCH] feat(python): Support Decimal inputs for `lit` (#16950) --- crates/polars-expr/src/expressions/literal.rs | 4 +++ crates/polars-plan/src/logical_plan/lit.rs | 9 +++++++ py-polars/src/functions/lazy.rs | 4 +++ py-polars/src/lazyframe/visitor/expr_nodes.rs | 2 +- py-polars/tests/unit/functions/test_lit.py | 26 +++++++++++++++++++ 5 files changed, 44 insertions(+), 1 deletion(-) diff --git a/crates/polars-expr/src/expressions/literal.rs b/crates/polars-expr/src/expressions/literal.rs index 9e6427e78b0d..6b43825087a1 100644 --- a/crates/polars-expr/src/expressions/literal.rs +++ b/crates/polars-expr/src/expressions/literal.rs @@ -37,6 +37,10 @@ impl PhysicalExpr for LiteralExpr { UInt64(v) => UInt64Chunked::full(LITERAL_NAME, *v, 1).into_series(), Float32(v) => Float32Chunked::full(LITERAL_NAME, *v, 1).into_series(), Float64(v) => Float64Chunked::full(LITERAL_NAME, *v, 1).into_series(), + #[cfg(feature = "dtype-decimal")] + Decimal(v, scale) => Int128Chunked::full(LITERAL_NAME, *v, 1) + .into_decimal_unchecked(None, *scale) + .into_series(), Boolean(v) => BooleanChunked::full(LITERAL_NAME, *v, 1).into_series(), Null => polars_core::prelude::Series::new_null(LITERAL_NAME, 1), Range { diff --git a/crates/polars-plan/src/logical_plan/lit.rs b/crates/polars-plan/src/logical_plan/lit.rs index 1c2dedddafa4..c0dcab76d3c6 100644 --- a/crates/polars-plan/src/logical_plan/lit.rs +++ b/crates/polars-plan/src/logical_plan/lit.rs @@ -45,6 +45,9 @@ pub enum LiteralValue { Float32(f32), /// A 64-bit floating point number. Float64(f64), + /// A 128-bit decimal number with a maximum scale of 38. + #[cfg(feature = "dtype-decimal")] + Decimal(i128, usize), Range { low: i64, high: i64, @@ -121,6 +124,8 @@ impl LiteralValue { Int64(v) => AnyValue::Int64(*v), Float32(v) => AnyValue::Float32(*v), Float64(v) => AnyValue::Float64(*v), + #[cfg(feature = "dtype-decimal")] + Decimal(v, scale) => AnyValue::Decimal(*v, *scale), String(v) => AnyValue::String(v), #[cfg(feature = "dtype-duration")] Duration(v, tu) => AnyValue::Duration(*v, *tu), @@ -192,6 +197,8 @@ impl LiteralValue { LiteralValue::Int64(_) => DataType::Int64, LiteralValue::Float32(_) => DataType::Float32, LiteralValue::Float64(_) => DataType::Float64, + #[cfg(feature = "dtype-decimal")] + LiteralValue::Decimal(_, scale) => DataType::Decimal(None, Some(*scale)), LiteralValue::String(_) => DataType::String, LiteralValue::Binary(_) => DataType::Binary, LiteralValue::Range { data_type, .. } => data_type.clone(), @@ -276,6 +283,8 @@ impl TryFrom> for LiteralValue { AnyValue::Int64(i) => Ok(Self::Int64(i)), AnyValue::Float32(f) => Ok(Self::Float32(f)), AnyValue::Float64(f) => Ok(Self::Float64(f)), + #[cfg(feature = "dtype-decimal")] + AnyValue::Decimal(v, scale) => Ok(Self::Decimal(v, scale)), #[cfg(feature = "dtype-date")] AnyValue::Date(v) => Ok(LiteralValue::Date(v)), #[cfg(feature = "dtype-datetime")] diff --git a/py-polars/src/functions/lazy.rs b/py-polars/src/functions/lazy.rs index cc46b9f05120..8e241b701f8d 100644 --- a/py-polars/src/functions/lazy.rs +++ b/py-polars/src/functions/lazy.rs @@ -5,6 +5,7 @@ use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use pyo3::types::{PyBool, PyBytes, PyFloat, PyInt, PyString}; +use crate::conversion::any_value::py_object_to_any_value; use crate::conversion::{get_lf, Wrap}; use crate::expr::ToExprs; use crate::map::lazy::binary_lambda; @@ -428,6 +429,9 @@ pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool) -> PyResult { Ok(dsl::lit(Null {}).into()) } else if let Ok(value) = value.downcast::() { Ok(dsl::lit(value.as_bytes()).into()) + } else if value.get_type().qualname().unwrap() == "Decimal" { + let av = py_object_to_any_value(value, true)?; + Ok(Expr::Literal(LiteralValue::try_from(av).unwrap()).into()) } else if allow_object { let s = Python::with_gil(|py| { PySeries::new_object(py, "", vec![ObjectValue::from(value.into_py(py))], false).series diff --git a/py-polars/src/lazyframe/visitor/expr_nodes.rs b/py-polars/src/lazyframe/visitor/expr_nodes.rs index 03cf03d40b77..51aed93d2fc1 100644 --- a/py-polars/src/lazyframe/visitor/expr_nodes.rs +++ b/py-polars/src/lazyframe/visitor/expr_nodes.rs @@ -556,7 +556,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { }, Binary(_) => return Err(PyNotImplementedError::new_err("binary literal")), Range { .. } => return Err(PyNotImplementedError::new_err("range literal")), - Date(..) | DateTime(..) => Literal { + Date(..) | DateTime(..) | Decimal(..) => Literal { value: Wrap(lit.to_any_value().unwrap()).to_object(py), dtype, }, diff --git a/py-polars/tests/unit/functions/test_lit.py b/py-polars/tests/unit/functions/test_lit.py index 79c7391048b6..b9964f857183 100644 --- a/py-polars/tests/unit/functions/test_lit.py +++ b/py-polars/tests/unit/functions/test_lit.py @@ -2,6 +2,7 @@ import enum from datetime import datetime, timedelta +from decimal import Decimal from typing import Any import numpy as np @@ -10,6 +11,7 @@ import polars as pl from polars.testing import assert_frame_equal +from polars.testing.parametric.strategies import series from polars.testing.parametric.strategies.data import datetimes @@ -155,3 +157,27 @@ def test_datetime_ms(value: datetime) -> None: result = pl.select(pl.lit(value, dtype=pl.Datetime("ms")))["literal"][0] expected_microsecond = value.microsecond // 1000 * 1000 assert result == value.replace(microsecond=expected_microsecond) + + +def test_lit_decimal() -> None: + value = Decimal("0.1") + + expr = pl.lit(value) + df = pl.select(expr) + result = df.item() + + assert df.dtypes[0] == pl.Decimal(None, 1) + assert result == value + + +@given(s=series(min_size=1, max_size=1, allow_null=False, allowed_dtypes=pl.Decimal)) +def test_lit_decimal_parametric(s: pl.Series) -> None: + scale = s.dtype.scale # type: ignore[attr-defined] + value = s.item() + + expr = pl.lit(value) + df = pl.select(expr) + result = df.item() + + assert df.dtypes[0] == pl.Decimal(None, scale) + assert result == value