Skip to content

Commit

Permalink
Move coroutine check on boot time
Browse files Browse the repository at this point in the history
  • Loading branch information
Ananto30 committed Jun 23, 2024
1 parent 9533947 commit a036245
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 25 deletions.
2 changes: 1 addition & 1 deletion tests/unit/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_register_rpc(self):
def add(msg: Tuple[int, int]) -> int:
return msg[0] + msg[1]

self.assertEqual(server._rpc_router, {"add": add})
self.assertEqual(server._rpc_router, {"add": (add, False)})
self.assertEqual(server._rpc_input_type_map, {"add": Tuple[int, int]})
self.assertEqual(server._rpc_return_type_map, {"add": int})

Expand Down
8 changes: 5 additions & 3 deletions zero/client_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import os
import signal
import sys
from asyncio import iscoroutinefunction
from functools import partial
from multiprocessing.pool import Pool
from typing import Callable, Dict, Optional
from typing import Callable, Dict, Optional, Tuple

import zmq.utils.win32

Expand Down Expand Up @@ -56,7 +57,8 @@ def __init__(
self._encoder = encoder or get_encoder(config.ENCODER)

# Stores rpc functions against their names
self._rpc_router: Dict[str, Callable] = {}
# and if they are coroutines
self._rpc_router: Dict[str, Tuple[Callable, bool]] = {}

# Stores rpc functions `msg` types
self._rpc_input_type_map: Dict[str, Optional[type]] = {}
Expand Down Expand Up @@ -88,7 +90,7 @@ def register_rpc(self, func: Callable):
func
)

self._rpc_router[func.__name__] = func
self._rpc_router[func.__name__] = (func, iscoroutinefunction(func))
return func

def run(self, workers: int = os.cpu_count() or 1):
Expand Down
8 changes: 2 additions & 6 deletions zero/client_server/worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import inspect
import logging
import time
from typing import Optional
Expand Down Expand Up @@ -74,14 +73,11 @@ def handle_msg(self, rpc, msg):
logging.error("Function `%s` not found!", rpc)
return {"__zerror__function_not_found": f"Function `{rpc}` not found!"}

func = self._rpc_router[rpc]
func, is_coro = self._rpc_router[rpc]
ret = None

try:
# TODO: is this a bottleneck
if inspect.iscoroutinefunction(func):
# this is blocking
# ret = self._loop.run_until_complete(func(msg) if msg else func())
if is_coro:
ret = async_to_sync(func)(msg) if msg else async_to_sync(func)()
else:
ret = func(msg) if msg else func()
Expand Down
6 changes: 4 additions & 2 deletions zero/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def get_return_type_str(self, func_name: str): # pragma: no cover
return self._rpc_return_type_map[func_name].__name__

def get_function_str(self, func_name: str):
func_lines = inspect.getsourcelines(self._rpc_router[func_name])[0]
func = self._rpc_router[func_name][0]
func_lines = inspect.getsourcelines(func)[0]
def_line = [line for line in func_lines if "def" in line][0]

# put self after the first (
Expand All @@ -71,7 +72,8 @@ def get_function_str(self, func_name: str):
return def_line.replace("\n", "")

def get_function_input_param_name(self, func_name: str):
func_lines = inspect.getsourcelines(self._rpc_router[func_name])[0]
func = self._rpc_router[func_name][0]
func_lines = inspect.getsourcelines(func)[0]
def_line = [line for line in func_lines if "def" in line][0]
params = def_line.split("(")[1].split(")")[0]
return params.split(":")[0].strip()
Expand Down
2 changes: 1 addition & 1 deletion zero/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def start_log_poller(cls, ipc, port):
logging.info(log)
except KeyboardInterrupt:
print("Caught KeyboardInterrupt, terminating async logger")
except Exception as exc:
except Exception as exc: # pylint: disable=broad-except
print(exc)
finally:
log_listener.close()
Expand Down
4 changes: 2 additions & 2 deletions zero/pubsub/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ def __init__(self, host: str, port: int, use_async: bool = True):

def _init_sync_socket(self):
ctx = zmq.Context()
self.__socket: zmq.Socket = ctx.socket(zmq.PUB)
self.__socket = ctx.socket(zmq.PUB)
self._set_socket_opt()
self.__socket.connect(f"tcp://{self.__host}:{self.__port}")

def _init_async_socket(self):
ctx = zmq.asyncio.Context()
self.__socket: zmq.Socket = ctx.socket(zmq.PUB)
self.__socket = ctx.socket(zmq.PUB)
self._set_socket_opt()
self.__socket.connect(f"tcp://{self.__host}:{self.__port}")

Expand Down
11 changes: 6 additions & 5 deletions zero/pubsub/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ def run(self):
prcs.join()

def _create_zmq_device(self):
gateway: zmq.Socket = None
backend: zmq.Socket = None

gateway = None
backend = None
try:
gateway = self._ctx.socket(zmq.SUB)
gateway.bind(f"tcp://*:{self._port}")
Expand All @@ -70,8 +69,10 @@ def _create_zmq_device(self):
logging.error(exc)
logging.error("bringing down zmq device")
finally:
gateway.close()
backend.close()
if gateway is not None:
gateway.close()
if backend is not None:
backend.close()
self._ctx.term()


Expand Down
7 changes: 4 additions & 3 deletions zero/zero_mq/queue_device/broker.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import logging

import zmq
from zmq import proxy


class ZeroMQBroker:
def __init__(self):
self.context = zmq.Context.instance()

self.gateway: zmq.Socket = self.context.socket(zmq.ROUTER)
self.backend: zmq.Socket = self.context.socket(zmq.DEALER)
self.gateway = self.context.socket(zmq.ROUTER)
self.backend = self.context.socket(zmq.DEALER)

def listen(self, address: str, channel: str) -> None:
self.gateway.bind(f"{address}")
self.backend.bind(f"{channel}")
logging.info("Starting server at %s", address)

zmq.proxy(self.gateway, self.backend)
proxy(self.gateway, self.backend)

def close(self) -> None:
self.gateway.close()
Expand Down
4 changes: 2 additions & 2 deletions zero/zero_mq/queue_device/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, default_timeout):
self._default_timeout = default_timeout
self._context = zmq.Context.instance()

self.socket: zmq.Socket = self._context.socket(zmq.DEALER)
self.socket = self._context.socket(zmq.DEALER)
self.socket.setsockopt(zmq.LINGER, 0) # dont buffer messages
self.socket.setsockopt(zmq.RCVTIMEO, default_timeout)
self.socket.setsockopt(zmq.SNDTIMEO, default_timeout)
Expand Down Expand Up @@ -109,7 +109,7 @@ async def poll(self, timeout: int) -> bool:

async def recv(self) -> bytes:
try:
return await self.socket.recv() # type: ignore
return await self.socket.recv()
except zmqerr.Again as exc:
raise ConnectionException(
f"Connection error for recv at {self._address}"
Expand Down

0 comments on commit a036245

Please sign in to comment.