diff --git a/py-polars/docs/source/reference/index.rst b/py-polars/docs/source/reference/index.rst index c662f96a0f94..1b5116eea4b5 100644 --- a/py-polars/docs/source/reference/index.rst +++ b/py-polars/docs/source/reference/index.rst @@ -53,6 +53,7 @@ methods. All classes and functions exposed in the ``polars.*`` namespace are pub :maxdepth: 2 datatypes + schema/index .. grid:: diff --git a/py-polars/docs/source/reference/schema/index.rst b/py-polars/docs/source/reference/schema/index.rst new file mode 100644 index 000000000000..e383fe899862 --- /dev/null +++ b/py-polars/docs/source/reference/schema/index.rst @@ -0,0 +1,11 @@ +====== +Schema +====== + +.. currentmodule:: polars + +.. autoclass:: Schema + :members: + :noindex: + :autosummary: + :autosummary-nosignatures: diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index e869efec4370..fc405d4de542 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -209,6 +209,7 @@ thread_pool_size, threadpool_size, ) +from polars.schema import Schema from polars.series import Series from polars.sql import SQLContext, sql from polars.string_cache import ( @@ -252,7 +253,9 @@ "Expr", "LazyFrame", "Series", + # other classes "InProcessQuery", + "Schema", # polars.datatypes "Array", "Binary", diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 20bbef081349..513e72e8f925 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -5,7 +5,7 @@ import contextlib import os import random -from collections import OrderedDict, defaultdict +from collections import defaultdict from collections.abc import Sized from io import BytesIO, StringIO, TextIOWrapper from operator import itemgetter @@ -94,6 +94,7 @@ TooManyRowsReturnedError, ) from polars.functions import col, lit +from polars.schema import Schema from polars.selectors import _expand_selector_dicts, _expand_selectors from polars.type_aliases import DbWriteMode, JaxExportType, TorchExportType @@ -765,14 +766,9 @@ def flags(self) -> dict[str, dict[str, bool]]: return {name: self[name].flags for name in self.columns} @property - def schema(self) -> OrderedDict[str, DataType]: + def schema(self) -> Schema: """ - Get a mapping of column names to their data type. - - Returns - ------- - OrderedDict - An ordered mapping of column names to their data type. + Get an ordered mapping of column names to their data type. Examples -------- @@ -784,9 +780,9 @@ def schema(self) -> OrderedDict[str, DataType]: ... } ... ) >>> df.schema - OrderedDict({'foo': Int64, 'bar': Float64, 'ham': String}) + Schema({'foo': Int64, 'bar': Float64, 'ham': String}) """ - return OrderedDict(zip(self.columns, self.dtypes)) + return Schema(zip(self.columns, self.dtypes)) def __array__( self, dtype: npt.DTypeLike | None = None, copy: bool | None = None diff --git a/py-polars/polars/expr/name.py b/py-polars/polars/expr/name.py index 482b30ef60ff..9c730d2d3206 100644 --- a/py-polars/polars/expr/name.py +++ b/py-polars/polars/expr/name.py @@ -301,7 +301,7 @@ def map_fields(self, function: Callable[[str], str]) -> Expr: -------- >>> df = pl.DataFrame({"x": {"a": 1, "b": 2}}) >>> df.select(pl.col("x").name.map_fields(lambda x: x.upper())).schema - OrderedDict({'x': Struct({'A': Int64, 'B': Int64})}) + Schema({'x': Struct({'A': Int64, 'B': Int64})}) """ return self._from_pyexpr(self._pyexpr.name_map_fields(function)) @@ -322,7 +322,7 @@ def prefix_fields(self, prefix: str) -> Expr: -------- >>> df = pl.DataFrame({"x": {"a": 1, "b": 2}}) >>> df.select(pl.col("x").name.prefix_fields("prefix_")).schema - OrderedDict({'x': Struct({'prefix_a': Int64, 'prefix_b': Int64})}) + Schema({'x': Struct({'prefix_a': Int64, 'prefix_b': Int64})}) """ return self._from_pyexpr(self._pyexpr.name_prefix_fields(prefix)) @@ -343,6 +343,6 @@ def suffix_fields(self, suffix: str) -> Expr: -------- >>> df = pl.DataFrame({"x": {"a": 1, "b": 2}}) >>> df.select(pl.col("x").name.suffix_fields("_suffix")).schema - OrderedDict({'x': Struct({'a_suffix': Int64, 'b_suffix': Int64})}) + Schema({'x': Struct({'a_suffix': Int64, 'b_suffix': Int64})}) """ return self._from_pyexpr(self._pyexpr.name_suffix_fields(suffix)) diff --git a/py-polars/polars/functions/as_datatype.py b/py-polars/polars/functions/as_datatype.py index aa5a994ba73c..4febc8d92880 100644 --- a/py-polars/polars/functions/as_datatype.py +++ b/py-polars/polars/functions/as_datatype.py @@ -563,7 +563,7 @@ def struct( Use keyword arguments to easily name each struct field. >>> df.select(pl.struct(p="int", q="bool").alias("my_struct")).schema - OrderedDict({'my_struct': Struct({'p': Int64, 'q': Boolean})}) + Schema({'my_struct': Struct({'p': Int64, 'q': Boolean})}) """ pyexprs = parse_as_list_of_expressions(*exprs, **named_exprs) expr = wrap_expr(plr.as_struct(pyexprs)) diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 9818301c0ee0..c14109b36167 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -2,7 +2,6 @@ import contextlib import os -from collections import OrderedDict from datetime import date, datetime, time, timedelta from functools import lru_cache, reduce from io import BytesIO, StringIO @@ -77,6 +76,7 @@ from polars.dependencies import import_optional, subprocess from polars.lazyframe.group_by import LazyGroupBy from polars.lazyframe.in_process import InProcessQuery +from polars.schema import Schema from polars.selectors import _expand_selectors, by_dtype, expand_selector from polars.slice import LazyPolarsSlice @@ -452,14 +452,9 @@ def dtypes(self) -> list[DataType]: return self._ldf.dtypes() @property - def schema(self) -> OrderedDict[str, DataType]: + def schema(self) -> Schema: """ - Get a mapping of column names to their data type. - - Returns - ------- - OrderedDict - An ordered mapping of column names to their data type. + Get an ordered mapping of column names to their data type. Warnings -------- @@ -476,9 +471,9 @@ def schema(self) -> OrderedDict[str, DataType]: ... } ... ) >>> lf.schema - OrderedDict({'foo': Int64, 'bar': Float64, 'ham': String}) + Schema({'foo': Int64, 'bar': Float64, 'ham': String}) """ - return OrderedDict(self._ldf.schema()) + return Schema(self._ldf.schema()) @property def width(self) -> int: diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py new file mode 100644 index 000000000000..216147872517 --- /dev/null +++ b/py-polars/polars/schema.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from collections import OrderedDict +from typing import TYPE_CHECKING, Iterable, Mapping + +if TYPE_CHECKING: + from polars.datatypes import DataType + + BaseSchema = OrderedDict[str, DataType] +else: + # Python 3.8 does not support generic OrderedDict at runtime + BaseSchema = OrderedDict + + +class Schema(BaseSchema): + """ + Ordered mapping of column names to their data type. + + Parameters + ---------- + schema + The schema definition given by column names and their associated *instantiated* + Polars data type. Accepts a mapping or an iterable of tuples. + + Examples + -------- + Define a schema by passing *instantiated* data types. + + >>> schema = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) + >>> schema + Schema({'foo': Int8, 'bar': String}) + + Access the data type associated with a specific column name. + + >>> schema["foo"] + Int8 + + Access various schema properties using the `names`, `dtypes`, and `len` methods. + + >>> schema.names() + ['foo', 'bar'] + >>> schema.dtypes() + [Int8, String] + >>> schema.len() + 2 + """ + + def __init__(self, schema: Mapping[str, DataType] | Iterable[tuple[str, DataType]]): + super().__init__(schema) + + def names(self) -> list[str]: + """Get the column names of the schema.""" + return list(self.keys()) + + def dtypes(self) -> list[DataType]: + """Get the data types of the schema.""" + return list(self.values()) + + def len(self) -> int: + """Get the number of columns in the schema.""" + return len(self) diff --git a/py-polars/polars/series/struct.py b/py-polars/polars/series/struct.py index 3d4c164b0056..160558a207b8 100644 --- a/py-polars/polars/series/struct.py +++ b/py-polars/polars/series/struct.py @@ -1,14 +1,14 @@ from __future__ import annotations -from collections import OrderedDict from typing import TYPE_CHECKING, Sequence from polars._utils.various import BUILDING_SPHINX_DOCS, sphinx_accessor from polars._utils.wrap import wrap_df +from polars.schema import Schema from polars.series.utils import expr_dispatch if TYPE_CHECKING: - from polars import DataFrame, DataType, Series + from polars import DataFrame, Series from polars.polars import PySeries elif BUILDING_SPHINX_DOCS: property = sphinx_accessor @@ -91,7 +91,7 @@ def rename_fields(self, names: Sequence[str]) -> Series: """ @property - def schema(self) -> OrderedDict[str, DataType]: + def schema(self) -> Schema: """ Get the struct definition as a name/dtype schema dict. @@ -99,11 +99,13 @@ def schema(self) -> OrderedDict[str, DataType]: -------- >>> s = pl.Series([{"a": 1, "b": 2}, {"a": 3, "b": 4}]) >>> s.struct.schema - OrderedDict({'a': Int64, 'b': Int64}) + Schema({'a': Int64, 'b': Int64}) """ if getattr(self, "_s", None) is None: - return OrderedDict() - return OrderedDict(self._s.dtype().to_schema()) + return Schema({}) + + schema = self._s.dtype().to_schema() + return Schema(schema) def unnest(self) -> DataFrame: """ diff --git a/py-polars/tests/unit/constructors/test_dataframe.py b/py-polars/tests/unit/constructors/test_dataframe.py index 9de21d42bbbe..d2f7ac19c7eb 100644 --- a/py-polars/tests/unit/constructors/test_dataframe.py +++ b/py-polars/tests/unit/constructors/test_dataframe.py @@ -191,3 +191,11 @@ def test_list_null_constructor_schema() -> None: expected = pl.List(pl.Null) assert pl.DataFrame({"a": [[]]}).dtypes[0] == expected assert pl.DataFrame(schema={"a": pl.List}).dtypes[0] == expected + + +def test_df_init_schema_object() -> None: + schema = pl.Schema({"a": pl.Int8(), "b": pl.String()}) + df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]}, schema=schema) + + assert df.columns == schema.names() + assert df.dtypes == schema.dtypes() diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py new file mode 100644 index 000000000000..eab0f06a3fa4 --- /dev/null +++ b/py-polars/tests/unit/test_schema.py @@ -0,0 +1,34 @@ +import polars as pl + + +def test_schema() -> None: + s = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) + + assert s["foo"] == pl.Int8() + assert s["bar"] == pl.String() + assert s.len() == 2 + assert s.names() == ["foo", "bar"] + assert s.dtypes() == [pl.Int8(), pl.String()] + + +def test_schema_parse_nonpolars_dtypes() -> None: + # Currently, no parsing is being done. + s = pl.Schema({"foo": pl.List, "bar": int}) # type: ignore[arg-type] + + assert s["foo"] == pl.List + assert s["bar"] == int + assert s.len() == 2 + assert s.names() == ["foo", "bar"] + assert s.dtypes() == [pl.List, int] + + +def test_schema_equality() -> None: + s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()}) + s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) + s3 = pl.Schema({"bar": pl.Float64(), "foo": pl.Int8()}) + assert s1 == s1 + assert s2 == s2 + assert s3 == s3 + assert s1 != s2 + assert s1 != s3 + assert s2 != s3