Skip to content

Commit

Permalink
[SPARK-35467][SPARK-35468][SPARK-35477][PYTHON] Fix disallow_untyped_…
Browse files Browse the repository at this point in the history
…defs mypy checks

### What changes were proposed in this pull request?

Adds more type annotations in the files:

- `python/pyspark/pandas/spark/accessors.py`
- `python/pyspark/pandas/typedef/typehints.py`
- `python/pyspark/pandas/utils.py`

and fixes the mypy check failures.

### Why are the changes needed?

We should enable more `disallow_untyped_defs` mypy checks.

### Does this PR introduce _any_ user-facing change?

Yes.
This PR adds more type annotations in pandas APIs on Spark module, which can impact interaction with development tools for users.

### How was this patch tested?

The mypy check with a new configuration and existing tests should pass.

Closes #32627 from ueshin/issues/SPARK-35467_35468_35477/disallow_untyped_defs.

Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
ueshin authored and HyukjinKwon committed May 24, 2021
1 parent 9e1b204 commit 1b75c24
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 73 deletions.
9 changes: 0 additions & 9 deletions python/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,6 @@ ignore_missing_imports = True
[mypy-pyspark.pandas.data_type_ops.*]
disallow_untyped_defs = False

[mypy-pyspark.pandas.spark.accessors]
disallow_untyped_defs = False

[mypy-pyspark.pandas.typedef.typehints]
disallow_untyped_defs = False

[mypy-pyspark.pandas.accessors]
disallow_untyped_defs = False

Expand All @@ -189,8 +183,5 @@ disallow_untyped_defs = False
[mypy-pyspark.pandas.series]
disallow_untyped_defs = False

[mypy-pyspark.pandas.utils]
disallow_untyped_defs = False

[mypy-pyspark.pandas.window]
disallow_untyped_defs = False
2 changes: 1 addition & 1 deletion python/pyspark/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def align_diff_index_ops(func, this_index_ops: "IndexOpsMixin", *args) -> "Index
name=this_index_ops.name,
)
elif isinstance(this_index_ops, Series):
this = this_index_ops.reset_index()
this = cast(DataFrame, this_index_ops.reset_index())
that = [
cast(Series, col.to_series() if isinstance(col, Index) else col)
.rename(i)
Expand Down
14 changes: 12 additions & 2 deletions python/pyspark/pandas/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@
"""
Wrappers around spark that correspond to common pandas functions.
"""
from typing import Any, Optional, Union, List, Tuple, Type, Sized, cast
from typing import ( # noqa: F401 (SPARK-34943)
Any,
Dict,
List,
Optional,
Sized,
Tuple,
Type,
Union,
cast,
)
from collections import OrderedDict
from collections.abc import Iterable
from distutils.version import LooseVersion
Expand Down Expand Up @@ -307,7 +317,7 @@ def read_csv(

if isinstance(names, str):
sdf = reader.schema(names).csv(path)
column_labels = OrderedDict((col, col) for col in sdf.columns)
column_labels = OrderedDict((col, col) for col in sdf.columns) # type: Dict[Any, str]
else:
sdf = reader.csv(path)
if is_list_like(names):
Expand Down
37 changes: 17 additions & 20 deletions python/pyspark/pandas/spark/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,26 @@
but Spark has it.
"""
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Optional, Union, List, cast
from typing import TYPE_CHECKING, Callable, List, Optional, Union, cast

from pyspark import StorageLevel
from pyspark.sql import Column, DataFrame as SparkDataFrame
from pyspark.sql.types import DataType, StructType

if TYPE_CHECKING:
from pyspark.sql._typing import OptionalPrimitiveType # noqa: F401 (SPARK-34943)
from pyspark._typing import PrimitiveType # noqa: F401 (SPARK-34943)

import pyspark.pandas as ps # noqa: F401 (SPARK-34943)
from pyspark.pandas.base import IndexOpsMixin # noqa: F401 (SPARK-34943)
from pyspark.pandas.frame import CachedDataFrame # noqa: F401 (SPARK-34943)


class SparkIndexOpsMethods(object, metaclass=ABCMeta):
class SparkIndexOpsMethods(metaclass=ABCMeta):
"""Spark related features. Usually, the features here are missing in pandas
but Spark has it."""

def __init__(self, data: Union["IndexOpsMixin"]):
def __init__(self, data: "IndexOpsMixin"):
self._data = data

@property
Expand All @@ -59,7 +62,7 @@ def column(self) -> Column:
"""
return self._data._internal.spark_column_for(self._data._column_label)

def transform(self, func) -> Union["ps.Series", "ps.Index"]:
def transform(self, func: Callable[[Column], Column]) -> Union["ps.Series", "ps.Index"]:
"""
Applies a function that takes and returns a Spark column. It allows to natively
apply a Spark function and column APIs with the Spark column internally used
Expand Down Expand Up @@ -130,12 +133,7 @@ def analyzed(self) -> Union["ps.Series", "ps.Index"]:


class SparkSeriesMethods(SparkIndexOpsMethods):
def transform(self, func) -> "ps.Series":
return cast("ps.Series", super().transform(func))

transform.__doc__ = SparkIndexOpsMethods.transform.__doc__

def apply(self, func) -> "ps.Series":
def apply(self, func: Callable[[Column], Column]) -> "ps.Series":
"""
Applies a function that takes and returns a Spark column. It allows to natively
apply a Spark function and column APIs with the Spark column internally used
Expand Down Expand Up @@ -256,11 +254,6 @@ def analyzed(self) -> "ps.Series":


class SparkIndexMethods(SparkIndexOpsMethods):
def transform(self, func) -> "ps.Index":
return cast("ps.Index", super().transform(func))

transform.__doc__ = SparkIndexOpsMethods.transform.__doc__

@property
def analyzed(self) -> "ps.Index":
"""
Expand Down Expand Up @@ -641,7 +634,7 @@ def persist(
)
return CachedDataFrame(self._psdf._internal, storage_level=storage_level)

def hint(self, name: str, *parameters) -> "ps.DataFrame":
def hint(self, name: str, *parameters: "PrimitiveType") -> "ps.DataFrame":
"""
Specifies some hint on the current DataFrame.
Expand Down Expand Up @@ -685,7 +678,7 @@ def to_table(
mode: str = "overwrite",
partition_cols: Optional[Union[str, List[str]]] = None,
index_col: Optional[Union[str, List[str]]] = None,
**options
**options: "OptionalPrimitiveType",
) -> None:
"""
Write the DataFrame into a Spark table. :meth:`DataFrame.spark.to_table`
Expand Down Expand Up @@ -760,7 +753,7 @@ def to_spark_io(
mode: str = "overwrite",
partition_cols: Optional[Union[str, List[str]]] = None,
index_col: Optional[Union[str, List[str]]] = None,
**options
**options: "OptionalPrimitiveType",
) -> None:
"""Write the DataFrame out to a Spark data source. :meth:`DataFrame.spark.to_spark_io`
is an alias of :meth:`DataFrame.to_spark_io`.
Expand Down Expand Up @@ -881,7 +874,11 @@ def explain(self, extended: Optional[bool] = None, mode: Optional[str] = None) -
"""
self._psdf._internal.to_internal_spark_frame.explain(extended, mode)

def apply(self, func, index_col: Optional[Union[str, List[str]]] = None) -> "ps.DataFrame":
def apply(
self,
func: Callable[[SparkDataFrame], SparkDataFrame],
index_col: Optional[Union[str, List[str]]] = None,
) -> "ps.DataFrame":
"""
Applies a function that takes and returns a Spark DataFrame. It allows natively
apply a Spark function and column APIs with the Spark column internally used
Expand Down Expand Up @@ -1227,7 +1224,7 @@ def unpersist(self) -> None:
self._psdf._cached.unpersist()


def _test():
def _test() -> None:
import os
import doctest
import shutil
Expand Down
18 changes: 9 additions & 9 deletions python/pyspark/pandas/typedef/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import datetime
import decimal
from inspect import getfullargspec, isclass
from typing import Generic, List, Optional, Tuple, TypeVar, Union # noqa: F401
from typing import Any, Callable, Generic, List, Optional, Tuple, TypeVar, Union # noqa: F401

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(self, dtype: Dtype, spark_type: types.DataType):
self.dtype = dtype
self.spark_type = spark_type

def __repr__(self):
def __repr__(self) -> str:
return "SeriesType[{}]".format(self.spark_type)


Expand All @@ -96,7 +96,7 @@ def __init__(
]
) # type: types.StructType

def __repr__(self):
def __repr__(self) -> str:
return "DataFrameType[{}]".format(self.spark_type)


Expand All @@ -106,16 +106,16 @@ def __init__(self, dtype: Dtype, spark_type: types.DataType):
self.dtype = dtype
self.spark_type = spark_type

def __repr__(self):
def __repr__(self) -> str:
return "ScalarType[{}]".format(self.spark_type)


# The type is left unspecified or we do not know about this type.
class UnknownType(object):
def __init__(self, tpe):
def __init__(self, tpe: Any):
self.tpe = tpe

def __repr__(self):
def __repr__(self) -> str:
return "UnknownType[{}]".format(self.tpe)


Expand Down Expand Up @@ -262,7 +262,7 @@ def spark_type_to_pandas_dtype(
return np.dtype(to_arrow_type(spark_type).to_pandas_dtype())


def pandas_on_spark_type(tpe) -> Tuple[Dtype, types.DataType]:
def pandas_on_spark_type(tpe: Union[str, type, Dtype]) -> Tuple[Dtype, types.DataType]:
"""
Convert input into a pandas only dtype object or a numpy dtype object,
and its corresponding Spark DataType.
Expand Down Expand Up @@ -322,7 +322,7 @@ def infer_pd_series_spark_type(pser: pd.Series, dtype: Dtype) -> types.DataType:
return as_spark_type(dtype)


def infer_return_type(f) -> Union[SeriesType, DataFrameType, ScalarType, UnknownType]:
def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarType, UnknownType]:
"""
Infer the return type from the return type annotation of the given function.
Expand Down Expand Up @@ -517,7 +517,7 @@ def infer_return_type(f) -> Union[SeriesType, DataFrameType, ScalarType, Unknown
return ScalarType(*types)


def _test():
def _test() -> None:
import doctest
import sys
import pyspark.pandas.typedef.typehints
Expand Down
Loading

0 comments on commit 1b75c24

Please sign in to comment.