Skip to content

Commit

Permalink
feat(python): Allow Python Enums as dtype inputs (#19926)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Nov 22, 2024
1 parent 4af1c43 commit 414d883
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 2 deletions.
5 changes: 5 additions & 0 deletions py-polars/polars/datatypes/_parse.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import enum
import functools
import re
import sys
from datetime import date, datetime, time, timedelta
from decimal import Decimal as PyDecimal
from inspect import isclass
from typing import TYPE_CHECKING, Any, ForwardRef, NoReturn, Union, get_args

from polars.datatypes.classes import (
Expand All @@ -14,6 +16,7 @@
Datetime,
Decimal,
Duration,
Enum,
Float64,
Int64,
List,
Expand Down Expand Up @@ -94,6 +97,8 @@ def parse_py_type_into_dtype(input: PythonDataType | type[object]) -> PolarsData
return Null()
elif input is list or input is tuple:
return List
elif isclass(input) and issubclass(input, enum.Enum):
return Enum(input)
# this is required as pass through. Don't remove
elif input == Unknown:
return Unknown
Expand Down
7 changes: 5 additions & 2 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import enum
from collections import OrderedDict
from collections.abc import Mapping
from datetime import timezone
Expand Down Expand Up @@ -596,7 +597,7 @@ class Enum(DataType):

categories: Series

def __init__(self, categories: Series | Iterable[str]) -> None:
def __init__(self, categories: Series | Iterable[str] | type[enum.Enum]) -> None:
# Issuing the warning on `__init__` does not trigger when the class is used
# without being instantiated, but it's better than nothing
from polars._utils.unstable import issue_unstable_warning
Expand All @@ -606,7 +607,9 @@ def __init__(self, categories: Series | Iterable[str]) -> None:
" It is a work-in-progress feature and may not always work as expected."
)

if not isinstance(categories, pl.Series):
if isclass(categories) and issubclass(categories, enum.Enum):
categories = pl.Series(values=categories.__members__.values())
elif not isinstance(categories, pl.Series):
categories = pl.Series(values=categories)

if categories.is_empty():
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/unit/constructors/test_dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
import sys
from collections import OrderedDict
from collections.abc import Mapping
Expand Down Expand Up @@ -194,3 +195,13 @@ def test_df_init_schema_object() -> None:
def test_df_init_data_orientation_inference_warning() -> None:
with pytest.warns(DataOrientationWarning):
pl.from_records([[1, 2, 3], [4, 5, 6]], schema=["a", "b", "c"])


def test_df_init_enum_dtype() -> None:
class PythonEnum(str, enum.Enum):
A = "A"
B = "B"
C = "C"

df = pl.DataFrame({"Col 1": ["A", "B", "C"]}, schema={"Col 1": PythonEnum})
assert df.dtypes[0] == pl.Enum(["A", "B", "C"])
21 changes: 21 additions & 0 deletions py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
import operator
import re
from datetime import date
Expand Down Expand Up @@ -41,6 +42,26 @@ def test_enum_init_empty(categories: pl.Series | list[str] | None) -> None:
assert_series_equal(dtype.categories, expected)


def test_enum_init_python_enum_19724() -> None:
class PythonEnum(str, enum.Enum):
CAT1 = "A"
CAT2 = "B"
CAT3 = "C"

result = pl.Enum(PythonEnum)
assert result == pl.Enum(["A", "B", "C"])


def test_enum_init_python_enum_ints_19724() -> None:
class PythonEnum(int, enum.Enum):
CAT1 = 1
CAT2 = 2
CAT3 = 3

with pytest.raises(TypeError, match="Enum categories must be strings"):
pl.Enum(PythonEnum)


def test_enum_non_existent() -> None:
with pytest.raises(
InvalidOperationError,
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/datatypes/test_parse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
from datetime import date, datetime
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -43,6 +44,27 @@ def test_parse_into_dtype(input: Any, expected: PolarsDataType) -> None:
assert_dtype_equal(result, expected)


def test_parse_into_dtype_enum_19724() -> None:
class PythonEnum(str, enum.Enum):
CAT1 = "A"
CAT2 = "B"
CAT3 = "C"

result = parse_into_dtype(PythonEnum)
expected = pl.Enum(["A", "B", "C"])
assert_dtype_equal(result, expected)


def test_parse_into_dtype_enum_ints_19724() -> None:
class PythonEnum(int, enum.Enum):
CAT1 = 1
CAT2 = 2
CAT3 = 3

with pytest.raises(TypeError, match="Enum categories must be strings"):
parse_into_dtype(PythonEnum)


@pytest.mark.parametrize(
("input", "expected"),
[
Expand Down

0 comments on commit 414d883

Please sign in to comment.