Skip to content

Commit

Permalink
Bug fix: Update CRD feature fixed (#25)
Browse files Browse the repository at this point in the history
* bug fix

* unit test updated

* minor change
  • Loading branch information
venkatajagannath authored Jul 22, 2024
1 parent b64e7b6 commit 96009cc
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
11 changes: 7 additions & 4 deletions ray_provider/operators/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
use_gpu: bool = False,
kuberay_version: str = "1.0.0",
gpu_device_plugin_yaml: str = "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml",
update_if_exists: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -43,6 +44,7 @@ def __init__(
self.use_gpu = use_gpu
self.kuberay_version = kuberay_version
self.gpu_device_plugin_yaml = gpu_device_plugin_yaml
self.update_if_exists = update_if_exists

self._validate_yaml_file(ray_cluster_yaml)

Expand Down Expand Up @@ -70,10 +72,11 @@ def _create_or_update_cluster(
"""Create or update the Ray cluster based on the cluster specification."""
try:
self.hook.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace)
self.log.info(f"Updating existing Ray cluster: {name}")
self.hook.patch_custom_object(
group=group, version=version, namespace=namespace, plural=plural, name=name, body=cluster_spec
)
if self.update_if_exists:
self.log.info(f"Updating existing Ray cluster: {name}")
self.hook.custom_object_client.patch_namespaced_custom_object(
group=group, version=version, namespace=namespace, plural=plural, name=name, body=cluster_spec
)
except client.exceptions.ApiException as e:
if e.status == 404:
self.log.info(f"Creating new Ray cluster: {name}")
Expand Down
27 changes: 25 additions & 2 deletions tests/operators/test_ray_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ def test_init(self, mock_file_ops):
use_gpu=True,
kuberay_version="1.1.0",
gpu_device_plugin_yaml="custom_gpu_plugin.yaml",
update_if_exists=True,
)

assert operator.conn_id == "test_conn"
assert operator.ray_cluster_yaml == "cluster.yaml"
assert operator.use_gpu is True
assert operator.kuberay_version == "1.1.0"
assert operator.gpu_device_plugin_yaml == "custom_gpu_plugin.yaml"
assert operator.update_if_exists is True

def test_validate_yaml_file_not_exist(self):
with patch("os.path.isfile", return_value=False):
Expand Down Expand Up @@ -69,11 +71,32 @@ def test_create_or_update_cluster_create(self, mock_hook, operator):

def test_create_or_update_cluster_update(self, mock_hook, operator):
mock_hook.get_custom_object.return_value = {}
operator.update_if_exists = True # Set this to True to test the update logic

operator.hook = mock_hook
operator._create_or_update_cluster("test_group", "v1", "rayclusters", "test-cluster", "default", {})
operator._create_or_update_cluster(
"test_group", "v1", "rayclusters", "test-cluster", "default", {"spec": "updated"}
)

mock_hook.custom_object_client.patch_namespaced_custom_object.assert_called_once_with(
group="test_group",
version="v1",
namespace="default",
plural="rayclusters",
name="test-cluster",
body={"spec": "updated"},
)

def test_create_or_update_cluster_no_update(self, mock_hook, operator):
mock_hook.get_custom_object.return_value = {}
operator.update_if_exists = False # Set this to False to test when update is not allowed

operator.hook = mock_hook
operator._create_or_update_cluster(
"test_group", "v1", "rayclusters", "test-cluster", "default", {"spec": "updated"}
)

mock_hook.patch_custom_object.assert_called_once()
mock_hook.custom_object_client.patch_namespaced_custom_object.assert_not_called()

def test_setup_gpu_driver(self, mock_hook, operator):
mock_hook.load_yaml_content.return_value = {"metadata": {"name": "gpu-plugin"}}
Expand Down

0 comments on commit 96009cc

Please sign in to comment.