From 654d56e4d68df567b6656117dd05cfcd18f55c21 Mon Sep 17 00:00:00 2001 From: Jun Ki Min <42475935+loomlike@users.noreply.github.com> Date: Mon, 28 Nov 2022 23:29:54 -0800 Subject: [PATCH] Add 'format' arg to get_result_df (#885) * Add 'format' arg to get_result_df Signed-off-by: Jun Ki Min <42475935+loomlike@users.noreply.github.com> * Add unittest for arg alias of get_result_df Signed-off-by: Jun Ki Min <42475935+loomlike@users.noreply.github.com> * Update explicit functions to use kwargs and update unit-tests accordingly Signed-off-by: Jun Ki Min <42475935+loomlike@users.noreply.github.com> Signed-off-by: Jun Ki Min <42475935+loomlike@users.noreply.github.com> --- feathr_project/feathr/utils/job_utils.py | 15 +++++- .../test/unit/utils/test_job_utils.py | 53 ++++++++++++++++++- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/feathr_project/feathr/utils/job_utils.py b/feathr_project/feathr/utils/job_utils.py index d9c73c355..e03645f71 100644 --- a/feathr_project/feathr/utils/job_utils.py +++ b/feathr_project/feathr/utils/job_utils.py @@ -31,7 +31,7 @@ def get_result_pandas_df( Returns: pandas DataFrame """ - return get_result_df(client, data_format, res_url, local_cache_path) + return get_result_df(client=client, data_format=data_format, res_url=res_url, local_cache_path=local_cache_path) def get_result_spark_df( @@ -56,12 +56,19 @@ def get_result_spark_df( Returns: Spark DataFrame """ - return get_result_df(client, data_format, res_url, local_cache_path, spark=spark) + return get_result_df( + client=client, + data_format=data_format, + res_url=res_url, + local_cache_path=local_cache_path, + spark=spark, + ) def get_result_df( client: FeathrClient, data_format: str = None, + format: str = None, res_url: str = None, local_cache_path: str = None, spark: SparkSession = None, @@ -72,6 +79,7 @@ def get_result_df( client: Feathr client data_format: Format to read the downloaded files. Currently support `parquet`, `delta`, `avro`, and `csv`. Default to use client's job tags if exists. + format: An alias for `data_format` (for backward compatibility). res_url: Result URL to download files from. Note that this will not block the job so you need to make sure the job is finished and the result URL contains actual data. Default to use client's job tags if exists. local_cache_path (optional): Specify the absolute download directory. if the user does not provide this, @@ -82,6 +90,9 @@ def get_result_df( Returns: Either Spark or pandas DataFrame. """ + if format is not None: + data_format = format + if data_format is None: # May use data format from the job tags if client.get_job_tags() and client.get_job_tags().get(OUTPUT_FORMAT): diff --git a/feathr_project/test/unit/utils/test_job_utils.py b/feathr_project/test/unit/utils/test_job_utils.py index 0909fb56e..4a0d835e5 100644 --- a/feathr_project/test/unit/utils/test_job_utils.py +++ b/feathr_project/test/unit/utils/test_job_utils.py @@ -26,7 +26,12 @@ def test__get_result_pandas_df(mocker: MockerFixture): res_url = "some_res_url" local_cache_path = "some_local_cache_path" get_result_pandas_df(client, data_format, res_url, local_cache_path) - mocked_get_result_df.assert_called_once_with(client, data_format, res_url, local_cache_path) + mocked_get_result_df.assert_called_once_with( + client=client, + data_format=data_format, + res_url=res_url, + local_cache_path=local_cache_path, + ) def test__get_result_spark_df(mocker: MockerFixture): @@ -38,7 +43,13 @@ def test__get_result_spark_df(mocker: MockerFixture): res_url = "some_res_url" local_cache_path = "some_local_cache_path" get_result_spark_df(spark, client, data_format, res_url, local_cache_path) - mocked_get_result_df.assert_called_once_with(client, data_format, res_url, local_cache_path, spark=spark) + mocked_get_result_df.assert_called_once_with( + client=client, + data_format=data_format, + res_url=res_url, + local_cache_path=local_cache_path, + spark=spark, + ) @pytest.mark.parametrize( @@ -226,3 +237,41 @@ def test__get_result_df__with_spark_session( ) assert isinstance(df, DataFrame) assert df.count() == expected_count + + +@pytest.mark.parametrize( + "format,output_filename,expected_count", [ + ("csv", "output.csv", 5), + ] +) +def test__get_result_df__arg_alias( + workspace_dir: str, + format: str, + output_filename: str, + expected_count: int, +): + """Test get_result_df returns pandas DataFrame with the argument alias `format` instead of using `data_format`""" + for spark_runtime in ["local", "databricks", "azure_synapse"]: + # Note: make sure the output file exists in the test_user_workspace + res_url = str(Path(workspace_dir, "mock_results", output_filename)) + local_cache_path = res_url + + # Mock client + client = MagicMock() + client.spark_runtime = spark_runtime + + # Mock feathr_spark_launcher.download_result + if client.spark_runtime == "databricks": + res_url = f"dbfs:/{res_url}" + if client.spark_runtime == "azure_synapse" and format == "delta": + # TODO currently pass the delta table test on Synapse result due to the delta table package bug. + continue + + df = get_result_df( + client=client, + format=format, + res_url=res_url, + local_cache_path=local_cache_path, + ) + assert isinstance(df, pd.DataFrame) + assert len(df) == expected_count