Skip to content

Commit

Permalink
Add strict parameter to lit
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Aug 18, 2024
1 parent 1dc2533 commit d5ef03d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
12 changes: 10 additions & 2 deletions py-polars/polars/functions/lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@


def lit(
value: Any, dtype: PolarsDataType | None = None, *, allow_object: bool = False
value: Any,
dtype: PolarsDataType | None = None,
*,
allow_object: bool = False,
strict: bool = True,
) -> Expr:
"""
Return an expression representing a literal value.
Expand All @@ -41,6 +45,10 @@ def lit(
If type is unknown use an 'object' type.
By default, we will raise a `ValueException`
if the type is unknown.
strict : bool, default True
Throw an error if any value does not exactly match the given or inferred data
type. If set to `False`, values that do not match the data type are cast to
that data type or, if casting is not possible, set to null instead.
Notes
-----
Expand Down Expand Up @@ -147,7 +155,7 @@ def lit(
return lit(pl.Series("literal", value, dtype=dtype))

elif isinstance(value, (list, tuple)):
return lit(pl.Series("literal", [value], dtype=dtype))
return lit(pl.Series("literal", [value], dtype=dtype, strict=strict))

elif isinstance(value, enum.Enum):
lit_value = value.value
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/unit/expr/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,17 @@ def lit_series(value: Any, dtype: PolarsDataType | None) -> pl.Series:
)


def test_lit_strict() -> None:
# Default is strict = False
with pytest.raises(TypeError, match="unexpected value while building Series"):
pl.select(pl.lit([1, 1.0]))

assert_series_equal(
pl.select(pl.lit([1, 1.0], strict=False).alias("a")).to_series(),
pl.Series("a", [[1.0, 1.0]]),
)


def test_lit_empty_tu() -> None:
td = timedelta(1)
assert pl.select(pl.lit(td, dtype=pl.Duration)).item() == td
Expand Down

0 comments on commit d5ef03d

Please sign in to comment.