Skip to content

Commit

Permalink
Set the node unit into rdzv_config (#772)
Browse files Browse the repository at this point in the history
  • Loading branch information
workingloong authored Oct 23, 2023
1 parent 80ea8de commit 924c2d2
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
5 changes: 5 additions & 0 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ class ElasticLaunchConfig(LaunchConfig):
auto_tunning: bool = False
exclude_straggler: bool = False

def set_node_unit(self, node_unit):
"""Set the number unint of ndoes."""
self.node_unit = node_unit
self.rdzv_configs["node_unit"] = node_unit


@dataclass
class ProcessError:
Expand Down
5 changes: 5 additions & 0 deletions dlrover/python/tests/test_elastic_training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def setUp(self) -> None:
run_id="test",
)
self.config = ElasticLaunchConfig(**launch_config.__dict__)
self.config.set_node_unit(2)
rdzv_parameters = RendezvousParameters(
backend=self.config.rdzv_backend,
endpoint=self.config.rdzv_endpoint,
Expand Down Expand Up @@ -90,6 +91,10 @@ def setUp(self) -> None:
def addCleanup(self):
self._master.stop()

def test_node_unit(self):
node_unit = int(self.rdzv_handler._rdzv_params.get("node_unit", "1"))
self.assertEqual(node_unit, 2)

def test_rank0_rendzevous(self):
node_id = 0
agent = ElasticTrainingAgent(
Expand Down
2 changes: 1 addition & 1 deletion dlrover/trainer/torch/elastic_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,11 @@ def _elastic_config_from_args(
config, cmd, cmd_args = config_from_args(args)
elastic_config = ElasticLaunchConfig(**config.__dict__)
elastic_config.network_check = getattr(args, "network_check", False)
elastic_config.node_unit = getattr(args, "node_unit", 1)
elastic_config.auto_tunning = getattr(args, "auto_tunning", False)
elastic_config.exclude_straggler = getattr(
args, "exclude_straggler", False
)
elastic_config.set_node_unit(getattr(args, "node_unit", 1))
return elastic_config, cmd, cmd_args


Expand Down

0 comments on commit 924c2d2

Please sign in to comment.