diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index ad85fdcc2..de146ce67 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -104,6 +104,11 @@ class ElasticLaunchConfig(LaunchConfig): auto_tunning: bool = False exclude_straggler: bool = False + def set_node_unit(self, node_unit): + """Set the number unint of ndoes.""" + self.node_unit = node_unit + self.rdzv_configs["node_unit"] = node_unit + @dataclass class ProcessError: diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index 6f7077852..08a94fceb 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -52,6 +52,7 @@ def setUp(self) -> None: run_id="test", ) self.config = ElasticLaunchConfig(**launch_config.__dict__) + self.config.set_node_unit(2) rdzv_parameters = RendezvousParameters( backend=self.config.rdzv_backend, endpoint=self.config.rdzv_endpoint, @@ -90,6 +91,10 @@ def setUp(self) -> None: def addCleanup(self): self._master.stop() + def test_node_unit(self): + node_unit = int(self.rdzv_handler._rdzv_params.get("node_unit", "1")) + self.assertEqual(node_unit, 2) + def test_rank0_rendzevous(self): node_id = 0 agent = ElasticTrainingAgent( diff --git a/dlrover/trainer/torch/elastic_run.py b/dlrover/trainer/torch/elastic_run.py index 485c8918b..a9f2da75c 100644 --- a/dlrover/trainer/torch/elastic_run.py +++ b/dlrover/trainer/torch/elastic_run.py @@ -231,11 +231,11 @@ def _elastic_config_from_args( config, cmd, cmd_args = config_from_args(args) elastic_config = ElasticLaunchConfig(**config.__dict__) elastic_config.network_check = getattr(args, "network_check", False) - elastic_config.node_unit = getattr(args, "node_unit", 1) elastic_config.auto_tunning = getattr(args, "auto_tunning", False) elastic_config.exclude_straggler = getattr( args, "exclude_straggler", False ) + elastic_config.set_node_unit(getattr(args, "node_unit", 1)) return elastic_config, cmd, cmd_args