Skip to content

Commit

Permalink
Use dataclass for GCP instance
Browse files Browse the repository at this point in the history
  • Loading branch information
mr0re1 committed Jan 23, 2025
1 parent b98b7d1 commit 75c2db8
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def add_nodeset_topology(
except Exception:
continue

phys_host = inst.resourceStatus.get("physicalHost", "")
phys_host = inst.resource_status.get("physicalHost", "")
bldr.summary.physical_host[inst.name] = phys_host
up_nodes.add(inst.name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def get_node_action(nodename: str) -> NodeAction:
elif (state is None or "POWERED_DOWN" in state.flags) and inst.status == "RUNNING":
log.info("%s is potential orphan node", nodename)
threshold = timedelta(seconds=90)
age = datetime.now() - parse_gcp_timestamp(inst.creationTimestamp)
age = datetime.now() - inst.creation_timestamp
log.info(f"{nodename} state: {state}, age: {age}")
if age < threshold:
log.info(f"{nodename} not marked as orphan, it started less than {threshold.seconds}s ago ({age.seconds}s)")
Expand Down Expand Up @@ -464,9 +464,9 @@ def get_slurm_reservation_maintenance(lkp: util.Lookup) -> Dict[str, datetime]:
def get_upcoming_maintenance(lkp: util.Lookup) -> Dict[str, Tuple[str, datetime]]:
upc_maint_map = {}

for node, properties in lkp.instances().items():
if 'upcomingMaintenance' in properties:
start_time = parse_gcp_timestamp(properties['upcomingMaintenance']['startTimeWindow']['earliest'])
for node, inst in lkp.instances().items():
if inst.upcoming_maintenance:
start_time = parse_gcp_timestamp(inst.upcoming_maintenance['startTimeWindow']['earliest'])
upc_maint_map[node + "_maintenance"] = (node, start_time)

return upc_maint_map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
from typing import Optional, Any
import sys
from dataclasses import dataclass, field
from datetime import datetime

SCRIPTS_DIR = "community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts"
if SCRIPTS_DIR not in sys.path:
sys.path.append(SCRIPTS_DIR) # TODO: make this more robust

import util


SOME_TS = datetime.fromisoformat("2018-09-03T20:56:35.450686+00:00")
# TODO: use "real" classes once they are defined (instead of NSDict)

@dataclass
Expand Down Expand Up @@ -83,17 +86,17 @@ class TstMachineConf:
class TstTemplateInfo:
gpu: Optional[util.AcceleratorInfo]

@dataclass
class TstInstance:
name: str
region: str = "gondor"
zone: str = "anorien"
placementPolicyId: Optional[str] = None
physicalHost: Optional[str] = None

@property
def resourceStatus(self):
return {"physicalHost": self.physicalHost}
def tstInstance(name: str, physical_host: Optional[str] = None):
return util.Instance(
name=name,
status="RUNNING",
creation_timestamp=SOME_TS,
resource_status=util.NSDict(
physicalHost = physical_host
),
scheduling=util.NSDict(),
upcoming_maintenance=None,
)

def make_to_hostnames_mock(tbl: Optional[dict[str, list[str]]]):
tbl = tbl or {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import json
import mock
from pytest_unordered import unordered
from common import TstCfg, TstNodeset, TstTPU, TstInstance
from common import TstCfg, TstNodeset, TstTPU, tstInstance
import sort_nodes

import util
Expand Down Expand Up @@ -62,13 +62,13 @@ def tpu_se(ns: str, lkp) -> TstTPU:
lkp = util.Lookup(cfg)
lkp.instances = lambda: { n.name: n for n in [
# nodeset blue
TstInstance("m22-blue-0"), # no physicalHost
TstInstance("m22-blue-0", physicalHost="/a/a/a"),
TstInstance("m22-blue-1", physicalHost="/a/a/b"),
TstInstance("m22-blue-2", physicalHost="/a/b/a"),
TstInstance("m22-blue-3", physicalHost="/b/a/a"),
tstInstance("m22-blue-0"), # no physicalHost
tstInstance("m22-blue-0", physical_host="/a/a/a"),
tstInstance("m22-blue-1", physical_host="/a/a/b"),
tstInstance("m22-blue-2", physical_host="/a/b/a"),
tstInstance("m22-blue-3", physical_host="/b/a/a"),
# nodeset green
TstInstance("m22-green-3", physicalHost="/a/a/c"),
tstInstance("m22-green-3", physical_host="/a/a/c"),
]}

uncompressed = conf.gen_topology(lkp)
Expand Down Expand Up @@ -173,19 +173,19 @@ def test_gen_topology_conf_update():
# don't dump

# set empty physicalHost - no reconfigure
lkp.instances = lambda: { n.name: n for n in [TstInstance("m22-green-0", physicalHost="")]}
lkp.instances = lambda: { n.name: n for n in [tstInstance("m22-green-0", physical_host="")]}
upd, sum = conf.gen_topology_conf(lkp)
assert upd == False
# don't dump

# set physicalHost - reconfigure
lkp.instances = lambda: { n.name: n for n in [TstInstance("m22-green-0", physicalHost="/a/b/c")]}
lkp.instances = lambda: { n.name: n for n in [tstInstance("m22-green-0", physical_host="/a/b/c")]}
upd, sum = conf.gen_topology_conf(lkp)
assert upd == True
sum.dump(lkp)

# change physicalHost - reconfigure
lkp.instances = lambda: { n.name: n for n in [TstInstance("m22-green-0", physicalHost="/a/b/z")]}
lkp.instances = lambda: { n.name: n for n in [tstInstance("m22-green-0", physical_host="/a/b/z")]}
upd, sum = conf.gen_topology_conf(lkp)
assert upd == True
sum.dump(lkp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,29 @@ def sockets(self) -> int:
self.family, 1, # assume 1 socket for all other families
)


@dataclass(frozen=True)
class Instance:
name: str
status: str
creation_timestamp: datetime

# TODO: use proper InstanceResourceStatus class
resource_status: NSDict
# TODO: use proper InstanceScheduling class
scheduling: NSDict
# TODO: use proper UpcomingMaintenance class
upcoming_maintenance: Optional[NSDict] = None

@classmethod
def from_json(cls, jo: dict) -> "Instance":
return cls(
name=jo["name"],
status=jo["status"],
creation_timestamp=parse_gcp_timestamp(jo["creationTimestamp"]),
resource_status=NSDict(jo["resourceStatus"]),
scheduling=NSDict(jo["scheduling"]),
upcoming_maintenance=NSDict(jo["upcomingMaintenance"]) if "upcomingMaintenance" in jo else None
)

@lru_cache(maxsize=1)
def default_credentials():
Expand Down Expand Up @@ -1500,84 +1522,39 @@ def node_state(self, nodename: str) -> Optional[NodeState]:


@lru_cache(maxsize=1)
def instances(self) -> Dict[str, object]:
def instances(self) -> Dict[str, Instance]:
instance_information_fields = [
"advancedMachineFeatures",
"cpuPlatform",
"creationTimestamp",
"disks",
"disks",
"fingerprint",
"guestAccelerators",
"hostname",
"id",
"kind",
"labelFingerprint",
"labels",
"lastStartTimestamp",
"lastStopTimestamp",
"lastSuspendedTimestamp",
"machineType",
"metadata",
"name",
"networkInterfaces",
"resourceStatus",
"scheduling",
"selfLink",
"serviceAccounts",
"shieldedInstanceConfig",
"shieldedInstanceIntegrityPolicy",
"sourceMachineImage",
"status",
"statusMessage",
"tags",
"zone",
# "deletionProtection",
# "startRestricted",
]

# TODO: Merge this with all fields when upcoming maintenance is
# supported in beta.
if endpoint_version(ApiEndpoint.COMPUTE) == 'alpha':
instance_information_fields.append("upcomingMaintenance")

instance_information_fields = sorted(set(instance_information_fields))
instance_fields = ",".join(instance_information_fields)
instance_fields = ",".join(sorted(instance_information_fields))
fields = f"items.zones.instances({instance_fields}),nextPageToken"
flt = f"labels.slurm_cluster_name={self.cfg.slurm_cluster_name} AND name:{self.cfg.slurm_cluster_name}-*"
act = self.compute.instances()
op = act.aggregatedList(project=self.project, fields=fields, filter=flt)

def properties(inst):
"""change instance properties to a preferred format"""
inst["zone"] = trim_self_link(inst["zone"])
inst["machineType"] = trim_self_link(inst["machineType"])
# metadata is fetched as a dict of dicts like:
# {'key': key, 'value': value}, kinda silly
metadata = {i["key"]: i["value"] for i in inst["metadata"].get("items", [])}
if "slurm_instance_role" not in metadata:
return None
inst["role"] = metadata["slurm_instance_role"]
inst["metadata"] = metadata
# del inst["metadata"] # no need to store all the metadata
return NSDict(inst)

instances = {}
while op is not None:
result = ensure_execute(op)
instance_iter = (
(inst["name"], properties(inst))
for inst in chain.from_iterable(
zone.get("instances", []) for zone in result.get("items", {}).values()
)
)
instances.update(
{name: props for name, props in instance_iter if props is not None}
)
for zone in result.get("items", {}).values():
for jo in zone.get("instances", []):
inst = Instance.from_json(jo)
if inst.name in instances:
log.error(f"Duplicate VM name {inst.name} across multiple zones")
instances[inst.name] = inst
op = act.aggregatedList_next(op, result)
return instances

def instance(self, instance_name: str) -> Optional[object]:
def instance(self, instance_name: str) -> Optional[Instance]:
return self.instances().get(instance_name)

@lru_cache()
Expand Down

0 comments on commit 75c2db8

Please sign in to comment.