Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 'format' arg to get_result_df #885

Merged
merged 3 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions feathr_project/feathr/utils/job_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_result_pandas_df(
Returns:
pandas DataFrame
"""
return get_result_df(client, data_format, res_url, local_cache_path)
return get_result_df(client=client, data_format=data_format, res_url=res_url, local_cache_path=local_cache_path)


def get_result_spark_df(
Expand All @@ -56,12 +56,19 @@ def get_result_spark_df(
Returns:
Spark DataFrame
"""
return get_result_df(client, data_format, res_url, local_cache_path, spark=spark)
return get_result_df(
client=client,
data_format=data_format,
res_url=res_url,
local_cache_path=local_cache_path,
spark=spark,
)


def get_result_df(
client: FeathrClient,
data_format: str = None,
format: str = None,
res_url: str = None,
local_cache_path: str = None,
spark: SparkSession = None,
Expand All @@ -72,6 +79,7 @@ def get_result_df(
client: Feathr client
data_format: Format to read the downloaded files. Currently support `parquet`, `delta`, `avro`, and `csv`.
Default to use client's job tags if exists.
format: An alias for `data_format` (for backward compatibility).
res_url: Result URL to download files from. Note that this will not block the job so you need to make sure
the job is finished and the result URL contains actual data. Default to use client's job tags if exists.
local_cache_path (optional): Specify the absolute download directory. if the user does not provide this,
Expand All @@ -82,6 +90,9 @@ def get_result_df(
Returns:
Either Spark or pandas DataFrame.
"""
if format is not None:
data_format = format
xiaoyongzhu marked this conversation as resolved.
Show resolved Hide resolved

if data_format is None:
# May use data format from the job tags
if client.get_job_tags() and client.get_job_tags().get(OUTPUT_FORMAT):
Expand Down
53 changes: 51 additions & 2 deletions feathr_project/test/unit/utils/test_job_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def test__get_result_pandas_df(mocker: MockerFixture):
res_url = "some_res_url"
local_cache_path = "some_local_cache_path"
get_result_pandas_df(client, data_format, res_url, local_cache_path)
mocked_get_result_df.assert_called_once_with(client, data_format, res_url, local_cache_path)
mocked_get_result_df.assert_called_once_with(
client=client,
data_format=data_format,
res_url=res_url,
local_cache_path=local_cache_path,
)


def test__get_result_spark_df(mocker: MockerFixture):
Expand All @@ -38,7 +43,13 @@ def test__get_result_spark_df(mocker: MockerFixture):
res_url = "some_res_url"
local_cache_path = "some_local_cache_path"
get_result_spark_df(spark, client, data_format, res_url, local_cache_path)
mocked_get_result_df.assert_called_once_with(client, data_format, res_url, local_cache_path, spark=spark)
mocked_get_result_df.assert_called_once_with(
client=client,
data_format=data_format,
res_url=res_url,
local_cache_path=local_cache_path,
spark=spark,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -226,3 +237,41 @@ def test__get_result_df__with_spark_session(
)
assert isinstance(df, DataFrame)
assert df.count() == expected_count


@pytest.mark.parametrize(
"format,output_filename,expected_count", [
("csv", "output.csv", 5),
]
)
def test__get_result_df__arg_alias(
workspace_dir: str,
format: str,
output_filename: str,
expected_count: int,
):
"""Test get_result_df returns pandas DataFrame with the argument alias `format` instead of using `data_format`"""
for spark_runtime in ["local", "databricks", "azure_synapse"]:
# Note: make sure the output file exists in the test_user_workspace
res_url = str(Path(workspace_dir, "mock_results", output_filename))
local_cache_path = res_url

# Mock client
client = MagicMock()
client.spark_runtime = spark_runtime

# Mock feathr_spark_launcher.download_result
if client.spark_runtime == "databricks":
res_url = f"dbfs:/{res_url}"
if client.spark_runtime == "azure_synapse" and format == "delta":
# TODO currently pass the delta table test on Synapse result due to the delta table package bug.
continue

df = get_result_df(
client=client,
format=format,
res_url=res_url,
local_cache_path=local_cache_path,
)
assert isinstance(df, pd.DataFrame)
assert len(df) == expected_count