Skip to content

Commit

Permalink
add timeout to ZmqClientTransport
Browse files Browse the repository at this point in the history
  • Loading branch information
lnoor committed Jan 26, 2022
1 parent 72a686b commit c7b3ed4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
3 changes: 3 additions & 0 deletions tinyrpc/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,6 @@ class InvalidParamsError(RPCError, ABC):

class ServerError(RPCError, ABC):
"""An internal error in the RPC system occurred."""

class TimeoutError(Exception):
"""No reply received within the timeout period."""
27 changes: 21 additions & 6 deletions tinyrpc/transports/zmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

from __future__ import absolute_import # needed for zmq import

from typing import Tuple, Any
from typing import Tuple, Any, Dict

import zmq

from . import ServerTransport, ClientTransport
from .. import exc


class ZmqServerTransport(ServerTransport):
Expand Down Expand Up @@ -50,23 +51,36 @@ class ZmqClientTransport(ClientTransport):
:param socket: A :py:const:`zmq.REQ` socket instance, connected to the
server socket.
:param timeout: An optional integer. When set it defines the time period
in milliseconds to wait for a reply.
It will generate a :py:class:`exc.TimeoutError` exception
if no reply was received in time.
"""

def __init__(self, socket: zmq.Socket) -> None:
def __init__(self, socket: zmq.Socket, timeout: int = None) -> None:
self.socket = socket
self.timeout = timeout

def send_message(self, message: bytes, expect_reply: bool = True) -> bytes:
self.socket.send(message)

# zmq contains a state machine preventing a new request
# until the previous one is answered, so always receive
reply = self.socket.recv()

if self.timeout is None:
reply = self.socket.recv()
else:
poller = zmq.Poller()
poller.register(self.socket, zmq.POLLIN)
ready = dict(poller.poll(self.timeout))
if ready.get(self.socket) == zmq.POLLIN:
reply = self.socket.recv()
else:
raise exc.TimeoutError()
if expect_reply:
return reply

@classmethod
def create(cls, zmq_context: zmq.Context, endpoint: str) -> 'ZmqClientTransport':
def create(cls, zmq_context: zmq.Context, endpoint: str, timeout: int = None) -> 'ZmqClientTransport':
"""Create new client transport.
Instead of creating the socket yourself, you can call this function and
Expand All @@ -77,7 +91,8 @@ def create(cls, zmq_context: zmq.Context, endpoint: str) -> 'ZmqClientTransport'
:param zmq_context: A 0mq context.
:param endpoint: The endpoint the server is bound to.
:param timeout: Optional period in milliseconds to wait for reply
"""
socket = zmq_context.socket(zmq.REQ)
socket.connect(endpoint)
return cls(socket)
return cls(socket, timeout)

0 comments on commit c7b3ed4

Please sign in to comment.