Skip to content

Commit

Permalink
feat(sdk): support volume mount in tune API
Browse files Browse the repository at this point in the history
  • Loading branch information
truc0 committed Feb 5, 2025
1 parent 4d2a230 commit a2d2c5f
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import multiprocessing
import time
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union, TypedDict

import grpc
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
Expand All @@ -30,6 +30,10 @@

logger = logging.getLogger(__name__)

TuneStoragePerTrialType = TypedDict(
"TuneStoragePerTrial",
{"volume": client.V1Volume, "mount_path": str},
)

class KatibClient(object):
def __init__(
Expand Down Expand Up @@ -186,6 +190,7 @@ def tune(
env_per_trial: Optional[
Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]
] = None,
storage_per_trial: Optional[Dict[str, TuneStoragePerTrialType]] = None,
algorithm_name: str = "random",
algorithm_settings: Union[
dict, List[models.V1beta1AlgorithmSetting], None
Expand Down Expand Up @@ -468,6 +473,19 @@ class name in this argument.
f"Incorrect value for env_per_trial: {env_per_trial}"
)

volumes: List[client.V1Volume] = []
volume_mounts: List[client.V1VolumeMount] = []
if storage_per_trial:
for name, storage in storage_per_trial.items():
volumes.append(storage["volume"])
volume_mounts.append(
client.V1VolumeMount(name=name, mount_path=storage["mount_path"]),
)
print('='*100)
print("volumes", volumes)
print("volume_mounts", volume_mounts)
print('='*100)

# Create Trial specification.
trial_spec = client.V1Job(
api_version="batch/v1",
Expand All @@ -488,8 +506,10 @@ class name in this argument.
env=env if env else None,
env_from=env_from if env_from else None,
resources=resources_per_trial,
volume_mounts=volume_mounts if volume_mounts else None,
)
],
volumes=volumes if volumes else None,
),
)
),
Expand Down Expand Up @@ -576,7 +596,7 @@ class name in this argument.
f"It must also start and end with an alphanumeric character."
)
elif hasattr(e, "status") and e.status == 409:
print(f"PVC '{name}' already exists in namespace " f"{namespace}.")
print(f"PVC '{name}' already exists in namespace {namespace}.")
else:
raise RuntimeError(f"failed to create PVC. Error: {e}")

Expand Down

0 comments on commit a2d2c5f

Please sign in to comment.