Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support http communication implement for DLRover Master and Agent. #1429

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ runs:
args:
- "/bin/bash"
- "-c"
- " python -m grpc_tools.protoc -I. \
- "sh scripts/ci_install.sh basic && python -m grpc_tools.protoc -I. \
dlrover/proto/*.proto --python_out=. --grpc_python_out=. \
&& export PYTHONPATH=`pwd` \
&& cd examples/tensorflow/criteo_deeprec\
Expand Down
2 changes: 1 addition & 1 deletion .github/actions/dlrover-system-test-deepfm/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ runs:
args:
- "/bin/bash"
- "-c"
- " python -m grpc_tools.protoc -I. \
- "sh scripts/ci_install.sh basic && python -m grpc_tools.protoc -I. \
dlrover/proto/*.proto --python_out=. --grpc_python_out=. \
&& pip install deepctr deprecated\
&& export PYTHONPATH=`pwd` \
Expand Down
3 changes: 1 addition & 2 deletions .github/actions/dlrover-system-test-tf2/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ runs:
args:
- "/bin/bash"
- "-c"
- "pip install protobuf==3.20 kubernetes grpcio-tools psutil deprecated\
&& python -m grpc_tools.protoc -I. \
- "sh scripts/ci_install.sh basic && python -m grpc_tools.protoc -I. \
dlrover/proto/*.proto --python_out=. --grpc_python_out=. \
&& pip install deepctr \
&& pip install h5py==3.7.0 \
Expand Down
4 changes: 2 additions & 2 deletions dlrover/python/brain/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os

from dlrover.proto import brain_pb2, brain_pb2_grpc
from dlrover.python.common.grpc import build_channel, grpc_server_ready
from dlrover.python.common.comm import build_grpc_channel, grpc_server_ready
from dlrover.python.common.log import default_logger as logger

DATA_STORE = "base_datastore"
Expand Down Expand Up @@ -268,7 +268,7 @@ def build_brain_client():
```
"""
brain_addr = os.getenv(_ENV_BRAIN_ADDR_KEY, "")
channel = build_channel(brain_addr)
channel = build_grpc_channel(brain_addr)
if channel and grpc_server_ready(channel):
return BrainClient(channel)
else:
Expand Down
142 changes: 68 additions & 74 deletions dlrover/python/common/grpc.py → dlrover/python/common/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import pickle
import random
import socket
from contextlib import closing
from dataclasses import dataclass, field
from typing import Dict, List

import grpc

from dlrover.python.common.constants import GRPC, AscendConstants
from dlrover.python.common.constants import GRPC
from dlrover.python.common.log import default_logger as logger
from dlrover.python.common.serialize import JsonSerializable

TIMEOUT_SEC = 5


def build_channel(addr):
def build_grpc_channel(addr):
if not addr_connected(addr):
return None
channel = grpc.insecure_channel(
Expand Down Expand Up @@ -68,74 +66,6 @@
return False


def find_free_port(port=0):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", port))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]


def find_free_port_in_range(start=0, end=65535, random_port=True):
"""Find a free port from a range."""
bind_ports = set()
while True:
if random_port:
port = random.randint(start, end)
else:
port = start + len(bind_ports)
if port in bind_ports:
continue
try:
return find_free_port(port)
except OSError:
logger.warning(f"Socket creation attempt failed with {port}.")
bind_ports.add(port)
if len(bind_ports) == end - start + 1:
break
raise RuntimeError(f"Fail to find a free port in [{start}, {end})")


def find_free_port_in_set(ports):
for port in ports:
try:
return find_free_port(port)
except OSError:
logger.warning(f"Socket creation attempt failed with {port}.")
raise RuntimeError(f"Fail to find a free port in {ports}")


def find_free_port_for_hccl(
start=AscendConstants.HCCL_PORT_START_DEFAULT,
) -> int:
max_port = 65500
cur_start = start
end = start + 10000
if end > max_port:
end = max_port
logger.info(f"Try to find available port for hccl from {start}")
checking_port = 0
while True:
try:
cur_end = cur_start + AscendConstants.NPU_PER_NODE
for port in range(cur_start, cur_end):
checking_port = port
find_free_port(port)
logger.info(f"Find available port start from: {cur_start}")
break
except OSError:
logger.warning(
f"Target port has already been used: {checking_port}."
)
if checking_port > 0:
cur_start = checking_port + 1
else:
cur_start = cur_start + AscendConstants.NPU_PER_NODE
if cur_start > end:
cur_start = 0
break
return cur_start


def grpc_server_ready(channel) -> bool:
try:
grpc.channel_ready_future(channel).result(timeout=TIMEOUT_SEC)
Expand All @@ -144,11 +74,25 @@
return False


def deserialize_message(data: bytes):
def serialize_message(message):
"""The method will create a message instance with the content.
Args:
pickle_data: pickle bytes of a class instance.
"""
data = None
if message:
try:
data = pickle.dumps(message)
except Exception as e:
logger.warning(f"Pickle failed to load {str(data)}", e)
return data

Check warning on line 88 in dlrover/python/common/comm.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/common/comm.py#L82-L88

Added lines #L82 - L88 were not covered by tests


def deserialize_message(data: bytes):
"""The method will create a message instance with the content.
Args:
data: pickle bytes of a class instance.
"""
message = None
if data:
try:
Expand All @@ -163,6 +107,47 @@
return pickle.dumps(self)


@dataclass
class BaseRequest(Message):
node_id: int = -1
node_type: str = ""
data: bytes = b""

def to_json(self):
return {
"node_id": self.node_id,
"node_type": self.node_type,
"data": base64.b64encode(self.data).decode("utf-8"),
}

@staticmethod
def from_json(data):
return BaseRequest(
node_id=data.get("node_id"),
node_type=data.get("node_type"),
data=base64.b64decode(data.get("data")),
)


@dataclass
class BaseResponse(Message):
success: bool = False
data: bytes = b""

def to_json(self):
return {

Check warning on line 138 in dlrover/python/common/comm.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/common/comm.py#L138

Added line #L138 was not covered by tests
"success": self.success,
"data": base64.b64encode(self.data).decode("utf-8"),
}

@staticmethod
def from_json(data):
return BaseResponse(

Check warning on line 145 in dlrover/python/common/comm.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/common/comm.py#L145

Added line #L145 was not covered by tests
success=bool(data.get("success")),
data=base64.b64decode(data.get("data")),
)


@dataclass
class TaskRequest(Message):
dataset_name: str = ""
Expand Down Expand Up @@ -526,3 +511,12 @@
@dataclass
class HeartbeatResponse(Message):
action: DiagnosisAction = field(default_factory=DiagnosisAction)


class TaskType(object):
NONE = "NONE"
TRAINING = "TRAINING"
EVALUATION = "EVALUATION"
PREDICTION = "PREDICTION"
WAIT = "WAIT"
TRAIN_END_CALLBACK = "TRAIN_END_CALLBACK"
9 changes: 9 additions & 0 deletions dlrover/python/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class PlatformType(object):
LOCAL = "local"


class CommunicationType(object):
COMM_SERVICE_GRPC = "grpc"
COMM_SERVICE_HTTP = "http"


class ElasticJobApi(object):
GROUP = "elastic.iml.github.io"
VERION = "v1alpha1"
Expand Down Expand Up @@ -248,6 +253,7 @@ class TrainingLoopStatus(object):
class NodeEnv(object):
RELAUNCHED_POD = "RELAUNCHED_POD"
DLROVER_MASTER_ADDR = "DLROVER_MASTER_ADDR"
DLROVER_MASTER_SERVICE_TYPE = "DLROVER_MASTER_SERVICE_TYPE"
GRPC_ENABLE_FORK = "GRPC_ENABLE_FORK_SUPPORT"
GRPC_POLL_STRATEGY = "GRPC_POLL_STRATEGY"
POD_NAME = "POD_NAME"
Expand Down Expand Up @@ -359,6 +365,9 @@ class JobConstant(object):
INSUFFICIENT_NODE_TIMEOUT_DEFAULT_MAX = 3600
PENDING_NODE_TIMEOUT_DEFAULT_MIN = 600

# timeout 60s
MASTER_CLIENT_DEFAULT_TIMEOUT = 60

# grpc timeout 60s
MASTER_CLIENT_GRPC_DEFAULT_TIMEOUT = 60

Expand Down
13 changes: 9 additions & 4 deletions dlrover/python/common/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@

import os

from dlrover.python.common import grpc
from dlrover.python.common.constants import UserEnv
from dlrover.python.common.constants import CommunicationType, UserEnv
from dlrover.python.common.log import default_logger as logger
from dlrover.python.common.singleton import Singleton
from dlrover.python.util.common_util import (
find_free_port_in_range,
find_free_port_in_set,
)


class ConfigKeys(object):
Expand All @@ -38,6 +41,7 @@ class ConfigKeys(object):


class DefaultValues(object):
SERVICE_TYPE = CommunicationType.COMM_SERVICE_GRPC
TRAIN_SPEED_RECORD_NUM = 50
SEC_TO_START_AUTOSCALE_WORKER = 90
STEP_TO_ADJUST_WORKER = 200
Expand All @@ -61,6 +65,7 @@ class DefaultValues(object):

class Context(Singleton):
def __init__(self):
self.master_service_type = DefaultValues.SERVICE_TYPE
self.train_speed_record_num = DefaultValues.TRAIN_SPEED_RECORD_NUM
self.seconds_to_autoscale_worker = (
DefaultValues.SEC_TO_START_AUTOSCALE_WORKER
Expand Down Expand Up @@ -173,13 +178,13 @@ def config_master_port(self, port=0):
for port in host_ports_env.split(","):
ports.append(int(port))
try:
self.master_port = grpc.find_free_port_in_set(ports)
self.master_port = find_free_port_in_set(ports)
except RuntimeError as e:
logger.warning(e)
elif port > 0:
self.master_port = port
if self.master_port is None:
self.master_port = grpc.find_free_port_in_range(20000, 30000)
self.master_port = find_free_port_in_range(20000, 30000)

def get_param_value_from_brain(self, key_name, default_value, dtype=float):
"""TODO: Get the configured value from Brain service."""
Expand Down
Loading
Loading