diff --git a/sdk/python/kubeflow/training/api/training_client_test.py b/sdk/python/kubeflow/training/api/training_client_test.py index 90ae04637f..2a8eb05603 100644 --- a/sdk/python/kubeflow/training/api/training_client_test.py +++ b/sdk/python/kubeflow/training/api/training_client_test.py @@ -17,26 +17,51 @@ from kubernetes.client import V1ResourceRequirements import pytest -LIST_RESPONSE = [{"metadata": {"name": "Dummy V1PodList"}}] TEST_NAME = "test" +TIMEOUT = "timeout" +RUNTIME = "runtime" +MOCK_POD_OBJ = "mock_pod_obj" +NO_PODS = "no_pods" +DUMMY_POD_NAME = "Dummy V1PodList" +LIST_RESPONSE = [ + {"metadata": {"name": DUMMY_POD_NAME}}, +] -def create_namespaced_custom_object_response(*args, **kwargs): - if args[2] == "timeout": +def conditional_error_handler(*args, **kwargs): + if args[2] == TIMEOUT: raise multiprocessing.TimeoutError() - elif args[2] == "runtime": + elif args[2] == RUNTIME: raise RuntimeError() def list_namespaced_pod_response(*args, **kwargs): class MockResponse: def get(self, timeout): - # Simulate a response from the Kubernetes API, and pass timeout for verification - LIST_RESPONSE[0]["timeout"] = timeout - if args[0] == "timeout": + """ + Simulates Kubernetes API response for listing namespaced pods, + and pass timeout for verification + + :return: + - If `args[0] == "timeout"`, raises `TimeoutError`. + - If `args[0] == "runtime"`, raises `Exception`. + - If `args[0] == "mock_pod_obj"`, returns a mock pod object + with `metadata.name = "Dummy V1PodList"`. + - If `args[0] == "no_pods"`, returns an empty list of pods. + - Otherwise, returns a default list of dicts representing pods, + with `timeout` included, for testing. + """ + LIST_RESPONSE[0][TIMEOUT] = timeout + if args[0] == TIMEOUT: raise multiprocessing.TimeoutError() - if args[0] == "runtime": + if args[0] == RUNTIME: raise Exception() + if args[0] == MOCK_POD_OBJ: + pod_obj = Mock(metadata=Mock()) + pod_obj.metadata.name = DUMMY_POD_NAME + return Mock(items=[pod_obj]) + if args[0] == NO_PODS: + return Mock(items=[]) return Mock(items=LIST_RESPONSE) return MockResponse() @@ -156,12 +181,12 @@ def __init__(self, kind) -> None: ), ( "create_namespaced_custom_object timeout error", - {"job": create_job(), "namespace": "timeout"}, + {"job": create_job(), "namespace": TIMEOUT}, TimeoutError, ), ( "create_namespaced_custom_object runtime error", - {"job": create_job(), "namespace": "runtime"}, + {"job": create_job(), "namespace": RUNTIME}, RuntimeError, ), ( @@ -233,7 +258,7 @@ def __init__(self, kind) -> None: "invalid flow with TimeoutError", { "name": TEST_NAME, - "namespace": "timeout", + "namespace": TIMEOUT, }, "Label not relevant", TimeoutError, @@ -242,7 +267,7 @@ def __init__(self, kind) -> None: "invalid flow with RuntimeError", { "name": TEST_NAME, - "namespace": "runtime", + "namespace": RUNTIME, }, "Label not relevant", RuntimeError, @@ -250,14 +275,72 @@ def __init__(self, kind) -> None: ] +test_data_get_job_pod_names = [ + ( + "valid flow", + { + "name": TEST_NAME, + "namespace": MOCK_POD_OBJ, + }, + [DUMMY_POD_NAME], + ), + ( + "valid flow with no pods available", + { + "name": TEST_NAME, + "namespace": NO_PODS, + }, + [], + ), +] + + +test_data_update_job = [ + ( + "valid flow", + { + "name": TEST_NAME, + "job": create_job(), + }, + "No output", + ), + ( + "invalid job_kind", + { + "name": TEST_NAME, + "job": create_job(), + "job_kind": "invalid_job_kind", + }, + ValueError, + ), + ( + "invalid flow with TimeoutError", + { + "name": TEST_NAME, + "namespace": TIMEOUT, + "job": create_job(), + }, + TimeoutError, + ), + ( + "invalid flow with RuntimeError", + { + "name": TEST_NAME, + "namespace": RUNTIME, + "job": create_job(), + }, + RuntimeError, + ), +] + + @pytest.fixture def training_client(): with patch( "kubernetes.client.CustomObjectsApi", return_value=Mock( - create_namespaced_custom_object=Mock( - side_effect=create_namespaced_custom_object_response - ) + create_namespaced_custom_object=Mock(side_effect=conditional_error_handler), + patch_namespaced_custom_object=Mock(side_effect=conditional_error_handler), ), ), patch( "kubernetes.client.CoreV1Api", @@ -304,8 +387,45 @@ def test_get_job_pods( label_selector=expected_label_selector, async_req=True, ) - assert out[0].pop("timeout") == kwargs.get("timeout", constants.DEFAULT_TIMEOUT) + assert out[0].pop(TIMEOUT) == kwargs.get(TIMEOUT, constants.DEFAULT_TIMEOUT) assert out == expected_output except Exception as e: assert type(e) is expected_output print("test execution complete") + + +@pytest.mark.parametrize( + "test_name,kwargs,expected_output", + test_data_get_job_pod_names, +) +def test_get_job_pod_names(training_client, test_name, kwargs, expected_output): + """ + test get_job_pod_names function of training client + """ + print("Executing test:", test_name) + out = training_client.get_job_pod_names(**kwargs) + assert out == expected_output + print("test execution complete") + + +@pytest.mark.parametrize("test_name,kwargs,expected_output", test_data_update_job) +def test_update_job(training_client, test_name, kwargs, expected_output): + """ + test update_job function of training client + """ + print("Executing test:", test_name) + try: + training_client.update_job(**kwargs) + training_client.custom_api.patch_namespaced_custom_object.assert_called_with( + constants.GROUP, + constants.VERSION, + kwargs.get("namespace", constants.DEFAULT_NAMESPACE), + constants.JOB_PARAMETERS[kwargs.get("job_kind", training_client.job_kind)][ + "plural" + ], + kwargs.get("name"), + kwargs.get("job"), + ) + except Exception as e: + assert type(e) is expected_output + print("test execution complete")