diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py index fed6de4..372862a 100644 --- a/tests/unit/test_server.py +++ b/tests/unit/test_server.py @@ -166,7 +166,7 @@ def test_register_rpc(self): def add(msg: Tuple[int, int]) -> int: return msg[0] + msg[1] - self.assertEqual(server._rpc_router, {"add": add}) + self.assertEqual(server._rpc_router, {"add": (add, False)}) self.assertEqual(server._rpc_input_type_map, {"add": Tuple[int, int]}) self.assertEqual(server._rpc_return_type_map, {"add": int}) diff --git a/zero/client_server/server.py b/zero/client_server/server.py index 61dbfa4..37677c6 100644 --- a/zero/client_server/server.py +++ b/zero/client_server/server.py @@ -2,9 +2,10 @@ import os import signal import sys +from asyncio import iscoroutinefunction from functools import partial from multiprocessing.pool import Pool -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Tuple import zmq.utils.win32 @@ -56,7 +57,8 @@ def __init__( self._encoder = encoder or get_encoder(config.ENCODER) # Stores rpc functions against their names - self._rpc_router: Dict[str, Callable] = {} + # and if they are coroutines + self._rpc_router: Dict[str, Tuple[Callable, bool]] = {} # Stores rpc functions `msg` types self._rpc_input_type_map: Dict[str, Optional[type]] = {} @@ -88,7 +90,7 @@ def register_rpc(self, func: Callable): func ) - self._rpc_router[func.__name__] = func + self._rpc_router[func.__name__] = (func, iscoroutinefunction(func)) return func def run(self, workers: int = os.cpu_count() or 1): diff --git a/zero/client_server/worker.py b/zero/client_server/worker.py index e8d8bf2..feec284 100644 --- a/zero/client_server/worker.py +++ b/zero/client_server/worker.py @@ -1,5 +1,4 @@ import asyncio -import inspect import logging import time from typing import Optional @@ -74,14 +73,11 @@ def handle_msg(self, rpc, msg): logging.error("Function `%s` not found!", rpc) return {"__zerror__function_not_found": f"Function `{rpc}` not found!"} - func = self._rpc_router[rpc] + func, is_coro = self._rpc_router[rpc] ret = None try: - # TODO: is this a bottleneck - if inspect.iscoroutinefunction(func): - # this is blocking - # ret = self._loop.run_until_complete(func(msg) if msg else func()) + if is_coro: ret = async_to_sync(func)(msg) if msg else async_to_sync(func)() else: ret = func(msg) if msg else func() diff --git a/zero/codegen/codegen.py b/zero/codegen/codegen.py index 2d44866..9ffd6f4 100644 --- a/zero/codegen/codegen.py +++ b/zero/codegen/codegen.py @@ -56,7 +56,8 @@ def get_return_type_str(self, func_name: str): # pragma: no cover return self._rpc_return_type_map[func_name].__name__ def get_function_str(self, func_name: str): - func_lines = inspect.getsourcelines(self._rpc_router[func_name])[0] + func = self._rpc_router[func_name][0] + func_lines = inspect.getsourcelines(func)[0] def_line = [line for line in func_lines if "def" in line][0] # put self after the first ( @@ -71,7 +72,8 @@ def get_function_str(self, func_name: str): return def_line.replace("\n", "") def get_function_input_param_name(self, func_name: str): - func_lines = inspect.getsourcelines(self._rpc_router[func_name])[0] + func = self._rpc_router[func_name][0] + func_lines = inspect.getsourcelines(func)[0] def_line = [line for line in func_lines if "def" in line][0] params = def_line.split("(")[1].split(")")[0] return params.split(":")[0].strip() diff --git a/zero/logger.py b/zero/logger.py index 990d812..32e03e1 100644 --- a/zero/logger.py +++ b/zero/logger.py @@ -63,7 +63,7 @@ def start_log_poller(cls, ipc, port): logging.info(log) except KeyboardInterrupt: print("Caught KeyboardInterrupt, terminating async logger") - except Exception as exc: + except Exception as exc: # pylint: disable=broad-except print(exc) finally: log_listener.close() diff --git a/zero/pubsub/publisher.py b/zero/pubsub/publisher.py index b8f3e9b..6a42141 100644 --- a/zero/pubsub/publisher.py +++ b/zero/pubsub/publisher.py @@ -17,13 +17,13 @@ def __init__(self, host: str, port: int, use_async: bool = True): def _init_sync_socket(self): ctx = zmq.Context() - self.__socket: zmq.Socket = ctx.socket(zmq.PUB) + self.__socket = ctx.socket(zmq.PUB) self._set_socket_opt() self.__socket.connect(f"tcp://{self.__host}:{self.__port}") def _init_async_socket(self): ctx = zmq.asyncio.Context() - self.__socket: zmq.Socket = ctx.socket(zmq.PUB) + self.__socket = ctx.socket(zmq.PUB) self._set_socket_opt() self.__socket.connect(f"tcp://{self.__host}:{self.__port}") diff --git a/zero/pubsub/subscriber.py b/zero/pubsub/subscriber.py index 5ada7bd..a6a3761 100644 --- a/zero/pubsub/subscriber.py +++ b/zero/pubsub/subscriber.py @@ -48,9 +48,8 @@ def run(self): prcs.join() def _create_zmq_device(self): - gateway: zmq.Socket = None - backend: zmq.Socket = None - + gateway = None + backend = None try: gateway = self._ctx.socket(zmq.SUB) gateway.bind(f"tcp://*:{self._port}") @@ -70,8 +69,10 @@ def _create_zmq_device(self): logging.error(exc) logging.error("bringing down zmq device") finally: - gateway.close() - backend.close() + if gateway is not None: + gateway.close() + if backend is not None: + backend.close() self._ctx.term() diff --git a/zero/zero_mq/queue_device/broker.py b/zero/zero_mq/queue_device/broker.py index f4f33c6..7d5645b 100644 --- a/zero/zero_mq/queue_device/broker.py +++ b/zero/zero_mq/queue_device/broker.py @@ -1,21 +1,22 @@ import logging import zmq +from zmq import proxy class ZeroMQBroker: def __init__(self): self.context = zmq.Context.instance() - self.gateway: zmq.Socket = self.context.socket(zmq.ROUTER) - self.backend: zmq.Socket = self.context.socket(zmq.DEALER) + self.gateway = self.context.socket(zmq.ROUTER) + self.backend = self.context.socket(zmq.DEALER) def listen(self, address: str, channel: str) -> None: self.gateway.bind(f"{address}") self.backend.bind(f"{channel}") logging.info("Starting server at %s", address) - zmq.proxy(self.gateway, self.backend) + proxy(self.gateway, self.backend) def close(self) -> None: self.gateway.close() diff --git a/zero/zero_mq/queue_device/client.py b/zero/zero_mq/queue_device/client.py index 1fe70ac..c768651 100644 --- a/zero/zero_mq/queue_device/client.py +++ b/zero/zero_mq/queue_device/client.py @@ -15,7 +15,7 @@ def __init__(self, default_timeout): self._default_timeout = default_timeout self._context = zmq.Context.instance() - self.socket: zmq.Socket = self._context.socket(zmq.DEALER) + self.socket = self._context.socket(zmq.DEALER) self.socket.setsockopt(zmq.LINGER, 0) # dont buffer messages self.socket.setsockopt(zmq.RCVTIMEO, default_timeout) self.socket.setsockopt(zmq.SNDTIMEO, default_timeout) @@ -109,7 +109,7 @@ async def poll(self, timeout: int) -> bool: async def recv(self) -> bytes: try: - return await self.socket.recv() # type: ignore + return await self.socket.recv() except zmqerr.Again as exc: raise ConnectionException( f"Connection error for recv at {self._address}"