From b1bea1a3dce2b0cb489a4f461893c7687dac0f3c Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Thu, 21 Nov 2024 15:48:05 +0800 Subject: [PATCH 01/20] impl http server --- dlrover/python/brain/client.py | 4 +- dlrover/python/common/constants.py | 2 + dlrover/python/common/grpc.py | 2 +- dlrover/python/common/http_server.py | 69 ++++++++++++ dlrover/python/common/test.py | 46 ++++++++ dlrover/python/elastic_agent/master_client.py | 4 +- dlrover/python/master/dist_master.py | 4 +- dlrover/python/master/local_master.py | 4 +- dlrover/python/master/servicer.py | 102 ++++++++++++------ dlrover/python/tests/test_common_util.py | 5 +- dlrover/python/tests/test_http_server.py | 95 ++++++++++++++++ dlrover/python/tests/test_servicer.py | 6 +- dlrover/python/util/common_util.py | 11 ++ 13 files changed, 310 insertions(+), 44 deletions(-) create mode 100644 dlrover/python/common/http_server.py create mode 100644 dlrover/python/common/test.py create mode 100644 dlrover/python/tests/test_http_server.py diff --git a/dlrover/python/brain/client.py b/dlrover/python/brain/client.py index d23746f46..c1ac869f5 100644 --- a/dlrover/python/brain/client.py +++ b/dlrover/python/brain/client.py @@ -14,7 +14,7 @@ import os from dlrover.proto import brain_pb2, brain_pb2_grpc -from dlrover.python.common.grpc import build_channel, grpc_server_ready +from dlrover.python.common.grpc import build_grpc_channel, grpc_server_ready from dlrover.python.common.log import default_logger as logger DATA_STORE = "base_datastore" @@ -268,7 +268,7 @@ def build_brain_client(): ``` """ brain_addr = os.getenv(_ENV_BRAIN_ADDR_KEY, "") - channel = build_channel(brain_addr) + channel = build_grpc_channel(brain_addr) if channel and grpc_server_ready(channel): return BrainClient(channel) else: diff --git a/dlrover/python/common/constants.py b/dlrover/python/common/constants.py index ec2c76faf..c45623cec 100644 --- a/dlrover/python/common/constants.py +++ b/dlrover/python/common/constants.py @@ -14,6 +14,8 @@ class BasicClass(object): LOG_LEVEL_ENV = "DLROVER_LOG_LEVEL" + COMM_SERVICE_GRPC = "grpc" + COMM_SERVICE_HTTP = "http" class PriorityClass(object): diff --git a/dlrover/python/common/grpc.py b/dlrover/python/common/grpc.py index 8d691f3a7..b73c9820b 100644 --- a/dlrover/python/common/grpc.py +++ b/dlrover/python/common/grpc.py @@ -27,7 +27,7 @@ TIMEOUT_SEC = 5 -def build_channel(addr): +def build_grpc_channel(addr): if not addr_connected(addr): return None channel = grpc.insecure_channel( diff --git a/dlrover/python/common/http_server.py b/dlrover/python/common/http_server.py new file mode 100644 index 000000000..d344c20d2 --- /dev/null +++ b/dlrover/python/common/http_server.py @@ -0,0 +1,69 @@ +# Copyright 2024 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import signal +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from http.server import HTTPServer +from socketserver import ThreadingMixIn + +import tornado + +from dlrover.python.common.log import default_logger as logger + + +class CustomHTTPServer(object): + + SERVING_THREAD_NAME = "http-server-serving-thread" + + def __init__(self, address, port, handler_class): + self._address = address + self._port = port + self._handler_class = handler_class + + self._io_loop = None + self._server = None + self._serving_started = False + + def start_serving(self): + if not self.is_serving(): + self._serving_started = True + + server_thread = threading.Thread( + target=self._start_server, + name=CustomHTTPServer.SERVING_THREAD_NAME, + ) + server_thread.start() + + # wait 3s for sever start + time.sleep(3) + + def _start_server(self): + try: + self._server = tornado.httpserver.HTTPServer( + tornado.web.Application([(r"/", self._handler_class)])) + self._server.listen(self._port) + self._io_loop = tornado.ioloop.IOLoop.current() + self._io_loop.start() + except Exception as e: + logger.error(f"Http server start with error: {e}") + + def stop_serving(self): + if self._server: + self._server.stop() + self._io_loop.add_callback(self._io_loop.stop) + + self._serving_started = False + + def is_serving(self): + return self._serving_started diff --git a/dlrover/python/common/test.py b/dlrover/python/common/test.py new file mode 100644 index 000000000..697873678 --- /dev/null +++ b/dlrover/python/common/test.py @@ -0,0 +1,46 @@ +import threading +import tornado.ioloop +import tornado.web +import tornado.httpserver +import time +import signal + +class MainHandler(tornado.web.RequestHandler): + def get(self): + self.write("Hello, world") + +def make_app(): + return tornado.web.Application([ + (r"/", MainHandler), + ]) + +def start_tornado_server(): + app = make_app() + server = tornado.httpserver.HTTPServer(app) + server.listen(8000) + tornado.ioloop.IOLoop.current().start() + +def stop_tornado_server(): + tornado.ioloop.IOLoop.current().stop() + +if __name__ == "__main__": + # 启动 Tornado 服务器的后台线程 + server_thread = threading.Thread(target=start_tornado_server) + server_thread.start() + + # 处理系统信号以优雅地关闭服务器 + def signal_handler(signum, frame): + print("Stopping Tornado server") + stop_tornado_server() + server_thread.join() + print("Tornado server stopped") + + signal.signal(signal.SIGINT, signal_handler) + + # 主线程继续做其他事情 + try: + while True: + print("Main thread is doing other things") + time.sleep(1) + except KeyboardInterrupt: + signal_handler(signal.SIGINT, None) \ No newline at end of file diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index 93b833ba4..8da02af3e 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -86,7 +86,7 @@ def __init__(self, master_addr, node_id, node_type, timeout=5): ) self._timeout = timeout self._master_addr = master_addr - self._channel = grpc.build_channel(master_addr) + self._channel = grpc.build_grpc_channel(master_addr) self._stub = elastic_training_pb2_grpc.MasterStub(self._channel) self._node_id = node_id self._node_type = node_type @@ -107,7 +107,7 @@ def close_channel(self): self._channel.close() def open_channel(self): - self._channel = grpc.build_channel(self._master_addr) + self._channel = grpc.build_grpc_channel(self._master_addr) self._stub = elastic_training_pb2_grpc.MasterStub(self._channel) def find_free_port(self): diff --git a/dlrover/python/master/dist_master.py b/dlrover/python/master/dist_master.py index c68942e2c..7af0eb6ca 100644 --- a/dlrover/python/master/dist_master.py +++ b/dlrover/python/master/dist_master.py @@ -149,14 +149,14 @@ def __init__( ) self.elastic_ps_service = _create_elastic_ps_service_if_needed(args) self.sync_service = SyncService(self.job_manager) - self._master_server = self._create_master_grpc_service(port, args) + self._master_server = self._create_master_service(port, args) self._job_args = args self._stop_requested = False self._exit_code = 0 self._exit_reason = None self._error_monitor = error_monitor - def _create_master_grpc_service(self, port, params: JobArgs): + def _create_master_service(self, port, params: JobArgs): return create_master_service( port, self.task_manager, diff --git a/dlrover/python/master/local_master.py b/dlrover/python/master/local_master.py index e37d38fde..d0d066278 100644 --- a/dlrover/python/master/local_master.py +++ b/dlrover/python/master/local_master.py @@ -48,13 +48,13 @@ def __init__(self, port, args: JobArgs): self.job_metric_collector = self._create_metric_collector_if_needed( args ) - self._master_server = self._create_master_grpc_service(port, args) + self._master_server = self._create_master_service(port, args) self._job_args = args for i in range(args.node_args[NodeType.WORKER].group_resource.count): self.speed_monitor.add_running_worker(NodeType.WORKER, i) self.speed_monitor.set_target_worker_num(1) - def _create_master_grpc_service(self, port, params: JobArgs): + def _create_master_service(self, port, params: JobArgs): return create_master_service( port, self.task_manager, diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index 9a1db93c6..5ee196ae1 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -13,7 +13,9 @@ import importlib import threading import time +from abc import ABC from concurrent import futures +from http.server import ThreadingHTTPServer from typing import Dict, List, Optional import grpc as grpc_lib @@ -28,7 +30,7 @@ NodeType, RendezvousName, TrainingExceptionLevel, - TrainingLoopStatus, + TrainingLoopStatus, BasicClass, ) from dlrover.python.common.global_context import Context from dlrover.python.common.log import default_logger as logger @@ -65,8 +67,8 @@ ray_event_queue = RayEventQueue.singleton_instance() -class MasterServicer(elastic_training_pb2_grpc.MasterServicer): - """Master service implementation""" +class MasterServicer(ABC): + """Master service base class.""" def __init__( self, @@ -98,6 +100,40 @@ def __init__( "dlrover.python.diagnosis.common.diagnosis_data" ) + +class HttpMasterServicer(MasterServicer): + """Master service with http implementation.""" + + def __init__( + self, + task_manager, + job_manager, + speed_monitor: SpeedMonitor, + rdzv_managers: Dict[str, RendezvousManager], + diagnosis_manager: DiagnosisManager, + job_metric_collector=None, + elastic_ps_service=None, + sync_service=None, + ): + super(HttpMasterServicer, self).__init__(task_manager, job_manager, speed_monitor, rdzv_managers, diagnosis_manager, job_metric_collector, elastic_ps_service, sync_service) + + +class GrpcMasterServicer(MasterServicer, elastic_training_pb2_grpc.MasterServicer): + """Master service with grpc implementation.""" + + def __init__( + self, + task_manager, + job_manager, + speed_monitor: SpeedMonitor, + rdzv_managers: Dict[str, RendezvousManager], + diagnosis_manager: DiagnosisManager, + job_metric_collector=None, + elastic_ps_service=None, + sync_service=None, + ): + super(GrpcMasterServicer, self).__init__(task_manager, job_manager, speed_monitor, rdzv_managers, diagnosis_manager, job_metric_collector, elastic_ps_service, sync_service) + def get(self, request, _): node_type = request.node_type node_id = request.node_id @@ -665,34 +701,38 @@ def create_master_service( job_metric_collector, elastic_ps_service, sync_service, + service_type=BasicClass.COMM_SERVICE_GRPC, + max_threads=64, ) -> MasterServicer: - """Create GRPC server""" - logger.info("Creating master service") - server = grpc_lib.server( - futures.ThreadPoolExecutor(max_workers=64), - options=[ - ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH), - ( - "grpc.max_receive_message_length", - GRPC.MAX_RECEIVE_MESSAGE_LENGTH, - ), - ], - ) - master_servicer = MasterServicer( - task_manager=task_manager, - job_manager=job_manager, - speed_monitor=speed_monitor, - rdzv_managers=rdzv_managers, - diagnosis_manager=diagnosis_manager, - job_metric_collector=job_metric_collector, - elastic_ps_service=elastic_ps_service, - sync_service=sync_service, - ) - elastic_training_pb2_grpc.add_MasterServicer_to_server( - master_servicer, server - ) - server.add_insecure_port("[::]:{}".format(port)) - logger.info("The port of the master server is: %d", port) + logger.info(f"Creating master {service_type} service with port: {port}") + + if service_type == BasicClass.COMM_SERVICE_GRPC: + server = grpc_lib.server( + futures.ThreadPoolExecutor(max_workers=max_threads), + options=[ + ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH), + ( + "grpc.max_receive_message_length", + GRPC.MAX_RECEIVE_MESSAGE_LENGTH, + ), + ], + ) + master_servicer = GrpcMasterServicer( + task_manager=task_manager, + job_manager=job_manager, + speed_monitor=speed_monitor, + rdzv_managers=rdzv_managers, + diagnosis_manager=diagnosis_manager, + job_metric_collector=job_metric_collector, + elastic_ps_service=elastic_ps_service, + sync_service=sync_service, + ) - return server + elastic_training_pb2_grpc.add_MasterServicer_to_server( + master_servicer, server + ) + server.add_insecure_port("[::]:{}".format(port)) + return server + else: + pass diff --git a/dlrover/python/tests/test_common_util.py b/dlrover/python/tests/test_common_util.py index 895eebbc9..3235adbae 100644 --- a/dlrover/python/tests/test_common_util.py +++ b/dlrover/python/tests/test_common_util.py @@ -1,4 +1,4 @@ -# Copyright 2022 The DLRover Authors. All rights reserved. +# Copyright 2024 The DLRover Authors. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,3 +20,6 @@ class CommonUtilTest(unittest.TestCase): def test_get_dlrover_version(self): self.assertIsNotNone(cu.get_dlrover_version()) self.assertNotEqual(cu.get_dlrover_version(), "Unknown") + + def test_is_port_in_use(self): + self.assertFalse(cu.is_port_in_use(65530)) diff --git a/dlrover/python/tests/test_http_server.py b/dlrover/python/tests/test_http_server.py new file mode 100644 index 000000000..b80251219 --- /dev/null +++ b/dlrover/python/tests/test_http_server.py @@ -0,0 +1,95 @@ +# Copyright 2022 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time +import unittest +from concurrent.futures import ThreadPoolExecutor, as_completed + +import requests +import tornado + +from dlrover.python.common.http_server import CustomHTTPServer +from util.common_util import is_port_in_use + +TEST_SERVER_ADDR = "localhost" +TEST_SERVER_PORT = 8000 + + +class HttpServerClientTest(unittest.TestCase): + + def setUp(self): + self.server = None + + def tearDown(self): + if self.server is not None: + self.server.stop_serving() + self.server = None + + def test_server_basic(self): + self.server = CustomHTTPServer(TEST_SERVER_ADDR, TEST_SERVER_PORT, TestRequestHandler, 2) + self.assertIsNotNone(self.server) + self.assertFalse(is_port_in_use(TEST_SERVER_PORT)) + + self.assertFalse(self.server.is_serving()) + self.server.start_serving() + self.assertTrue(self.server.is_serving()) + self.assertTrue(is_port_in_use(TEST_SERVER_PORT)) + self.server.start_serving() + self.assertTrue(self.server.is_serving()) + + active_threads_name = [t.name for t in threading.enumerate()] + self.assertIn(CustomHTTPServer.SERVING_THREAD_NAME, active_threads_name) + time.sleep(1) + + # test get request + self._test_get_request() + + self.server.stop_serving() + self.assertFalse(self.server.is_serving()) + + def _test_get_request(self): + try: + with requests.get("http://localhost:8000") as response: + self.assertEqual(response.status_code, 200) + self.assertEqual(response.text, "Hello, world!") + return response + except Exception as e: + raise e + + def test_server_concurrency(self): + self.server = CustomHTTPServer(TEST_SERVER_ADDR, TEST_SERVER_PORT, + TestRequestHandler) + self.server.start_serving() + + futures = [] + result_num = 0 + client_size = 1000 + with ThreadPoolExecutor(max_workers=client_size) as executor: + for i in range(client_size): + futures.append( + executor.submit(self._test_get_request) + ) + + for future in as_completed(futures): + if future.result().status_code == 200: + result_num += 1 + self.assertEqual(len(futures), client_size) + self.assertEqual(result_num, client_size) + + self.server.stop_serving() + + +class TestRequestHandler(tornado.web.RequestHandler): + def get(self): + self.write("Hello, world!") diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index 548bc3f72..dc2483f5b 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -38,7 +38,7 @@ from dlrover.python.master.monitor.speed_monitor import SpeedMonitor from dlrover.python.master.node.dist_job_manager import create_job_manager from dlrover.python.master.node.job_context import get_job_context -from dlrover.python.master.servicer import MasterServicer +from dlrover.python.master.servicer import GrpcMasterServicer from dlrover.python.master.shard.task_manager import TaskManager from dlrover.python.master.stats.job_collector import JobMetricCollector from dlrover.python.tests.test_utils import ( @@ -81,7 +81,7 @@ def setUp(self) -> None: RendezvousName.NETWORK_CHECK: NetworkCheckRendezvousManager(), } sync_service = SyncService(self.job_manager) - self.servicer = MasterServicer( + self.servicer = GrpcMasterServicer( task_manager=self.task_manager, job_manager=self.job_manager, speed_monitor=speed_monitor, @@ -507,7 +507,7 @@ def setUp(self) -> None: "1", "default", "local", "dlrover" ) self.elastic_ps_service = ElasticPsService() - self.servicer = MasterServicer( + self.servicer = GrpcMasterServicer( task_manager=self.task_manager, job_manager=self.job_manager, speed_monitor=speed_monitor, diff --git a/dlrover/python/util/common_util.py b/dlrover/python/util/common_util.py index 824fa5341..b9ba393f2 100644 --- a/dlrover/python/util/common_util.py +++ b/dlrover/python/util/common_util.py @@ -13,6 +13,7 @@ import importlib.metadata import re +import socket import dlrover.python.util.file_util as fu @@ -57,3 +58,13 @@ def get_installed_version(package_name): return version except importlib.metadata.PackageNotFoundError: return None + + +def is_port_in_use(port): + """ + Check if the port is in use. + """ + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + result = sock.connect_ex(('localhost', port)) + return result == 0 From 00f4a758165fbec699bd7b43a5e9f4b06ed80a6f Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Wed, 4 Dec 2024 16:48:06 +0800 Subject: [PATCH 02/20] fix ut --- dlrover/python/tests/test_http_server.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dlrover/python/tests/test_http_server.py b/dlrover/python/tests/test_http_server.py index b80251219..1ac3c0b5b 100644 --- a/dlrover/python/tests/test_http_server.py +++ b/dlrover/python/tests/test_http_server.py @@ -37,7 +37,7 @@ def tearDown(self): self.server = None def test_server_basic(self): - self.server = CustomHTTPServer(TEST_SERVER_ADDR, TEST_SERVER_PORT, TestRequestHandler, 2) + self.server = CustomHTTPServer(TEST_SERVER_ADDR, TEST_SERVER_PORT, TestRequestHandler) self.assertIsNotNone(self.server) self.assertFalse(is_port_in_use(TEST_SERVER_PORT)) @@ -68,13 +68,12 @@ def _test_get_request(self): raise e def test_server_concurrency(self): - self.server = CustomHTTPServer(TEST_SERVER_ADDR, TEST_SERVER_PORT, - TestRequestHandler) + self.server = CustomHTTPServer(TEST_SERVER_ADDR, TEST_SERVER_PORT, TestRequestHandler) self.server.start_serving() futures = [] result_num = 0 - client_size = 1000 + client_size = 100 with ThreadPoolExecutor(max_workers=client_size) as executor: for i in range(client_size): futures.append( From 5c1007453e46f9960c80a7db0d97fe28658c0100 Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Wed, 4 Dec 2024 19:57:04 +0800 Subject: [PATCH 03/20] impl new servicer based on http --- dlrover/python/brain/client.py | 2 +- dlrover/python/common/{grpc.py => comm.py} | 79 +---- dlrover/python/common/global_context.py | 9 +- dlrover/python/common/http_server.py | 50 ++- dlrover/python/common/node.py | 2 +- dlrover/python/common/test.py | 35 +- .../config/paral_config_tuner.py | 4 +- dlrover/python/elastic_agent/master_client.py | 122 +++---- .../python/elastic_agent/monitor/resource.py | 2 +- .../python/elastic_agent/sharding/client.py | 4 +- .../python/elastic_agent/tensorflow/hooks.py | 2 +- .../python/elastic_agent/torch/training.py | 10 +- .../hyperparams/simple_strategy_generator.py | 4 +- .../python/master/node/dist_job_manager.py | 2 +- .../python/master/node/local_job_manager.py | 2 +- dlrover/python/master/servicer.py | 314 ++++++++++-------- dlrover/python/master/shard/task_manager.py | 6 +- dlrover/python/master/stats/job_collector.py | 2 +- dlrover/python/master/stats/reporter.py | 2 +- .../python/tests/test_agent_config_tuner.py | 4 +- dlrover/python/tests/test_agent_monitor.py | 2 +- dlrover/python/tests/test_common_util.py | 34 ++ dlrover/python/tests/test_grpc_utils.py | 36 +- dlrover/python/tests/test_http_server.py | 35 +- dlrover/python/tests/test_job_manager.py | 12 +- dlrover/python/tests/test_master_client.py | 13 +- dlrover/python/tests/test_servicer.py | 94 +++--- .../python/tests/test_strategy_generator.py | 4 +- dlrover/python/tests/test_task_manager.py | 2 +- dlrover/python/tests/test_utils.py | 2 +- dlrover/python/util/common_util.py | 72 ++++ .../tests/torch/checkpoint_egine_test.py | 2 +- .../tests/torch/elastic_dataloader_test.py | 2 +- dlrover/trainer/tests/torch/elastic_test.py | 2 +- dlrover/trainer/tests/torch/fsdp_ckpt_test.py | 4 +- dlrover/trainer/torch/elastic_run.py | 6 +- 36 files changed, 546 insertions(+), 433 deletions(-) rename dlrover/python/common/{grpc.py => comm.py} (80%) diff --git a/dlrover/python/brain/client.py b/dlrover/python/brain/client.py index c1ac869f5..aef75b84b 100644 --- a/dlrover/python/brain/client.py +++ b/dlrover/python/brain/client.py @@ -14,7 +14,7 @@ import os from dlrover.proto import brain_pb2, brain_pb2_grpc -from dlrover.python.common.grpc import build_grpc_channel, grpc_server_ready +from dlrover.python.common.comm import build_grpc_channel, grpc_server_ready from dlrover.python.common.log import default_logger as logger DATA_STORE = "base_datastore" diff --git a/dlrover/python/common/grpc.py b/dlrover/python/common/comm.py similarity index 80% rename from dlrover/python/common/grpc.py rename to dlrover/python/common/comm.py index 066c87c7a..b762ba9d2 100644 --- a/dlrover/python/common/grpc.py +++ b/dlrover/python/common/comm.py @@ -12,15 +12,13 @@ # limitations under the License. import pickle -import random import socket -from contextlib import closing from dataclasses import dataclass, field from typing import Dict, List import grpc -from dlrover.python.common.constants import GRPC, AscendConstants +from dlrover.python.common.constants import GRPC from dlrover.python.common.log import default_logger as logger from dlrover.python.common.serialize import JsonSerializable @@ -68,74 +66,6 @@ def addr_connected(addr): return False -def find_free_port(port=0): - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: - s.bind(("", port)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - return s.getsockname()[1] - - -def find_free_port_in_range(start=0, end=65535, random_port=True): - """Find a free port from a range.""" - bind_ports = set() - while True: - if random_port: - port = random.randint(start, end) - else: - port = start + len(bind_ports) - if port in bind_ports: - continue - try: - return find_free_port(port) - except OSError: - logger.warning(f"Socket creation attempt failed with {port}.") - bind_ports.add(port) - if len(bind_ports) == end - start + 1: - break - raise RuntimeError(f"Fail to find a free port in [{start}, {end})") - - -def find_free_port_in_set(ports): - for port in ports: - try: - return find_free_port(port) - except OSError: - logger.warning(f"Socket creation attempt failed with {port}.") - raise RuntimeError(f"Fail to find a free port in {ports}") - - -def find_free_port_for_hccl( - start=AscendConstants.HCCL_PORT_START_DEFAULT, -) -> int: - max_port = 65500 - cur_start = start - end = start + 10000 - if end > max_port: - end = max_port - logger.info(f"Try to find available port for hccl from {start}") - checking_port = 0 - while True: - try: - cur_end = cur_start + AscendConstants.NPU_PER_NODE - for port in range(cur_start, cur_end): - checking_port = port - find_free_port(port) - logger.info(f"Find available port start from: {cur_start}") - break - except OSError: - logger.warning( - f"Target port has already been used: {checking_port}." - ) - if checking_port > 0: - cur_start = checking_port + 1 - else: - cur_start = cur_start + AscendConstants.NPU_PER_NODE - if cur_start > end: - cur_start = 0 - break - return cur_start - - def grpc_server_ready(channel) -> bool: try: grpc.channel_ready_future(channel).result(timeout=TIMEOUT_SEC) @@ -163,6 +93,13 @@ def serialize(self): return pickle.dumps(self) +@dataclass +class BaseMessage(Message): + node_id: int = -1 + node_type: str = "" + data: bytes = bytes() + + @dataclass class TaskRequest(Message): dataset_name: str = "" diff --git a/dlrover/python/common/global_context.py b/dlrover/python/common/global_context.py index 30c4335d1..f039d2c4d 100644 --- a/dlrover/python/common/global_context.py +++ b/dlrover/python/common/global_context.py @@ -13,10 +13,13 @@ import os -from dlrover.python.common import grpc from dlrover.python.common.constants import UserEnv from dlrover.python.common.log import default_logger as logger from dlrover.python.common.singleton import Singleton +from dlrover.python.util.common_util import ( + find_free_port_in_range, + find_free_port_in_set, +) class ConfigKeys(object): @@ -167,13 +170,13 @@ def config_master_port(self, port=0): for port in host_ports_env.split(","): ports.append(int(port)) try: - self.master_port = grpc.find_free_port_in_set(ports) + self.master_port = find_free_port_in_set(ports) except RuntimeError as e: logger.warning(e) elif port > 0: self.master_port = port if self.master_port is None: - self.master_port = grpc.find_free_port_in_range(20000, 30000) + self.master_port = find_free_port_in_range(20000, 30000) def get_param_value_from_brain(self, key_name, default_value, dtype=float): """TODO: Get the configured value from Brain service.""" diff --git a/dlrover/python/common/http_server.py b/dlrover/python/common/http_server.py index d344c20d2..701949f37 100644 --- a/dlrover/python/common/http_server.py +++ b/dlrover/python/common/http_server.py @@ -10,38 +10,63 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import signal + +import abc import threading import time -from concurrent.futures import ThreadPoolExecutor -from http.server import HTTPServer -from socketserver import ThreadingMixIn import tornado from dlrover.python.common.log import default_logger as logger -class CustomHTTPServer(object): - - SERVING_THREAD_NAME = "http-server-serving-thread" - +class CustomHTTPServer(abc.ABC): def __init__(self, address, port, handler_class): self._address = address self._port = port self._handler_class = handler_class + @property + def address(self): + return self._address + + @property + def port(self): + return self._port + + @property + def handler_class(self): + return self._handler_class + + @abc.abstractmethod + def start(self): + """Start the server.""" + pass + + @abc.abstractmethod + def stop(self): + """Stop the server.""" + pass + + +class TornadoHTTPServer(CustomHTTPServer): + + SERVING_THREAD_NAME = "http-server-serving-thread" + + def __init__(self, address, port, handler_class): + super().__init__(address, port, handler_class) + self._io_loop = None self._server = None self._serving_started = False - def start_serving(self): + def start(self): if not self.is_serving(): self._serving_started = True server_thread = threading.Thread( target=self._start_server, - name=CustomHTTPServer.SERVING_THREAD_NAME, + name=TornadoHTTPServer.SERVING_THREAD_NAME, ) server_thread.start() @@ -51,14 +76,15 @@ def start_serving(self): def _start_server(self): try: self._server = tornado.httpserver.HTTPServer( - tornado.web.Application([(r"/", self._handler_class)])) + tornado.web.Application([(r"/", self._handler_class)]) + ) self._server.listen(self._port) self._io_loop = tornado.ioloop.IOLoop.current() self._io_loop.start() except Exception as e: logger.error(f"Http server start with error: {e}") - def stop_serving(self): + def stop(self): if self._server: self._server.stop() self._io_loop.add_callback(self._io_loop.stop) diff --git a/dlrover/python/common/node.py b/dlrover/python/common/node.py index 220b46d29..df7780416 100644 --- a/dlrover/python/common/node.py +++ b/dlrover/python/common/node.py @@ -14,6 +14,7 @@ import copy import time +from dlrover.python.common.comm import ParallelConfig from dlrover.python.common.constants import ( NodeEventType, NodeExitReason, @@ -21,7 +22,6 @@ NodeStatus, PriorityClass, ) -from dlrover.python.common.grpc import ParallelConfig from dlrover.python.common.serialize import JsonSerializable diff --git a/dlrover/python/common/test.py b/dlrover/python/common/test.py index 697873678..294398b58 100644 --- a/dlrover/python/common/test.py +++ b/dlrover/python/common/test.py @@ -1,18 +1,37 @@ +# Copyright 2024 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import signal import threading +import time + +import tornado.httpserver import tornado.ioloop import tornado.web -import tornado.httpserver -import time -import signal + class MainHandler(tornado.web.RequestHandler): def get(self): self.write("Hello, world") + def make_app(): - return tornado.web.Application([ - (r"/", MainHandler), - ]) + return tornado.web.Application( + [ + (r"/", MainHandler), + ] + ) + def start_tornado_server(): app = make_app() @@ -20,9 +39,11 @@ def start_tornado_server(): server.listen(8000) tornado.ioloop.IOLoop.current().start() + def stop_tornado_server(): tornado.ioloop.IOLoop.current().stop() + if __name__ == "__main__": # 启动 Tornado 服务器的后台线程 server_thread = threading.Thread(target=start_tornado_server) @@ -43,4 +64,4 @@ def signal_handler(signum, frame): print("Main thread is doing other things") time.sleep(1) except KeyboardInterrupt: - signal_handler(signal.SIGINT, None) \ No newline at end of file + signal_handler(signal.SIGINT, None) diff --git a/dlrover/python/elastic_agent/config/paral_config_tuner.py b/dlrover/python/elastic_agent/config/paral_config_tuner.py index afb4bc306..ca9d5390e 100644 --- a/dlrover/python/elastic_agent/config/paral_config_tuner.py +++ b/dlrover/python/elastic_agent/config/paral_config_tuner.py @@ -16,12 +16,12 @@ import threading import time -from dlrover.python.common.constants import ConfigPath -from dlrover.python.common.grpc import ( +from dlrover.python.common.comm import ( DataLoaderConfig, OptimizerConfig, ParallelConfig, ) +from dlrover.python.common.constants import ConfigPath from dlrover.python.common.log import default_logger as logger from dlrover.python.common.singleton import Singleton from dlrover.python.elastic_agent.master_client import MasterClient diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index 32ceb7bc7..5a1d9c617 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -20,7 +20,7 @@ from typing import Dict, Optional from dlrover.proto import elastic_training_pb2, elastic_training_pb2_grpc -from dlrover.python.common import env_utils, grpc +from dlrover.python.common import comm, env_utils from dlrover.python.common.constants import ( NetworkFailureReason, NodeEnv, @@ -86,7 +86,7 @@ def __init__(self, master_addr, node_id, node_type, timeout=5): ) self._timeout = timeout self._master_addr = master_addr - self._channel = grpc.build_grpc_channel(master_addr) + self._channel = comm.build_grpc_channel(master_addr) self._stub = elastic_training_pb2_grpc.MasterStub(self._channel) self._node_id = node_id self._node_type = node_type @@ -107,7 +107,7 @@ def close_channel(self): self._channel.close() def open_channel(self): - self._channel = grpc.build_grpc_channel(self._master_addr) + self._channel = comm.build_grpc_channel(self._master_addr) self._stub = elastic_training_pb2_grpc.MasterStub(self._channel) def find_free_port(self): @@ -120,7 +120,7 @@ def find_free_port(self): return port @retry_grpc_request - def _report(self, message: grpc.Message): + def _report(self, message: comm.Message): request = elastic_training_pb2.Message() request.node_id = self._node_id request.node_type = self._node_type @@ -128,23 +128,23 @@ def _report(self, message: grpc.Message): return self._stub.report(request, timeout=self._timeout) @retry_grpc_request - def _get(self, message: grpc.Message): + def _get(self, message: comm.Message): request = elastic_training_pb2.Message() request.node_id = self._node_id request.node_type = self._node_type request.data = message.serialize() response = self._stub.get(request, timeout=self._timeout) - res_message = grpc.deserialize_message(response.data) + res_message = comm.deserialize_message(response.data) return res_message def kv_store_set(self, key, value): - message = grpc.KeyValuePair(key, value) + message = comm.KeyValuePair(key, value) response = self._report(message) return response.success def kv_store_get(self, key): - request = grpc.KeyValuePair(key) - result: grpc.KeyValuePair = self._get(request) + request = comm.KeyValuePair(key) + result: comm.KeyValuePair = self._get(request) return result.value def get_task(self, dataset_name): @@ -159,7 +159,7 @@ def get_task(self, dataset_name): c.f. /dlrover/proto/dlrover.proto """ - req = grpc.TaskRequest(dataset_name) + req = comm.TaskRequest(dataset_name) success = False res = None @@ -175,7 +175,7 @@ def get_task(self, dataset_name): if not success: logger.warning(exception) if not res: - res = grpc.Task() + res = comm.Task() return success, res def report_task_result(self, dataset_name, task_id, err_msg): @@ -188,7 +188,7 @@ def report_task_result(self, dataset_name, task_id, err_msg): err_msg: string the error message on training. """ - message = grpc.TaskResult(dataset_name, task_id, err_msg) + message = comm.TaskResult(dataset_name, task_id, err_msg) return self._report(message) def report_dataset_shard_params( @@ -202,7 +202,7 @@ def report_dataset_shard_params( task_type=elastic_training_pb2.NONE, storage_type="", ): - message = grpc.DatasetShardParams( + message = comm.DatasetShardParams( batch_size=batch_size, num_epochs=num_epochs, dataset_size=dataset_size, @@ -215,20 +215,20 @@ def report_dataset_shard_params( return self._report(message) def ready_for_ps_relaunch(self): - message = grpc.PsReady() + message = comm.PsReady() return self._report(message) def get_shard_checkpoint(self, dataset_name): - req = grpc.ShardCheckpointRequest(dataset_name) - res: grpc.ShardCheckpoint = self._get(req) + req = comm.ShardCheckpointRequest(dataset_name) + res: comm.ShardCheckpoint = self._get(req) return res.content def report_shard_checkpoint(self, shard_checkpoint): - request = grpc.ShardCheckpoint(shard_checkpoint) + request = comm.ShardCheckpoint(shard_checkpoint) return self._report(request) def report_used_resource(self, memory, cpu, gpu_stats): - message = grpc.ResourceStats(memory, cpu, gpu_stats) + message = comm.ResourceStats(memory, cpu, gpu_stats) return self._report(message) def report_model_info(self, model_info): @@ -237,7 +237,7 @@ def report_model_info(self, model_info): def report_global_step( self, global_step, timestamp, elapsed_time_per_step=0 ): - message = grpc.GlobalStep( + message = comm.GlobalStep( timestamp=timestamp, step=global_step, elapsed_time_per_step=elapsed_time_per_step, @@ -245,8 +245,8 @@ def report_global_step( return self._report(message) def report_heart_beat(self, timestamp) -> DiagnosisAction: - message = grpc.HeartBeat(timestamp=timestamp) - response: grpc.HeartbeatResponse = self._get(message) + message = comm.HeartBeat(timestamp=timestamp) + response: comm.HeartbeatResponse = self._get(message) action = NoAction() if not response: @@ -267,16 +267,16 @@ def report_heart_beat(self, timestamp) -> DiagnosisAction: return action def get_cluster_version(self, version_type, task_type, task_id): - request = grpc.ClusterVersionRequest( + request = comm.ClusterVersionRequest( task_type=task_type, task_id=task_id, version_type=version_type, ) - result: grpc.ClusterVersion = self._get(request) + result: comm.ClusterVersion = self._get(request) return result.version def update_node_addr(self, task_type, task_id, node_addr): - message = grpc.NodeAddress(type=task_type, id=task_id, addr=node_addr) + message = comm.NodeAddress(type=task_type, id=task_id, addr=node_addr) res = self._report(message) return res @@ -288,12 +288,12 @@ def report_node_event( event_elapsed_time=0, node_rank=-1, ): - message = grpc.NodeEvent( + message = comm.NodeEvent( event_type=event_type, event_message=event_msg, event_time=event_time, event_elapsed_time=event_elapsed_time, - node=grpc.NodeMeta( + node=comm.NodeMeta( type=self._node_type, id=self._node_id, addr=self._node_ip ), ) @@ -319,7 +319,7 @@ def report_succeeded_exited(self): def update_cluster_version( self, version_type, version, task_type, task_id ): - message = grpc.ClusterVersion( + message = comm.ClusterVersion( task_type=task_type, task_id=task_id, version_type=version_type, @@ -328,17 +328,17 @@ def update_cluster_version( self._report(message) def query_ps_nodes(self): - request = grpc.PsNodesRequest() - result: grpc.PsNodes = self._get(request) + request = comm.PsNodesRequest() + result: comm.PsNodes = self._get(request) return result.nodes, result.ps_failure def query_training_status(self): - request = grpc.TrainingStatusRequest() - response: grpc.TrainingStatus = self._get(request) + request = comm.TrainingStatusRequest() + response: comm.TrainingStatus = self._get(request) return response.status def join_sync(self, sync_name): - message = grpc.SyncJoin(sync_name) + message = comm.SyncJoin(sync_name) logger.info( " {}:{} join sync {}".format( self._node_id, self._node_type, sync_name @@ -348,50 +348,50 @@ def join_sync(self, sync_name): return response.success def sync_finished(self, sync_name): - message = grpc.SyncFinish(sync_name) + message = comm.SyncFinish(sync_name) response = self._report(message) return response.success def barrier(self, barrier_name, notify=False): - message = grpc.SyncBarrier(barrier_name, notify) + message = comm.SyncBarrier(barrier_name, notify) response = self._report(message) return response.success def get_running_nodes(self): - request = grpc.RunningNodesRequest() - result: grpc.RunningNodes = self._get(request) + request = comm.RunningNodesRequest() + result: comm.RunningNodes = self._get(request) return result.nodes def num_nodes_waiting(self, rdzv_name): - request = grpc.WaitingNodeNumRequest(rdzv_name=rdzv_name) + request = comm.WaitingNodeNumRequest(rdzv_name=rdzv_name) try: - result: grpc.RendezvousState = self._get(request) + result: comm.RendezvousState = self._get(request) return result.waiting_num except Exception: logger.warning("Fail to query the number of waiting nodes.") return 0 def join_rendezvous(self, node_rank, local_world_size, rdzv_name=""): - request = grpc.JoinRendezvousRequest( + request = comm.JoinRendezvousRequest( node_id=self._node_id, node_rank=node_rank, local_world_size=local_world_size, rdzv_name=rdzv_name, node_ip=self._node_ip, ) - result: grpc.RendezvousState = self._get(request) + result: comm.RendezvousState = self._get(request) return result.round def get_comm_world(self, rdzv_name, node_rank): - request = grpc.CommWorldRequest(node_id=node_rank, rdzv_name=rdzv_name) - result: grpc.RendezvousState = self._get(request) + request = comm.CommWorldRequest(node_id=node_rank, rdzv_name=rdzv_name) + result: comm.RendezvousState = self._get(request) return result.round, result.group, result.world def check_fault_node(self, timeout=300): - request = grpc.NetworkReadyRequest() + request = comm.NetworkReadyRequest() start = time.time() while True: - result: grpc.NetworkCheckResult = self._get(request) + result: comm.NetworkCheckResult = self._get(request) if ( result.reason == NetworkFailureReason.WAITING_NODE and time.time() - start < timeout @@ -402,10 +402,10 @@ def check_fault_node(self, timeout=300): return result.nodes def check_straggler(self, timeout=300): - request = grpc.StragglerExistRequest() + request = comm.StragglerExistRequest() start = time.time() while True: - result: grpc.NetworkCheckResult = self._get(request) + result: comm.NetworkCheckResult = self._get(request) if ( result.reason == NetworkFailureReason.WAITING_NODE and time.time() - start < timeout @@ -418,7 +418,7 @@ def check_straggler(self, timeout=300): def report_rdzv_params( self, min_nodes, max_nodes, waiting_timeout, node_unit, joint_timeout ): - message = grpc.RendezvousParams( + message = comm.RendezvousParams( min_nodes, max_nodes, waiting_timeout, @@ -429,48 +429,48 @@ def report_rdzv_params( return response.success def report_failures(self, error_data, restart_count=-1, level=""): - message = grpc.NodeFailure(error_data, restart_count, level) + message = comm.NodeFailure(error_data, restart_count, level) self._report(message) - def report_paral_config(self, config: grpc.ParallelConfig): + def report_paral_config(self, config: comm.ParallelConfig): self._report(config) def report_diagnosis_agent_metrics(self, data: DiagnosisData): - message = grpc.DiagnosisReportData( + message = comm.DiagnosisReportData( data.__class__.__name__, data.to_json(), data.node_rank, ) self._report(message) - def get_paral_config(self) -> grpc.ParallelConfig: - request = grpc.ParallelConfigRequest() + def get_paral_config(self) -> comm.ParallelConfig: + request = comm.ParallelConfigRequest() result = self._get(request) return result def need_to_restart_training(self): - request = grpc.CheckHardwareResetRequest() + request = comm.CheckHardwareResetRequest() try: - result: grpc.ParallelConfig = self._get(request) + result: comm.ParallelConfig = self._get(request) return result.restart except Exception: logger.warning("Fail to verify restarting training processes.") return False def sync_checkpoint(self, step): - request = grpc.NodeCheckpointState() + request = comm.NodeCheckpointState() request.step = step response = self._report(request) return response.success - def sync_training_ports(self, port) -> grpc.SyncTrainingPort: - request = grpc.SyncTrainingPort(port=port) - response: grpc.SyncTrainingPort = self._get(request) + def sync_training_ports(self, port) -> comm.SyncTrainingPort: + request = comm.SyncTrainingPort(port=port) + response: comm.SyncTrainingPort = self._get(request) return response def get_elastic_run_config(self) -> Dict[str, str]: - request = grpc.ElasticRunConfigRequest() - response: grpc.ElasticRunConfig = self._get(request) + request = comm.ElasticRunConfigRequest() + response: comm.ElasticRunConfig = self._get(request) return response.configs def report_event( @@ -483,7 +483,7 @@ def report_event( ): if labels is None: labels = {} - message = grpc.Event( + message = comm.Event( event_type=event_type, instance=instance, action=action, diff --git a/dlrover/python/elastic_agent/monitor/resource.py b/dlrover/python/elastic_agent/monitor/resource.py index 65d33726c..02472f2d9 100644 --- a/dlrover/python/elastic_agent/monitor/resource.py +++ b/dlrover/python/elastic_agent/monitor/resource.py @@ -18,8 +18,8 @@ import psutil import pynvml +from dlrover.python.common.comm import GPUStats from dlrover.python.common.constants import NodeEnv -from dlrover.python.common.grpc import GPUStats from dlrover.python.common.log import default_logger as logger from dlrover.python.common.singleton import Singleton from dlrover.python.elastic_agent.master_client import MasterClient diff --git a/dlrover/python/elastic_agent/sharding/client.py b/dlrover/python/elastic_agent/sharding/client.py index 9c7f4b1ac..5ae64b36c 100644 --- a/dlrover/python/elastic_agent/sharding/client.py +++ b/dlrover/python/elastic_agent/sharding/client.py @@ -18,7 +18,7 @@ from multiprocessing import SimpleQueue from dlrover.proto import elastic_training_pb2 -from dlrover.python.common import grpc +from dlrover.python.common import comm from dlrover.python.common.log import default_logger as logger from dlrover.python.elastic_agent.master_client import MasterClient from dlrover.python.elastic_agent.monitor.training import TFTrainingReporter @@ -223,7 +223,7 @@ def get_shard_checkpoint(self): return shard_checkpoint def restore_shard_from_checkpoint(self, shard_checkpoint): - message = grpc.ShardCheckpoint(shard_checkpoint) + message = comm.ShardCheckpoint(shard_checkpoint) res = self._mc.report_shard_checkpoint(message) return res.success diff --git a/dlrover/python/elastic_agent/tensorflow/hooks.py b/dlrover/python/elastic_agent/tensorflow/hooks.py index eb942b5f9..ff53847bf 100644 --- a/dlrover/python/elastic_agent/tensorflow/hooks.py +++ b/dlrover/python/elastic_agent/tensorflow/hooks.py @@ -18,7 +18,7 @@ SessionRunHook, ) -from dlrover.python.common.grpc import ModelInfo, OpStats, TensorStats +from dlrover.python.common.comm import ModelInfo, OpStats, TensorStats from dlrover.python.common.log import default_logger as logger from dlrover.python.elastic_agent.master_client import MasterClient from dlrover.python.elastic_agent.monitor.training import ( diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 5c99da91d..d44cab9c3 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -83,11 +83,6 @@ TrainingExceptionLevel, ) from dlrover.python.common.error import ProcessError -from dlrover.python.common.grpc import ( - find_free_port_for_hccl, - find_free_port_in_range, - find_free_port_in_set, -) from dlrover.python.common.log import default_logger as logger from dlrover.python.diagnosis.common.constants import DiagnosisActionType from dlrover.python.diagnosis.common.diagnosis_action import NodeAction @@ -102,6 +97,11 @@ from dlrover.python.elastic_agent.monitor.training import TorchTrainingMonitor from dlrover.python.elastic_agent.torch.ckpt_saver import AsyncCheckpointSaver from dlrover.python.elastic_agent.torch.master_kv_store import MasterKVStore +from dlrover.python.util.common_util import ( + find_free_port_for_hccl, + find_free_port_in_range, + find_free_port_in_set, +) from dlrover.trainer.torch.utils import ( version_less_than_230, version_less_than_240, diff --git a/dlrover/python/master/hyperparams/simple_strategy_generator.py b/dlrover/python/master/hyperparams/simple_strategy_generator.py index 8e28e8f15..597afd1e0 100644 --- a/dlrover/python/master/hyperparams/simple_strategy_generator.py +++ b/dlrover/python/master/hyperparams/simple_strategy_generator.py @@ -14,12 +14,12 @@ import math from typing import Dict, List -from dlrover.python.common.constants import NodeType -from dlrover.python.common.grpc import ( +from dlrover.python.common.comm import ( DataLoaderConfig, OptimizerConfig, ParallelConfig, ) +from dlrover.python.common.constants import NodeType from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node from dlrover.python.master.hyperparams.strategy_generator import ( diff --git a/dlrover/python/master/node/dist_job_manager.py b/dlrover/python/master/node/dist_job_manager.py index 2c8b22e05..2b92a4bd8 100644 --- a/dlrover/python/master/node/dist_job_manager.py +++ b/dlrover/python/master/node/dist_job_manager.py @@ -20,6 +20,7 @@ from datetime import datetime from typing import Dict, List, Optional +from dlrover.python.common.comm import ParallelConfig from dlrover.python.common.constants import ( DistributionStrategy, ElasticJobLabel, @@ -33,7 +34,6 @@ TrainingExceptionLevel, ) from dlrover.python.common.global_context import Context -from dlrover.python.common.grpc import ParallelConfig from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node, NodeGroupResource from dlrover.python.diagnosis.common.constants import DiagnosisConstant diff --git a/dlrover/python/master/node/local_job_manager.py b/dlrover/python/master/node/local_job_manager.py index c0a911b04..4f1027112 100644 --- a/dlrover/python/master/node/local_job_manager.py +++ b/dlrover/python/master/node/local_job_manager.py @@ -11,8 +11,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dlrover.python.common.comm import ParallelConfig from dlrover.python.common.constants import NodeStatus, NodeType -from dlrover.python.common.grpc import ParallelConfig from dlrover.python.common.node import Node from dlrover.python.diagnosis.common.diagnosis_action import ( DiagnosisAction, diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index d18c3c184..921ebd62d 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -12,28 +12,32 @@ # limitations under the License. import importlib +import json import threading import time -from abc import ABC +from abc import ABC, abstractmethod from concurrent import futures -from http.server import ThreadingHTTPServer from typing import Dict, List, Optional import grpc as grpc_lib +import tornado from dlrover.proto import elastic_training_pb2, elastic_training_pb2_grpc -from dlrover.python.common import grpc +from dlrover.python.common import comm +from dlrover.python.common.comm import BaseMessage from dlrover.python.common.constants import ( GRPC, + BasicClass, CustomMetricKeys, JobConstant, NodeEventType, NodeType, RendezvousName, TrainingExceptionLevel, - TrainingLoopStatus, BasicClass, + TrainingLoopStatus, ) from dlrover.python.common.global_context import Context +from dlrover.python.common.http_server import TornadoHTTPServer from dlrover.python.common.log import default_logger as logger from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData from dlrover.python.master.diagnosis.diagnosis_manager import DiagnosisManager @@ -104,94 +108,65 @@ def __init__( ) self._kv_store.clear() - -class HttpMasterServicer(MasterServicer): - """Master service with http implementation.""" - - def __init__( - self, - task_manager, - job_manager, - speed_monitor: SpeedMonitor, - rdzv_managers: Dict[str, RendezvousManager], - diagnosis_manager: DiagnosisManager, - job_metric_collector=None, - elastic_ps_service=None, - sync_service=None, - ): - super(HttpMasterServicer, self).__init__(task_manager, job_manager, speed_monitor, rdzv_managers, diagnosis_manager, job_metric_collector, elastic_ps_service, sync_service) - - -class GrpcMasterServicer(MasterServicer, elastic_training_pb2_grpc.MasterServicer): - """Master service with grpc implementation.""" - - def __init__( - self, - task_manager, - job_manager, - speed_monitor: SpeedMonitor, - rdzv_managers: Dict[str, RendezvousManager], - diagnosis_manager: DiagnosisManager, - job_metric_collector=None, - elastic_ps_service=None, - sync_service=None, - ): - super(GrpcMasterServicer, self).__init__(task_manager, job_manager, speed_monitor, rdzv_managers, diagnosis_manager, job_metric_collector, elastic_ps_service, sync_service) + @abstractmethod + def get_response(self): + """Should be implemented by subclasses.""" + pass def get(self, request, _): node_type = request.node_type node_id = request.node_id - req_message = grpc.deserialize_message(request.data) + req_message = comm.deserialize_message(request.data) - response = elastic_training_pb2.Message() + response = self.get_response() if not req_message: return response message = None - if isinstance(req_message, grpc.TaskRequest): + if isinstance(req_message, comm.TaskRequest): message = self._get_task(node_type, node_id, req_message) - elif isinstance(req_message, grpc.ShardCheckpointRequest): + elif isinstance(req_message, comm.ShardCheckpointRequest): message = self._get_shard_checkpoint(req_message) - elif isinstance(req_message, grpc.ClusterVersionRequest): + elif isinstance(req_message, comm.ClusterVersionRequest): message = self._get_cluster_version(req_message) - elif isinstance(req_message, grpc.RunningNodesRequest): + elif isinstance(req_message, comm.RunningNodesRequest): message = self._get_running_nodes() - elif isinstance(req_message, grpc.JoinRendezvousRequest): + elif isinstance(req_message, comm.JoinRendezvousRequest): message = self._join_rendezvous(req_message) - elif isinstance(req_message, grpc.WaitingNodeNumRequest): + elif isinstance(req_message, comm.WaitingNodeNumRequest): message = self._num_nodes_waiting(req_message.rdzv_name) - elif isinstance(req_message, grpc.NetworkReadyRequest): + elif isinstance(req_message, comm.NetworkReadyRequest): message = self._check_fault_node() - elif isinstance(req_message, grpc.StragglerExistRequest): + elif isinstance(req_message, comm.StragglerExistRequest): message = self._check_straggler() - elif isinstance(req_message, grpc.CommWorldRequest): + elif isinstance(req_message, comm.CommWorldRequest): message = self._get_comm_world(req_message) - elif isinstance(req_message, grpc.KeyValuePair): + elif isinstance(req_message, comm.KeyValuePair): message = self._kv_store_get(req_message) - elif isinstance(req_message, grpc.PsNodesRequest): + elif isinstance(req_message, comm.PsNodesRequest): message = self._query_ps_nodes() - elif isinstance(req_message, grpc.TrainingStatusRequest): + elif isinstance(req_message, comm.TrainingStatusRequest): message = self._get_training_status() - elif isinstance(req_message, grpc.ParallelConfigRequest): + elif isinstance(req_message, comm.ParallelConfigRequest): message = self._get_paral_config() - elif isinstance(req_message, grpc.CheckHardwareResetRequest): + elif isinstance(req_message, comm.CheckHardwareResetRequest): message = self._need_to_restart_training(node_type, node_id) - elif isinstance(req_message, grpc.SyncTrainingPort): + elif isinstance(req_message, comm.SyncTrainingPort): message = self._sync_training_ports(node_id, req_message) - elif isinstance(req_message, grpc.ElasticRunConfigRequest): + elif isinstance(req_message, comm.ElasticRunConfigRequest): configs = self._job_manager.get_elastic_run_configs() - message = grpc.ElasticRunConfig(configs=configs) - elif isinstance(req_message, grpc.HeartBeat): + message = comm.ElasticRunConfig(configs=configs) + elif isinstance(req_message, comm.HeartBeat): message = self._report_heartbeat(node_type, node_id, req_message) if message: response.data = message.serialize() return response - def _get_task(self, node_type, node_id, request: grpc.TaskRequest): + def _get_task(self, node_type, node_id, request: comm.TaskRequest): if not self._start_training_time: self._start_training_time = int(time.time()) - shard = grpc.Shard() - res = grpc.Task(shard=shard) + shard = comm.Shard() + res = comm.Task(shard=shard) ds_name = request.dataset_name dataset = self._task_manager.get_dataset(ds_name) if not dataset: @@ -212,16 +187,16 @@ def _get_task(self, node_type, node_id, request: grpc.TaskRequest): self._task_manager.reset_worker_start_task_time(node_id) return res - def _get_shard_checkpoint(self, request: grpc.ShardCheckpointRequest): - response = grpc.ShardCheckpoint() + def _get_shard_checkpoint(self, request: comm.ShardCheckpointRequest): + response = comm.ShardCheckpoint() dataset = self._task_manager.get_dataset(request.dataset_name) checkpoint = dataset.checkpoint() if checkpoint: response.content = checkpoint.to_json() return response - def _get_cluster_version(self, request: grpc.ClusterVersionRequest): - message = grpc.ClusterVersion() + def _get_cluster_version(self, request: comm.ClusterVersionRequest): + message = comm.ClusterVersion() if not self._elastic_ps_service: return message @@ -236,12 +211,12 @@ def _get_cluster_version(self, request: grpc.ClusterVersionRequest): return message def _query_ps_nodes(self): - res = grpc.PsNodes(nodes=[]) + res = comm.PsNodes(nodes=[]) training_ps: List[Node] = self._job_manager.get_next_cluster_ps() ready = self._job_manager.ready_for_new_ps_cluster() ps_failure = self._job_manager.has_ps_failure() for ps in training_ps: - ps_meta = grpc.NodeMeta() + ps_meta = comm.NodeMeta() ps_meta.type = NodeType.PS ps_meta.addr = ps.service_addr ps_meta.cpu = ps.config_resource.cpu @@ -252,10 +227,10 @@ def _query_ps_nodes(self): return res def _get_running_nodes(self): - res = grpc.RunningNodes(nodes=[]) + res = comm.RunningNodes(nodes=[]) nodes: List[Node] = self._job_manager.get_running_nodes() for node in nodes: - meta = grpc.NodeMeta() + meta = comm.NodeMeta() meta.type = node.type meta.addr = node.service_addr meta.cpu = node.config_resource.cpu @@ -267,7 +242,7 @@ def _get_running_nodes(self): return res def _get_training_status(self): - res = grpc.TrainingStatus() + res = comm.TrainingStatus() if self._task_manager.training_started(): res.status = TrainingLoopStatus.START else: @@ -279,7 +254,7 @@ def _check_fault_node(self): RendezvousName.NETWORK_CHECK ] nodes, reason = rdzv_manager.check_fault_node() - res = grpc.NetworkCheckResult(nodes=nodes, reason=reason) + res = comm.NetworkCheckResult(nodes=nodes, reason=reason) return res def _check_straggler(self): @@ -287,10 +262,10 @@ def _check_straggler(self): RendezvousName.NETWORK_CHECK ] nodes, reason = rdzv_manager.get_straggler() - res = grpc.NetworkCheckResult(nodes=nodes, reason=reason) + res = comm.NetworkCheckResult(nodes=nodes, reason=reason) return res - def _join_rendezvous(self, request: grpc.JoinRendezvousRequest): + def _join_rendezvous(self, request: comm.JoinRendezvousRequest): rdzv_manager = self._rdzv_managers[request.rdzv_name] node_rank = request.node_rank if node_rank == -1: # Back compatibility @@ -308,18 +283,18 @@ def _join_rendezvous(self, request: grpc.JoinRendezvousRequest): RendezvousName.ELASTIC_TRAINING ] training_manager.clear_waiting_nodes() - res = grpc.RendezvousState(round=round) + res = comm.RendezvousState(round=round) return res def _num_nodes_waiting(self, rdzv_name): waiting_num = self._rdzv_managers[rdzv_name].num_nodes_waiting() - res = grpc.RendezvousState(waiting_num=waiting_num) + res = comm.RendezvousState(waiting_num=waiting_num) return res - def _get_comm_world(self, request: grpc.CommWorldRequest): + def _get_comm_world(self, request: comm.CommWorldRequest): rdzv_manager = self._rdzv_managers[request.rdzv_name] rdzv_round, group, nodes = rdzv_manager.get_comm_world(request.node_id) - res = grpc.RendezvousState(world={}) + res = comm.RendezvousState(world={}) res.group = group res.round = rdzv_round for rank, meta in nodes.items(): @@ -330,76 +305,76 @@ def _get_comm_world(self, request: grpc.CommWorldRequest): self._job_metric_collector.collect_custom_data(metrics) return res - def _kv_store_get(self, request: grpc.KeyValuePair): + def _kv_store_get(self, request: comm.KeyValuePair): value = self._kv_store.get(request.key) - res = grpc.KeyValuePair(request.key, value) + res = comm.KeyValuePair(request.key, value) return res def _get_paral_config(self): res = self._job_manager.get_opt_strategy() if not res: - res = grpc.ParallelConfig() + res = comm.ParallelConfig() return res def _need_to_restart_training(self, node_type, node_id): restart = self._job_manager.verify_restarting_worker_training( node_type, node_id ) - res = grpc.ParallelConfig() + res = comm.ParallelConfig() res.restart = restart return res def report(self, request, _): node_type = request.node_type node_id = request.node_id - message = grpc.deserialize_message(request.data) + message = comm.deserialize_message(request.data) response = elastic_training_pb2.Response() if not message: return response success = False - if isinstance(message, grpc.DatasetShardParams): + if isinstance(message, comm.DatasetShardParams): success = self._collect_dataset_shard_params(message) - elif isinstance(message, grpc.ResourceStats): + elif isinstance(message, comm.ResourceStats): success = self._update_node_resource_usage( node_type, node_id, message ) - elif isinstance(message, grpc.ModelInfo): + elif isinstance(message, comm.ModelInfo): success = self._collect_model_info(message) - elif isinstance(message, grpc.GlobalStep): + elif isinstance(message, comm.GlobalStep): success = self._collect_global_step(message) - elif isinstance(message, grpc.ShardCheckpoint): + elif isinstance(message, comm.ShardCheckpoint): success = self._restore_shard_checkpoint(message) - elif isinstance(message, grpc.TaskResult): + elif isinstance(message, comm.TaskResult): success = self._report_task_result(message) - elif isinstance(message, grpc.ClusterVersion): + elif isinstance(message, comm.ClusterVersion): success = self._update_cluster_version(message) - elif isinstance(message, grpc.NodeAddress): + elif isinstance(message, comm.NodeAddress): success = self._update_node_address(message) - elif isinstance(message, grpc.NodeEvent): + elif isinstance(message, comm.NodeEvent): success = self._deal_with_reported_node_event(message) - elif isinstance(message, grpc.SyncJoin): + elif isinstance(message, comm.SyncJoin): success = self._join_sync(node_type, node_id, message) - elif isinstance(message, grpc.SyncFinish): + elif isinstance(message, comm.SyncFinish): success = self._sync_finished(message) - elif isinstance(message, grpc.SyncBarrier): + elif isinstance(message, comm.SyncBarrier): success = self._barrier(message) - elif isinstance(message, grpc.NodeFailure): + elif isinstance(message, comm.NodeFailure): success = self._report_failure(node_type, node_id, message) - elif isinstance(message, grpc.RendezvousParams): + elif isinstance(message, comm.RendezvousParams): success = self._report_rdzv_params(message) - elif isinstance(message, grpc.PsReady): + elif isinstance(message, comm.PsReady): success = self._ready_for_ps_relaunch() - elif isinstance(message, grpc.KeyValuePair): + elif isinstance(message, comm.KeyValuePair): success = self._kv_store_set(message) - elif isinstance(message, grpc.ParallelConfig): + elif isinstance(message, comm.ParallelConfig): success = self._report_paral_config(node_type, node_id, message) - elif isinstance(message, grpc.NodeCheckpointState): + elif isinstance(message, comm.NodeCheckpointState): success = self._sync_checkpoint(node_type, node_id, message) - elif isinstance(message, grpc.DiagnosisReportData): + elif isinstance(message, comm.DiagnosisReportData): success = self._report_node_diagnosis_data(message) - elif isinstance(message, grpc.Event): + elif isinstance(message, comm.Event): success = self._report_event(message) response.success = success @@ -409,7 +384,7 @@ def _ready_for_ps_relaunch(self): self._job_manager.post_ps_ready() return True - def _collect_dataset_shard_params(self, metrics: grpc.DatasetShardParams): + def _collect_dataset_shard_params(self, metrics: comm.DatasetShardParams): num_minibatches_per_task = ( metrics.num_minibatches_per_shard or _DEFAULT_NUM_MINIBATCHES_PER_SHARD @@ -443,7 +418,7 @@ def _collect_dataset_shard_params(self, metrics: grpc.DatasetShardParams): return True def _update_node_resource_usage( - self, node_type, node_id, metrics: grpc.ResourceStats + self, node_type, node_id, metrics: comm.ResourceStats ): logger.debug( f"Update resource usage for {node_type}-{node_id}," @@ -460,12 +435,12 @@ def _update_node_resource_usage( ) return True - def _collect_model_info(self, metrics: grpc.ModelInfo): + def _collect_model_info(self, metrics: comm.ModelInfo): if self._job_metric_collector: self._job_metric_collector.collect_model_metric(metrics) return True - def _collect_global_step(self, metrics: grpc.GlobalStep): + def _collect_global_step(self, metrics: comm.GlobalStep): self._speed_monitor.collect_global_step( metrics.step, metrics.timestamp ) @@ -473,7 +448,7 @@ def _collect_global_step(self, metrics: grpc.GlobalStep): self._check_start_auto_scale_worker() return True - def _restore_shard_checkpoint(self, message: grpc.ShardCheckpoint): + def _restore_shard_checkpoint(self, message: comm.ShardCheckpoint): success = self._task_manager.restore_dataset_from_checkpoint( message.content ) @@ -486,7 +461,7 @@ def _collect_runtime_stats(self): self._speed_monitor, nodes ) - def _report_task_result(self, request: grpc.TaskResult): + def _report_task_result(self, request: comm.TaskResult): success = True if request.err_message: logger.warning("Worker reported error: " + request.err_message) @@ -525,7 +500,7 @@ def _check_start_auto_scale_worker(self): self._job_manager.start_auto_scaling() self._start_autoscale = True - def _update_cluster_version(self, message: grpc.ClusterVersion): + def _update_cluster_version(self, message: comm.ClusterVersion): if not self._elastic_ps_service: return False @@ -539,7 +514,7 @@ def _update_cluster_version(self, message: grpc.ClusterVersion): ) return True - def _update_node_address(self, message: grpc.NodeAddress): + def _update_node_address(self, message: comm.NodeAddress): self._job_manager.update_node_service_addr( node_type=message.type, node_id=message.id, @@ -547,7 +522,7 @@ def _update_node_address(self, message: grpc.NodeAddress): ) return True - def _deal_with_reported_node_event(self, message: grpc.NodeEvent): + def _deal_with_reported_node_event(self, message: comm.NodeEvent): node = Node( node_type=message.node.type, node_id=message.node.id, @@ -572,7 +547,7 @@ def _deal_with_reported_node_event(self, message: grpc.NodeEvent): self._job_manager.process_reported_node_event(event) return True - def _join_sync(self, node_type, node_id, message: grpc.SyncJoin): + def _join_sync(self, node_type, node_id, message: comm.SyncJoin): success = False if self._sync_service: success = self._sync_service.join_sync( @@ -580,13 +555,13 @@ def _join_sync(self, node_type, node_id, message: grpc.SyncJoin): ) return success - def _sync_finished(self, message: grpc.SyncFinish): + def _sync_finished(self, message: comm.SyncFinish): success = False if self._sync_service: success = self._sync_service.sync_finished(message.sync_name) return success - def _barrier(self, message: grpc.SyncBarrier): + def _barrier(self, message: comm.SyncBarrier): if not self._sync_service: return False if message.notify: @@ -595,7 +570,7 @@ def _barrier(self, message: grpc.SyncBarrier): success = self._sync_service.barrier(message.barrier_name) return success - def _report_rdzv_params(self, message: grpc.RendezvousParams): + def _report_rdzv_params(self, message: comm.RendezvousParams): # Enable auto-scaling workers if elasticity is enabled. for manager in self._rdzv_managers.values(): manager.update_rdzv_params( @@ -613,7 +588,7 @@ def _report_rdzv_params(self, message: grpc.RendezvousParams): ) return True - def _report_failure(self, node_type, node_id, message: grpc.NodeFailure): + def _report_failure(self, node_type, node_id, message: comm.NodeFailure): self._job_manager.handle_training_failure( node_type, node_id, @@ -629,12 +604,12 @@ def _report_failure(self, node_type, node_id, message: grpc.NodeFailure): self._job_metric_collector.collect_custom_data(custom_data) return True - def _kv_store_set(self, message: grpc.KeyValuePair): + def _kv_store_set(self, message: comm.KeyValuePair): self._kv_store.set(message.key, message.value) return True def _report_paral_config( - self, node_type, node_id, message: grpc.ParallelConfig + self, node_type, node_id, message: comm.ParallelConfig ): if self._job_manager: logger.debug( @@ -649,14 +624,14 @@ def _report_paral_config( return True def _sync_checkpoint( - self, node_type, node_id, message: grpc.NodeCheckpointState + self, node_type, node_id, message: comm.NodeCheckpointState ): if RendezvousName.ELASTIC_TRAINING not in self._rdzv_managers: return False rdzv_manager = self._rdzv_managers[RendezvousName.ELASTIC_TRAINING] return rdzv_manager.sync_ckpt_nodes(node_id, message.step) - def _report_node_diagnosis_data(self, message: grpc.DiagnosisReportData): + def _report_node_diagnosis_data(self, message: comm.DiagnosisReportData): if self._diagnosis_manager: data_cls: Optional[DiagnosisData] = getattr( self._diagnosis_data_module, @@ -673,17 +648,17 @@ def _report_node_diagnosis_data(self, message: grpc.DiagnosisReportData): return True def _sync_training_ports( - self, node_id, message: grpc.SyncTrainingPort - ) -> grpc.SyncTrainingPort: + self, node_id, message: comm.SyncTrainingPort + ) -> comm.SyncTrainingPort: logger.info(f"try to sync port {message.port} from {node_id}") sync_ports: SyncNodeTrainingPorts = ( self._job_manager.sync_node_training_port(node_id, message.port) ) - return grpc.SyncTrainingPort( + return comm.SyncTrainingPort( port=sync_ports.training_port, newport=sync_ports.next_check_port ) - def _report_event(self, message: grpc.Event): + def _report_event(self, message: comm.Event): if self._error_monitor: self._error_monitor.report_event( message.event_type, @@ -695,16 +670,91 @@ def _report_event(self, message: grpc.Event): return True def _report_heartbeat( - self, node_type, node_id, message: grpc.HeartBeat - ) -> grpc.HeartbeatResponse: + self, node_type, node_id, message: comm.HeartBeat + ) -> comm.HeartbeatResponse: action = self._job_manager.collect_node_heart_beat( node_type, node_id, message.timestamp ) - grpc_action = grpc.DiagnosisAction( + grpc_action = comm.DiagnosisAction( action.__class__.__name__, action.to_json(), ) - return grpc.HeartbeatResponse(action=grpc_action) + return comm.HeartbeatResponse(action=grpc_action) + + +class HttpMasterServicer(MasterServicer, tornado.web.RequestHandler): + """Master service with http implementation.""" + + def __init__( + self, + task_manager, + job_manager, + speed_monitor: SpeedMonitor, + rdzv_managers: Dict[str, RendezvousManager], + diagnosis_manager: DiagnosisManager, + job_metric_collector=None, + elastic_ps_service=None, + sync_service=None, + error_monitor=None, + ): + super(HttpMasterServicer, self).__init__( + task_manager, + job_manager, + speed_monitor, + rdzv_managers, + diagnosis_manager, + job_metric_collector, + elastic_ps_service, + sync_service, + error_monitor, + ) + + def get_response(self): + return BaseMessage() + + def post(self, path): + data = self.get_body_argument("data", default=None) + request: BaseMessage = json.loads(data) + if path == "get": + return self.get(request, None) + elif path == "report": + return self.report(request, None) + else: + self.set_status(404) + self.write(f"No service found for {path}.") + + +class GrpcMasterServicer( + MasterServicer, elastic_training_pb2_grpc.MasterServicer +): + """Master service with grpc implementation.""" + + def __init__( + self, + task_manager, + job_manager, + speed_monitor: SpeedMonitor, + rdzv_managers: Dict[str, RendezvousManager], + diagnosis_manager: DiagnosisManager, + job_metric_collector=None, + elastic_ps_service=None, + sync_service=None, + error_monitor=None, + ): + super(GrpcMasterServicer, self).__init__( + task_manager, + job_manager, + speed_monitor, + rdzv_managers, + diagnosis_manager, + job_metric_collector, + elastic_ps_service, + sync_service, + error_monitor, + ) + + def get_response(self): + return elastic_training_pb2.Message() def create_master_service( @@ -720,17 +770,16 @@ def create_master_service( error_monitor=None, service_type=BasicClass.COMM_SERVICE_GRPC, max_threads=64, -) -> MasterServicer: - +): logger.info(f"Creating master {service_type} service with port: {port}") if service_type == BasicClass.COMM_SERVICE_GRPC: server = grpc_lib.server( futures.ThreadPoolExecutor(max_workers=max_threads), options=[ - ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH), + ("comm.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH), ( - "grpc.max_receive_message_length", + "comm.max_receive_message_length", GRPC.MAX_RECEIVE_MESSAGE_LENGTH, ), ], @@ -753,4 +802,5 @@ def create_master_service( server.add_insecure_port("[::]:{}".format(port)) return server else: - pass + server = TornadoHTTPServer("localhost", port, HttpMasterServicer) + return server diff --git a/dlrover/python/master/shard/task_manager.py b/dlrover/python/master/shard/task_manager.py index aad906b29..a5bab3cb8 100644 --- a/dlrover/python/master/shard/task_manager.py +++ b/dlrover/python/master/shard/task_manager.py @@ -18,7 +18,7 @@ from typing import Dict, List from dlrover.proto import elastic_training_pb2 -from dlrover.python.common import grpc +from dlrover.python.common import comm from dlrover.python.common.constants import NodeType from dlrover.python.common.log import default_logger as logger from dlrover.python.master.monitor.speed_monitor import SpeedMonitor @@ -123,7 +123,7 @@ def get_dataset_task(self, node_type, node_id, dataset_name): def get_dataset(self, dataset_name): return self._datasets.get(dataset_name, None) - def report_dataset_task(self, request: grpc.TaskResult, success: bool): + def report_dataset_task(self, request: comm.TaskResult, success: bool): """Report if the task is successful or not""" task_id = request.task_id @@ -180,7 +180,7 @@ def recover_tasks(self, node_type, node_id): ] if not ids: continue - request = grpc.TaskResult() + request = comm.TaskResult() recover_tasks = [] for id in ids: request.task_id = id diff --git a/dlrover/python/master/stats/job_collector.py b/dlrover/python/master/stats/job_collector.py index c15c6ff01..1bb02ea6b 100644 --- a/dlrover/python/master/stats/job_collector.py +++ b/dlrover/python/master/stats/job_collector.py @@ -17,8 +17,8 @@ from abc import ABCMeta, abstractmethod from typing import Dict, List +from dlrover.python.common.comm import ModelInfo from dlrover.python.common.constants import MemoryUnit -from dlrover.python.common.grpc import ModelInfo from dlrover.python.common.log import default_logger as logger from dlrover.python.master.monitor.speed_monitor import SpeedMonitor from dlrover.python.master.stats.reporter import JobMeta, StatsReporter diff --git a/dlrover/python/master/stats/reporter.py b/dlrover/python/master/stats/reporter.py index b2fafae66..c801a59b7 100644 --- a/dlrover/python/master/stats/reporter.py +++ b/dlrover/python/master/stats/reporter.py @@ -18,8 +18,8 @@ from dlrover.proto import brain_pb2 from dlrover.python.brain.client import GlobalBrainClient +from dlrover.python.common.comm import ModelInfo from dlrover.python.common.constants import ReporterType -from dlrover.python.common.grpc import ModelInfo from dlrover.python.common.log import default_logger as logger from dlrover.python.common.singleton import Singleton from dlrover.python.master.stats.training_metrics import ( diff --git a/dlrover/python/tests/test_agent_config_tuner.py b/dlrover/python/tests/test_agent_config_tuner.py index 15333f485..e0904048d 100644 --- a/dlrover/python/tests/test_agent_config_tuner.py +++ b/dlrover/python/tests/test_agent_config_tuner.py @@ -15,12 +15,12 @@ import os import unittest -from dlrover.python.common.constants import ConfigPath -from dlrover.python.common.grpc import ( +from dlrover.python.common.comm import ( DataLoaderConfig, OptimizerConfig, ParallelConfig, ) +from dlrover.python.common.constants import ConfigPath from dlrover.python.elastic_agent.config.paral_config_tuner import ( ParalConfigTuner, ) diff --git a/dlrover/python/tests/test_agent_monitor.py b/dlrover/python/tests/test_agent_monitor.py index 111c5126f..eb0669ad3 100644 --- a/dlrover/python/tests/test_agent_monitor.py +++ b/dlrover/python/tests/test_agent_monitor.py @@ -17,8 +17,8 @@ import unittest from unittest.mock import patch +from dlrover.python.common.comm import GPUStats from dlrover.python.common.constants import NodeEnv -from dlrover.python.common.grpc import GPUStats from dlrover.python.elastic_agent.master_client import ( MasterClient, build_master_client, diff --git a/dlrover/python/tests/test_common_util.py b/dlrover/python/tests/test_common_util.py index 3235adbae..6eef07c97 100644 --- a/dlrover/python/tests/test_common_util.py +++ b/dlrover/python/tests/test_common_util.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import socket import unittest import dlrover.python.util.common_util as cu @@ -23,3 +24,36 @@ def test_get_dlrover_version(self): def test_is_port_in_use(self): self.assertFalse(cu.is_port_in_use(65530)) + + def test_find_free_port(self): + port = cu.find_free_port() + self.assertTrue(port > 0) + port = cu.find_free_port_in_range(50001, 65535) + self.assertTrue(port > 50000) + + port = cu.find_free_port_in_range(50001, 65535, False) + self.assertTrue(port > 50000) + + ports = [] + for i in range(20): + ports.append(20000 + i) + port = cu.find_free_port_in_set(ports) + self.assertTrue(port in ports) + + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 10000)) + with self.assertRaises(RuntimeError): + cu.find_free_port_in_set([10000]) + with self.assertRaises(RuntimeError): + cu.find_free_port_in_range(10000, 10000) + s.close() + + def test_find_free_port_for_hccl(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 64003)) + port = cu.find_free_port_for_hccl() + self.assertEqual(port, 64004) + + +if __name__ == "__main__": + unittest.main() diff --git a/dlrover/python/tests/test_grpc_utils.py b/dlrover/python/tests/test_grpc_utils.py index 51e2bfa89..7d4b5edbf 100644 --- a/dlrover/python/tests/test_grpc_utils.py +++ b/dlrover/python/tests/test_grpc_utils.py @@ -11,44 +11,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import socket import unittest -from dlrover.python.common.grpc import ( +from dlrover.python.common.comm import ( Message, addr_connected, deserialize_message, - find_free_port, - find_free_port_for_hccl, - find_free_port_in_range, - find_free_port_in_set, ) class GRPCUtilTest(unittest.TestCase): - def test_find_free_port(self): - port = find_free_port() - self.assertTrue(port > 0) - port = find_free_port_in_range(50001, 65535) - self.assertTrue(port > 50000) - - port = find_free_port_in_range(50001, 65535, False) - self.assertTrue(port > 50000) - - ports = [] - for i in range(20): - ports.append(20000 + i) - port = find_free_port_in_set(ports) - self.assertTrue(port in ports) - - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("", 10000)) - with self.assertRaises(RuntimeError): - find_free_port_in_set([10000]) - with self.assertRaises(RuntimeError): - find_free_port_in_range(10000, 10000) - s.close() - def test_addr_connected(self): connected = addr_connected("") self.assertFalse(connected) @@ -63,12 +35,6 @@ def test_deserialize_message(self): de_message = deserialize_message(b"") self.assertIsNone(de_message) - def test_find_free_port_for_hccl(self): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("", 64003)) - port = find_free_port_for_hccl() - self.assertEqual(port, 64004) - if __name__ == "__main__": unittest.main() diff --git a/dlrover/python/tests/test_http_server.py b/dlrover/python/tests/test_http_server.py index 1ac3c0b5b..2a8a4a7da 100644 --- a/dlrover/python/tests/test_http_server.py +++ b/dlrover/python/tests/test_http_server.py @@ -18,44 +18,47 @@ import requests import tornado - -from dlrover.python.common.http_server import CustomHTTPServer from util.common_util import is_port_in_use +from dlrover.python.common.http_server import TornadoHTTPServer + TEST_SERVER_ADDR = "localhost" TEST_SERVER_PORT = 8000 class HttpServerClientTest(unittest.TestCase): - def setUp(self): self.server = None def tearDown(self): if self.server is not None: - self.server.stop_serving() + self.server.stop() self.server = None - def test_server_basic(self): - self.server = CustomHTTPServer(TEST_SERVER_ADDR, TEST_SERVER_PORT, TestRequestHandler) + def test_tornado_server_basic(self): + self.server = TornadoHTTPServer( + TEST_SERVER_ADDR, TEST_SERVER_PORT, TestRequestHandler + ) self.assertIsNotNone(self.server) self.assertFalse(is_port_in_use(TEST_SERVER_PORT)) self.assertFalse(self.server.is_serving()) - self.server.start_serving() + self.server.start() self.assertTrue(self.server.is_serving()) self.assertTrue(is_port_in_use(TEST_SERVER_PORT)) - self.server.start_serving() + self.server.start() self.assertTrue(self.server.is_serving()) active_threads_name = [t.name for t in threading.enumerate()] - self.assertIn(CustomHTTPServer.SERVING_THREAD_NAME, active_threads_name) + self.assertIn( + TornadoHTTPServer.SERVING_THREAD_NAME, active_threads_name + ) time.sleep(1) # test get request self._test_get_request() - self.server.stop_serving() + self.server.stop() self.assertFalse(self.server.is_serving()) def _test_get_request(self): @@ -68,17 +71,17 @@ def _test_get_request(self): raise e def test_server_concurrency(self): - self.server = CustomHTTPServer(TEST_SERVER_ADDR, TEST_SERVER_PORT, TestRequestHandler) - self.server.start_serving() + self.server = TornadoHTTPServer( + TEST_SERVER_ADDR, TEST_SERVER_PORT, TestRequestHandler + ) + self.server.start() futures = [] result_num = 0 client_size = 100 with ThreadPoolExecutor(max_workers=client_size) as executor: for i in range(client_size): - futures.append( - executor.submit(self._test_get_request) - ) + futures.append(executor.submit(self._test_get_request)) for future in as_completed(futures): if future.result().status_code == 200: @@ -86,7 +89,7 @@ def test_server_concurrency(self): self.assertEqual(len(futures), client_size) self.assertEqual(result_num, client_size) - self.server.stop_serving() + self.server.stop() class TestRequestHandler(tornado.web.RequestHandler): diff --git a/dlrover/python/tests/test_job_manager.py b/dlrover/python/tests/test_job_manager.py index e052e5f36..bfcbed7fc 100644 --- a/dlrover/python/tests/test_job_manager.py +++ b/dlrover/python/tests/test_job_manager.py @@ -22,6 +22,12 @@ from kubernetes import client from dlrover.proto import elastic_training_pb2 +from dlrover.python.common.comm import ( + DataLoaderConfig, + GPUStats, + OptimizerConfig, + ParallelConfig, +) from dlrover.python.common.constants import ( DistributionStrategy, ElasticJobLabel, @@ -32,12 +38,6 @@ NodeType, TrainingExceptionLevel, ) -from dlrover.python.common.grpc import ( - DataLoaderConfig, - GPUStats, - OptimizerConfig, - ParallelConfig, -) from dlrover.python.common.node import NodeGroupResource, NodeResource from dlrover.python.diagnosis.common.diagnosis_action import ( EventAction, diff --git a/dlrover/python/tests/test_master_client.py b/dlrover/python/tests/test_master_client.py index 9c309a026..c790170ec 100644 --- a/dlrover/python/tests/test_master_client.py +++ b/dlrover/python/tests/test_master_client.py @@ -14,16 +14,17 @@ import json import time import unittest +from typing import List from unittest import mock -from dlrover.python.common import grpc +from dlrover.python.common import comm +from dlrover.python.common.comm import DiagnosisAction, HeartbeatResponse from dlrover.python.common.constants import ( NodeEventType, NodeType, RendezvousName, TrainingExceptionLevel, ) -from dlrover.python.common.grpc import DiagnosisAction, HeartbeatResponse from dlrover.python.diagnosis.common.diagnosis_action import ( EventAction, NoAction, @@ -47,8 +48,8 @@ def test_open_channel(self): self._master_client.open_channel() def test_report_used_resource(self): - gpu_stats: list[grpc.GPUStats] = [ - grpc.GPUStats( + gpu_stats: List[comm.GPUStats] = [ + comm.GPUStats( index=0, total_memory_mb=24000, used_memory_mb=4000, @@ -121,7 +122,7 @@ def test_report(self): ts = int(time.time()) self._master_client.report_global_step(100, ts) - model_info = grpc.ModelInfo() + model_info = comm.ModelInfo() self._master_client.report_model_info(model_info) success = self._master_client.join_sync("test-sync") @@ -160,7 +161,7 @@ def test_get(self): config = self._master_client.get_paral_config() if config: - self.assertIsInstance(config, grpc.ParallelConfig) + self.assertIsInstance(config, comm.ParallelConfig) def test_num_nodes_waiting(self): rdzv_name = RendezvousName.ELASTIC_TRAINING diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index 667c7693f..c52731a45 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -18,7 +18,8 @@ import ray from dlrover.proto import elastic_training_pb2 -from dlrover.python.common import env_utils, grpc +from dlrover.python.common import comm, env_utils +from dlrover.python.common.comm import GPUStats from dlrover.python.common.constants import ( NodeEventType, NodeStatus, @@ -26,7 +27,6 @@ PSClusterVersionType, RendezvousName, ) -from dlrover.python.common.grpc import GPUStats from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric from dlrover.python.master.diagnosis.diagnosis_manager import DiagnosisManager from dlrover.python.master.elastic_training.elastic_ps import ElasticPsService @@ -98,14 +98,14 @@ def tearDown(self) -> None: def test_query_running_nodes(self): request = elastic_training_pb2.Message() - message = grpc.RunningNodesRequest() + message = comm.RunningNodesRequest() request.data = message.serialize() res = self.servicer.get(request, None) - ret: grpc.RunningNodes = grpc.deserialize_message(res.data) + ret: comm.RunningNodes = comm.deserialize_message(res.data) self.assertEqual(len(ret.nodes), 3) def test_dataset_service(self): - request = grpc.DatasetShardParams() + request = comm.DatasetShardParams() request.batch_size = 10 request.num_epochs = 1 request.dataset_size = 1000 @@ -119,19 +119,19 @@ def test_dataset_service(self): collector = self.job_metric_collector._stats_reporter self.assertEqual(collector._dataset_metric.get_size(), 1000) - request = grpc.TaskRequest("test") - task: grpc.Task = self.servicer._get_task(NodeType.WORKER, 0, request) + request = comm.TaskRequest("test") + task: comm.Task = self.servicer._get_task(NodeType.WORKER, 0, request) self.assertEqual(task.task_id, 0) self.assertEqual(task.shard.start, 0) self.assertEqual(task.shard.end, 100) - request = grpc.TaskResult(dataset_name="test", task_id=0) + request = comm.TaskResult(dataset_name="test", task_id=0) request.task_id = 0 request.dataset_name = "test" self.servicer._report_task_result(request) self.assertEqual(len(self.task_manager._datasets["test"].doing), 0) - request = grpc.ShardCheckpointRequest("test") + request = comm.ShardCheckpointRequest("test") request.dataset_name = "test" checkpoint = self.servicer._get_shard_checkpoint(request) @@ -141,7 +141,7 @@ def test_dataset_service(self): def test_metric_service(self): self.job_manager._init_nodes() self.job_manager._init_job_auto_scaler() - request = grpc.ResourceStats(gpu_stats=[]) + request = comm.ResourceStats(gpu_stats=[]) request.memory = 4096 request.cpu = 2 gpu_stats: list[GPUStats] = [ @@ -153,7 +153,7 @@ def test_metric_service(self): ) ] for gpu in gpu_stats: - gpu_stats_message = grpc.GPUStats() + gpu_stats_message = comm.GPUStats() gpu_stats_message.index = gpu.index gpu_stats_message.total_memory_mb = gpu.total_memory_mb gpu_stats_message.used_memory_mb = gpu.used_memory_mb @@ -162,8 +162,8 @@ def test_metric_service(self): self.servicer._update_node_resource_usage(NodeType.WORKER, 0, request) self.servicer._update_node_resource_usage(NodeType.PS, 0, request) - request = grpc.ModelInfo( - tensor_stats=grpc.TensorStats(), op_stats=grpc.OpStats() + request = comm.ModelInfo( + tensor_stats=comm.TensorStats(), op_stats=comm.OpStats() ) request.tensor_stats.variable_count = 100 request.tensor_stats.total_variable_size = 10000 @@ -184,7 +184,7 @@ def test_metric_service(self): ps0.status = NodeStatus.RUNNING self.job_context.update_job_node(ps0) - request = grpc.GlobalStep() + request = comm.GlobalStep() self.task_manager._speed_monitor.add_running_worker(NodeType.WORKER, 0) self.task_manager._speed_monitor.set_target_worker_num(1) ts = int(time.time()) @@ -235,30 +235,30 @@ def test_get(self): self.assertEqual(response.data, b"") def test_get_cluster_version(self): - message = grpc.ClusterVersionRequest(NodeType.WORKER, 0, "local") + message = comm.ClusterVersionRequest(NodeType.WORKER, 0, "local") request = elastic_training_pb2.Message() request.data = message.serialize() response = self.servicer.get(request, None) - res_msg = grpc.deserialize_message(response.data) + res_msg = comm.deserialize_message(response.data) self.assertEqual(res_msg.version, 0) - message = grpc.ClusterVersionRequest(NodeType.PS, 0, "local") + message = comm.ClusterVersionRequest(NodeType.PS, 0, "local") request = elastic_training_pb2.Message() request.data = message.serialize() response = self.servicer.get(request, None) - res_msg = grpc.deserialize_message(response.data) + res_msg = comm.deserialize_message(response.data) self.assertEqual(res_msg.version, 0) def test_get_training_status(self): - message = grpc.TrainingStatusRequest() + message = comm.TrainingStatusRequest() request = elastic_training_pb2.Message() request.data = message.serialize() response = self.servicer.get(request, None) - res_msg: grpc.TrainingStatus = grpc.deserialize_message(response.data) + res_msg: comm.TrainingStatus = comm.deserialize_message(response.data) self.assertEqual(res_msg.status, 3) def test_num_nodes_waiting(self): - message = grpc.WaitingNodeNumRequest( + message = comm.WaitingNodeNumRequest( 0, 8, RendezvousName.ELASTIC_TRAINING ) request = elastic_training_pb2.Message() @@ -267,7 +267,7 @@ def test_num_nodes_waiting(self): RendezvousName.ELASTIC_TRAINING ]._waiting_nodes = {0: 8} response = self.servicer.get(request, None) - res_msg: grpc.RendezvousState = grpc.deserialize_message(response.data) + res_msg: comm.RendezvousState = comm.deserialize_message(response.data) self.assertEqual(res_msg.waiting_num, 1) def test_report(self): @@ -278,8 +278,8 @@ def test_report(self): def test_report_task_result(self): request = elastic_training_pb2.Message() - message = grpc.TaskResult("test", 0, "error") - dataset_params = grpc.DatasetShardParams( + message = comm.TaskResult("test", 0, "error") + dataset_params = comm.DatasetShardParams( batch_size=64, num_epochs=1, dataset_size=10000, @@ -292,7 +292,7 @@ def test_report_task_result(self): response = self.servicer.report(request, None) self.assertFalse(response.success, False) - message = grpc.TaskResult("test", 0, "") + message = comm.TaskResult("test", 0, "") request.data = message.serialize() self.servicer._start_autoscale = False self.servicer._speed_monitor.completed_global_step == 0 @@ -303,7 +303,7 @@ def test_report_task_result(self): def test_update_cluster_version(self): request = elastic_training_pb2.Message() - message = grpc.ClusterVersion( + message = comm.ClusterVersion( NodeType.WORKER, 0, PSClusterVersionType.LOCAL, 1 ) request.data = message.serialize() @@ -313,7 +313,7 @@ def test_update_cluster_version(self): self.servicer._elastic_ps_service._worker_local_version[0], 1 ) - message = grpc.ClusterVersion( + message = comm.ClusterVersion( NodeType.WORKER, 0, PSClusterVersionType.RESTORED, 1 ) request.data = message.serialize() @@ -323,7 +323,7 @@ def test_update_cluster_version(self): self.servicer._elastic_ps_service._worker_restored_version[0], 1 ) - message = grpc.ClusterVersion( + message = comm.ClusterVersion( NodeType.PS, 0, PSClusterVersionType.GLOBAL, 1 ) request.data = message.serialize() @@ -333,7 +333,7 @@ def test_update_cluster_version(self): def test_sync(self): request = elastic_training_pb2.Message() - message = grpc.SyncJoin("test") + message = comm.SyncJoin("test") request.data = message.serialize() request.node_type = NodeType.WORKER request.node_id = 0 @@ -342,7 +342,7 @@ def test_sync(self): sync_obj = self.servicer._sync_service._sync_objs_target["test"] self.assertEqual(len(sync_obj), 2) - message = grpc.SyncFinish("test") + message = comm.SyncFinish("test") request.data = message.serialize() response = self.servicer.report(request, None) self.assertFalse(response.success) @@ -351,7 +351,7 @@ def test_sync(self): response = self.servicer.report(request, None) self.assertTrue(response.success) - message = grpc.SyncBarrier("test") + message = comm.SyncBarrier("test") request.data = message.serialize() response = self.servicer.report(request, None) self.assertFalse(response.success) @@ -361,41 +361,41 @@ def test_sync(self): self.assertTrue(response.success) def test_get_paral_config(self): - message = grpc.ParallelConfigRequest() + message = comm.ParallelConfigRequest() request = elastic_training_pb2.Message() request.data = message.serialize() self.servicer.report(request, None) response = self.servicer.get(request, None) - config = grpc.deserialize_message(response.data) + config = comm.deserialize_message(response.data) if config: - self.assertIsInstance(config, grpc.ParallelConfig) + self.assertIsInstance(config, comm.ParallelConfig) def test_get_straggler(self): - message = grpc.StragglerExistRequest() + message = comm.StragglerExistRequest() request = elastic_training_pb2.Message() request.data = message.serialize() self.servicer.report(request, None) response = self.servicer.get(request, None) - config = grpc.deserialize_message(response.data) - self.assertIsInstance(config, grpc.NetworkCheckResult) + config = comm.deserialize_message(response.data) + self.assertIsInstance(config, comm.NetworkCheckResult) def test_check_hardware_reset(self): - message = grpc.CheckHardwareResetRequest() + message = comm.CheckHardwareResetRequest() request = elastic_training_pb2.Message() request.data = message.serialize() response = self.servicer.get(request, None) - config = grpc.deserialize_message(response.data) - self.assertIsInstance(config, grpc.ParallelConfig) + config = comm.deserialize_message(response.data) + self.assertIsInstance(config, comm.ParallelConfig) self.assertFalse(config.restart) def test_join_rendezvous(self): - request = grpc.JoinRendezvousRequest( + request = comm.JoinRendezvousRequest( 0, 8, RendezvousName.ELASTIC_TRAINING ) self.servicer._join_rendezvous(request) res = self.servicer._num_nodes_waiting(RendezvousName.ELASTIC_TRAINING) self.assertEqual(res.waiting_num, 1) - request = grpc.JoinRendezvousRequest( + request = comm.JoinRendezvousRequest( 0, 8, RendezvousName.NETWORK_CHECK ) self.servicer._join_rendezvous(request) @@ -405,7 +405,7 @@ def test_join_rendezvous(self): def test_report_heartbeat(self): request = elastic_training_pb2.Message() ts = int(time.time()) - message = grpc.HeartBeat(ts) + message = comm.HeartBeat(ts) request.data = message.serialize() request.node_type = NodeType.WORKER request.node_id = 0 @@ -415,7 +415,7 @@ def test_report_heartbeat(self): self.assertEqual(worker0.heartbeat_time, ts) def test_sync_checkpoint(self): - message = grpc.NodeCheckpointState(step=100) + message = comm.NodeCheckpointState(step=100) et_name = RendezvousName.ELASTIC_TRAINING rdzv_manager = self.servicer._rdzv_managers[et_name] rdzv_manager._latest_rdzv_nodes = [0, 1] @@ -433,7 +433,7 @@ def test_report_node_diagnosis_data(self): is_final_result=True, ) - request = grpc.DiagnosisReportData( + request = comm.DiagnosisReportData( test.__class__.__name__, test.to_json(), test.node_rank, @@ -441,7 +441,7 @@ def test_report_node_diagnosis_data(self): self.assertTrue(self.servicer._report_node_diagnosis_data(request)) def test_deal_with_reported_node_event(self): - request = grpc.NodeEvent(node=grpc.NodeMeta()) + request = comm.NodeEvent(node=comm.NodeMeta()) task_id = 1 task_type = NodeType.PS request.node.type = task_type @@ -522,7 +522,7 @@ def tearDown(self) -> None: self.job_context.clear_job_nodes() def test_update_node_addr(self): - request = grpc.NodeMeta() + request = comm.NodeMeta() task_id = 1 task_type = NodeType.PS addr = "localhost:5001" diff --git a/dlrover/python/tests/test_strategy_generator.py b/dlrover/python/tests/test_strategy_generator.py index b3cf20fe1..4cbfa179b 100644 --- a/dlrover/python/tests/test_strategy_generator.py +++ b/dlrover/python/tests/test_strategy_generator.py @@ -16,12 +16,12 @@ from typing import Dict, List from unittest.mock import patch -from dlrover.python.common.constants import NodeType -from dlrover.python.common.grpc import ( +from dlrover.python.common.comm import ( DataLoaderConfig, GPUStats, OptimizerConfig, ) +from dlrover.python.common.constants import NodeType from dlrover.python.common.node import Node from dlrover.python.master.hyperparams.simple_strategy_generator import ( SimpleStrategyGenerator, diff --git a/dlrover/python/tests/test_task_manager.py b/dlrover/python/tests/test_task_manager.py index f8eba0dde..6f5d6f16b 100644 --- a/dlrover/python/tests/test_task_manager.py +++ b/dlrover/python/tests/test_task_manager.py @@ -15,8 +15,8 @@ import unittest from dlrover.proto import elastic_training_pb2 +from dlrover.python.common.comm import TaskResult from dlrover.python.common.constants import NodeType -from dlrover.python.common.grpc import TaskResult from dlrover.python.master.shard.task_manager import DatasetShardCheckpoint from dlrover.python.tests.test_utils import ( create_task_manager, diff --git a/dlrover/python/tests/test_utils.py b/dlrover/python/tests/test_utils.py index ed9e426ba..8b745a64f 100644 --- a/dlrover/python/tests/test_utils.py +++ b/dlrover/python/tests/test_utils.py @@ -27,7 +27,6 @@ NodeType, PlatformType, ) -from dlrover.python.common.grpc import find_free_port from dlrover.python.common.node import NodeGroupResource, NodeResource from dlrover.python.master.local_master import LocalJobMaster from dlrover.python.master.monitor.speed_monitor import SpeedMonitor @@ -35,6 +34,7 @@ from dlrover.python.master.shard.task_manager import TaskManager from dlrover.python.scheduler.job import JobArgs, LocalJobArgs, NodeArgs from dlrover.python.scheduler.kubernetes import k8sClient +from dlrover.python.util.common_util import find_free_port WITH_TO_DELETED = "WITH_TO_DELETED" diff --git a/dlrover/python/util/common_util.py b/dlrover/python/util/common_util.py index 1ae36cbe6..20de4eee3 100644 --- a/dlrover/python/util/common_util.py +++ b/dlrover/python/util/common_util.py @@ -12,10 +12,14 @@ # limitations under the License. import importlib.metadata +import random import re import socket +from contextlib import closing import dlrover.python.util.file_util as fu +from dlrover.python.common.constants import AscendConstants +from dlrover.python.common.log import default_logger as logger def get_dlrover_version(): @@ -68,3 +72,71 @@ def is_port_in_use(port=0) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: result = sock.connect_ex(("localhost", int(port))) return result == 0 + + +def find_free_port(port=0): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", port)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +def find_free_port_in_range(start=0, end=65535, random_port=True): + """Find a free port from a range.""" + bind_ports = set() + while True: + if random_port: + port = random.randint(start, end) + else: + port = start + len(bind_ports) + if port in bind_ports: + continue + try: + return find_free_port(port) + except OSError: + logger.warning(f"Socket creation attempt failed with {port}.") + bind_ports.add(port) + if len(bind_ports) == end - start + 1: + break + raise RuntimeError(f"Fail to find a free port in [{start}, {end})") + + +def find_free_port_in_set(ports): + for port in ports: + try: + return find_free_port(port) + except OSError: + logger.warning(f"Socket creation attempt failed with {port}.") + raise RuntimeError(f"Fail to find a free port in {ports}") + + +def find_free_port_for_hccl( + start=AscendConstants.HCCL_PORT_START_DEFAULT, +) -> int: + max_port = 65500 + cur_start = start + end = start + 10000 + if end > max_port: + end = max_port + logger.info(f"Try to find available port for hccl from {start}") + checking_port = 0 + while True: + try: + cur_end = cur_start + AscendConstants.NPU_PER_NODE + for port in range(cur_start, cur_end): + checking_port = port + find_free_port(port) + logger.info(f"Find available port start from: {cur_start}") + break + except OSError: + logger.warning( + f"Target port has already been used: {checking_port}." + ) + if checking_port > 0: + cur_start = checking_port + 1 + else: + cur_start = cur_start + AscendConstants.NPU_PER_NODE + if cur_start > end: + cur_start = 0 + break + return cur_start diff --git a/dlrover/trainer/tests/torch/checkpoint_egine_test.py b/dlrover/trainer/tests/torch/checkpoint_egine_test.py index 210de9a53..55e9242a9 100644 --- a/dlrover/trainer/tests/torch/checkpoint_egine_test.py +++ b/dlrover/trainer/tests/torch/checkpoint_egine_test.py @@ -26,7 +26,6 @@ import torch.optim as optim from dlrover.python.common.constants import CheckpointConstant, NodeEnv -from dlrover.python.common.grpc import find_free_port from dlrover.python.common.multi_process import clear_sock_dir from dlrover.python.common.storage import PosixDiskStorage from dlrover.python.elastic_agent.master_client import ( @@ -41,6 +40,7 @@ TempDirCheckpointSaver, ) from dlrover.python.tests.test_utils import start_local_master +from dlrover.python.util.common_util import find_free_port from dlrover.trainer.torch.flash_checkpoint.deepspeed_engine import ( DeepSpeedCheckpointEngine, ) diff --git a/dlrover/trainer/tests/torch/elastic_dataloader_test.py b/dlrover/trainer/tests/torch/elastic_dataloader_test.py index 5fb0d687d..ba6850c7e 100644 --- a/dlrover/trainer/tests/torch/elastic_dataloader_test.py +++ b/dlrover/trainer/tests/torch/elastic_dataloader_test.py @@ -18,7 +18,7 @@ import numpy as np from torch.utils.data import Dataset -from dlrover.python.common.grpc import ParallelConfig +from dlrover.python.common.comm import ParallelConfig from dlrover.trainer.torch.elastic.dataloader import ElasticDataLoader diff --git a/dlrover/trainer/tests/torch/elastic_test.py b/dlrover/trainer/tests/torch/elastic_test.py index 2db720d90..934d5b233 100644 --- a/dlrover/trainer/tests/torch/elastic_test.py +++ b/dlrover/trainer/tests/torch/elastic_test.py @@ -21,8 +21,8 @@ import torch from torch.utils.data import Dataset +from dlrover.python.common.comm import ParallelConfig from dlrover.python.common.constants import ConfigPath -from dlrover.python.common.grpc import ParallelConfig from dlrover.trainer.torch.elastic.dataloader import ElasticDataLoader from dlrover.trainer.torch.elastic.trainer import ( ElasticTrainer, diff --git a/dlrover/trainer/tests/torch/fsdp_ckpt_test.py b/dlrover/trainer/tests/torch/fsdp_ckpt_test.py index ecbdbab5b..6dbcda8dc 100644 --- a/dlrover/trainer/tests/torch/fsdp_ckpt_test.py +++ b/dlrover/trainer/tests/torch/fsdp_ckpt_test.py @@ -46,7 +46,6 @@ WriteItemType, ) -from dlrover.python.common import grpc from dlrover.python.common.constants import CheckpointConstant from dlrover.python.common.multi_process import SharedMemory, clear_sock_dir from dlrover.python.common.storage import PosixDiskStorage @@ -59,6 +58,7 @@ SharedMemoryHandler, ) from dlrover.python.tests.test_utils import start_local_master +from dlrover.python.util.common_util import find_free_port from dlrover.trainer.torch.flash_checkpoint.fsdp import FsdpShardCheckpointer from dlrover.trainer.torch.flash_checkpoint.fsdp_engine import ( FileReader, @@ -170,7 +170,7 @@ def setUp(self): AsyncCheckpointSaver.start_async_saving_ckpt() os.environ["LOCAL_RANK"] = "0" os.environ["LOCAL_WORLD_SIZE"] = "1" - port = grpc.find_free_port() + port = find_free_port() set_torch_dist_env(port) dist.init_process_group(backend="gloo") diff --git a/dlrover/trainer/torch/elastic_run.py b/dlrover/trainer/torch/elastic_run.py index 7761d104f..bfb67518b 100644 --- a/dlrover/trainer/torch/elastic_run.py +++ b/dlrover/trainer/torch/elastic_run.py @@ -106,7 +106,7 @@ ) import dlrover.python.util.common_util as cu -from dlrover.python.common import env_utils, grpc +from dlrover.python.common import comm, env_utils from dlrover.python.common.constants import ( Accelerators, NodeEnv, @@ -247,7 +247,7 @@ def _launch_dlrover_local_master(master_addr, job_name, node_num): logger.info(f"Start dlrover master with addr {master_addr}") if not master_addr: host = "127.0.0.1" - port = grpc.find_free_port() + port = cu.find_free_port() else: host = master_addr.split(":")[0] port = int(master_addr.split(":")[1]) @@ -418,7 +418,7 @@ def run(args): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") job_name = os.getenv(NodeEnv.JOB_NAME, f"standalone_{timestamp}") os.environ[NodeEnv.TORCHELASTIC_RUN_ID] = job_name - dlrover_master_ready = grpc.addr_connected(master_addr) + dlrover_master_ready = comm.addr_connected(master_addr) _, max_nodes = parse_min_max_nnodes(args.nnodes) if not dlrover_master_ready and node_rank == 0: # Only start the dlrover master on the rank-0 node. From c253c0e9310a6814e732c2872a25fc0f3af48421 Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Wed, 4 Dec 2024 20:08:20 +0800 Subject: [PATCH 04/20] add deps --- docker/ci.dockerfile | 2 +- scripts/ci_install.sh | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/ci.dockerfile b/docker/ci.dockerfile index 34b89a126..ff4d3efe8 100644 --- a/docker/ci.dockerfile +++ b/docker/ci.dockerfile @@ -34,7 +34,7 @@ RUN /install-go.bash ${GO_MIRROR_URL} && rm /install-go.bash COPY docker/scripts/install-protobuf.bash / RUN /install-protobuf.bash && rm /install-protobuf.bash -# Install Pre-commit +# Install python deps RUN pip install pre-commit pytest kubernetes grpcio-tools psutil \ deprecated -i https://mirrors.aliyun.com/pypi/simple/ diff --git a/scripts/ci_install.sh b/scripts/ci_install.sh index 44d752abf..26fa7191b 100644 --- a/scripts/ci_install.sh +++ b/scripts/ci_install.sh @@ -21,6 +21,7 @@ pip install pyhocon pip install pytest-cov pip install pytest-ordering pip install packaging +pip install tornado pip install tensorflow==2.13.0 pip install deepspeed==0.12.6 pip install accelerate==0.29.2 From cb57a9fcb9a560d445cf63f43bca28ab213fb789 Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Thu, 5 Dec 2024 15:42:42 +0800 Subject: [PATCH 05/20] fix ut --- .github/actions/dlrover-python-test/action.yml | 2 +- .../dlrover-system-test-criteo-deeprec/action.yaml | 2 +- .github/actions/dlrover-system-test-deepfm/action.yaml | 2 +- .github/actions/dlrover-system-test-tf2/action.yaml | 3 +-- dlrover/python/tests/test_http_server.py | 2 +- scripts/ci_install.sh | 8 +++++++- 6 files changed, 12 insertions(+), 7 deletions(-) diff --git a/.github/actions/dlrover-python-test/action.yml b/.github/actions/dlrover-python-test/action.yml index c4530d968..7dfb00b21 100644 --- a/.github/actions/dlrover-python-test/action.yml +++ b/.github/actions/dlrover-python-test/action.yml @@ -7,7 +7,7 @@ runs: args: - "/bin/bash" - "-c" - - "sh scripts/ci_install.sh && python -m grpc_tools.protoc -I. \ + - "sh scripts/ci_install.sh basic && python -m grpc_tools.protoc -I. \ dlrover/proto/*.proto --python_out=. --grpc_python_out=. \ && ROLE_NAME=dlrover-trainer \ python -m pytest --durations=10 dlrover/python/tests dlrover/trainer/tests \ diff --git a/.github/actions/dlrover-system-test-criteo-deeprec/action.yaml b/.github/actions/dlrover-system-test-criteo-deeprec/action.yaml index 450b63ae6..1ca373736 100644 --- a/.github/actions/dlrover-system-test-criteo-deeprec/action.yaml +++ b/.github/actions/dlrover-system-test-criteo-deeprec/action.yaml @@ -7,7 +7,7 @@ runs: args: - "/bin/bash" - "-c" - - " python -m grpc_tools.protoc -I. \ + - "sh scripts/ci_install.sh basic && python -m grpc_tools.protoc -I. \ dlrover/proto/*.proto --python_out=. --grpc_python_out=. \ && export PYTHONPATH=`pwd` \ && cd examples/tensorflow/criteo_deeprec\ diff --git a/.github/actions/dlrover-system-test-deepfm/action.yaml b/.github/actions/dlrover-system-test-deepfm/action.yaml index 01c74b9ef..08264e373 100644 --- a/.github/actions/dlrover-system-test-deepfm/action.yaml +++ b/.github/actions/dlrover-system-test-deepfm/action.yaml @@ -7,7 +7,7 @@ runs: args: - "/bin/bash" - "-c" - - " python -m grpc_tools.protoc -I. \ + - "sh scripts/ci_install.sh basic && python -m grpc_tools.protoc -I. \ dlrover/proto/*.proto --python_out=. --grpc_python_out=. \ && pip install deepctr deprecated\ && export PYTHONPATH=`pwd` \ diff --git a/.github/actions/dlrover-system-test-tf2/action.yaml b/.github/actions/dlrover-system-test-tf2/action.yaml index 9f2ff0d3e..26a86d115 100644 --- a/.github/actions/dlrover-system-test-tf2/action.yaml +++ b/.github/actions/dlrover-system-test-tf2/action.yaml @@ -7,8 +7,7 @@ runs: args: - "/bin/bash" - "-c" - - "pip install protobuf==3.20 kubernetes grpcio-tools psutil deprecated\ -&& python -m grpc_tools.protoc -I. \ + - "sh scripts/ci_install.sh basic && python -m grpc_tools.protoc -I. \ dlrover/proto/*.proto --python_out=. --grpc_python_out=. \ && pip install deepctr \ && pip install h5py==3.7.0 \ diff --git a/dlrover/python/tests/test_http_server.py b/dlrover/python/tests/test_http_server.py index 2a8a4a7da..f96277478 100644 --- a/dlrover/python/tests/test_http_server.py +++ b/dlrover/python/tests/test_http_server.py @@ -18,9 +18,9 @@ import requests import tornado -from util.common_util import is_port_in_use from dlrover.python.common.http_server import TornadoHTTPServer +from dlrover.python.util.common_util import is_port_in_use TEST_SERVER_ADDR = "localhost" TEST_SERVER_PORT = 8000 diff --git a/scripts/ci_install.sh b/scripts/ci_install.sh index 26fa7191b..2abe95fa7 100644 --- a/scripts/ci_install.sh +++ b/scripts/ci_install.sh @@ -16,12 +16,18 @@ pip install kubernetes pip install grpcio-tools pip install psutil pip install deprecated -pip install 'ray[default]' pip install pyhocon pip install pytest-cov pip install pytest-ordering pip install packaging pip install tornado + +if [ "$1" = "basic" ]; then + echo "'Basic' dependencies only." + exit 0 +fi + +pip install 'ray[default]' pip install tensorflow==2.13.0 pip install deepspeed==0.12.6 pip install accelerate==0.29.2 From 95869dd13d1f042e1d5ccd3b8a0399236a845a5b Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Thu, 5 Dec 2024 15:47:36 +0800 Subject: [PATCH 06/20] fix ut --- .github/actions/dlrover-python-test/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/dlrover-python-test/action.yml b/.github/actions/dlrover-python-test/action.yml index 7dfb00b21..c4530d968 100644 --- a/.github/actions/dlrover-python-test/action.yml +++ b/.github/actions/dlrover-python-test/action.yml @@ -7,7 +7,7 @@ runs: args: - "/bin/bash" - "-c" - - "sh scripts/ci_install.sh basic && python -m grpc_tools.protoc -I. \ + - "sh scripts/ci_install.sh && python -m grpc_tools.protoc -I. \ dlrover/proto/*.proto --python_out=. --grpc_python_out=. \ && ROLE_NAME=dlrover-trainer \ python -m pytest --durations=10 dlrover/python/tests dlrover/trainer/tests \ From 4c10514de0c90a1928978eb8e3f90279ff590e3c Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Fri, 6 Dec 2024 11:03:19 +0800 Subject: [PATCH 07/20] stash --- dlrover/python/common/constants.py | 7 ++- dlrover/python/common/http_server.py | 11 ++++- dlrover/python/master/servicer.py | 6 +-- dlrover/python/tests/test_servicer.py | 70 +++++++++++++++++++++++++-- 4 files changed, 83 insertions(+), 11 deletions(-) diff --git a/dlrover/python/common/constants.py b/dlrover/python/common/constants.py index be5fbb19e..348c5a14c 100644 --- a/dlrover/python/common/constants.py +++ b/dlrover/python/common/constants.py @@ -14,8 +14,6 @@ class BasicClass(object): LOG_LEVEL_ENV = "DLROVER_LOG_LEVEL" - COMM_SERVICE_GRPC = "grpc" - COMM_SERVICE_HTTP = "http" class PriorityClass(object): @@ -30,6 +28,11 @@ class PlatformType(object): LOCAL = "local" +class CommunicationType(object): + COMM_SERVICE_GRPC = "grpc" + COMM_SERVICE_HTTP = "http" + + class ElasticJobApi(object): GROUP = "elastic.iml.github.io" VERION = "v1alpha1" diff --git a/dlrover/python/common/http_server.py b/dlrover/python/common/http_server.py index 701949f37..39456668a 100644 --- a/dlrover/python/common/http_server.py +++ b/dlrover/python/common/http_server.py @@ -21,6 +21,8 @@ class CustomHTTPServer(abc.ABC): + """Self designed http server.""" + def __init__(self, address, port, handler_class): self._address = address self._port = port @@ -44,8 +46,13 @@ def start(self): pass @abc.abstractmethod - def stop(self): - """Stop the server.""" + def stop(self, grace=None): + """ + Stop the server. + + Arg: + grace (Optional[float]): Grace period. + """ pass diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index 921ebd62d..948e48425 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -34,7 +34,7 @@ NodeType, RendezvousName, TrainingExceptionLevel, - TrainingLoopStatus, + TrainingLoopStatus, CommunicationType, ) from dlrover.python.common.global_context import Context from dlrover.python.common.http_server import TornadoHTTPServer @@ -768,12 +768,12 @@ def create_master_service( elastic_ps_service, sync_service, error_monitor=None, - service_type=BasicClass.COMM_SERVICE_GRPC, + service_type=CommunicationType.COMM_SERVICE_GRPC, max_threads=64, ): logger.info(f"Creating master {service_type} service with port: {port}") - if service_type == BasicClass.COMM_SERVICE_GRPC: + if service_type == CommunicationType.COMM_SERVICE_GRPC: server = grpc_lib.server( futures.ThreadPoolExecutor(max_workers=max_threads), options=[ diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index c52731a45..e79a8151d 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -15,7 +15,7 @@ import unittest from unittest import mock -import ray +# import ray from dlrover.proto import elastic_training_pb2 from dlrover.python.common import comm, env_utils @@ -25,7 +25,7 @@ NodeStatus, NodeType, PSClusterVersionType, - RendezvousName, + RendezvousName, CommunicationType, ) from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric from dlrover.python.master.diagnosis.diagnosis_manager import DiagnosisManager @@ -38,7 +38,8 @@ from dlrover.python.master.monitor.speed_monitor import SpeedMonitor from dlrover.python.master.node.dist_job_manager import create_job_manager from dlrover.python.master.node.job_context import get_job_context -from dlrover.python.master.servicer import GrpcMasterServicer +from dlrover.python.master.servicer import GrpcMasterServicer, \ + create_master_service from dlrover.python.master.shard.task_manager import TaskManager from dlrover.python.master.stats.job_collector import JobMetricCollector from dlrover.python.tests.test_utils import ( @@ -51,7 +52,68 @@ ray_event_queue = RayEventQueue.singleton_instance() -class MasterServicerTest(unittest.TestCase): +class MasterServicerBasicTest(unittest.TestCase): + def setUp(self) -> None: + self.grpc_servicer = create_master_service( + 8080, + None, + None, + None, + None, + None, + None, + None, + None, + None, + service_type=CommunicationType.COMM_SERVICE_GRPC, + ) + + def tearDown(self) -> None: + pass + + def test_http_start_and_stop(self): + http_servicer = create_master_service( + 8081, + None, + None, + None, + None, + None, + None, + None, + None, + None, + service_type=CommunicationType.COMM_SERVICE_HTTP, + ) + self.assertIsNotNone(http_servicer) + self.assertFalse(http_servicer.is_serving()) + + http_servicer.start() + self.assertTrue(http_servicer.is_serving()) + + http_servicer.stop() + self.assertFalse(http_servicer.is_serving()) + + def test_grpc_start_and_stop(self): + grpc_servicer = create_master_service( + 8081, + None, + None, + None, + None, + None, + None, + None, + None, + None, + service_type=CommunicationType.COMM_SERVICE_GRPC, + ) + self.assertIsNotNone(grpc_servicer) + grpc_servicer.start() + grpc_servicer.stop(grace=None) + + +class MasterServicerFunctionalTest(unittest.TestCase): def setUp(self) -> None: mock_k8s_client() params = MockK8sPSJobArgs() From d8f390cb731bf805eaf120a8e8f7fe8b514fa47b Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Thu, 2 Jan 2025 19:29:41 +0800 Subject: [PATCH 08/20] fix --- dlrover/python/common/comm.py | 2 +- dlrover/python/common/constants.py | 5 +- dlrover/python/common/global_context.py | 4 +- dlrover/python/common/http_server.py | 13 +- dlrover/python/elastic_agent/master_client.py | 129 +++++++++++------- .../python/elastic_agent/torch/training.py | 2 +- dlrover/python/master/scaler/pod_scaler.py | 8 ++ dlrover/python/master/servicer.py | 82 ++++++++--- dlrover/python/tests/test_http_server.py | 23 +++- dlrover/python/tests/test_master_client.py | 4 +- dlrover/python/tests/test_servicer.py | 66 ++++++--- 11 files changed, 238 insertions(+), 100 deletions(-) diff --git a/dlrover/python/common/comm.py b/dlrover/python/common/comm.py index c42d99f58..22d0e83b4 100644 --- a/dlrover/python/common/comm.py +++ b/dlrover/python/common/comm.py @@ -97,7 +97,7 @@ def serialize(self): class BaseMessage(Message): node_id: int = -1 node_type: str = "" - data: bytes = bytes() + data: str = "" @dataclass diff --git a/dlrover/python/common/constants.py b/dlrover/python/common/constants.py index 4282c3d67..6c08a94d1 100644 --- a/dlrover/python/common/constants.py +++ b/dlrover/python/common/constants.py @@ -253,6 +253,7 @@ class TrainingLoopStatus(object): class NodeEnv(object): RELAUNCHED_POD = "RELAUNCHED_POD" DLROVER_MASTER_ADDR = "DLROVER_MASTER_ADDR" + DLROVER_MASTER_SERVICE_TYPE = "DLROVER_MASTER_SERVICE_TYPE" GRPC_ENABLE_FORK = "GRPC_ENABLE_FORK_SUPPORT" GRPC_POLL_STRATEGY = "GRPC_POLL_STRATEGY" POD_NAME = "POD_NAME" @@ -363,8 +364,8 @@ class JobConstant(object): INSUFFICIENT_NODE_TIMEOUT_DEFAULT_MIN = 600 INSUFFICIENT_NODE_TIMEOUT_DEFAULT_MAX = 3600 PENDING_NODE_TIMEOUT_DEFAULT_MIN = 600 - # grpc timeout 60s - MASTER_CLIENT_GRPC_DEFAULT_TIMEOUT = 60 + # timeout 60s + MASTER_CLIENT_DEFAULT_TIMEOUT = 60 # sleep 3s on NetworkFailureReason.WAITING_NODE MASTER_CLIENT_CHECK_FAULT_TIMEOUT = 3 # sleep 3s on NetworkFailureReason.WAITING_NODE diff --git a/dlrover/python/common/global_context.py b/dlrover/python/common/global_context.py index 572dc6446..1f29c81c0 100644 --- a/dlrover/python/common/global_context.py +++ b/dlrover/python/common/global_context.py @@ -13,7 +13,7 @@ import os -from dlrover.python.common.constants import UserEnv +from dlrover.python.common.constants import CommunicationType, UserEnv from dlrover.python.common.log import default_logger as logger from dlrover.python.common.singleton import Singleton from dlrover.python.util.common_util import ( @@ -41,6 +41,7 @@ class ConfigKeys(object): class DefaultValues(object): + SERVICE_TYPE = CommunicationType.COMM_SERVICE_GRPC TRAIN_SPEED_RECORD_NUM = 50 SEC_TO_START_AUTOSCALE_WORKER = 90 STEP_TO_ADJUST_WORKER = 200 @@ -64,6 +65,7 @@ class DefaultValues(object): class Context(Singleton): def __init__(self): + self.master_service_type = DefaultValues.SERVICE_TYPE self.train_speed_record_num = DefaultValues.TRAIN_SPEED_RECORD_NUM self.seconds_to_autoscale_worker = ( DefaultValues.SEC_TO_START_AUTOSCALE_WORKER diff --git a/dlrover/python/common/http_server.py b/dlrover/python/common/http_server.py index 39456668a..989972bac 100644 --- a/dlrover/python/common/http_server.py +++ b/dlrover/python/common/http_server.py @@ -23,10 +23,10 @@ class CustomHTTPServer(abc.ABC): """Self designed http server.""" - def __init__(self, address, port, handler_class): + def __init__(self, address, port, handler_classes): self._address = address self._port = port - self._handler_class = handler_class + self._handler_classes = handler_classes @property def address(self): @@ -37,8 +37,8 @@ def port(self): return self._port @property - def handler_class(self): - return self._handler_class + def handler_classes(self): + return self._handler_classes @abc.abstractmethod def start(self): @@ -83,7 +83,7 @@ def start(self): def _start_server(self): try: self._server = tornado.httpserver.HTTPServer( - tornado.web.Application([(r"/", self._handler_class)]) + tornado.web.Application(self._handler_classes) ) self._server.listen(self._port) self._io_loop = tornado.ioloop.IOLoop.current() @@ -94,7 +94,8 @@ def _start_server(self): def stop(self): if self._server: self._server.stop() - self._io_loop.add_callback(self._io_loop.stop) + if self._io_loop: + self._io_loop.add_callback(self._io_loop.stop) self._serving_started = False diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index e41480f0d..dad1b9f97 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -16,12 +16,14 @@ import socket import threading import time +from abc import ABC, abstractmethod from contextlib import closing from typing import Dict, Optional from dlrover.proto import elastic_training_pb2, elastic_training_pb2_grpc from dlrover.python.common import comm, env_utils from dlrover.python.common.constants import ( + CommunicationType, JobConstant, NetworkFailureReason, NodeEnv, @@ -34,9 +36,10 @@ NoAction, ) from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData +from dlrover.python.util.common_util import find_free_port -def retry_grpc_request(func): +def retry_request(func): def wrapper(self, *args, **kwargs): retry = kwargs.get("retry", 10) exception = None @@ -58,7 +61,7 @@ def wrapper(self, *args, **kwargs): return wrapper -class MasterClient(Singleton): +class MasterClient(Singleton, ABC): """MasterClient provides some APIs connect with the master service via gRPC call. Args: @@ -87,56 +90,26 @@ def __init__(self, master_addr, node_id, node_type, timeout=5): ) self._timeout = timeout self._master_addr = master_addr - self._channel = comm.build_grpc_channel(master_addr) - self._stub = elastic_training_pb2_grpc.MasterStub(self._channel) self._node_id = node_id self._node_type = node_type self._node_ip = os.getenv("NODE_IP", "") self._worker_local_process_id = int(os.getenv("LOCAL_RANK", 0)) - self._ddp_server_port = self.find_free_port() - + self._ddp_server_port = find_free_port() self._diagnosis_action_module = importlib.import_module( "dlrover.python.diagnosis.common.diagnosis_action" ) - def __del__(self): - if self._channel: - self._channel.close() - - def close_channel(self): - if self._channel: - self._channel.close() - - def open_channel(self): - self._channel = comm.build_grpc_channel(self._master_addr) - self._stub = elastic_training_pb2_grpc.MasterStub(self._channel) - - def find_free_port(self): - with closing( - socket.socket(socket.AF_INET, socket.SOCK_STREAM) - ) as sock: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("localhost", 0)) - _, port = sock.getsockname() - return port - - @retry_grpc_request + @retry_request + @abstractmethod def _report(self, message: comm.Message): - request = elastic_training_pb2.Message() - request.node_id = self._node_id - request.node_type = self._node_type - request.data = message.serialize() - return self._stub.report(request, timeout=self._timeout) + """Abstraction of report function.""" + pass - @retry_grpc_request + @retry_request + @abstractmethod def _get(self, message: comm.Message): - request = elastic_training_pb2.Message() - request.node_id = self._node_id - request.node_type = self._node_type - request.data = message.serialize() - response = self._stub.get(request, timeout=self._timeout) - res_message = comm.deserialize_message(response.data) - return res_message + """Abstraction of get function.""" + pass def kv_store_set(self, key, value): message = comm.KeyValuePair(key, value) @@ -502,8 +475,60 @@ def singleton_instance(cls, *args, **kwargs): return cls._instance +class GrpcMasterClient(MasterClient): + def __init__(self, master_addr, node_id, node_type, timeout=5): + super(GrpcMasterClient, self).__init__( + master_addr, node_id, node_type, timeout + ) + self._open_grpc_channel() + + def __del__(self): + self._close_grpc_channel() + + def _close_grpc_channel(self): + if self._channel: + self._channel.close() + + def _open_grpc_channel(self): + self._channel = comm.build_grpc_channel(self._master_addr) + self._stub = elastic_training_pb2_grpc.MasterStub(self._channel) + + @retry_request + def _report(self, message: comm.Message): + request = elastic_training_pb2.Message() + request.node_id = self._node_id + request.node_type = self._node_type + request.data = message.serialize() + return self._stub.report(request, timeout=self._timeout) + + @retry_request + def _get(self, message: comm.Message): + request = elastic_training_pb2.Message() + request.node_id = self._node_id + request.node_type = self._node_type + request.data = message.serialize() + response = self._stub.get(request, timeout=self._timeout) + res_message = comm.deserialize_message(response.data) + return res_message + + +class HttpMasterClient(MasterClient): + def __init__(self, master_addr, node_id, node_type, timeout=5): + super(HttpMasterClient, self).__init__( + master_addr, node_id, node_type, timeout + ) + + @retry_request + def _report(self, message: comm.Message): + pass + + @retry_request + def _get(self, message: comm.Message): + pass + + def build_master_client( - master_addr=None, timeout=JobConstant.MASTER_CLIENT_GRPC_DEFAULT_TIMEOUT + master_addr=None, timeout=JobConstant.MASTER_CLIENT_DEFAULT_TIMEOUT ): """ Build a master client. @@ -526,12 +551,24 @@ def build_master_client( logger.info(f"set master_client timeout to {_timeout}") master_client = None - logger.info(f"Build master client with addr {master_addr}.") + master_service_type = os.getenv( + NodeEnv.DLROVER_MASTER_SERVICE_TYPE, + CommunicationType.COMM_SERVICE_GRPC, + ) + logger.info( + f"Build {master_service_type} master client " + f"with addr {master_addr}." + ) if master_addr: try: - master_client = MasterClient( - master_addr, node_id, node_type, timeout - ) + if master_service_type == CommunicationType.COMM_SERVICE_GRPC: + master_client = GrpcMasterClient( + master_addr, node_id, node_type, timeout + ) + else: + master_client = HttpMasterClient( + master_addr, node_id, node_type, timeout + ) except Exception: logger.info("The master is not available now.") return master_client diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 8124f4d1e..7937bd9b4 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -97,12 +97,12 @@ from dlrover.python.elastic_agent.monitor.training import TorchTrainingMonitor from dlrover.python.elastic_agent.torch.ckpt_saver import AsyncCheckpointSaver from dlrover.python.elastic_agent.torch.master_kv_store import MasterKVStore -from dlrover.python.util.numa_util import get_gpu_affinity, get_npu_affinity from dlrover.python.util.common_util import ( find_free_port_for_hccl, find_free_port_in_range, find_free_port_in_set, ) +from dlrover.python.util.numa_util import get_gpu_affinity, get_npu_affinity from dlrover.trainer.torch.utils import ( version_less_than_230, version_less_than_240, diff --git a/dlrover/python/master/scaler/pod_scaler.py b/dlrover/python/master/scaler/pod_scaler.py index cf0a4149f..6c065e5a2 100644 --- a/dlrover/python/master/scaler/pod_scaler.py +++ b/dlrover/python/master/scaler/pod_scaler.py @@ -101,6 +101,7 @@ def __init__(self, job_name, namespace, error_monitor=None): self._job_uid = "" self.api_client = client.ApiClient() self._master_addr = "" + self._master_service_type = _dlrover_context.master_service_type self._error_monitor = error_monitor self._started = False @@ -500,6 +501,7 @@ def _create_pod(self, node: Node): # Deprecated env vars env.append(V1EnvVar(name=NodeEnv.WORKER_TYPE, value=node.type)) + env.append(V1EnvVar(name=NodeEnv.WORKER_ID, value=str(node.id))) env.append(V1EnvVar(name=NodeEnv.WORKER_NUM, value=str(worker_num))) env.append( @@ -508,6 +510,12 @@ def _create_pod(self, node: Node): env.append( V1EnvVar(name=NodeEnv.DLROVER_MASTER_ADDR, value=self._master_addr) ) + env.append( + V1EnvVar( + name=NodeEnv.DLROVER_MASTER_SERVICE_TYPE, + value=self._master_service_type, + ) + ) env.append( V1EnvVar( diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index 948e48425..af86d1c93 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -27,14 +27,14 @@ from dlrover.python.common.comm import BaseMessage from dlrover.python.common.constants import ( GRPC, - BasicClass, + CommunicationType, CustomMetricKeys, JobConstant, NodeEventType, NodeType, RendezvousName, TrainingExceptionLevel, - TrainingLoopStatus, CommunicationType, + TrainingLoopStatus, ) from dlrover.python.common.global_context import Context from dlrover.python.common.http_server import TornadoHTTPServer @@ -682,7 +682,7 @@ def _report_heartbeat( return comm.HeartbeatResponse(action=grpc_action) -class HttpMasterServicer(MasterServicer, tornado.web.RequestHandler): +class HttpMasterServicer(MasterServicer): """Master service with http implementation.""" def __init__( @@ -697,7 +697,7 @@ def __init__( sync_service=None, error_monitor=None, ): - super(HttpMasterServicer, self).__init__( + super().__init__( task_manager, job_manager, speed_monitor, @@ -712,17 +712,6 @@ def __init__( def get_response(self): return BaseMessage() - def post(self, path): - data = self.get_body_argument("data", default=None) - request: BaseMessage = json.loads(data) - if path == "get": - return self.get(request, None) - elif path == "report": - return self.report(request, None) - else: - self.set_status(404) - self.write(f"No service found for {path}.") - class GrpcMasterServicer( MasterServicer, elastic_training_pb2_grpc.MasterServicer @@ -757,6 +746,34 @@ def get_response(self): return elastic_training_pb2.Message() +class HttpMasterHandler(tornado.web.RequestHandler): + def initialize(self, master_servicer: HttpMasterServicer): + self._handler = master_servicer + + def get(self): + self.write("Not supported") + + def post(self): + response = BaseMessage() + + try: + path = self.request.path + request = BaseMessage(**json.loads(self.request.body)) + + if path == "/get": + response = self._handler.get(request, BaseMessage()) + elif path == "/report": + response = self._handler.report(request, BaseMessage()) + else: + self.set_status(404) + logger.error(f"No service found for {path}.") + except Exception as e: + logger.error(f"Unexpected error: {e}") + self.set_status(500) + finally: + self.write(response.serialize()) + + def create_master_service( port, task_manager, @@ -768,9 +785,9 @@ def create_master_service( elastic_ps_service, sync_service, error_monitor=None, - service_type=CommunicationType.COMM_SERVICE_GRPC, max_threads=64, ): + service_type = _dlrover_context.master_service_type logger.info(f"Creating master {service_type} service with port: {port}") if service_type == CommunicationType.COMM_SERVICE_GRPC: @@ -802,5 +819,36 @@ def create_master_service( server.add_insecure_port("[::]:{}".format(port)) return server else: - server = TornadoHTTPServer("localhost", port, HttpMasterServicer) + master_servicer = HttpMasterServicer( + task_manager=task_manager, + job_manager=job_manager, + speed_monitor=speed_monitor, + rdzv_managers=rdzv_managers, + diagnosis_manager=diagnosis_manager, + job_metric_collector=job_metric_collector, + elastic_ps_service=elastic_ps_service, + sync_service=sync_service, + error_monitor=error_monitor, + ) + server = TornadoHTTPServer( + "localhost", + port, + [ + ( + r"/", + HttpMasterHandler, + dict(master_servicer=master_servicer), + ), + ( + r"/get", + HttpMasterHandler, + dict(master_servicer=master_servicer), + ), + ( + r"/report", + HttpMasterHandler, + dict(master_servicer=master_servicer), + ), + ], + ) return server diff --git a/dlrover/python/tests/test_http_server.py b/dlrover/python/tests/test_http_server.py index f96277478..ea6ac063a 100644 --- a/dlrover/python/tests/test_http_server.py +++ b/dlrover/python/tests/test_http_server.py @@ -37,7 +37,9 @@ def tearDown(self): def test_tornado_server_basic(self): self.server = TornadoHTTPServer( - TEST_SERVER_ADDR, TEST_SERVER_PORT, TestRequestHandler + TEST_SERVER_ADDR, + TEST_SERVER_PORT, + [(r"/", TestRequestHandler), (r"/report", TestRequestHandler)], ) self.assertIsNotNone(self.server) self.assertFalse(is_port_in_use(TEST_SERVER_PORT)) @@ -55,8 +57,9 @@ def test_tornado_server_basic(self): ) time.sleep(1) - # test get request + # test get and post request self._test_get_request() + self._test_post_request() self.server.stop() self.assertFalse(self.server.is_serving()) @@ -70,9 +73,20 @@ def _test_get_request(self): except Exception as e: raise e + def _test_post_request(self): + try: + with requests.post("http://localhost:8000/report") as response: + self.assertEqual(response.status_code, 200) + self.assertEqual(response.text, "Hello, world!!") + return response + except Exception as e: + raise e + def test_server_concurrency(self): self.server = TornadoHTTPServer( - TEST_SERVER_ADDR, TEST_SERVER_PORT, TestRequestHandler + TEST_SERVER_ADDR, + TEST_SERVER_PORT, + [(r"/", TestRequestHandler), (r"/report", TestRequestHandler)], ) self.server.start() @@ -95,3 +109,6 @@ def test_server_concurrency(self): class TestRequestHandler(tornado.web.RequestHandler): def get(self): self.write("Hello, world!") + + def post(self): + self.write("Hello, world!!") diff --git a/dlrover/python/tests/test_master_client.py b/dlrover/python/tests/test_master_client.py index b549fb399..f654cd410 100644 --- a/dlrover/python/tests/test_master_client.py +++ b/dlrover/python/tests/test_master_client.py @@ -44,8 +44,8 @@ def tearDown(self): def test_open_channel(self): self.assertEqual(self._master_client._timeout, 1) self.assertEqual(self._master_client._timeout, 1) - self._master_client.close_channel() - self._master_client.open_channel() + self._master_client._close_grpc_channel() + self._master_client._open_grpc_channel() def test_report_used_resource(self): gpu_stats: List[comm.GPUStats] = [ diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index e79a8151d..8c7a35268 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -10,12 +10,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import time import unittest from unittest import mock -# import ray +import ray +import requests from dlrover.proto import elastic_training_pb2 from dlrover.python.common import comm, env_utils @@ -25,8 +27,9 @@ NodeStatus, NodeType, PSClusterVersionType, - RendezvousName, CommunicationType, + RendezvousName, ) +from dlrover.python.common.global_context import Context from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric from dlrover.python.master.diagnosis.diagnosis_manager import DiagnosisManager from dlrover.python.master.elastic_training.elastic_ps import ElasticPsService @@ -38,8 +41,10 @@ from dlrover.python.master.monitor.speed_monitor import SpeedMonitor from dlrover.python.master.node.dist_job_manager import create_job_manager from dlrover.python.master.node.job_context import get_job_context -from dlrover.python.master.servicer import GrpcMasterServicer, \ - create_master_service +from dlrover.python.master.servicer import ( + GrpcMasterServicer, + create_master_service, +) from dlrover.python.master.shard.task_manager import TaskManager from dlrover.python.master.stats.job_collector import JobMetricCollector from dlrover.python.tests.test_utils import ( @@ -50,30 +55,19 @@ from dlrover.python.util.queue.queue import RayEventQueue ray_event_queue = RayEventQueue.singleton_instance() +TEST_SERVER_PORT = 8000 class MasterServicerBasicTest(unittest.TestCase): def setUp(self) -> None: - self.grpc_servicer = create_master_service( - 8080, - None, - None, - None, - None, - None, - None, - None, - None, - None, - service_type=CommunicationType.COMM_SERVICE_GRPC, - ) + pass def tearDown(self) -> None: pass def test_http_start_and_stop(self): http_servicer = create_master_service( - 8081, + TEST_SERVER_PORT, None, None, None, @@ -83,7 +77,6 @@ def test_http_start_and_stop(self): None, None, None, - service_type=CommunicationType.COMM_SERVICE_HTTP, ) self.assertIsNotNone(http_servicer) self.assertFalse(http_servicer.is_serving()) @@ -96,7 +89,7 @@ def test_http_start_and_stop(self): def test_grpc_start_and_stop(self): grpc_servicer = create_master_service( - 8081, + TEST_SERVER_PORT, None, None, None, @@ -106,12 +99,43 @@ def test_grpc_start_and_stop(self): None, None, None, - service_type=CommunicationType.COMM_SERVICE_GRPC, ) self.assertIsNotNone(grpc_servicer) grpc_servicer.start() grpc_servicer.stop(grace=None) + def test_http_basic(self): + context = Context.singleton_instance() + context.master_service_type = "http" + http_servicer = create_master_service( + TEST_SERVER_PORT, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + http_servicer.start() + + response = requests.get("http://localhost:8000/") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.text, "Not supported") + + response = requests.post( + "http://localhost:8000/get", + json={"node_type": "worker", "node_id": "1", "data": "test"}, + ) + self.assertEqual(response.status_code, 200) + response_content = comm.deserialize_message(response.content) + self.assertIsNotNone(response_content) + self.assertEqual(response_content.node_type, "") + + http_servicer.stop() + class MasterServicerFunctionalTest(unittest.TestCase): def setUp(self) -> None: From a3feea737d17b83bdb1cf7995e425b7924661a3e Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Fri, 3 Jan 2025 17:51:13 +0800 Subject: [PATCH 09/20] done master http server/client ut --- dlrover/python/common/comm.py | 56 +++++++++++++++++-- dlrover/python/common/http_server.py | 2 +- dlrover/python/elastic_agent/master_client.py | 48 ++++++++++++++-- dlrover/python/master/local_master.py | 3 +- dlrover/python/master/servicer.py | 24 ++++---- dlrover/python/tests/test_master_client.py | 31 ++++++++++ dlrover/python/tests/test_servicer.py | 9 ++- 7 files changed, 148 insertions(+), 25 deletions(-) diff --git a/dlrover/python/common/comm.py b/dlrover/python/common/comm.py index 22d0e83b4..6265c5962 100644 --- a/dlrover/python/common/comm.py +++ b/dlrover/python/common/comm.py @@ -10,7 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import base64 import pickle import socket from dataclasses import dataclass, field @@ -74,11 +74,25 @@ def grpc_server_ready(channel) -> bool: return False -def deserialize_message(data: bytes): +def serialize_message(message): """The method will create a message instance with the content. Args: pickle_data: pickle bytes of a class instance. """ + data = None + if message: + try: + data = pickle.dumps(message) + except Exception as e: + logger.warning(f"Pickle failed to load {str(data)}", e) + return data + + +def deserialize_message(data: bytes): + """The method will create a message instance with the content. + Args: + data: pickle bytes of a class instance. + """ message = None if data: try: @@ -94,10 +108,44 @@ def serialize(self): @dataclass -class BaseMessage(Message): +class BaseRequest(Message): node_id: int = -1 node_type: str = "" - data: str = "" + data: bytes = b"" + + def to_json(self): + return { + "node_id": self.node_id, + "node_type": self.node_type, + "data": base64.b64encode(self.data).decode("utf-8"), + } + + @staticmethod + def from_json(data): + return BaseRequest( + node_id=data.get("node_id"), + node_type=data.get("node_type"), + data=base64.b64decode(data.get("data")), + ) + + +@dataclass +class BaseResponse(Message): + success: bool = False + data: bytes = b"" + + def to_json(self): + return { + "success": self.success, + "data": base64.b64encode(self.data).decode("utf-8"), + } + + @staticmethod + def from_json(data): + return BaseResponse( + success=bool(data.get("success")), + data=base64.b64decode(data.get("data")), + ) @dataclass diff --git a/dlrover/python/common/http_server.py b/dlrover/python/common/http_server.py index 989972bac..caa3eec04 100644 --- a/dlrover/python/common/http_server.py +++ b/dlrover/python/common/http_server.py @@ -91,7 +91,7 @@ def _start_server(self): except Exception as e: logger.error(f"Http server start with error: {e}") - def stop(self): + def stop(self, grace=None): if self._server: self._server.stop() if self._io_loop: diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index dad1b9f97..aaba0c9e5 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -13,15 +13,16 @@ import importlib import os -import socket import threading import time from abc import ABC, abstractmethod -from contextlib import closing from typing import Dict, Optional +import requests + from dlrover.proto import elastic_training_pb2, elastic_training_pb2_grpc from dlrover.python.common import comm, env_utils +from dlrover.python.common.comm import BaseRequest, BaseResponse from dlrover.python.common.constants import ( CommunicationType, JobConstant, @@ -343,7 +344,7 @@ def num_nodes_waiting(self, rdzv_name): return result.waiting_num except Exception: logger.warning("Fail to query the number of waiting nodes.") - return 0 + return -1 def join_rendezvous(self, node_rank, local_world_size, rdzv_name=""): request = comm.JoinRendezvousRequest( @@ -518,13 +519,50 @@ def __init__(self, master_addr, node_id, node_type, timeout=5): master_addr, node_id, node_type, timeout ) + def _get_http_request_url(self, path: str) -> str: + return "http://" + self._master_addr + path + @retry_request def _report(self, message: comm.Message): - pass + with requests.post( + self._get_http_request_url("/report"), + json=self._gen_request(message).to_json(), + ) as response: + if response.status_code != 200: + error_msg = ( + "Failed to report master " + f"with http request: {type(message)}." + ) + raise RuntimeError(error_msg) + response_data: BaseResponse = comm.deserialize_message( + response.content + ) + return response_data @retry_request def _get(self, message: comm.Message): - pass + with requests.post( + self._get_http_request_url("/get"), + json=self._gen_request(message).to_json(), + ) as response: + if response.status_code != 200: + error_msg = ( + "Failed to get from master " + f"with http request: {type(message)}." + ) + raise RuntimeError(error_msg) + response_data: BaseResponse = comm.deserialize_message( + response.content + ) + return comm.deserialize_message(response_data.data) + + def _gen_request(self, message: comm.Message): + request = BaseRequest() + request.node_id = self._node_id + request.node_type = self._node_type + request.data = message.serialize() + + return request def build_master_client( diff --git a/dlrover/python/master/local_master.py b/dlrover/python/master/local_master.py index d0d066278..f713078f9 100644 --- a/dlrover/python/master/local_master.py +++ b/dlrover/python/master/local_master.py @@ -113,8 +113,7 @@ def stop(self): """ logger.info("Stopping master!") logger.info("Stopping RPC server!") - self._master_server.stop(None) - # self._master_server.stop(grace=0.1) + self._master_server.stop() logger.info("RPC server stopped!") logger.info("Master stopped!") diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index af86d1c93..087030783 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -24,7 +24,7 @@ from dlrover.proto import elastic_training_pb2, elastic_training_pb2_grpc from dlrover.python.common import comm -from dlrover.python.common.comm import BaseMessage +from dlrover.python.common.comm import BaseRequest, BaseResponse from dlrover.python.common.constants import ( GRPC, CommunicationType, @@ -329,7 +329,7 @@ def report(self, request, _): node_id = request.node_id message = comm.deserialize_message(request.data) - response = elastic_training_pb2.Response() + response = self.get_response() if not message: return response @@ -710,7 +710,7 @@ def __init__( ) def get_response(self): - return BaseMessage() + return BaseResponse() class GrpcMasterServicer( @@ -754,24 +754,28 @@ def get(self): self.write("Not supported") def post(self): - response = BaseMessage() - try: path = self.request.path - request = BaseMessage(**json.loads(self.request.body)) + request = BaseRequest.from_json(json.loads(self.request.body)) if path == "/get": - response = self._handler.get(request, BaseMessage()) + # return message + response = self._handler.get(request, BaseRequest()) + if not response.data: + response.success = True + self.write(response.serialize()) elif path == "/report": - response = self._handler.report(request, BaseMessage()) + # return boolean + self.write( + self._handler.report(request, BaseRequest()).serialize() + ) else: self.set_status(404) logger.error(f"No service found for {path}.") except Exception as e: logger.error(f"Unexpected error: {e}") self.set_status(500) - finally: - self.write(response.serialize()) + self.write(f"{str(e)}") def create_master_service( diff --git a/dlrover/python/tests/test_master_client.py b/dlrover/python/tests/test_master_client.py index f654cd410..e1f931a07 100644 --- a/dlrover/python/tests/test_master_client.py +++ b/dlrover/python/tests/test_master_client.py @@ -12,6 +12,7 @@ # limitations under the License. import json +import os import time import unittest from typing import List @@ -20,11 +21,14 @@ from dlrover.python.common import comm from dlrover.python.common.comm import DiagnosisAction, HeartbeatResponse from dlrover.python.common.constants import ( + CommunicationType, + NodeEnv, NodeEventType, NodeType, RendezvousName, TrainingExceptionLevel, ) +from dlrover.python.common.global_context import Context from dlrover.python.diagnosis.common.diagnosis_action import ( EventAction, NoAction, @@ -185,3 +189,30 @@ def test_report_heartbeat(self): self._master_client._get = mock.MagicMock(return_value=response_dto) action = self._master_client.report_heart_beat(now) self.assertTrue(isinstance(action, EventAction)) + + +class MasterHttpClientTest(unittest.TestCase): + def setUp(self) -> None: + os.environ[ + NodeEnv.DLROVER_MASTER_SERVICE_TYPE + ] = CommunicationType.COMM_SERVICE_HTTP + context = Context.singleton_instance() + context.master_service_type = "http" + self._master, addr = start_local_master() + self._master_client = build_master_client(addr, 3) + + def tearDown(self): + self._master.stop() + context = Context.singleton_instance() + context.master_service_type = "grpc" + os.environ.clear() + + def test_http_client(self): + # get request + rdzv_name = RendezvousName.ELASTIC_TRAINING + num = self._master_client.num_nodes_waiting(rdzv_name) + self.assertEqual(num, 0) + + # report request + res = self._master_client.ready_for_ps_relaunch() + self.assertTrue(res.success) diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index 8c7a35268..e870e5128 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -21,7 +21,7 @@ from dlrover.proto import elastic_training_pb2 from dlrover.python.common import comm, env_utils -from dlrover.python.common.comm import GPUStats +from dlrover.python.common.comm import BaseRequest, GPUStats from dlrover.python.common.constants import ( NodeEventType, NodeStatus, @@ -125,9 +125,12 @@ def test_http_basic(self): self.assertEqual(response.status_code, 200) self.assertEqual(response.text, "Not supported") + request = BaseRequest() + request.node_id = 1 + request.node_type = "worker" + request.data = "test".encode() response = requests.post( - "http://localhost:8000/get", - json={"node_type": "worker", "node_id": "1", "data": "test"}, + "http://localhost:8000/get", json=request.to_json() ) self.assertEqual(response.status_code, 200) response_content = comm.deserialize_message(response.content) From 1e28aca615838eabb716fda5b0f0d99877f8f2d1 Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Tue, 7 Jan 2025 10:19:12 +0800 Subject: [PATCH 10/20] fix http/grpc response --- dlrover/python/master/dist_master.py | 8 ++++---- dlrover/python/master/local_master.py | 8 ++++---- dlrover/python/master/master.py | 9 +++++++++ dlrover/python/master/servicer.py | 19 +++++++++++-------- 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/dlrover/python/master/dist_master.py b/dlrover/python/master/dist_master.py index 8240030d0..c10ec853e 100644 --- a/dlrover/python/master/dist_master.py +++ b/dlrover/python/master/dist_master.py @@ -34,7 +34,7 @@ RendezvousManager, ) from dlrover.python.master.elastic_training.sync_service import SyncService -from dlrover.python.master.master import JobMaster +from dlrover.python.master.master import JobMaster, get_service_type from dlrover.python.master.monitor.error_monitor import ErrorMonitor from dlrover.python.master.monitor.speed_monitor import SpeedMonitor from dlrover.python.master.node.dist_job_manager import create_job_manager @@ -184,10 +184,10 @@ def _create_metric_collector_if_needed(self, params: JobArgs): return collector def prepare(self): - # Start the master GRPC server - logger.info("Starting master RPC server") + # start the master server + logger.info(f"Starting master {get_service_type()} server") self._master_server.start() - logger.info("Master RPC server started") + logger.info(f"Master {get_service_type()} server started") # Composite the components if self.task_manager and self.job_manager: diff --git a/dlrover/python/master/local_master.py b/dlrover/python/master/local_master.py index f713078f9..7ccf31bbf 100644 --- a/dlrover/python/master/local_master.py +++ b/dlrover/python/master/local_master.py @@ -26,7 +26,7 @@ NetworkCheckRendezvousManager, RendezvousManager, ) -from dlrover.python.master.master import JobMaster +from dlrover.python.master.master import JobMaster, get_service_type from dlrover.python.master.monitor.speed_monitor import SpeedMonitor from dlrover.python.master.node.local_job_manager import create_job_manager from dlrover.python.master.servicer import create_master_service @@ -79,10 +79,10 @@ def _create_metric_collector_if_needed(self, params: JobArgs): return collector def prepare(self): - # Start the master GRPC server - logger.info("Starting master RPC server") + # start the master server + logger.info(f"Starting master {get_service_type()} server") self._master_server.start() - logger.info("Master RPC server started") + logger.info(f"Master {get_service_type()} server started") self.task_manager.start() self.job_manager.start() diff --git a/dlrover/python/master/master.py b/dlrover/python/master/master.py index 4bbc6f81f..559382da1 100644 --- a/dlrover/python/master/master.py +++ b/dlrover/python/master/master.py @@ -13,6 +13,15 @@ from abc import ABCMeta, abstractmethod +from dlrover.python.common.global_context import Context + + +_dlrover_context = Context.singleton_instance() + + +def get_service_type(): + return _dlrover_context.master_service_type + class JobMaster(metaclass=ABCMeta): @abstractmethod diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index 087030783..48a82148b 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -109,7 +109,7 @@ def __init__( self._kv_store.clear() @abstractmethod - def get_response(self): + def get_response(self, method): """Should be implemented by subclasses.""" pass @@ -118,7 +118,7 @@ def get(self, request, _): node_id = request.node_id req_message = comm.deserialize_message(request.data) - response = self.get_response() + response = self.get_response("get") if not req_message: return response message = None @@ -329,7 +329,7 @@ def report(self, request, _): node_id = request.node_id message = comm.deserialize_message(request.data) - response = self.get_response() + response = self.get_response("report") if not message: return response @@ -709,7 +709,7 @@ def __init__( error_monitor, ) - def get_response(self): + def get_response(self, method): return BaseResponse() @@ -742,8 +742,11 @@ def __init__( error_monitor, ) - def get_response(self): - return elastic_training_pb2.Message() + def get_response(self, method): + if method == "report": + return elastic_training_pb2.Response() + else: + return elastic_training_pb2.Message() class HttpMasterHandler(tornado.web.RequestHandler): @@ -772,6 +775,7 @@ def post(self): else: self.set_status(404) logger.error(f"No service found for {path}.") + self.write("") except Exception as e: logger.error(f"Unexpected error: {e}") self.set_status(500) @@ -821,7 +825,6 @@ def create_master_service( master_servicer, server ) server.add_insecure_port("[::]:{}".format(port)) - return server else: master_servicer = HttpMasterServicer( task_manager=task_manager, @@ -855,4 +858,4 @@ def create_master_service( ), ], ) - return server + return server From 72745068a07e26dbb788387b1c7da2e5c1ec456d Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Tue, 7 Jan 2025 10:57:33 +0800 Subject: [PATCH 11/20] add args params --- dlrover/python/master/args.py | 7 +++++++ dlrover/python/tests/test_args.py | 11 +++++++++++ 2 files changed, 18 insertions(+) diff --git a/dlrover/python/master/args.py b/dlrover/python/master/args.py index 177c91f71..80073f798 100644 --- a/dlrover/python/master/args.py +++ b/dlrover/python/master/args.py @@ -31,6 +31,13 @@ def add_params(parser): type=str, help="The name of platform which can be pyk8s, k8s, ray or local.", ) + parser.add_argument( + "--service_type", + "--service-type", + default="grpc", + type=str, + help="The service type of master: grpc/http.", + ) def print_args(args, exclude_args=[], groups=None): diff --git a/dlrover/python/tests/test_args.py b/dlrover/python/tests/test_args.py index d6b2d13b4..08463704c 100644 --- a/dlrover/python/tests/test_args.py +++ b/dlrover/python/tests/test_args.py @@ -27,3 +27,14 @@ def test_parse_master_args(self): parsed_args = parse_master_args(original_args) self.assertEqual(parsed_args.job_name, "test") self.assertTrue(parsed_args.namespace, "default") + self.assertTrue(parsed_args.service_type, "grpc") + + original_args = [ + "--job_name", + "test", + "--namespace", + "default", + "--service_type", + "http" + ] + self.assertTrue(parsed_args.service_type, "http") From 6c0f213e2c440e09ec4af0d6f45e46ceb577fd82 Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Wed, 8 Jan 2025 16:25:19 +0800 Subject: [PATCH 12/20] lint --- dlrover/python/common/log.py | 7 +++++++ dlrover/python/elastic_agent/master_client.py | 6 ++---- dlrover/python/master/master.py | 1 - dlrover/python/tests/test_args.py | 2 +- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/dlrover/python/common/log.py b/dlrover/python/common/log.py index 0a90a420b..7415bbf8f 100644 --- a/dlrover/python/common/log.py +++ b/dlrover/python/common/log.py @@ -45,6 +45,8 @@ def get_log_level(): def get_logger(name, handlers=None, update=False): + __setup_extra_logger() + if name in _LOGGER_CACHE and not update: return _LOGGER_CACHE[name] logger = logging.getLogger(name) @@ -54,4 +56,9 @@ def get_logger(name, handlers=None, update=False): return logger +def __setup_extra_logger(): + # tornado logger + logging.getLogger("tornado.access").setLevel(logging.WARNING) + + default_logger = get_logger(_DEFAULT_LOGGER) diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index aaba0c9e5..7979dc2cb 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -593,10 +593,8 @@ def build_master_client( NodeEnv.DLROVER_MASTER_SERVICE_TYPE, CommunicationType.COMM_SERVICE_GRPC, ) - logger.info( - f"Build {master_service_type} master client " - f"with addr {master_addr}." - ) + logger.info(f"Use [{master_service_type}] type for master client.") + if master_addr: try: if master_service_type == CommunicationType.COMM_SERVICE_GRPC: diff --git a/dlrover/python/master/master.py b/dlrover/python/master/master.py index 559382da1..13b9aa34d 100644 --- a/dlrover/python/master/master.py +++ b/dlrover/python/master/master.py @@ -15,7 +15,6 @@ from dlrover.python.common.global_context import Context - _dlrover_context = Context.singleton_instance() diff --git a/dlrover/python/tests/test_args.py b/dlrover/python/tests/test_args.py index 08463704c..6b908eeb6 100644 --- a/dlrover/python/tests/test_args.py +++ b/dlrover/python/tests/test_args.py @@ -35,6 +35,6 @@ def test_parse_master_args(self): "--namespace", "default", "--service_type", - "http" + "http", ] self.assertTrue(parsed_args.service_type, "http") From a6862bcc01e2437748e3494f09e5348c491db756 Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Thu, 9 Jan 2025 16:57:43 +0800 Subject: [PATCH 13/20] merged --- dlrover/python/tests/test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dlrover/python/tests/test_utils.py b/dlrover/python/tests/test_utils.py index de27d016f..6beb12d78 100644 --- a/dlrover/python/tests/test_utils.py +++ b/dlrover/python/tests/test_utils.py @@ -21,6 +21,7 @@ import dlrover.python.util.k8s_util as ku from dlrover.proto import elastic_training_pb2 +from dlrover.python.common.comm import addr_connected from dlrover.python.common.constants import ( DistributionStrategy, ElasticJobLabel, From ca1973e1f22a1422e0856e89b3a1d1e69434f320 Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Thu, 9 Jan 2025 17:07:29 +0800 Subject: [PATCH 14/20] add deps --- scripts/ci_install.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/ci_install.sh b/scripts/ci_install.sh index 7b94fbab3..fd0fb7b0f 100644 --- a/scripts/ci_install.sh +++ b/scripts/ci_install.sh @@ -19,6 +19,7 @@ pip install -q kubernetes pip install -q grpcio-tools pip install -q psutil pip install -q deprecated +pip install -q tornado if [ "$1" = "basic" ]; then echo "" From 6778f2a8dc147ec46f6641e9d3beb3cf33719c5b Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Fri, 10 Jan 2025 15:54:10 +0800 Subject: [PATCH 15/20] ut fix --- dlrover/python/master/local_master.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlrover/python/master/local_master.py b/dlrover/python/master/local_master.py index 7ccf31bbf..aa03ed865 100644 --- a/dlrover/python/master/local_master.py +++ b/dlrover/python/master/local_master.py @@ -113,7 +113,7 @@ def stop(self): """ logger.info("Stopping master!") logger.info("Stopping RPC server!") - self._master_server.stop() + self._master_server.stop(grace=None) logger.info("RPC server stopped!") logger.info("Master stopped!") From a95ec249bf01b28281d04329f9a5570afaa998fc Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Fri, 10 Jan 2025 16:26:56 +0800 Subject: [PATCH 16/20] ut fix --- dlrover/python/tests/test_servicer.py | 35 ++++++++++++++------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index e870e5128..3abfb74be 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -60,13 +60,14 @@ class MasterServicerBasicTest(unittest.TestCase): def setUp(self) -> None: - pass + self.server = None def tearDown(self) -> None: - pass + if self.server: + self.server.stop(grace=None) def test_http_start_and_stop(self): - http_servicer = create_master_service( + self.server = create_master_service( TEST_SERVER_PORT, None, None, @@ -78,17 +79,17 @@ def test_http_start_and_stop(self): None, None, ) - self.assertIsNotNone(http_servicer) - self.assertFalse(http_servicer.is_serving()) + self.assertIsNotNone(self.server) + self.assertFalse(self.server.is_serving()) - http_servicer.start() - self.assertTrue(http_servicer.is_serving()) + self.server.start() + self.assertTrue(self.server.is_serving()) - http_servicer.stop() - self.assertFalse(http_servicer.is_serving()) + self.server.stop() + self.assertFalse(self.server.is_serving()) def test_grpc_start_and_stop(self): - grpc_servicer = create_master_service( + self.server = create_master_service( TEST_SERVER_PORT, None, None, @@ -100,14 +101,14 @@ def test_grpc_start_and_stop(self): None, None, ) - self.assertIsNotNone(grpc_servicer) - grpc_servicer.start() - grpc_servicer.stop(grace=None) + self.assertIsNotNone(self.server) + self.server.start() + self.server.stop(grace=None) def test_http_basic(self): context = Context.singleton_instance() context.master_service_type = "http" - http_servicer = create_master_service( + self.server = create_master_service( TEST_SERVER_PORT, None, None, @@ -119,7 +120,7 @@ def test_http_basic(self): None, None, ) - http_servicer.start() + self.server.start() response = requests.get("http://localhost:8000/") self.assertEqual(response.status_code, 200) @@ -135,9 +136,9 @@ def test_http_basic(self): self.assertEqual(response.status_code, 200) response_content = comm.deserialize_message(response.content) self.assertIsNotNone(response_content) - self.assertEqual(response_content.node_type, "") + self.assertTrue(response_content.success) - http_servicer.stop() + self.server.stop() class MasterServicerFunctionalTest(unittest.TestCase): From 59161e6e4d3916ffc0a3e7b6d58ee55216a738e3 Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Fri, 10 Jan 2025 17:34:15 +0800 Subject: [PATCH 17/20] ut fix --- dlrover/python/common/comm.py | 9 ++++++ dlrover/python/master/servicer.py | 30 +++++++++++++++++--- dlrover/python/tests/test_sharding_client.py | 2 +- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/dlrover/python/common/comm.py b/dlrover/python/common/comm.py index 6265c5962..47e0083e4 100644 --- a/dlrover/python/common/comm.py +++ b/dlrover/python/common/comm.py @@ -511,3 +511,12 @@ class DiagnosisAction(Message): @dataclass class HeartbeatResponse(Message): action: DiagnosisAction = field(default_factory=DiagnosisAction) + + +class TaskType(object): + NONE = "NONE" + TRAINING = "TRAINING" + EVALUATION = "EVALUATION" + PREDICTION = "PREDICTION" + WAIT = "WAIT" + TRAIN_END_CALLBACK = "TRAIN_END_CALLBACK" diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index 9821a71e6..3e1d16484 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -24,7 +24,7 @@ from dlrover.proto import elastic_training_pb2, elastic_training_pb2_grpc from dlrover.python.common import comm -from dlrover.python.common.comm import BaseRequest, BaseResponse +from dlrover.python.common.comm import BaseRequest, BaseResponse, TaskType from dlrover.python.common.constants import ( GRPC, CommunicationType, @@ -114,6 +114,11 @@ def get_response(self, method): """Should be implemented by subclasses.""" pass + @abstractmethod + def get_task_type(self, task_type): + """Should be implemented by subclasses.""" + pass + def get(self, request, _): node_type = request.node_type node_id = request.node_id @@ -183,7 +188,7 @@ def _get_task(self, node_type, node_id, request: comm.TaskRequest): if task.shard.record_indices: res.shard.indices = task.shard.record_indices elif not dataset.completed(): - res.type = elastic_training_pb2.WAIT + res.type = self.get_task_type(TaskType.WAIT) with self._lock: self._task_manager.reset_worker_start_task_time(node_id) return res @@ -412,7 +417,7 @@ def _collect_dataset_shard_params(self, metrics: comm.DatasetShardParams): metrics.dataset_size, metrics.storage_type, ) - if metrics.task_type == elastic_training_pb2.TRAINING: + if metrics.task_type == self.get_task_type(TaskType.TRAINING): self._job_metric_collector.collect_training_hyper_params( metrics.num_epochs, metrics.batch_size ) @@ -482,7 +487,7 @@ def _report_task_result(self, request: comm.TaskResult): if ( self._job_metric_collector and task - and task.task_type == elastic_training_pb2.PREDICTION + and task.task_type == self.get_task_type(TaskType.PREDICTION) ): self._collect_runtime_stats() self._check_start_auto_scale_worker() @@ -713,6 +718,9 @@ def __init__( def get_response(self, method): return BaseResponse() + def get_task_type(self, task_type): + return task_type + class GrpcMasterServicer( MasterServicer, elastic_training_pb2_grpc.MasterServicer @@ -749,6 +757,20 @@ def get_response(self, method): else: return elastic_training_pb2.Message() + def get_task_type(self, task_type): + if task_type == TaskType.WAIT: + return elastic_training_pb2.WAIT + elif task_type == TaskType.TRAINING: + return elastic_training_pb2.TRAINING + elif task_type == TaskType.EVALUATION: + return elastic_training_pb2.EVALUATION + elif task_type == TaskType.PREDICTION: + return elastic_training_pb2.PREDICTION + elif task_type == TaskType.TRAIN_END_CALLBACK: + return elastic_training_pb2.TRAIN_END_CALLBACK + else: + return elastic_training_pb2.NONE + class HttpMasterHandler(tornado.web.RequestHandler): def initialize(self, master_servicer: HttpMasterServicer): diff --git a/dlrover/python/tests/test_sharding_client.py b/dlrover/python/tests/test_sharding_client.py index 6cad0dbb9..09d8777aa 100644 --- a/dlrover/python/tests/test_sharding_client.py +++ b/dlrover/python/tests/test_sharding_client.py @@ -28,7 +28,7 @@ class DataShardClientTest(unittest.TestCase): def setUp(self) -> None: self._master, addr = start_local_master() - MasterClient._instance = build_master_client(addr, 0.5) + MasterClient._instance = build_master_client(addr, 1) def tearDown(self): self._master.stop() From b343935c536350c3081d71d9939353255ff02c95 Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Fri, 10 Jan 2025 17:50:05 +0800 Subject: [PATCH 18/20] ut fix --- dlrover/python/tests/test_servicer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index 3abfb74be..5428ee392 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -63,10 +63,14 @@ def setUp(self) -> None: self.server = None def tearDown(self) -> None: + context = Context.singleton_instance() + context.master_service_type = "grpc" if self.server: self.server.stop(grace=None) def test_http_start_and_stop(self): + context = Context.singleton_instance() + context.master_service_type = "http" self.server = create_master_service( TEST_SERVER_PORT, None, From 711fd9ba5b76eefca5ebeb99fd9d05eeffb28278 Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Mon, 13 Jan 2025 10:53:35 +0800 Subject: [PATCH 19/20] optimized http server --- dlrover/python/common/http_server.py | 13 +++++- dlrover/python/common/test.py | 67 ---------------------------- 2 files changed, 11 insertions(+), 69 deletions(-) delete mode 100644 dlrover/python/common/test.py diff --git a/dlrover/python/common/http_server.py b/dlrover/python/common/http_server.py index caa3eec04..60dfccbdf 100644 --- a/dlrover/python/common/http_server.py +++ b/dlrover/python/common/http_server.py @@ -12,6 +12,7 @@ # limitations under the License. import abc +import asyncio import threading import time @@ -20,6 +21,14 @@ from dlrover.python.common.log import default_logger as logger +def is_asyncio_loop_running(): + try: + asyncio.get_running_loop() + return True + except RuntimeError: + return False + + class CustomHTTPServer(abc.ABC): """Self designed http server.""" @@ -77,8 +86,8 @@ def start(self): ) server_thread.start() - # wait 3s for sever start - time.sleep(3) + while not self._io_loop or is_asyncio_loop_running(): + time.sleep(0.1) def _start_server(self): try: diff --git a/dlrover/python/common/test.py b/dlrover/python/common/test.py deleted file mode 100644 index 294398b58..000000000 --- a/dlrover/python/common/test.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2024 The DLRover Authors. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import signal -import threading -import time - -import tornado.httpserver -import tornado.ioloop -import tornado.web - - -class MainHandler(tornado.web.RequestHandler): - def get(self): - self.write("Hello, world") - - -def make_app(): - return tornado.web.Application( - [ - (r"/", MainHandler), - ] - ) - - -def start_tornado_server(): - app = make_app() - server = tornado.httpserver.HTTPServer(app) - server.listen(8000) - tornado.ioloop.IOLoop.current().start() - - -def stop_tornado_server(): - tornado.ioloop.IOLoop.current().stop() - - -if __name__ == "__main__": - # 启动 Tornado 服务器的后台线程 - server_thread = threading.Thread(target=start_tornado_server) - server_thread.start() - - # 处理系统信号以优雅地关闭服务器 - def signal_handler(signum, frame): - print("Stopping Tornado server") - stop_tornado_server() - server_thread.join() - print("Tornado server stopped") - - signal.signal(signal.SIGINT, signal_handler) - - # 主线程继续做其他事情 - try: - while True: - print("Main thread is doing other things") - time.sleep(1) - except KeyboardInterrupt: - signal_handler(signal.SIGINT, None) From 10e994ca417d07542fca345e880b7ef7399c79ba Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Tue, 14 Jan 2025 11:38:11 +0800 Subject: [PATCH 20/20] add params support --- docs/deployment/argument.md | 15 ++++++++------- go/elasticjob/pkg/controllers/master/master.go | 3 ++- .../pkg/controllers/master/master_test.go | 3 ++- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/docs/deployment/argument.md b/docs/deployment/argument.md index 18d319df0..e728fdbae 100644 --- a/docs/deployment/argument.md +++ b/docs/deployment/argument.md @@ -6,13 +6,14 @@ when training with DLRover. ## 1. DLRover Master Arguments * For master(dlrover.python.master.main) initiation. User can use annotations(except: job_name, namespace) to express the following arguments. -| name | description | mandatory | format | default | options | -|-----------|-----------------------------------------------------------------|----|----------------------|-----|-----------------------------------------------------------------------------------| -| job_name |
The name of the job defined by user. | Yes | string | n/a |
n/a | -| namespace | The name of the Kubernetes namespace where ElasticJob pods will be created. | No | string | default | n/a | -| platform | The name of platform. | No | string | pyk8s | pyk8s, k8s, ray or local | -| pending_timeout | The timeout value of pending. | No | integer(unit: second) | 900 | \>=0 | -| pending_fail_strategy | The fail strategy for pending case. | No | integer | 1 | -1: disabled
0: skip
1: verify necessary parts
2: verify all parts | +| name | description | mandatory | format | default | options | +|-----------------------|-----------------------------------------------------------------------------|----|----------------------|---------|---------------------------------------------------------------------------------| +| job_name |
The name of the job defined by user. | Yes | string | n/a |
n/a | +| namespace | The name of the Kubernetes namespace where ElasticJob pods will be created. | No | string | default | n/a | +| platform | The name of platform. | No | string | pyk8s | pyk8s, k8s, ray or local | +| pending_timeout | The timeout value of pending. | No | integer(unit: second) | 900 | \>=0 | +| pending_fail_strategy | The fail strategy for pending case. | No | integer | 1 | -1: disabled
0: skip
1: verify necessary parts
2: verify all parts | +| service_type | The type of master service. | No | string | grpc | grpc,http | ## 2. Training Arguments diff --git a/go/elasticjob/pkg/controllers/master/master.go b/go/elasticjob/pkg/controllers/master/master.go index 07f5855b0..4828da3d5 100644 --- a/go/elasticjob/pkg/controllers/master/master.go +++ b/go/elasticjob/pkg/controllers/master/master.go @@ -48,6 +48,7 @@ const ( // supported arguments(should be supported in 'dlrover.python.master.args') pendingTimeoutArg = "pending_timeout" pendingFailStrategyArg = "pending_fail_strategy" + serviceType = "service_type" ) // Manager generates a master pod object. @@ -238,7 +239,7 @@ func (m *Manager) StopRunningPods( } func getMasterArguments() []string { - return []string{pendingTimeoutArg, pendingFailStrategyArg} + return []string{pendingTimeoutArg, pendingFailStrategyArg, serviceType} } // NewMasterTemplateToJob sets configurations to the master template of a job. diff --git a/go/elasticjob/pkg/controllers/master/master_test.go b/go/elasticjob/pkg/controllers/master/master_test.go index 360df0ade..c8af64ea2 100644 --- a/go/elasticjob/pkg/controllers/master/master_test.go +++ b/go/elasticjob/pkg/controllers/master/master_test.go @@ -29,7 +29,7 @@ func TestCreateMasterPod(t *testing.T) { ObjectMeta: metav1.ObjectMeta{ Name: "test-ps", Namespace: "dlrover", - Annotations: map[string]string{"pending_timeout": "300"}, + Annotations: map[string]string{"pending_timeout": "300", "service_type": "http"}, Labels: map[string]string{}, }, } @@ -44,6 +44,7 @@ func TestCreateMasterPod(t *testing.T) { assert.True(t, strings.Contains(pod.Spec.Containers[0].Command[2], "--job_name test-ps")) assert.True(t, strings.Contains(pod.Spec.Containers[0].Command[2], "--port 50001")) assert.True(t, strings.Contains(pod.Spec.Containers[0].Command[2], "--pending_timeout 300")) + assert.True(t, strings.Contains(pod.Spec.Containers[0].Command[2], "--service_type http")) } func TestCreateMasterPodWithImage(t *testing.T) {