diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py index e7a8663af18..9f853f62a9c 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py @@ -50,7 +50,7 @@ def get_observation_log_response(*args, **kwargs): metric_logs=[ katib_api_pb2.MetricLog( time_stamp="2024-07-29T15:09:08Z", - metric=katib_api_pb2.Metric(name="result",value="0.99") + metric=katib_api_pb2.Metric(name="result", value="0.99"), ) ] ) @@ -245,36 +245,28 @@ def create_experiment( test_get_trial_metrics_data = [ ( "valid trial name", - { - "name": "example", - "namespace": "valid", - "timeout": constants.DEFAULT_TIMEOUT - }, + {"name": "example", "namespace": "valid", "timeout": constants.DEFAULT_TIMEOUT}, [ katib_api_pb2.MetricLog( time_stamp="2024-07-29T15:09:08Z", - metric=katib_api_pb2.Metric(name="result",value="0.99") + metric=katib_api_pb2.Metric(name="result", value="0.99"), ) - ] + ], ), ( "invalid trial name", { "name": "invalid", "namespace": "invalid", - "timeout": constants.DEFAULT_TIMEOUT + "timeout": constants.DEFAULT_TIMEOUT, }, - RuntimeError + RuntimeError, ), ( "GetObservationLog timeout error", - { - "name": "example", - "namespace": "valid", - "timeout": 0 - }, - RuntimeError - ) + {"name": "example", "namespace": "valid", "timeout": 0}, + RuntimeError, + ), ] @@ -287,16 +279,11 @@ def katib_client(): side_effect=create_namespaced_custom_object_response ) ), - ), patch( - "kubernetes.config.load_kube_config", - return_value=Mock() - ), patch( - "kubeflow.katib.katib_api_pb2_grpc.DBManagerStub", + ), patch("kubernetes.config.load_kube_config", return_value=Mock()), patch( + "kubeflow.katib.katib_api_pb2_grpc.DBManagerStub", return_value=Mock( - GetObservationLog=Mock( - side_effect=get_observation_log_response - ) - ) + GetObservationLog=Mock(side_effect=get_observation_log_response) + ), ): client = KatibClient() yield client @@ -318,7 +305,9 @@ def test_create_experiment(katib_client, test_name, kwargs, expected_output): print("test execution complete") -@pytest.mark.parametrize("test_name,kwargs,expected_output", test_get_trial_metrics_data) +@pytest.mark.parametrize( + "test_name,kwargs,expected_output", test_get_trial_metrics_data +) def test_get_trial_metrics(katib_client, test_name, kwargs, expected_output): """ test get_trial_metrics function of katib client diff --git a/sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py b/sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py index b8c513e9b78..50dd02c9c20 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/report_metrics.py @@ -68,11 +68,11 @@ def report_metrics( metric_logs=[ katib_api_pb2.MetricLog( time_stamp=timestamp, - metric=katib_api_pb2.Metric(name=name,value=str(value)) + metric=katib_api_pb2.Metric(name=name, value=str(value)), ) for name, value in metrics.items() ] - ) + ), ), timeout=timeout, ) diff --git a/sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py b/sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py index 129afcf4577..69f13698b13 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py @@ -17,71 +17,40 @@ def report_observation_log_response(*args, **kwargs): test_report_metrics_data = [ ( "valid metrics with float type", - { - "metrics": { - "result": 0.99 - }, - "timeout": constants.DEFAULT_TIMEOUT - - }, + {"metrics": {"result": 0.99}, "timeout": constants.DEFAULT_TIMEOUT}, TEST_RESULT_SUCCESS, - ENV_VARIABLE_NOT_EMPTY + ENV_VARIABLE_NOT_EMPTY, ), ( "valid metrics with string type", - { - "metrics": { - "result": "0.99" - }, - "timeout": constants.DEFAULT_TIMEOUT - }, + {"metrics": {"result": "0.99"}, "timeout": constants.DEFAULT_TIMEOUT}, TEST_RESULT_SUCCESS, - ENV_VARIABLE_NOT_EMPTY + ENV_VARIABLE_NOT_EMPTY, ), ( "valid metrics with int type", - { - "metrics": { - "result": 1 - }, - "timeout": constants.DEFAULT_TIMEOUT - }, + {"metrics": {"result": 1}, "timeout": constants.DEFAULT_TIMEOUT}, TEST_RESULT_SUCCESS, - ENV_VARIABLE_NOT_EMPTY + ENV_VARIABLE_NOT_EMPTY, ), ( "ReportObservationLog timeout error", - { - "metrics": { - "result": 0.99 - }, - "timeout": 0 - }, + {"metrics": {"result": 0.99}, "timeout": 0}, RuntimeError, - ENV_VARIABLE_NOT_EMPTY + ENV_VARIABLE_NOT_EMPTY, ), ( "invalid metrics with type string", - { - "metrics": { - "result": "abc" - }, - "timeout": constants.DEFAULT_TIMEOUT - }, + {"metrics": {"result": "abc"}, "timeout": constants.DEFAULT_TIMEOUT}, ValueError, - ENV_VARIABLE_NOT_EMPTY + ENV_VARIABLE_NOT_EMPTY, ), ( "Trial name is not passed to env variables", - { - "metrics": { - "result": 0.99 - }, - "timeout": constants.DEFAULT_TIMEOUT - }, + {"metrics": {"result": 0.99}, "timeout": constants.DEFAULT_TIMEOUT}, ValueError, - ENV_VARIABLE_EMPTY - ) + ENV_VARIABLE_EMPTY, + ), ] @@ -113,9 +82,16 @@ def mock_report_observation_log(): @pytest.mark.parametrize( "test_name,kwargs,expected_output,mock_getenv", test_report_metrics_data, - indirect=["mock_getenv"] + indirect=["mock_getenv"], ) -def test_report_metrics(test_name, kwargs, expected_output, mock_getenv, mock_get_current_k8s_namespace, mock_report_observation_log): +def test_report_metrics( + test_name, + kwargs, + expected_output, + mock_getenv, + mock_get_current_k8s_namespace, + mock_report_observation_log, +): """ test report_metrics function """ diff --git a/sdk/python/v1beta1/setup.py b/sdk/python/v1beta1/setup.py index fbf6b865156..49c689a235c 100644 --- a/sdk/python/v1beta1/setup.py +++ b/sdk/python/v1beta1/setup.py @@ -46,7 +46,7 @@ "kubeflow/katib/katib_api_pb2_grpc.py", ) - with open("kubeflow/katib/katib_api_pb2_grpc.py", 'r+') as file: + with open("kubeflow/katib/katib_api_pb2_grpc.py", "r+") as file: content = file.read() new_content = content.replace("api_pb2", "kubeflow.katib.katib_api_pb2") file.seek(0)