diff --git a/py-polars/polars/datatypes/_parse.py b/py-polars/polars/datatypes/_parse.py index e7ac78cae6dd..3a5424bc8b59 100644 --- a/py-polars/polars/datatypes/_parse.py +++ b/py-polars/polars/datatypes/_parse.py @@ -1,10 +1,12 @@ from __future__ import annotations +import enum import functools import re import sys from datetime import date, datetime, time, timedelta from decimal import Decimal as PyDecimal +from inspect import isclass from typing import TYPE_CHECKING, Any, ForwardRef, NoReturn, Union, get_args from polars.datatypes.classes import ( @@ -14,6 +16,7 @@ Datetime, Decimal, Duration, + Enum, Float64, Int64, List, @@ -94,6 +97,8 @@ def parse_py_type_into_dtype(input: PythonDataType | type[object]) -> PolarsData return Null() elif input is list or input is tuple: return List + elif isclass(input) and issubclass(input, enum.Enum): + return Enum(input) # this is required as pass through. Don't remove elif input == Unknown: return Unknown diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index 5543f629a620..bc60c3d6bad1 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import enum from collections import OrderedDict from collections.abc import Mapping from datetime import timezone @@ -596,7 +597,7 @@ class Enum(DataType): categories: Series - def __init__(self, categories: Series | Iterable[str]) -> None: + def __init__(self, categories: Series | Iterable[str] | type[enum.Enum]) -> None: # Issuing the warning on `__init__` does not trigger when the class is used # without being instantiated, but it's better than nothing from polars._utils.unstable import issue_unstable_warning @@ -606,7 +607,9 @@ def __init__(self, categories: Series | Iterable[str]) -> None: " It is a work-in-progress feature and may not always work as expected." ) - if not isinstance(categories, pl.Series): + if isclass(categories) and issubclass(categories, enum.Enum): + categories = pl.Series(values=categories.__members__.values()) + elif not isinstance(categories, pl.Series): categories = pl.Series(values=categories) if categories.is_empty(): diff --git a/py-polars/tests/unit/constructors/test_dataframe.py b/py-polars/tests/unit/constructors/test_dataframe.py index e885919294d1..dd5ae1aba896 100644 --- a/py-polars/tests/unit/constructors/test_dataframe.py +++ b/py-polars/tests/unit/constructors/test_dataframe.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum import sys from collections import OrderedDict from collections.abc import Mapping @@ -194,3 +195,13 @@ def test_df_init_schema_object() -> None: def test_df_init_data_orientation_inference_warning() -> None: with pytest.warns(DataOrientationWarning): pl.from_records([[1, 2, 3], [4, 5, 6]], schema=["a", "b", "c"]) + + +def test_df_init_enum_dtype() -> None: + class PythonEnum(str, enum.Enum): + A = "A" + B = "B" + C = "C" + + df = pl.DataFrame({"Col 1": ["A", "B", "C"]}, schema={"Col 1": PythonEnum}) + assert df.dtypes[0] == pl.Enum(["A", "B", "C"]) diff --git a/py-polars/tests/unit/datatypes/test_enum.py b/py-polars/tests/unit/datatypes/test_enum.py index 9bd4a49ddd1d..bc5a9370a222 100644 --- a/py-polars/tests/unit/datatypes/test_enum.py +++ b/py-polars/tests/unit/datatypes/test_enum.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum import operator import re from datetime import date @@ -41,6 +42,26 @@ def test_enum_init_empty(categories: pl.Series | list[str] | None) -> None: assert_series_equal(dtype.categories, expected) +def test_enum_init_python_enum_19724() -> None: + class PythonEnum(str, enum.Enum): + CAT1 = "A" + CAT2 = "B" + CAT3 = "C" + + result = pl.Enum(PythonEnum) + assert result == pl.Enum(["A", "B", "C"]) + + +def test_enum_init_python_enum_ints_19724() -> None: + class PythonEnum(int, enum.Enum): + CAT1 = 1 + CAT2 = 2 + CAT3 = 3 + + with pytest.raises(TypeError, match="Enum categories must be strings"): + pl.Enum(PythonEnum) + + def test_enum_non_existent() -> None: with pytest.raises( InvalidOperationError, diff --git a/py-polars/tests/unit/datatypes/test_parse.py b/py-polars/tests/unit/datatypes/test_parse.py index 0979292e8e50..017e7dd82b93 100644 --- a/py-polars/tests/unit/datatypes/test_parse.py +++ b/py-polars/tests/unit/datatypes/test_parse.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum from datetime import date, datetime from typing import ( TYPE_CHECKING, @@ -43,6 +44,27 @@ def test_parse_into_dtype(input: Any, expected: PolarsDataType) -> None: assert_dtype_equal(result, expected) +def test_parse_into_dtype_enum_19724() -> None: + class PythonEnum(str, enum.Enum): + CAT1 = "A" + CAT2 = "B" + CAT3 = "C" + + result = parse_into_dtype(PythonEnum) + expected = pl.Enum(["A", "B", "C"]) + assert_dtype_equal(result, expected) + + +def test_parse_into_dtype_enum_ints_19724() -> None: + class PythonEnum(int, enum.Enum): + CAT1 = 1 + CAT2 = 2 + CAT3 = 3 + + with pytest.raises(TypeError, match="Enum categories must be strings"): + parse_into_dtype(PythonEnum) + + @pytest.mark.parametrize( ("input", "expected"), [