Skip to content

Commit

Permalink
Add 'format' arg to get_result_df (#885)
Browse files Browse the repository at this point in the history
* Add 'format' arg to get_result_df

Signed-off-by: Jun Ki Min <[email protected]>

* Add unittest for arg alias of get_result_df

Signed-off-by: Jun Ki Min <[email protected]>

* Update explicit functions to use kwargs and update unit-tests accordingly

Signed-off-by: Jun Ki Min <[email protected]>

Signed-off-by: Jun Ki Min <[email protected]>
  • Loading branch information
loomlike authored Nov 29, 2022
1 parent 15550ca commit 654d56e
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
15 changes: 13 additions & 2 deletions feathr_project/feathr/utils/job_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_result_pandas_df(
Returns:
pandas DataFrame
"""
return get_result_df(client, data_format, res_url, local_cache_path)
return get_result_df(client=client, data_format=data_format, res_url=res_url, local_cache_path=local_cache_path)


def get_result_spark_df(
Expand All @@ -56,12 +56,19 @@ def get_result_spark_df(
Returns:
Spark DataFrame
"""
return get_result_df(client, data_format, res_url, local_cache_path, spark=spark)
return get_result_df(
client=client,
data_format=data_format,
res_url=res_url,
local_cache_path=local_cache_path,
spark=spark,
)


def get_result_df(
client: FeathrClient,
data_format: str = None,
format: str = None,
res_url: str = None,
local_cache_path: str = None,
spark: SparkSession = None,
Expand All @@ -72,6 +79,7 @@ def get_result_df(
client: Feathr client
data_format: Format to read the downloaded files. Currently support `parquet`, `delta`, `avro`, and `csv`.
Default to use client's job tags if exists.
format: An alias for `data_format` (for backward compatibility).
res_url: Result URL to download files from. Note that this will not block the job so you need to make sure
the job is finished and the result URL contains actual data. Default to use client's job tags if exists.
local_cache_path (optional): Specify the absolute download directory. if the user does not provide this,
Expand All @@ -82,6 +90,9 @@ def get_result_df(
Returns:
Either Spark or pandas DataFrame.
"""
if format is not None:
data_format = format

if data_format is None:
# May use data format from the job tags
if client.get_job_tags() and client.get_job_tags().get(OUTPUT_FORMAT):
Expand Down
53 changes: 51 additions & 2 deletions feathr_project/test/unit/utils/test_job_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def test__get_result_pandas_df(mocker: MockerFixture):
res_url = "some_res_url"
local_cache_path = "some_local_cache_path"
get_result_pandas_df(client, data_format, res_url, local_cache_path)
mocked_get_result_df.assert_called_once_with(client, data_format, res_url, local_cache_path)
mocked_get_result_df.assert_called_once_with(
client=client,
data_format=data_format,
res_url=res_url,
local_cache_path=local_cache_path,
)


def test__get_result_spark_df(mocker: MockerFixture):
Expand All @@ -38,7 +43,13 @@ def test__get_result_spark_df(mocker: MockerFixture):
res_url = "some_res_url"
local_cache_path = "some_local_cache_path"
get_result_spark_df(spark, client, data_format, res_url, local_cache_path)
mocked_get_result_df.assert_called_once_with(client, data_format, res_url, local_cache_path, spark=spark)
mocked_get_result_df.assert_called_once_with(
client=client,
data_format=data_format,
res_url=res_url,
local_cache_path=local_cache_path,
spark=spark,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -226,3 +237,41 @@ def test__get_result_df__with_spark_session(
)
assert isinstance(df, DataFrame)
assert df.count() == expected_count


@pytest.mark.parametrize(
"format,output_filename,expected_count", [
("csv", "output.csv", 5),
]
)
def test__get_result_df__arg_alias(
workspace_dir: str,
format: str,
output_filename: str,
expected_count: int,
):
"""Test get_result_df returns pandas DataFrame with the argument alias `format` instead of using `data_format`"""
for spark_runtime in ["local", "databricks", "azure_synapse"]:
# Note: make sure the output file exists in the test_user_workspace
res_url = str(Path(workspace_dir, "mock_results", output_filename))
local_cache_path = res_url

# Mock client
client = MagicMock()
client.spark_runtime = spark_runtime

# Mock feathr_spark_launcher.download_result
if client.spark_runtime == "databricks":
res_url = f"dbfs:/{res_url}"
if client.spark_runtime == "azure_synapse" and format == "delta":
# TODO currently pass the delta table test on Synapse result due to the delta table package bug.
continue

df = get_result_df(
client=client,
format=format,
res_url=res_url,
local_cache_path=local_cache_path,
)
assert isinstance(df, pd.DataFrame)
assert len(df) == expected_count

0 comments on commit 654d56e

Please sign in to comment.