Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more supported type annotations, fix spark connect issue #542

Merged
merged 8 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
],
"postCreateCommand": "make devenv",
"features": {
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
"ghcr.io/devcontainers/features/docker-in-docker:2.11.0": {},
"ghcr.io/devcontainers/features/java:1": {
"version": "11"
}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_all.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, "3.10"] # TODO: add back 3.11 when dask-sql is compatible
python-version: [3.8, "3.10", "3.11"]

steps:
- uses: actions/checkout@v2
Expand Down
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Release Notes

## 0.9.1

- [543](https://github.com/fugue-project/fugue/issues/543) Support type hinting with standard collections
- [544](https://github.com/fugue-project/fugue/issues/544) Fix Spark connect import issue on worker side

## 0.9.0

- [482](https://github.com/fugue-project/fugue/issues/482) Move Fugue SQL dependencies into extra `[sql]` and functions to become soft dependencies
Expand Down
29 changes: 13 additions & 16 deletions fugue/dataframe/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
PositionalParam,
function_wrapper,
)
from triad.utils.convert import compare_annotations
from triad.utils.iter import EmptyAwareIterable, make_empty_aware

from ..constants import FUGUE_ENTRYPOINT
Expand All @@ -37,6 +38,14 @@
from .pandas_dataframe import PandasDataFrame


def _compare_iter(tp: Any) -> Any:
return lambda x: compare_annotations(
x, Iterable[tp] # type:ignore
) or compare_annotations(
x, Iterator[tp] # type:ignore
)


@function_wrapper(FUGUE_ENTRYPOINT)
class DataFrameFunctionWrapper(FunctionWrapper):
@property
Expand Down Expand Up @@ -228,10 +237,7 @@ def count(self, df: List[List[Any]]) -> int:
return len(df)


@fugue_annotated_param(
Iterable[List[Any]],
matcher=lambda x: x == Iterable[List[Any]] or x == Iterator[List[Any]],
)
@fugue_annotated_param(Iterable[List[Any]], matcher=_compare_iter(List[Any]))
class _IterableListParam(_LocalNoSchemaDataFrameParam):
@no_type_check
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[List[Any]]:
Expand Down Expand Up @@ -288,10 +294,7 @@ def count(self, df: List[Dict[str, Any]]) -> int:
return len(df)


@fugue_annotated_param(
Iterable[Dict[str, Any]],
matcher=lambda x: x == Iterable[Dict[str, Any]] or x == Iterator[Dict[str, Any]],
)
@fugue_annotated_param(Iterable[Dict[str, Any]], matcher=_compare_iter(Dict[str, Any]))
class _IterableDictParam(_LocalNoSchemaDataFrameParam):
@no_type_check
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[Dict[str, Any]]:
Expand Down Expand Up @@ -360,10 +363,7 @@ def format_hint(self) -> Optional[str]:
return "pandas"


@fugue_annotated_param(
Iterable[pd.DataFrame],
matcher=lambda x: x == Iterable[pd.DataFrame] or x == Iterator[pd.DataFrame],
)
@fugue_annotated_param(Iterable[pd.DataFrame], matcher=_compare_iter(pd.DataFrame))
class _IterablePandasParam(LocalDataFrameParam):
@no_type_check
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[pd.DataFrame]:
Expand Down Expand Up @@ -419,10 +419,7 @@ def format_hint(self) -> Optional[str]:
return "pyarrow"


@fugue_annotated_param(
Iterable[pa.Table],
matcher=lambda x: x == Iterable[pa.Table] or x == Iterator[pa.Table],
)
@fugue_annotated_param(Iterable[pa.Table], matcher=_compare_iter(pa.Table))
class _IterableArrowParam(LocalDataFrameParam):
@no_type_check
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[pa.Table]:
Expand Down
2 changes: 1 addition & 1 deletion fugue_spark/_utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
try:
from pyspark.sql.connect.session import SparkSession as SparkConnectSession
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
except ImportError: # pragma: no cover
except Exception: # pragma: no cover
SparkConnectSession = None
SparkConnectDataFrame = None
import pyspark.sql as ps
Expand Down
2 changes: 1 addition & 1 deletion fugue_version/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.0"
__version__ = "0.9.1"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_version() -> str:
keywords="distributed spark dask ray sql dsl domain specific language",
url="http://github.com/fugue-project/fugue",
install_requires=[
"triad>=0.9.6",
"triad>=0.9.7",
"adagio>=0.2.4",
],
extras_require={
Expand Down
16 changes: 15 additions & 1 deletion tests/fugue/dataframe/test_function_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import copy
import sys
from typing import Any, Dict, Iterable, Iterator, List

import pandas as pd
Expand Down Expand Up @@ -29,7 +32,10 @@


def test_function_wrapper():
for f in [f20, f21, f212, f22, f23, f24, f25, f26, f30, f31, f32, f35, f36]:
fs = [f20, f21, f212, f22, f23, f24, f25, f26, f30, f31, f32, f35, f36]
if sys.version_info >= (3, 9):
fs.append(f33)
for f in fs:
df = ArrayDataFrame([[0]], "a:int")
w = DataFrameFunctionWrapper(f, "^[ldsp][ldsp]$", "[ldspq]")
res = w.run([df], dict(a=df), ignore_unknown=False, output_schema="a:int")
Expand Down Expand Up @@ -372,6 +378,14 @@ def f32(
return ArrayDataFrame(arr, "a:int").as_dict_iterable()


def f33(
e: list[dict[str, Any]], a: Iterable[dict[str, Any]]
) -> EmptyAwareIterable[Dict[str, Any]]:
e += list(a)
arr = [[x["a"]] for x in e]
return ArrayDataFrame(arr, "a:int").as_dict_iterable()


def f35(e: pd.DataFrame, a: LocalDataFrame) -> Iterable[pd.DataFrame]:
e = PandasDataFrame(e, "a:int").as_pandas()
a = ArrayDataFrame(a, "a:int").as_pandas()
Expand Down
42 changes: 42 additions & 0 deletions tests/fugue_dask/test_execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from fugue_dask.execution_engine import DaskExecutionEngine
from fugue_test.builtin_suite import BuiltInTests
from fugue_test.execution_suite import ExecutionEngineTests
from fugue.column import col, all_cols
import fugue.column.functions as ff

_CONF = {
"fugue.rpc.server": "fugue.rpc.flask.FlaskRPCServer",
Expand All @@ -50,6 +52,46 @@ def test_get_parallelism(self):
def test__join_outer_pandas_incompatible(self):
return

# TODO: dask-sql 2024.5.0 has a bug, can't pass the HAVING tests
def test_select(self):
try:
import qpd
import dask_sql
except ImportError:
return

a = ArrayDataFrame(
[[1, 2], [None, 2], [None, 1], [3, 4], [None, 4]], "a:double,b:int"
)

# simple
b = fa.select(a, col("b"), (col("b") + 1).alias("c").cast(str))
self.df_eq(
b,
[[2, "3"], [2, "3"], [1, "2"], [4, "5"], [4, "5"]],
"b:int,c:str",
throw=True,
)

# with distinct
b = fa.select(
a, col("b"), (col("b") + 1).alias("c").cast(str), distinct=True
)
self.df_eq(
b,
[[2, "3"], [1, "2"], [4, "5"]],
"b:int,c:str",
throw=True,
)

# wildcard
b = fa.select(a, all_cols(), where=col("a") + col("b") == 3)
self.df_eq(b, [[1, 2]], "a:double,b:int", throw=True)

# aggregation
b = fa.select(a, col("a"), ff.sum(col("b")).cast(float).alias("b"))
self.df_eq(b, [[1, 2], [3, 4], [None, 7]], "a:double,b:double", throw=True)

def test_to_df(self):
e = self.engine
a = e.to_df([[1, 2], [3, 4]], "a:int,b:int")
Expand Down
4 changes: 2 additions & 2 deletions tests/fugue_duckdb/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class DuckDataFrameTests(DataFrameTests.Tests):
def df(self, data: Any = None, schema: Any = None) -> DuckDataFrame:
df = ArrowDataFrame(data, schema)
return DuckDataFrame(duckdb.from_arrow(df.native, self.context.session))
return DuckDataFrame(duckdb.from_arrow(df.native, connection=self.context.session))

def test_as_array_special_values(self):
for func in [
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_duck_as_local(self):
class NativeDuckDataFrameTests(DataFrameTests.NativeTests):
def df(self, data: Any = None, schema: Any = None) -> DuckDataFrame:
df = ArrowDataFrame(data, schema)
return DuckDataFrame(duckdb.from_arrow(df.native, self.context.session)).native
return DuckDataFrame(duckdb.from_arrow(df.native, connection=self.context.session)).native

def to_native_df(self, pdf: pd.DataFrame) -> Any:
return duckdb.from_df(pdf)
Expand Down
2 changes: 1 addition & 1 deletion tests/fugue_duckdb/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_type_conversion(backend_context):

def assert_(tp):
dt = duckdb.from_arrow(
pa.Table.from_pydict(dict(a=pa.nulls(2, tp))), con
pa.Table.from_pydict(dict(a=pa.nulls(2, tp))), connection=con
).types[0]
assert to_pa_type(dt) == tp
dt = to_duck_type(tp)
Expand Down
16 changes: 9 additions & 7 deletions tests/fugue_ibis/mock/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,17 @@ def sample(
f"one and only one of n and frac should be non-negative, {n}, {frac}"
),
)
tn = self.get_temp_table_name()
idf = self.to_df(df)
tn = f"({idf.native.compile()})"
if seed is not None:
_seed = f",{seed}"
else:
_seed = ""
if frac is not None:
sql = f"SELECT * FROM {tn} USING SAMPLE bernoulli({frac*100} PERCENT)"
sql = f"SELECT * FROM {tn} USING SAMPLE {frac*100}% (bernoulli{_seed})"
else:
sql = f"SELECT * FROM {tn} USING SAMPLE reservoir({n} ROWS)"
if seed is not None:
sql += f" REPEATABLE ({seed})"
idf = self.to_df(df)
_res = f"WITH {tn} AS ({idf.native.compile()}) " + sql
sql = f"SELECT * FROM {tn} USING SAMPLE {n} ROWS (reservoir{_seed})"
_res = f"SELECT * FROM ({sql})" # ibis has a bug to inject LIMIT
return self.to_df(self.backend.sql(_res))

def _register_df(
Expand Down
3 changes: 3 additions & 0 deletions tests/fugue_ibis/test_execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def test_properties(self):
assert not self.engine.map_engine.is_distributed
assert not self.engine.sql_engine.is_distributed

assert self.engine.sql_engine.get_temp_table_name(
) != self.engine.sql_engine.get_temp_table_name()

def test_select(self):
# it can't work properly with DuckDB (hugeint is not recognized)
pass
Expand Down
Loading