Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Add Schema class #16873

Merged
merged 5 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ methods. All classes and functions exposed in the ``polars.*`` namespace are pub
:maxdepth: 2

datatypes
schema/index


.. grid::
Expand Down
11 changes: 11 additions & 0 deletions py-polars/docs/source/reference/schema/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
======
Schema
======

.. currentmodule:: polars

.. autoclass:: Schema
:members:
:noindex:
:autosummary:
:autosummary-nosignatures:
3 changes: 3 additions & 0 deletions py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@
thread_pool_size,
threadpool_size,
)
from polars.schema import Schema
from polars.series import Series
from polars.sql import SQLContext, sql
from polars.string_cache import (
Expand Down Expand Up @@ -252,7 +253,9 @@
"Expr",
"LazyFrame",
"Series",
# other classes
"InProcessQuery",
"Schema",
# polars.datatypes
"Array",
"Binary",
Expand Down
16 changes: 6 additions & 10 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import contextlib
import os
import random
from collections import OrderedDict, defaultdict
from collections import defaultdict
from collections.abc import Sized
from io import BytesIO, StringIO, TextIOWrapper
from operator import itemgetter
Expand Down Expand Up @@ -94,6 +94,7 @@
TooManyRowsReturnedError,
)
from polars.functions import col, lit
from polars.schema import Schema
from polars.selectors import _expand_selector_dicts, _expand_selectors
from polars.type_aliases import DbWriteMode, JaxExportType, TorchExportType

Expand Down Expand Up @@ -765,14 +766,9 @@ def flags(self) -> dict[str, dict[str, bool]]:
return {name: self[name].flags for name in self.columns}

@property
def schema(self) -> OrderedDict[str, DataType]:
def schema(self) -> Schema:
"""
Get a mapping of column names to their data type.

Returns
-------
OrderedDict
An ordered mapping of column names to their data type.
Get an ordered mapping of column names to their data type.

Examples
--------
Expand All @@ -784,9 +780,9 @@ def schema(self) -> OrderedDict[str, DataType]:
... }
... )
>>> df.schema
OrderedDict({'foo': Int64, 'bar': Float64, 'ham': String})
Schema({'foo': Int64, 'bar': Float64, 'ham': String})
"""
return OrderedDict(zip(self.columns, self.dtypes))
return Schema(zip(self.columns, self.dtypes))

def __array__(
self, dtype: npt.DTypeLike | None = None, copy: bool | None = None
Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/expr/name.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def map_fields(self, function: Callable[[str], str]) -> Expr:
--------
>>> df = pl.DataFrame({"x": {"a": 1, "b": 2}})
>>> df.select(pl.col("x").name.map_fields(lambda x: x.upper())).schema
OrderedDict({'x': Struct({'A': Int64, 'B': Int64})})
Schema({'x': Struct({'A': Int64, 'B': Int64})})
"""
return self._from_pyexpr(self._pyexpr.name_map_fields(function))

Expand All @@ -322,7 +322,7 @@ def prefix_fields(self, prefix: str) -> Expr:
--------
>>> df = pl.DataFrame({"x": {"a": 1, "b": 2}})
>>> df.select(pl.col("x").name.prefix_fields("prefix_")).schema
OrderedDict({'x': Struct({'prefix_a': Int64, 'prefix_b': Int64})})
Schema({'x': Struct({'prefix_a': Int64, 'prefix_b': Int64})})
"""
return self._from_pyexpr(self._pyexpr.name_prefix_fields(prefix))

Expand All @@ -343,6 +343,6 @@ def suffix_fields(self, suffix: str) -> Expr:
--------
>>> df = pl.DataFrame({"x": {"a": 1, "b": 2}})
>>> df.select(pl.col("x").name.suffix_fields("_suffix")).schema
OrderedDict({'x': Struct({'a_suffix': Int64, 'b_suffix': Int64})})
Schema({'x': Struct({'a_suffix': Int64, 'b_suffix': Int64})})
"""
return self._from_pyexpr(self._pyexpr.name_suffix_fields(suffix))
2 changes: 1 addition & 1 deletion py-polars/polars/functions/as_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def struct(
Use keyword arguments to easily name each struct field.

>>> df.select(pl.struct(p="int", q="bool").alias("my_struct")).schema
OrderedDict({'my_struct': Struct({'p': Int64, 'q': Boolean})})
Schema({'my_struct': Struct({'p': Int64, 'q': Boolean})})
"""
pyexprs = parse_as_list_of_expressions(*exprs, **named_exprs)
expr = wrap_expr(plr.as_struct(pyexprs))
Expand Down
15 changes: 5 additions & 10 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import contextlib
import os
from collections import OrderedDict
from datetime import date, datetime, time, timedelta
from functools import lru_cache, reduce
from io import BytesIO, StringIO
Expand Down Expand Up @@ -77,6 +76,7 @@
from polars.dependencies import import_optional, subprocess
from polars.lazyframe.group_by import LazyGroupBy
from polars.lazyframe.in_process import InProcessQuery
from polars.schema import Schema
from polars.selectors import _expand_selectors, by_dtype, expand_selector
from polars.slice import LazyPolarsSlice

Expand Down Expand Up @@ -452,14 +452,9 @@ def dtypes(self) -> list[DataType]:
return self._ldf.dtypes()

@property
def schema(self) -> OrderedDict[str, DataType]:
def schema(self) -> Schema:
"""
Get a mapping of column names to their data type.

Returns
-------
OrderedDict
An ordered mapping of column names to their data type.
Get an ordered mapping of column names to their data type.

Warnings
--------
Expand All @@ -476,9 +471,9 @@ def schema(self) -> OrderedDict[str, DataType]:
... }
... )
>>> lf.schema
OrderedDict({'foo': Int64, 'bar': Float64, 'ham': String})
Schema({'foo': Int64, 'bar': Float64, 'ham': String})
"""
return OrderedDict(self._ldf.schema())
return Schema(self._ldf.schema())

@property
def width(self) -> int:
Expand Down
61 changes: 61 additions & 0 deletions py-polars/polars/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

from collections import OrderedDict
from typing import TYPE_CHECKING, Iterable, Mapping

if TYPE_CHECKING:
from polars.datatypes import DataType

BaseSchema = OrderedDict[str, DataType]
else:
# Python 3.8 does not support generic OrderedDict at runtime
BaseSchema = OrderedDict


class Schema(BaseSchema):
"""
Ordered mapping of column names to their data type.

Parameters
----------
schema
The schema definition given by column names and their associated *instantiated*
Polars data type. Accepts a mapping or an iterable of tuples.

Examples
--------
Define a schema by passing *instantiated* data types.

>>> schema = pl.Schema({"foo": pl.Int8(), "bar": pl.String()})
>>> schema
Schema({'foo': Int8, 'bar': String})

Access the data type associated with a specific column name.

>>> schema["foo"]
Int8

Access various schema properties using the `names`, `dtypes`, and `len` methods.

>>> schema.names()
['foo', 'bar']
>>> schema.dtypes()
[Int8, String]
>>> schema.len()
2
"""

def __init__(self, schema: Mapping[str, DataType] | Iterable[tuple[str, DataType]]):
super().__init__(schema)

def names(self) -> list[str]:
"""Get the column names of the schema."""
return list(self.keys())

def dtypes(self) -> list[DataType]:
"""Get the data types of the schema."""
return list(self.values())

def len(self) -> int:
stinodego marked this conversation as resolved.
Show resolved Hide resolved
"""Get the number of columns in the schema."""
return len(self)
14 changes: 8 additions & 6 deletions py-polars/polars/series/struct.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

from collections import OrderedDict
from typing import TYPE_CHECKING, Sequence

from polars._utils.various import BUILDING_SPHINX_DOCS, sphinx_accessor
from polars._utils.wrap import wrap_df
from polars.schema import Schema
from polars.series.utils import expr_dispatch

if TYPE_CHECKING:
from polars import DataFrame, DataType, Series
from polars import DataFrame, Series
from polars.polars import PySeries
elif BUILDING_SPHINX_DOCS:
property = sphinx_accessor
Expand Down Expand Up @@ -91,19 +91,21 @@ def rename_fields(self, names: Sequence[str]) -> Series:
"""

@property
def schema(self) -> OrderedDict[str, DataType]:
def schema(self) -> Schema:
"""
Get the struct definition as a name/dtype schema dict.

Examples
--------
>>> s = pl.Series([{"a": 1, "b": 2}, {"a": 3, "b": 4}])
>>> s.struct.schema
OrderedDict({'a': Int64, 'b': Int64})
Schema({'a': Int64, 'b': Int64})
"""
if getattr(self, "_s", None) is None:
return OrderedDict()
return OrderedDict(self._s.dtype().to_schema())
return Schema({})

schema = self._s.dtype().to_schema()
return Schema(schema)

def unnest(self) -> DataFrame:
"""
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/constructors/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,11 @@ def test_list_null_constructor_schema() -> None:
expected = pl.List(pl.Null)
assert pl.DataFrame({"a": [[]]}).dtypes[0] == expected
assert pl.DataFrame(schema={"a": pl.List}).dtypes[0] == expected


def test_df_init_schema_object() -> None:
schema = pl.Schema({"a": pl.Int8(), "b": pl.String()})
df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]}, schema=schema)

assert df.columns == schema.names()
assert df.dtypes == schema.dtypes()
34 changes: 34 additions & 0 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import polars as pl


def test_schema() -> None:
s = pl.Schema({"foo": pl.Int8(), "bar": pl.String()})

assert s["foo"] == pl.Int8()
assert s["bar"] == pl.String()
assert s.len() == 2
assert s.names() == ["foo", "bar"]
assert s.dtypes() == [pl.Int8(), pl.String()]


def test_schema_parse_nonpolars_dtypes() -> None:
# Currently, no parsing is being done.
s = pl.Schema({"foo": pl.List, "bar": int}) # type: ignore[arg-type]

assert s["foo"] == pl.List
assert s["bar"] == int
assert s.len() == 2
assert s.names() == ["foo", "bar"]
assert s.dtypes() == [pl.List, int]


def test_schema_equality() -> None:
s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()})
s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()})
s3 = pl.Schema({"bar": pl.Float64(), "foo": pl.Int8()})
assert s1 == s1
assert s2 == s2
assert s3 == s3
assert s1 != s2
assert s1 != s3
assert s2 != s3