Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#38 Make API URL configurable for use in Europe #39

Merged
merged 3 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ default: release
PLATFORMS = linux/amd64,linux/arm64
APP="castai/hibernate"
TAG_LATEST=$(APP):latest
TAG_VERSION=$(APP):v0.11
TAG_VERSION=$(APP):v0.12

gke:
(cd ./hack/gke && terraform init && terraform apply -auto-approve)
Expand Down
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ kubectl get secret castai-hibernate -n castai-agent -o json | jq --arg API_KEY "
Modify the `.spec.schedule` parameter for the Hibernate-pause and Hibernate-resume cronjobs according to  [this syntax](https://kubernetes.io/docs/concepts/workloads/controllers/cron-jobs/#schedule-syntax). Beginning with Kubernetes v1.25 and later versions, it is possible to define a time zone for a CronJob by assigning a valid time zone name to `.spec.timeZone`. For instance, by assigning `.spec.timeZone: "Etc/UTC"`, Kubernetes will interpret the schedule with respect to Coordinated Universal Time (UTC). To access a list of acceptable time zone options, please refer to the following link: [List of Valid Time Zones](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones).


### Set API URL

If you need to use a different API URL (e.g. europe for example), you can provide the URL via environment variable:

```
API_URL = https://api.eu.cast.ai
```

Default is https://api.cast.ai


## How it works

Hibernate-pause Job will
Expand Down
81 changes: 42 additions & 39 deletions app/cast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,58 +10,58 @@ class NetworkError(Exception):


@basic_retry(attempts=3, pause=5)
def get_cluster_status(clusterid, castai_apitoken):
url = "https://api.cast.ai/v1/kubernetes/external-clusters/{}".format
def get_cluster_status(cluster_id, castai_api_url, castai_api_token):
url = f"{castai_api_url}/v1/kubernetes/external-clusters/{cluster_id}"
header_dict = {"accept": "application/json",
"X-API-Key": castai_apitoken}
"X-API-Key": castai_api_token}

resp = requests.get(url=url(clusterid), headers=header_dict)
resp = requests.get(url, headers=header_dict)
if resp.status_code == 200:
return resp.json()


@basic_retry(attempts=3, pause=5)
def cluster_ready(clusterid, castai_apitoken):
cluster = get_cluster_status(clusterid, castai_apitoken)
def cluster_ready(cluster_id, castai_api_url, castai_api_token):
cluster = get_cluster_status(cluster_id, castai_api_url, castai_api_token)
logging.info(f"TEST cluster status: {cluster.get('status')}, id: {cluster.get('id')}")
if cluster.get('status') == 'ready':
return True
return False


def get_castai_policy(cluster_id, castai_api_token):
url = "https://api.cast.ai/v1/kubernetes/clusters/{}/policies".format
def get_castai_policy(cluster_id, castai_api_url, castai_api_token):
url = f"{castai_api_url}/v1/kubernetes/clusters/{cluster_id}/policies"
header_dict = {"accept": "application/json",
"X-API-Key": castai_api_token}

resp = requests.get(url=url(cluster_id), headers=header_dict)
resp = requests.get(url, headers=header_dict)
if resp.status_code == 200:
return resp.json()


def set_castai_policy(cluster_id, castai_api_token, updated_policies):
url = "https://api.cast.ai/v1/kubernetes/clusters/{}/policies".format
def set_castai_policy(cluster_id, castai_api_url, castai_api_token, updated_policies):
url = f"{castai_api_url}/v1/kubernetes/clusters/{cluster_id}/policies"
header_dict = {"accept": "application/json",
"X-API-Key": castai_api_token}

resp = requests.put(url=url(cluster_id), json=updated_policies, headers=header_dict)
resp = requests.put(url, json=updated_policies, headers=header_dict)
if resp.status_code == 200:
return resp.json()


@basic_retry(attempts=2, pause=5)
def toggle_autoscaler_top_flag(cluster_id: str, castai_api_token: str, policy_value: bool):
def toggle_autoscaler_top_flag(cluster_id: str, castai_api_url: str, castai_api_token: str, policy_value: bool):
"""" Disable CAST autoscaler to prevent adding new nodes automatically"""

current_policies = get_castai_policy(cluster_id, castai_api_token)
current_policies = get_castai_policy(cluster_id, castai_api_url, castai_api_token)

if current_policies["enabled"] != policy_value:
logging.info("Update policy. mismatch")
logging.info(f'Current: {current_policies["enabled"]} Future: {policy_value}')

current_policies["enabled"] = policy_value

validate_policies = set_castai_policy(cluster_id, castai_api_token, current_policies)
validate_policies = set_castai_policy(cluster_id, castai_api_url, castai_api_token, current_policies)
if validate_policies["enabled"] == policy_value:
logging.info("Update completed")
return True
Expand All @@ -74,9 +74,10 @@ def toggle_autoscaler_top_flag(cluster_id: str, castai_api_token: str, policy_va


@basic_retry(attempts=3, pause=5)
def create_hibernation_node(cluster_id: str, castai_api_token: str, instance_type: str, k8s_taint: str, cloud: str):
def create_hibernation_node(cluster_id: str, castai_api_url: str, castai_api_token: str, instance_type: str,
k8s_taint: str, cloud: str):
""" Create Node with Taint that will stay running during hibernation"""
url = "https://api.cast.ai/v1/kubernetes/external-clusters/{}/nodes".format
url = f"{castai_api_url}/v1/kubernetes/external-clusters/{cluster_id}/nodes"
header_dict = {"accept": "application/json",
"X-API-Key": castai_api_token}

Expand Down Expand Up @@ -104,7 +105,7 @@ def create_hibernation_node(cluster_id: str, castai_api_token: str, instance_typ

with Session() as session:
try:
with session.post(url=url(cluster_id), json=new_node_body, headers=header_dict) as postresp:
with session.post(url, json=new_node_body, headers=header_dict) as postresp:
postresp.raise_for_status()
add_node_result = postresp.json()
except Exception as e:
Expand All @@ -113,13 +114,13 @@ def create_hibernation_node(cluster_id: str, castai_api_token: str, instance_typ
# wait for new node to be created, listen to operation
ops_id = add_node_result["operationId"]
nodeId = add_node_result["nodeId"]
urlOperations = "https://api.cast.ai/v1/kubernetes/external-clusters/operations/{}".format
urlOperations = f"{castai_api_url}/v1/kubernetes/external-clusters/operations/{ops_id}"
done_node_creation = False

while not done_node_creation:
logging.info("checking node creation operation ID: %s", ops_id)
try:
with session.get(url=urlOperations(ops_id), headers=header_dict) as operation:
with session.get(urlOperations, headers=header_dict) as operation:
operation.raise_for_status()
ops_response = operation.json()
except Exception as e:
Expand All @@ -135,10 +136,10 @@ def create_hibernation_node(cluster_id: str, castai_api_token: str, instance_typ


@basic_retry(attempts=4, pause=15)
def delete_all_pausable_nodes(cluster_id: str, castai_api_token: str, hibernation_node_id: str,
def delete_all_pausable_nodes(cluster_id: str, castai_api_url: str, castai_api_token: str, hibernation_node_id: str,
protect_removal_disabled: str, job_node_id=None):
"""" Delete all nodes through CAST AI mothership excluding hibernation node"""
node_list_result = get_castai_nodes(cluster_id, castai_api_token)
node_list_result = get_castai_nodes(cluster_id, castai_api_url, castai_api_token)
for node in node_list_result["items"]:
# drain/delete each node
if node["id"] == hibernation_node_id or node["id"] == job_node_id:
Expand All @@ -148,12 +149,12 @@ def delete_all_pausable_nodes(cluster_id: str, castai_api_token: str, hibernatio
logging.info("Skipping node protected by removal-disabled ID: %s " % node["id"])
continue
logging.info("Deleting: %s with id: %s" % (node["name"], node["id"]))
delete_castai_node(cluster_id, castai_api_token, node["id"])
delete_castai_node(cluster_id, castai_api_url, castai_api_token, node["id"])


def get_castai_nodes_by_instance_type(cluster_id: str, castai_api_token: str, instance_type: str):
def get_castai_nodes_by_instance_type(cluster_id: str, castai_api_url: str, castai_api_token: str, instance_type: str):
"""" Get all nodes by instance type"""
node_list_result = get_castai_nodes(cluster_id, castai_api_token)
node_list_result = get_castai_nodes(cluster_id, castai_api_url, castai_api_token)
nodes = []
for node in node_list_result["items"]:
if node["instanceType"] == instance_type and node["state"]["phase"] == "ready":
Expand All @@ -162,8 +163,10 @@ def get_castai_nodes_by_instance_type(cluster_id: str, castai_api_token: str, in
return nodes


def get_suitable_hibernation_node(cluster_id: str, castai_api_token: str, instance_type: str, cloud: str):
cast_nodes = get_castai_nodes_by_instance_type(cluster_id, castai_api_token, instance_type=instance_type)
def get_suitable_hibernation_node(cluster_id: str, castai_api_url: str, castai_api_token: str, instance_type: str,
cloud: str):
cast_nodes = get_castai_nodes_by_instance_type(cluster_id, castai_api_url, castai_api_token,
instance_type=instance_type)
for node in sorted(cast_nodes, key=lambda k: k['createdAt']):
if node["labels"].get("scheduling.cast.ai/paused-cluster") == "true":
if cloud == "AKS": # Azure special case use system node
Expand All @@ -175,24 +178,24 @@ def get_suitable_hibernation_node(cluster_id: str, castai_api_token: str, instan
return node["name"]


def get_castai_nodes(cluster_id, castai_api_token):
def get_castai_nodes(cluster_id, castai_api_url, castai_api_token):
""" Get all nodes from CAST AI API inside the cluster"""
url = "https://api.cast.ai/v1/kubernetes/external-clusters/{}/nodes".format
url = f"{castai_api_url}/v1/kubernetes/external-clusters/{cluster_id}/nodes"
header_dict = {"accept": "application/json",
"X-API-Key": castai_api_token}

resp = requests.get(url=url(cluster_id), headers=header_dict)
resp = requests.get(url, headers=header_dict)
if resp.status_code == 200:
return resp.json()


def get_castai_node_name_by_id(cluster_id, castai_api_token, node_id):
def get_castai_node_name_by_id(cluster_id, castai_api_url, castai_api_token, node_id):
""" Get node by CAST AI id from CAST AI API"""
url = "https://api.cast.ai/v1/kubernetes/external-clusters/{}/nodes/{}".format
url = f"{castai_api_url}/v1/kubernetes/external-clusters/{cluster_id}/nodes/{node_id}"
header_dict = {"accept": "application/json",
"X-API-Key": castai_api_token}

resp = requests.get(url=url(cluster_id, node_id), headers=header_dict)
resp = requests.get(url, headers=header_dict)
if resp.status_code == 200:
if resp.json()['name']:
return resp.json()['name']
Expand All @@ -201,17 +204,17 @@ def get_castai_node_name_by_id(cluster_id, castai_api_token, node_id):


@basic_retry(attempts=3, pause=30)
def delete_castai_node(cluster_id, castai_api_token, node_id):
def delete_castai_node(cluster_id, castai_api_url, castai_api_token, node_id):
""" Delete single node"""
url = "https://api.cast.ai/v1/kubernetes/external-clusters/{}/nodes/{}".format
url = f"{castai_api_url}/v1/kubernetes/external-clusters/{cluster_id}/nodes/{node_id}"
header_dict = {"accept": "application/json",
"X-API-Key": castai_api_token}
paramsDelete = {
"forceDelete": True,
"drainTimeout": 60
}

resp = requests.delete(url=url(cluster_id, node_id), headers=header_dict, params=paramsDelete)
resp = requests.delete(url, headers=header_dict, params=paramsDelete)
if resp.status_code == 200:
delete_node_result = resp.json()
logging.info(delete_node_result)
Expand All @@ -220,11 +223,11 @@ def delete_castai_node(cluster_id, castai_api_token, node_id):
return False


def get_cluster_details(cluster_id, castai_api_token):
url = "https://api.cast.ai/v1/kubernetes/external-clusters/{}".format
def get_cluster_details(cluster_id, castai_api_url, castai_api_token):
url = f"{castai_api_url}/v1/kubernetes/external-clusters/{cluster_id}"
header_dict = {"accept": "application/json",
"X-API-Key": castai_api_token}

resp = requests.get(url=url(cluster_id), headers=header_dict)
resp = requests.get(url, headers=header_dict)
if resp.status_code == 200:
return resp.json()
37 changes: 22 additions & 15 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def get_logging_level():
k8s_v1 = client.CoreV1Api()
k8s_v1_apps = client.AppsV1Api()

castai_api_url = os.environ.get("API_URL", "https://api.cast.ai")
castai_api_token = os.environ["API_KEY"]
cluster_id = os.environ["CLUSTER_ID"]
hibernate_node_type_override = os.environ.get("HIBERNATE_NODE")
Expand Down Expand Up @@ -70,27 +71,28 @@ def get_logging_level():

def handle_resume():
logging.info("Resuming cluster, autoscaling will be enabled")
policy_changed = toggle_autoscaler_top_flag(cluster_id, castai_api_token, True)
policy_changed = toggle_autoscaler_top_flag(cluster_id, castai_api_url, castai_api_token, True)
if not policy_changed:
raise Exception("could not enable CAST AI autoscaler.")

logging.info("Resume operation completed.")


def handle_suspend(cloud):
current_policies = get_castai_policy(cluster_id, castai_api_token)
current_policies = get_castai_policy(cluster_id, castai_api_url, castai_api_token)
if current_policies["enabled"] == False:
logging.info("Cluster is already with disabled autoscaler policies, checking for dirty state.")
if last_run_dirty(client=k8s_v1, cm=configmap_name, ns=ns):
raise Exception("Cluster is already paused, but last run was dirty, clean configMap to retry or wait 12h")
else:
time.sleep(360) # avoid double running
nodes = get_castai_nodes(cluster_id=cluster_id, castai_api_token=castai_api_token)
nodes = get_castai_nodes(cluster_id=cluster_id, castai_api_url=castai_api_url,
castai_api_token=castai_api_token)
logging.info(f'Number of nodes found in the cluster: {len(nodes["items"])}')
logging.info("Cluster is already with disabled autoscaler policies, exiting.")
return 0

toggle_autoscaler_top_flag(cluster_id, castai_api_token, False)
toggle_autoscaler_top_flag(cluster_id, castai_api_url, castai_api_token, False)

my_node_name_id = ""
if my_node_name:
Expand All @@ -102,7 +104,8 @@ def handle_suspend(cloud):
else:
hibernate_node_type = instance_type[cloud]

candidate_node = get_suitable_hibernation_node(cluster_id=cluster_id, castai_api_token=castai_api_token,
candidate_node = get_suitable_hibernation_node(cluster_id=cluster_id, castai_api_url=castai_api_url,
castai_api_token=castai_api_token,
instance_type=hibernate_node_type, cloud=cloud)

hibernation_node_id = None
Expand All @@ -113,7 +116,7 @@ def handle_suspend(cloud):
node_name=candidate_node)

if my_node_name_id == hibernation_node_id:
node_list_result = get_castai_nodes(cluster_id, castai_api_token)
node_list_result = get_castai_nodes(cluster_id, castai_api_url, castai_api_token)
nodes = []
for node in node_list_result["items"]:
if node["state"]["phase"] == "ready":
Expand All @@ -125,13 +128,14 @@ def handle_suspend(cloud):

if not hibernation_node_id:
logging.info("No suitable hibernation node found, should make one")
hibernation_node_id = create_hibernation_node(cluster_id, castai_api_token, instance_type=hibernate_node_type,
hibernation_node_id = create_hibernation_node(cluster_id, castai_api_url, castai_api_token,
instance_type=hibernate_node_type,
k8s_taint=castai_pause_toleration, cloud=cloud)

if not hibernation_node_id:
raise Exception("could not create hibernation node")

node_name = get_castai_node_name_by_id(cluster_id, castai_api_token, hibernation_node_id)
node_name = get_castai_node_name_by_id(cluster_id, castai_api_url, castai_api_token, hibernation_node_id)

hibernation_node_status = check_hibernation_node_readiness(client=k8s_v1, taint=castai_pause_toleration,
node_name=node_name)
Expand Down Expand Up @@ -162,23 +166,26 @@ def handle_suspend(cloud):

if my_node_name_id and my_node_name_id != hibernation_node_id:
logging.info("Job pod node id and hibernation node is not the same")
delete_all_pausable_nodes(cluster_id=cluster_id, castai_api_token=castai_api_token,
delete_all_pausable_nodes(cluster_id=cluster_id, castai_api_url=castai_api_url,
castai_api_token=castai_api_token,
hibernation_node_id=hibernation_node_id,
protect_removal_disabled=protect_removal_disabled, job_node_id=my_node_name_id)
defer_job_node_deletion = True
else:
logging.info("Delete all nodes except hibernation node")
delete_all_pausable_nodes(cluster_id, castai_api_token, hibernation_node_id, protect_removal_disabled)
delete_all_pausable_nodes(cluster_id, castai_api_url, castai_api_token, hibernation_node_id,
protect_removal_disabled)

remove_node_taint(client=k8s_v1, pause_taint=castai_pause_toleration, node_id=hibernation_node_id)

if defer_job_node_deletion:
logging.info("Delete jobs node with id %s:", my_node_name_id)
delete_all_pausable_nodes(cluster_id=cluster_id, castai_api_token=castai_api_token,
delete_all_pausable_nodes(cluster_id=cluster_id, castai_api_url=castai_api_url,
castai_api_token=castai_api_token,
hibernation_node_id=hibernation_node_id,
protect_removal_disabled=protect_removal_disabled)

if cluster_ready(clusterid=cluster_id, castai_apitoken=castai_api_token):
if cluster_ready(cluster_id=cluster_id, castai_api_url=castai_api_url, castai_api_token=castai_api_token):
logging.info(f"cluster ready, updating last run status to success.")
update_last_run_status(client=k8s_v1, cm=configmap_name, ns=ns, status="success")
logging.info("Pause operation completed.")
Expand All @@ -187,10 +194,10 @@ def handle_suspend(cloud):
raise Exception("Pause finished, but cluster is not ready")


def get_cloud_provider(cluster_id: str, castai_api_token):
def get_cloud_provider(cluster_id: str, castai_api_url: str, castai_api_token: str):
'''Detect cloud provider from CAST AI, then node labels or fallback to env var CLOUD'''

cloud_var = get_cluster_details(cluster_id, castai_api_token)["providerType"].upper()
cloud_var = get_cluster_details(cluster_id, castai_api_url, castai_api_token)["providerType"].upper()
if cloud_var is not None:
logging.info(f"Cloud %s auto-detected from CAST AI cluster details API", cloud_var)
return cloud_var
Expand All @@ -201,7 +208,7 @@ def get_cloud_provider(cluster_id: str, castai_api_token):

def main():
try:
cloud = get_cloud_provider(cluster_id, castai_api_token)
cloud = get_cloud_provider(cluster_id, castai_api_url, castai_api_token)
except:
logging.error("could not detect cloud provider, check API key or network problems")
exit(1)
Expand Down