Skip to content

Commit

Permalink
The master notifies the worker agent to restart the training process. (
Browse files Browse the repository at this point in the history
…#775)

* Add test cases.

* Add test cases.

* Fix bugs to update reset_hardware

* Restart the training process on worker nodes.
  • Loading branch information
workingloong authored Nov 23, 2023
1 parent 25a06ed commit 9bb113d
Show file tree
Hide file tree
Showing 17 changed files with 167 additions and 3 deletions.
6 changes: 6 additions & 0 deletions dlrover/python/common/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,13 @@ class ParallelConfigRequest(Message):
pass


@dataclass
class CheckHardwareResetRequest(Message):
pass


@dataclass
class ParallelConfig(Message):
dataloader: DataLoaderConfig = DataLoaderConfig()
optimizer: OptimizerConfig = OptimizerConfig()
restart: bool = False
8 changes: 6 additions & 2 deletions dlrover/python/common/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def __init__(
service_addr=None,
host_name=None,
host_ip=None,
paral_config=None,
paral_config=ParallelConfig(),
restart_training=False,
):
self.type = node_type
self.id = node_id
Expand All @@ -209,7 +210,8 @@ def __init__(
self.host_name = host_name
self.host_ip = host_ip
self.hang = False
self.paral_config = ParallelConfig()
self.paral_config = paral_config
self.restart_training = restart_training

def inc_relaunch_count(self):
self.relaunch_count += 1
Expand All @@ -221,6 +223,7 @@ def update_info(
create_time=None,
host_name=None,
host_ip=None,
restart_training=False,
):
if name is not None:
self.name = name
Expand All @@ -232,6 +235,7 @@ def update_info(
self.host_name = host_name
if host_ip:
self.host_ip = host_ip
self.restart_training = restart_training

def update_status(self, status=None):
if status is not None:
Expand Down
9 changes: 9 additions & 0 deletions dlrover/python/elastic_agent/master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,15 @@ def get_paral_config(self) -> grpc.ParallelConfig:
result = self._get(request)
return result

def need_to_restart_training(self):
request = grpc.CheckHardwareResetRequest()
try:
result: grpc.ParallelConfig = self._get(request)
return result.restart
except Exception:
logger.warning("Fail to verify restarting training processes.")
return False

@classmethod
def singleton_instance(cls, *args, **kwargs):
if not MasterClient._instance:
Expand Down
11 changes: 11 additions & 0 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
while True:
assert self._worker_group.state != WorkerState.INIT
time.sleep(monitor_interval)
self._stop_workers_to_restart()
try:
run_result: RunResult = self._monitor_workers(
self._worker_group
Expand Down Expand Up @@ -557,6 +558,16 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
else:
raise Exception(f"[{role}] Worker group in {state.name} state")

def _stop_workers_to_restart(self):
"""
The agent query from the dlrover job master to check whether to restart
workers. If true, the agent firstly stops all workers.
"""
restart = self._client.need_to_restart_training()
if not restart:
return
self._stop_workers(self._worker_group)

def _report_failure_to_master(self, failures: Dict[int, ProcessFailure]):
errors = {}
if len(failures) == 0:
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/master/local_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def stop(self):
"""
logger.info("Stopping master!")
logger.info("Stopping RPC server!")
self._master_server.stop(grace=None)
self._master_server.stop(grace=0.1)
logger.info("RPC server stopped!")
logger.info("Master stopped!")

Expand Down
4 changes: 4 additions & 0 deletions dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def _process_event(self, event: NodeEvent):
create_time=event.node.create_time,
host_name=event.node.host_name,
host_ip=event.node.host_ip,
restart_training=event.node.restart_training,
)

# For the given node id, check whether it meets
Expand Down Expand Up @@ -734,6 +735,9 @@ def update_node_paral_config(self, node_type, node_id, paral_config):
node = self._job_nodes[node_type][node_id]
node.update_paral_config(paral_config)

def verify_restarting_worker_training(self):
return self._worker_manager.verify_restarting_training()


def create_job_manager(args: JobArgs, speed_monitor) -> DistributedJobManager:
critical_worker_index = get_critical_worker_index(args)
Expand Down
11 changes: 11 additions & 0 deletions dlrover/python/master/node/job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,17 @@ def get_opt_strategy(self):
def update_node_paral_config(self, node_type, node_id, paral_config):
pass

@abstractclassmethod
def verify_restarting_worker_training(self):
"""
Verify the necessity of restarting the training process
on the worker nodes.
Returns:
bool
"""
pass

def handle_training_failure(
self, node_type, node_id, restart_count=-1, error_data="", level=""
):
Expand Down
3 changes: 3 additions & 0 deletions dlrover/python/master/node/local_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def pend_without_workers(self):
def update_allreduce_node_unit(self, node_unit):
pass

def verify_restarting_worker_training(self):
return False

def get_opt_strategy(self) -> ParallelConfig:
strategy = self._job_strategy_generator.generate_opt_strategy()
return strategy
Expand Down
19 changes: 19 additions & 0 deletions dlrover/python/master/node/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,22 @@ def wait_worker_restart(self):
):
return True
return False

def verify_restarting_training(self):
"""
Verify if the worker requires restarting the training process.
The worker will restart the training processes if any of the
following conditions are met:
1. RestartTrain action in the Pod annotations.
2. One training process crashes in the worker.
Return:
bool
"""
restart = False
for worker in self._nodes.values():
if not worker.is_released and worker.restart_training:
restart = True
# Set False to avoid restart repeatedly.
worker.restart_training = False
return restart
10 changes: 10 additions & 0 deletions dlrover/python/master/servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def get(self, request, _):
message = self._get_training_status()
elif isinstance(req_message, grpc.ParallelConfigRequest):
message = self._get_paral_config()
elif isinstance(req_message, grpc.CheckHardwareResetRequest):
message = self._need_to_restart_training()

if message:
response.data = message.serialize()
Expand Down Expand Up @@ -268,6 +270,14 @@ def _kv_store_get(self, request: grpc.KeyValuePair):

def _get_paral_config(self):
res = self._job_manager.get_opt_strategy()
if not res:
res = grpc.ParallelConfig()
return res

def _need_to_restart_training(self):
restart = self._job_manager.verify_restarting_worker_training()
res = grpc.ParallelConfig()
res.restart = restart
return res

def report(self, request, _):
Expand Down
22 changes: 22 additions & 0 deletions dlrover/python/master/watcher/k8s_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from typing import List

from kubernetes import watch
Expand Down Expand Up @@ -99,6 +100,9 @@ def _convert_pod_event_to_node_event(event):
host_name = evt_obj.spec.node_name
host_ip = evt_obj.status.host_ip

restart = _verify_restarting_training(evt_obj)
logger.info(f"{evt_obj.metadata.name} resets hardware {restart}")

resource = _parse_container_resource(evt_obj.spec.containers[0])
status = evt_obj.status.phase
if evt_obj.metadata.deletion_timestamp:
Expand All @@ -114,6 +118,7 @@ def _convert_pod_event_to_node_event(event):
config_resource=resource,
host_name=host_name,
host_ip=host_ip,
restart_training=restart,
)
node.create_time = evt_obj.metadata.creation_timestamp
node.set_exit_reason(_get_pod_exit_reason(evt_obj))
Expand All @@ -127,6 +132,21 @@ def _parse_container_resource(container):
return NodeResource(cpu, memory)


def _verify_restarting_training(pod):
if not pod.metadata.annotations:
return False
action_str = pod.metadata.annotations.get(
"pod.sigma.ali/scheduled-action", ""
)
if not action_str:
return False
action_config = json.loads(action_str)
action = action_config.get("scheduledAction", "")
if action == "RestartTrain_Observe":
return True
return False


class PodWatcher(NodeWatcher):
"""PodWatcher monitors all Pods of a k8s Job."""

Expand Down Expand Up @@ -178,6 +198,7 @@ def list(self) -> List[Node]:
resource = _parse_container_resource(pod.spec.containers[0])
status = pod.status.phase
start_time = _get_start_timestamp(pod.status)
restart_training = _verify_restarting_training(pod)
node = Node(
node_type=pod_type,
node_id=pod_id,
Expand All @@ -186,6 +207,7 @@ def list(self) -> List[Node]:
status=status,
start_time=start_time,
config_resource=resource,
restart_training=restart_training,
)
node.set_exit_reason(_get_pod_exit_reason(pod))
nodes.append(node)
Expand Down
12 changes: 12 additions & 0 deletions dlrover/python/tests/test_elastic_training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,18 @@ def test_get_free_port(self):
port = agent._get_free_port()
self.assertTrue(port > 20000)

def test_restart_training(self):
self.config.restart = True
agent = ElasticTrainingAgent(
node_rank=0,
config=self.config,
entrypoint="echo",
spec=self.spec,
start_method=self.config.start_method,
log_dir=self.config.log_dir,
)
agent._stop_workers_to_restart()


class NetworkCheckElasticAgentTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down
25 changes: 25 additions & 0 deletions dlrover/python/tests/test_k8s_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.

import datetime
import json
import unittest
from typing import List

Expand All @@ -31,6 +32,7 @@
PodWatcher,
_convert_pod_event_to_node_event,
_get_pod_exit_reason,
_verify_restarting_training,
)
from dlrover.python.tests.test_utils import (
create_pod,
Expand Down Expand Up @@ -102,6 +104,29 @@ def test_get_pod_exit_reason(self):
exit_reason = _get_pod_exit_reason(pod)
self.assertEqual(exit_reason, NodeExitReason.FATAL_ERROR)

def test_verify_restarting_training(self):
labels = {
ElasticJobLabel.APP_NAME: "test",
ElasticJobLabel.REPLICA_TYPE_KEY: NodeType.WORKER,
ElasticJobLabel.REPLICA_INDEX_KEY: "0",
ElasticJobLabel.RANK_INDEX_KEY: "0",
}
pod = create_pod(labels)
reset = _verify_restarting_training(pod)
self.assertFalse(reset)
action = {
"observedTime": "2020-04-30 00:00:00",
"scheduledExecutionTime": "2020-04-30 00:10:00",
"scheduledAction": "RestartTrain_Observe",
"device_ids": ["npu_id_1", "npu_id_2"],
"eventType": "NPU_reset",
}
pod.metadata.annotations[
"pod.sigma.ali/scheduled-action"
] = json.dumps(action)
reset = _verify_restarting_training(pod)
self.assertTrue(reset)


class ScalePlanWatcherTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down
9 changes: 9 additions & 0 deletions dlrover/python/tests/test_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,15 @@ def test_get_straggler(self):
config = grpc.deserialize_message(response.data)
self.assertIsInstance(config, grpc.NetworkCheckResult)

def test_check_hardware_reset(self):
message = grpc.CheckHardwareResetRequest()
request = elastic_training_pb2.Message()
request.data = message.serialize()
response = self.servicer.get(request, None)
config = grpc.deserialize_message(response.data)
self.assertIsInstance(config, grpc.ParallelConfig)
self.assertFalse(config.restart)

def test_join_rendezvous(self):
request = grpc.JoinRendezvousRequest(
0, 8, RendezvousName.ELASTIC_TRAINING
Expand Down
1 change: 1 addition & 0 deletions dlrover/python/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def create_pod(labels):
metadata=client.V1ObjectMeta(
name="test-worker-0",
labels=labels,
annotations={},
),
status=status,
)
Expand Down
17 changes: 17 additions & 0 deletions dlrover/python/tests/test_worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,20 @@ def test_pending_without_workers(self):

wait = worker_manager.wait_worker_restart()
self.assertFalse(wait)

def test_verify_restarting_training(self):
worker_manager = WorkerManager(
self._job_nodes[NodeType.WORKER],
self._job_resource,
3,
self._elastic_job.get_node_service_addr,
self._elastic_job.get_node_name,
)
reset = worker_manager.verify_restarting_training()
self.assertFalse(reset)
worker_manager._nodes[0].restart_training = True
reset = worker_manager.verify_restarting_training()
self.assertTrue(reset)
worker_manager._nodes[0].is_released = True
reset = worker_manager.verify_restarting_training()
self.assertFalse(reset)
1 change: 1 addition & 0 deletions dlrover/trainer/tests/torch/elastic_run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@ def test_elastic_config_from_args(self):
self.assertTrue(config.network_check)
self.assertTrue(config.auto_tunning)
self.assertEqual(config.node_unit, 4)
self.assertEqual(config.rdzv_configs["node_unit"], 4)
self.assertEqual(cmd, "/usr/local/bin/python")
self.assertListEqual(cmd_args, ["-u", "test.py", "--batch_size", "16"])

0 comments on commit 9bb113d

Please sign in to comment.