diff --git a/python/mypy.ini b/python/mypy.ini index 395502d323752..2bc41d9ebcc0d 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -159,12 +159,6 @@ ignore_missing_imports = True [mypy-pyspark.pandas.data_type_ops.*] disallow_untyped_defs = False -[mypy-pyspark.pandas.spark.accessors] -disallow_untyped_defs = False - -[mypy-pyspark.pandas.typedef.typehints] -disallow_untyped_defs = False - [mypy-pyspark.pandas.accessors] disallow_untyped_defs = False @@ -189,8 +183,5 @@ disallow_untyped_defs = False [mypy-pyspark.pandas.series] disallow_untyped_defs = False -[mypy-pyspark.pandas.utils] -disallow_untyped_defs = False - [mypy-pyspark.pandas.window] disallow_untyped_defs = False diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py index c9e0a7e77378f..cb8a2c9a7390c 100644 --- a/python/pyspark/pandas/base.py +++ b/python/pyspark/pandas/base.py @@ -136,7 +136,7 @@ def align_diff_index_ops(func, this_index_ops: "IndexOpsMixin", *args) -> "Index name=this_index_ops.name, ) elif isinstance(this_index_ops, Series): - this = this_index_ops.reset_index() + this = cast(DataFrame, this_index_ops.reset_index()) that = [ cast(Series, col.to_series() if isinstance(col, Index) else col) .rename(i) diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py index 8f73976a8b830..961d88856672a 100644 --- a/python/pyspark/pandas/namespace.py +++ b/python/pyspark/pandas/namespace.py @@ -18,7 +18,17 @@ """ Wrappers around spark that correspond to common pandas functions. """ -from typing import Any, Optional, Union, List, Tuple, Type, Sized, cast +from typing import ( # noqa: F401 (SPARK-34943) + Any, + Dict, + List, + Optional, + Sized, + Tuple, + Type, + Union, + cast, +) from collections import OrderedDict from collections.abc import Iterable from distutils.version import LooseVersion @@ -307,7 +317,7 @@ def read_csv( if isinstance(names, str): sdf = reader.schema(names).csv(path) - column_labels = OrderedDict((col, col) for col in sdf.columns) + column_labels = OrderedDict((col, col) for col in sdf.columns) # type: Dict[Any, str] else: sdf = reader.csv(path) if is_list_like(names): diff --git a/python/pyspark/pandas/spark/accessors.py b/python/pyspark/pandas/spark/accessors.py index c0a6fb2e3276a..d94061acf28d4 100644 --- a/python/pyspark/pandas/spark/accessors.py +++ b/python/pyspark/pandas/spark/accessors.py @@ -20,23 +20,26 @@ but Spark has it. """ from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, Optional, Union, List, cast +from typing import TYPE_CHECKING, Callable, List, Optional, Union, cast from pyspark import StorageLevel from pyspark.sql import Column, DataFrame as SparkDataFrame from pyspark.sql.types import DataType, StructType if TYPE_CHECKING: + from pyspark.sql._typing import OptionalPrimitiveType # noqa: F401 (SPARK-34943) + from pyspark._typing import PrimitiveType # noqa: F401 (SPARK-34943) + import pyspark.pandas as ps # noqa: F401 (SPARK-34943) from pyspark.pandas.base import IndexOpsMixin # noqa: F401 (SPARK-34943) from pyspark.pandas.frame import CachedDataFrame # noqa: F401 (SPARK-34943) -class SparkIndexOpsMethods(object, metaclass=ABCMeta): +class SparkIndexOpsMethods(metaclass=ABCMeta): """Spark related features. Usually, the features here are missing in pandas but Spark has it.""" - def __init__(self, data: Union["IndexOpsMixin"]): + def __init__(self, data: "IndexOpsMixin"): self._data = data @property @@ -59,7 +62,7 @@ def column(self) -> Column: """ return self._data._internal.spark_column_for(self._data._column_label) - def transform(self, func) -> Union["ps.Series", "ps.Index"]: + def transform(self, func: Callable[[Column], Column]) -> Union["ps.Series", "ps.Index"]: """ Applies a function that takes and returns a Spark column. It allows to natively apply a Spark function and column APIs with the Spark column internally used @@ -130,12 +133,7 @@ def analyzed(self) -> Union["ps.Series", "ps.Index"]: class SparkSeriesMethods(SparkIndexOpsMethods): - def transform(self, func) -> "ps.Series": - return cast("ps.Series", super().transform(func)) - - transform.__doc__ = SparkIndexOpsMethods.transform.__doc__ - - def apply(self, func) -> "ps.Series": + def apply(self, func: Callable[[Column], Column]) -> "ps.Series": """ Applies a function that takes and returns a Spark column. It allows to natively apply a Spark function and column APIs with the Spark column internally used @@ -256,11 +254,6 @@ def analyzed(self) -> "ps.Series": class SparkIndexMethods(SparkIndexOpsMethods): - def transform(self, func) -> "ps.Index": - return cast("ps.Index", super().transform(func)) - - transform.__doc__ = SparkIndexOpsMethods.transform.__doc__ - @property def analyzed(self) -> "ps.Index": """ @@ -641,7 +634,7 @@ def persist( ) return CachedDataFrame(self._psdf._internal, storage_level=storage_level) - def hint(self, name: str, *parameters) -> "ps.DataFrame": + def hint(self, name: str, *parameters: "PrimitiveType") -> "ps.DataFrame": """ Specifies some hint on the current DataFrame. @@ -685,7 +678,7 @@ def to_table( mode: str = "overwrite", partition_cols: Optional[Union[str, List[str]]] = None, index_col: Optional[Union[str, List[str]]] = None, - **options + **options: "OptionalPrimitiveType", ) -> None: """ Write the DataFrame into a Spark table. :meth:`DataFrame.spark.to_table` @@ -760,7 +753,7 @@ def to_spark_io( mode: str = "overwrite", partition_cols: Optional[Union[str, List[str]]] = None, index_col: Optional[Union[str, List[str]]] = None, - **options + **options: "OptionalPrimitiveType", ) -> None: """Write the DataFrame out to a Spark data source. :meth:`DataFrame.spark.to_spark_io` is an alias of :meth:`DataFrame.to_spark_io`. @@ -881,7 +874,11 @@ def explain(self, extended: Optional[bool] = None, mode: Optional[str] = None) - """ self._psdf._internal.to_internal_spark_frame.explain(extended, mode) - def apply(self, func, index_col: Optional[Union[str, List[str]]] = None) -> "ps.DataFrame": + def apply( + self, + func: Callable[[SparkDataFrame], SparkDataFrame], + index_col: Optional[Union[str, List[str]]] = None, + ) -> "ps.DataFrame": """ Applies a function that takes and returns a Spark DataFrame. It allows natively apply a Spark function and column APIs with the Spark column internally used @@ -1227,7 +1224,7 @@ def unpersist(self) -> None: self._psdf._cached.unpersist() -def _test(): +def _test() -> None: import os import doctest import shutil diff --git a/python/pyspark/pandas/typedef/typehints.py b/python/pyspark/pandas/typedef/typehints.py index 76d265aa86cf4..819dca554923b 100644 --- a/python/pyspark/pandas/typedef/typehints.py +++ b/python/pyspark/pandas/typedef/typehints.py @@ -21,7 +21,7 @@ import datetime import decimal from inspect import getfullargspec, isclass -from typing import Generic, List, Optional, Tuple, TypeVar, Union # noqa: F401 +from typing import Any, Callable, Generic, List, Optional, Tuple, TypeVar, Union # noqa: F401 import numpy as np import pandas as pd @@ -78,7 +78,7 @@ def __init__(self, dtype: Dtype, spark_type: types.DataType): self.dtype = dtype self.spark_type = spark_type - def __repr__(self): + def __repr__(self) -> str: return "SeriesType[{}]".format(self.spark_type) @@ -96,7 +96,7 @@ def __init__( ] ) # type: types.StructType - def __repr__(self): + def __repr__(self) -> str: return "DataFrameType[{}]".format(self.spark_type) @@ -106,16 +106,16 @@ def __init__(self, dtype: Dtype, spark_type: types.DataType): self.dtype = dtype self.spark_type = spark_type - def __repr__(self): + def __repr__(self) -> str: return "ScalarType[{}]".format(self.spark_type) # The type is left unspecified or we do not know about this type. class UnknownType(object): - def __init__(self, tpe): + def __init__(self, tpe: Any): self.tpe = tpe - def __repr__(self): + def __repr__(self) -> str: return "UnknownType[{}]".format(self.tpe) @@ -262,7 +262,7 @@ def spark_type_to_pandas_dtype( return np.dtype(to_arrow_type(spark_type).to_pandas_dtype()) -def pandas_on_spark_type(tpe) -> Tuple[Dtype, types.DataType]: +def pandas_on_spark_type(tpe: Union[str, type, Dtype]) -> Tuple[Dtype, types.DataType]: """ Convert input into a pandas only dtype object or a numpy dtype object, and its corresponding Spark DataType. @@ -322,7 +322,7 @@ def infer_pd_series_spark_type(pser: pd.Series, dtype: Dtype) -> types.DataType: return as_spark_type(dtype) -def infer_return_type(f) -> Union[SeriesType, DataFrameType, ScalarType, UnknownType]: +def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarType, UnknownType]: """ Infer the return type from the return type annotation of the given function. @@ -517,7 +517,7 @@ def infer_return_type(f) -> Union[SeriesType, DataFrameType, ScalarType, Unknown return ScalarType(*types) -def _test(): +def _test() -> None: import doctest import sys import pyspark.pandas.typedef.typehints diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index c8294295d1e7d..efea2b035f4f7 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -22,7 +22,20 @@ from collections import OrderedDict from contextlib import contextmanager import os -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING, overload +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Tuple, + Union, + TYPE_CHECKING, + cast, + no_type_check, + overload, +) import warnings from pyspark import sql as spark @@ -44,6 +57,7 @@ from pyspark.pandas.base import IndexOpsMixin # noqa: F401 (SPARK-34943) from pyspark.pandas.frame import DataFrame # noqa: F401 (SPARK-34943) from pyspark.pandas.internal import InternalFrame # noqa: F401 (SPARK-34943) + from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943) ERROR_MESSAGE_CANNOT_COMBINE = ( @@ -90,7 +104,12 @@ def same_anchor( ) -def combine_frames(this, *args, how="full", preserve_order_column=False): +def combine_frames( + this: "DataFrame", + *args: Union["DataFrame", "Series"], + how: str = "full", + preserve_order_column: bool = False +) -> "DataFrame": """ This method combines `this` DataFrame with a different `that` DataFrame or Series from a different DataFrame. @@ -129,17 +148,17 @@ def combine_frames(this, *args, how="full", preserve_order_column=False): if get_option("compute.ops_on_diff_frames"): - def resolve(internal, side): + def resolve(internal: InternalFrame, side: str) -> InternalFrame: rename = lambda col: "__{}_{}".format(side, col) internal = internal.resolved_copy sdf = internal.spark_frame sdf = internal.spark_frame.select( - [ + *[ scol_for(sdf, col).alias(rename(col)) for col in sdf.columns if col not in HIDDEN_COLUMNS - ] - + list(HIDDEN_COLUMNS) + ], + *HIDDEN_COLUMNS ) return internal.copy( spark_frame=sdf, @@ -216,16 +235,16 @@ def resolve(internal, side): order_column = [] joined_df = joined_df.select( - merged_index_scols - + [ + *merged_index_scols, + *( scol_for(this_sdf, this_internal.spark_column_name_for(label)) for label in this_internal.column_labels - ] - + [ + ), + *( scol_for(that_sdf, that_internal.spark_column_name_for(label)) for label in that_internal.column_labels - ] - + order_column + ), + *order_column ) index_spark_columns = [scol_for(joined_df, col) for col in index_column_names] @@ -246,7 +265,7 @@ def resolve(internal, side): level = max(this_internal.column_labels_level, that_internal.column_labels_level) - def fill_label(label): + def fill_label(label: Optional[Tuple]) -> List: if label is None: return ([""] * (level - 1)) + [None] else: @@ -256,7 +275,7 @@ def fill_label(label): tuple(["this"] + fill_label(label)) for label in this_internal.column_labels ] + [tuple(["that"] + fill_label(label)) for label in that_internal.column_labels] column_label_names = ( - [None] * (1 + level - this_internal.column_labels_level) + cast(List[Optional[Tuple]], [None]) * (1 + level - this_internal.column_labels_level) ) + this_internal.column_label_names return DataFrame( InternalFrame( @@ -275,7 +294,7 @@ def fill_label(label): def align_diff_frames( - resolve_func, + resolve_func: Callable[["DataFrame", List[Tuple], List[Tuple]], Tuple["Series", Tuple]], this: "DataFrame", that: "DataFrame", fillna: bool = True, @@ -385,11 +404,11 @@ def align_diff_frames( # Should extract columns to apply and do it in a batch in case # it adds new columns for example. if len(this_columns_to_apply) > 0 or len(that_columns_to_apply) > 0: - psser_set, column_labels_applied = zip( + psser_set, column_labels_set = zip( *resolve_func(combined, this_columns_to_apply, that_columns_to_apply) ) columns_applied = list(psser_set) - column_labels_applied = list(column_labels_applied) + column_labels_applied = list(column_labels_set) else: columns_applied = [] column_labels_applied = [] @@ -420,12 +439,12 @@ def align_diff_frames( return psdf -def is_testing(): +def is_testing() -> bool: """ Indicates whether Spark is currently running tests. """ return "SPARK_TESTING" in os.environ -def default_session(conf=None): +def default_session(conf: Optional[Dict[str, Any]] = None) -> spark.SparkSession: if conf is None: conf = dict() @@ -443,7 +462,9 @@ def default_session(conf=None): @contextmanager -def sql_conf(pairs, *, spark=None): +def sql_conf( + pairs: Dict[str, Any], *, spark: Optional[spark.SparkSession] = None +) -> Iterator[None]: """ A convenient context manager to set `value` to the Spark SQL configuration `key` and then restores it back when it exits. @@ -473,7 +494,7 @@ def validate_arguments_and_invoke_function( pandas_on_spark_func: Callable, pandas_func: Callable, input_args: Dict, -): +) -> Any: """ Invokes a pandas function. @@ -529,7 +550,8 @@ def validate_arguments_and_invoke_function( return pandas_func(**args) -def lazy_property(fn): +@no_type_check +def lazy_property(fn: Callable[[Any], Any]) -> property: """ Decorator that makes a property lazy-evaluated. @@ -677,16 +699,19 @@ def is_name_like_value( return True -def validate_axis(axis=0, none_axis=0): +def validate_axis(axis: Optional[Union[int, str]] = 0, none_axis: int = 0) -> int: """ Check the given axis is valid. """ # convert to numeric axis - axis = {None: none_axis, "index": 0, "columns": 1}.get(axis, axis) - if axis not in (none_axis, 0, 1): + axis = cast( + Dict[Optional[Union[int, str]], int], {None: none_axis, "index": 0, "columns": 1} + ).get(axis, axis) + if axis in (none_axis, 0, 1): + return cast(int, axis) + else: raise ValueError("No axis named {0}".format(axis)) - return axis -def validate_bool_kwarg(value, arg_name): +def validate_bool_kwarg(value: Any, arg_name: str) -> Optional[bool]: """ Ensures that argument passed in arg_name is of type bool. """ if not (isinstance(value, bool) or value is None): raise TypeError( @@ -836,27 +861,43 @@ def verify_temp_column_name( return column_name_or_label -def compare_null_first(left, right, comp): +def compare_null_first( + left: spark.Column, + right: spark.Column, + comp: Callable[[spark.Column, spark.Column], spark.Column], +) -> spark.Column: return (left.isNotNull() & right.isNotNull() & comp(left, right)) | ( left.isNull() & right.isNotNull() ) -def compare_null_last(left, right, comp): +def compare_null_last( + left: spark.Column, + right: spark.Column, + comp: Callable[[spark.Column, spark.Column], spark.Column], +) -> spark.Column: return (left.isNotNull() & right.isNotNull() & comp(left, right)) | ( left.isNotNull() & right.isNull() ) -def compare_disallow_null(left, right, comp): +def compare_disallow_null( + left: spark.Column, + right: spark.Column, + comp: Callable[[spark.Column, spark.Column], spark.Column], +) -> spark.Column: return left.isNotNull() & right.isNotNull() & comp(left, right) -def compare_allow_null(left, right, comp): +def compare_allow_null( + left: spark.Column, + right: spark.Column, + comp: Callable[[spark.Column, spark.Column], spark.Column], +) -> spark.Column: return left.isNull() | right.isNull() | comp(left, right) -def _test(): +def _test() -> None: import os import doctest import sys