Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Partition into dedicated dataclass #3563

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024 "Google LLC"
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Dict

from dataclasses import dataclass, field

@dataclass(frozen=True)
class Partition:
name: str
enable_job_exclusive: bool = False
conf: Dict[str, Any] = field(default_factory=dict)

nodesets: List[str] = field(default_factory=list)
nodesets_dyn: List[str] = field(default_factory=list)
nodesets_tpu: List[str] = field(default_factory=list)

@property
def is_tpu(self) -> bool:
return len(self.nodesets_tpu) > 0

@property
def any_dynamic(self) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd probably name it something like has_dynamic

return len(self.nodesets_dyn) > 0

@classmethod
def from_json(cls, jo: dict) -> "Partition":
return cls(
name=jo["partition_name"],
enable_job_exclusive=jo["enable_job_exclusive"],
conf=jo.get("partition_conf", {}),

nodesets=jo.get("nodesets", []),
nodesets_dyn=jo.get("nodesets_dyn", []),
nodesets_tpu=jo.get("nodesets_tpu", []),
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import util
from util import dirs, slurmdirs
import tpu
from base import Partition

FILE_PREAMBLE = """
# Warning:
Expand Down Expand Up @@ -76,13 +77,9 @@ def get(key, default):
for nodeset in lkp.cfg.nodeset.values()
)

any_tpu = any(
tpu_nodeset is not None
for part in lkp.cfg.partitions.values()
for tpu_nodeset in part.partition_nodeset_tpu
)
any_tpu = any(p.is_tpu for p in lkp.partitions)
any_dynamic = any(p.any_dynamic for p in lkp.partitions)

any_dynamic = any(bool(p.partition_feature) for p in lkp.cfg.partitions.values())
comma_params = {
"LaunchParameters": [
"enable_nss_slurm",
Expand Down Expand Up @@ -180,7 +177,7 @@ def nodeset_dyn_lines(nodeset):
)


def partitionlines(partition, lkp: util.Lookup) -> str:
def partitionlines(partition: Partition, lkp: util.Lookup) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Could we change this to partition_lines?

"""Make a partition line for the slurm.conf"""
MIN_MEM_PER_CPU = 100

Expand All @@ -192,32 +189,23 @@ def defmempercpu(nodeset_name: str) -> int:
return max(MIN_MEM_PER_CPU, (machine.memory - mem_spec_limit) // machine.cpus)

defmem = min(
map(defmempercpu, partition.partition_nodeset), default=MIN_MEM_PER_CPU
)

nodesets = list(
chain(
partition.partition_nodeset,
partition.partition_nodeset_dyn,
partition.partition_nodeset_tpu,
)
map(defmempercpu, partition.nodesets), default=MIN_MEM_PER_CPU
)

is_tpu = len(partition.partition_nodeset_tpu) > 0
is_dyn = len(partition.partition_nodeset_dyn) > 0
nodesets = list(chain(partition.nodesets, partition.nodesets_dyn, partition.nodesets_tpu))

oversub_exlusive = partition.enable_job_exclusive or is_tpu
power_down_on_idle = partition.enable_job_exclusive and not is_dyn
oversub_exlusive = partition.enable_job_exclusive or partition.is_tpu
power_down_on_idle = partition.enable_job_exclusive and not partition.any_dynamic

line_elements = {
"PartitionName": partition.partition_name,
"PartitionName": partition.name,
"Nodes": ",".join(nodesets),
"State": "UP",
"DefMemPerCPU": defmem,
"SuspendTime": 300,
"Oversubscribe": "Exclusive" if oversub_exlusive else None,
"PowerDownOnIdle": "YES" if power_down_on_idle else None,
**partition.partition_conf,
**partition.conf,
}

return dict_to_conf(line_elements)
Expand All @@ -231,12 +219,8 @@ def suspend_exc_lines(lkp: util.Lookup) -> Iterable[str]:
static_nodelists.append(nodelist)
suspend_exc_nodes = {"SuspendExcNodes": static_nodelists}

dyn_parts = [
p.partition_name
for p in lkp.cfg.partitions.values()
if len(p.partition_nodeset_dyn) > 0
]
suspend_exc_parts = {"SuspendExcParts": [*dyn_parts]}
dyn_parts = [p.name for p in lkp.partitions if p.any_dynamic]
suspend_exc_parts = {"SuspendExcParts": dyn_parts}

return filter(
None,
Expand All @@ -255,7 +239,7 @@ def make_cloud_conf(lkp: util.Lookup) -> str:
*(nodeset_lines(n, lkp) for n in lkp.cfg.nodeset.values()),
*(nodeset_dyn_lines(n) for n in lkp.cfg.nodeset_dyn.values()),
*(nodeset_tpu_lines(n, lkp) for n in lkp.cfg.nodeset_tpu.values()),
*(partitionlines(p, lkp) for p in lkp.cfg.partitions.values()),
*(partitionlines(p, lkp) for p in lkp.partitions),
*(suspend_exc_lines(lkp)),
]
return "\n\n".join(filter(None, lines))
Expand Down Expand Up @@ -341,11 +325,7 @@ def install_cgroup_conf(lkp: util.Lookup) -> None:

def install_jobsubmit_lua(lkp: util.Lookup) -> None:
"""install job_submit.lua if there are tpu nodes in the cluster"""
if not any(
tpu_nodeset is not None
for part in lkp.cfg.partitions.values()
for tpu_nodeset in part.partition_nodeset_tpu
):
if not any(p.is_tpu for p in lkp.partitions):
return # No TPU partitions, no need for job_submit.lua

scripts_dir = lkp.cfg.slurm_scripts_dir or dirs.scripts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import argparse
import util
import tpu
from base import Partition


def get_vmcount_of_tpu_part(part):
def get_vmcount_of_part(part: Partition):
res = 0
lkp = util.lookup()
for ns in lkp.cfg.partitions[part].partition_nodeset_tpu:
for ns in part.nodesets_tpu:
tpu_obj = tpu.TPU.make(ns, lkp)
if res == 0:
res = tpu_obj.vmcount
Expand Down Expand Up @@ -54,19 +54,19 @@ def get_vmcount_of_tpu_part(part):
vmcounts = []
# valid equals to 0 means that we are ok, otherwise it will be set to one of the previously defined exit codes
valid = 0
for part in args.partitions.split(","):
if part not in util.lookup().cfg.partitions:
for part_name in args.partitions.split(","):
try:
part = util.lookup().partition(part_name)
except:
valid = PART_INVALID
break
else:
if util.lookup().partition_is_tpu(part):
vmcount = get_vmcount_of_tpu_part(part)
if vmcount == -1:
valid = DIFF_VMCOUNTS_SAME_PART
break
vmcounts.append(vmcount)
else:
vmcounts.append(0)
vmcount = get_vmcount_of_part(part)
if vmcount == -1:
valid = DIFF_VMCOUNTS_SAME_PART
break
vmcounts.append(vmcount)

# this means that there are different vmcounts for these partitions
if valid == 0 and len(set(vmcounts)) != 1:
valid = DIFF_PART_DIFFERENT_VMCOUNTS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def group_nodes_bulk(nodes: List[str], resume_data: Optional[ResumeData], lkp: u

# expand all exclusive job nodelists
for job in resume_data.jobs:
if not lkp.cfg.partitions[job.partition].enable_job_exclusive:
if not lkp.partition(job.partition).enable_job_exclusive:
continue

groups[job.job_id] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,6 @@ class TstNodeset:
enable_placement: bool = True
placement_max_distance: Optional[int] = None

@dataclass
class TstPartition:
partition_name: str = "euler"
partition_nodeset: list[str] = field(default_factory=list)
partition_nodeset_tpu: list[str] = field(default_factory=list)
enable_job_exclusive: bool = False

@dataclass
class TstCfg:
slurm_cluster_name: str = "m22"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import unittest
import tempfile

from common import TstCfg, TstNodeset, TstPartition, TstTPU # needed to import util
from common import TstCfg, TstNodeset, TstTPU # needed to import util
import util
import resume
from resume import ResumeData, ResumeJobData, BulkChunk, PlacementAndNodes
Expand Down Expand Up @@ -74,11 +74,11 @@ def test_group_nodes_bulk(mock_create_placements, mock_tpu):
"t": TstNodeset(nodeset_name="t"),
},
partitions={
"p1": TstPartition(
"p1": dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the Partition class?

partition_name="p1",
enable_job_exclusive=True,
),
"p2": TstPartition(
"p2": dict(
partition_name="p2",
partition_nodeset_tpu=["t"],
enable_job_exclusive=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

import yaml # noqa: E402
from addict import Dict as NSDict # noqa: E402

from base import Partition

USER_AGENT = "Slurm_GCP_Scripts/1.5 (GPN:SchedMD)"
ENV_CONFIG_YAML = os.getenv("SLURM_CONFIG_YAML")
Expand Down Expand Up @@ -536,8 +536,7 @@ def _assemble_config(
# add partition configs
for p_yaml in partitions:
p_cfg = NSDict(p_yaml)
assert p_cfg.get("partition_name"), "partition_name is required"
p_name = p_cfg.partition_name
p_name = Partition.from_json(p_cfg).name # + de-serialization check
assert p_name not in cfg.partitions, f"partition {p_name} already defined"
cfg.partitions[p_name] = p_cfg

Expand Down Expand Up @@ -1314,6 +1313,15 @@ def hostname_fqdn(self):
def zone(self):
return instance_metadata("zone")


@property
def partitions(self) -> List[Partition]:
return [Partition.from_json(jo) for jo in self.cfg.partitions]

def partition(self, name: str) -> Partition:
return Partition.from_json(self.cfg.partitions[name])


node_desc_regex = re.compile(
r"^(?P<prefix>(?P<cluster>[^\s\-]+)-(?P<nodeset>\S+))-(?P<node>(?P<suffix>\w+)|(?P<range>\[[\d,-]+\]))$"
)
Expand Down Expand Up @@ -1351,11 +1359,6 @@ def node_nodeset(self, node_name=None):

return self.cfg.nodeset[nodeset_name]

def partition_is_tpu(self, part: str) -> bool:
"""check if partition with name part contains a nodeset of type tpu"""
return len(self.cfg.partitions[part].partition_nodeset_tpu) > 0


def node_is_tpu(self, node_name=None):
nodeset_name = self.node_nodeset_name(node_name)
return self.cfg.nodeset_tpu.get(nodeset_name) is not None
Expand Down
Loading