Skip to content

Commit

Permalink
network and delta sync enhancements
Browse files Browse the repository at this point in the history
  • Loading branch information
dwsutherland committed Oct 18, 2019
1 parent 7a8e501 commit 77607df
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 70 deletions.
50 changes: 26 additions & 24 deletions bin/cylc-subscribe
Original file line number Diff line number Diff line change
Expand Up @@ -25,49 +25,44 @@ Invoke suite subscriber to receive published workflow output.
import asyncio
import json
import sys
if '--use-ssh' in sys.argv[1:]:
sys.argv.remove('--use-ssh')
from cylc.flow.remote import remrun
if remrun():
sys.exit(0)
import time

from google.protobuf.json_format import MessageToDict

from cylc.flow.option_parsers import CylcOptionParser as COP
from cylc.flow.network.scan import get_scan_items_from_fs, re_compile_filters
from cylc.flow.network.subscriber import WorkflowSubscriber
from cylc.flow.network.subscriber import WorkflowSubscriber, process_delta_msg
from cylc.flow.terminal import cli_function
from cylc.flow.ws_data_mgr import DELTAS_MAP

if '--use-ssh' in sys.argv[1:]:
sys.argv.remove('--use-ssh')
from cylc.flow.remote import remrun
if remrun():
sys.exit(0)

def print_workflow(topic, msg):
msg_type = topic.decode('utf-8')
try:
data = DELTAS_MAP[msg_type]()
except KeyError:
return
data.ParseFromString(msg)
print('Received: ', msg_type)

def print_message(_, data):
"""Print protobuf message."""
sys.stdout.write(
json.dumps(MessageToDict(data), indent=4) + '\n')


def get_option_parser():
"""Augment options parser to current context."""
parser = COP(__doc__, comms=True, argdoc=[
('REG', 'Suite name'),
('[TOPIC]', 'Subscription topic to receive')])
('[TOPICS]', 'Subscription topics to receive')])

return parser


@cli_function(get_option_parser)
def main(_, options, suite, topic=None):
def main(_, options, suite, topics=None):
host = None
port = None
cre_owner, cre_name = re_compile_filters(None, ['.*'])
while True:
for s_reg, s_host, s_port, s_pub_port in get_scan_items_from_fs(
for s_reg, s_host, _, s_pub_port in get_scan_items_from_fs(
cre_owner, cre_name):
if s_reg == suite:
host = s_host
Expand All @@ -78,13 +73,20 @@ def main(_, options, suite, topic=None):
time.sleep(5)

print(f'Connecting to tcp://{host}:{port}')
if topic is None:
topic = b'workflow'
topic_set = set()
if topics is None:
topic_set.add(b'workflow')
else:
topic = topic.encode('utf-8')
subscriber = WorkflowSubscriber(host, port, [topic])

asyncio.ensure_future(subscriber.subscribe([topic], print_workflow))
for topic in topics.split(','):
topic_set.add(topic.encode('utf-8'))
subscriber = WorkflowSubscriber(host, port, topics=topic_set)

asyncio.ensure_future(
subscriber.subscribe(
process_delta_msg,
func=print_message
)
)

# run Python run
try:
Expand Down
20 changes: 13 additions & 7 deletions cylc/flow/network/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@
import sys
from functools import partial
from typing import Union

from shutil import which
import jose.exceptions
import zmq
import zmq.asyncio

from shutil import which

import cylc.flow.flags
from cylc.flow import LOG
from cylc.flow.exceptions import (
ClientError,
Expand All @@ -44,11 +42,10 @@
load_contact_file
)

# we should only have one ZMQ context per-process
CONTEXT = zmq.asyncio.Context()


class ZMQClient(object):
class ZMQClient:
"""Initiate the REQ part of a ZMQ REQ-REP pair.
This class contains the logic for the ZMQ message interface and client -
Expand Down Expand Up @@ -94,10 +91,17 @@ class ZMQClient(object):
DEFAULT_TIMEOUT = 5. # 5 seconds

def __init__(self, host, port, encode_method, decode_method, secret_method,
timeout=None, timeout_handler=None, header=None):
context=None, timeout=None,
timeout_handler=None, header=None):
self.encode = encode_method
self.decode = decode_method
self.secret = secret_method
# we should only have one ZMQ context per-process
# don't instantiate a client unless none passed in
if context is None:
self.context = zmq.asyncio.Context()
else:
self.context = context
if timeout is None:
timeout = self.DEFAULT_TIMEOUT
else:
Expand All @@ -106,7 +110,7 @@ def __init__(self, host, port, encode_method, decode_method, secret_method,
self.timeout_handler = timeout_handler

# open the ZMQ socket
self.socket = CONTEXT.socket(zmq.REQ)
self.socket = self.context.socket(zmq.REQ)
self.socket.connect('tcp://%s:%d' % (host, port))
# if there is no server don't keep the client hanging around
self.socket.setsockopt(zmq.LINGER, int(self.DEFAULT_TIMEOUT))
Expand Down Expand Up @@ -218,6 +222,7 @@ def __init__(
owner: str = None,
host: str = None,
port: Union[int, str] = None,
context=None,
timeout: Union[float, str] = None
):
"""Initiate a client to the suite runtime API.
Expand Down Expand Up @@ -256,6 +261,7 @@ def __init__(
encode_method=encrypt,
decode_method=decrypt,
secret_method=partial(get_secret, suite),
context=context,
timeout=timeout,
header=self.get_header(),
timeout_handler=partial(self._timeout_handler, suite, host, port)
Expand Down
20 changes: 14 additions & 6 deletions cylc/flow/network/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def serialize_data(data, serializer):
"""Serialize by specified method."""
if callable(serializer):
return serializer(data)
elif isinstance(serializer, str):
if isinstance(serializer, str):
return getattr(data, serializer)()
return data

Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self, context=None):
def start(self, min_port, max_port):
"""Start the ZeroMQ publisher.
Will use a port range provided to select random ports.
Sockets created in alternate thread, port range forwarded.
Args:
min_port (int): minimum socket port number
Expand All @@ -82,13 +82,20 @@ def start(self, min_port, max_port):
# Context are thread safe, but Sockets are not so if multiple
# sockets then they need be created on their own thread.
self.thread = Thread(
target=self._create_socket,
target=self._start_publisher,
args=(min_port, max_port)
)
self.thread.start()

def _create_socket(self, min_port, max_port):
"""Create ZeroMQ Publish socket."""
def _start_publisher(self, min_port, max_port):
"""Create ZeroMQ Publish socket.
Will use a port range provided to select random ports.
Args:
min_port (int): minimum socket port number
max_port (int): maximum socket port number
"""
self.socket = self.context.socket(zmq.PUB)
# this limit on messages in queue is more than enough,
# as messages correspond to scheduler loops (*messages/loop):
Expand All @@ -105,8 +112,9 @@ def _create_socket(self, min_port, max_port):
self.socket.close()
raise CylcError(
'could not start Cylc ZMQ publisher: %s' % str(exc))

try:
asyncio.get_running_loop()
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
Expand Down
78 changes: 58 additions & 20 deletions cylc/flow/network/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from threading import Thread
import asyncio
from graphql.execution.executors.asyncio import AsyncioExecutor

import zmq

from cylc.flow import LOG
Expand All @@ -36,16 +35,18 @@
from cylc.flow.suite_status import (
KEY_META, KEY_NAME, KEY_OWNER, KEY_STATES,
KEY_TASKS_BY_STATE, KEY_UPDATE_TIME, KEY_VERSION)
from cylc.flow.ws_data_mgr import DELTAS_MAP
from cylc.flow.ws_messages_pb2 import PbEntireWorkflow
from cylc.flow import __version__ as CYLC_VERSION

# maps server methods to the protobuf message (for client/UIS import)
PB_METHOD_MAP = {
'pb_entire_workflow': PbEntireWorkflow
'pb_entire_workflow': PbEntireWorkflow,
'pb_data_elements': DELTAS_MAP
}


class ZMQServer(object):
class ZMQServer:
"""Initiate the REP part of a ZMQ REQ-REP pair.
This class contains the logic for the ZMQ message interface and client -
Expand Down Expand Up @@ -83,9 +84,13 @@ class ZMQServer(object):
"""

def __init__(self, encode_method, decode_method, secret_method):
def __init__(self, encode_method, decode_method, secret_method,
context=None):
if context is None:
self.context = zmq.Context()
else:
self.context = context
self.port = None
self.context = zmq.Context()
self.socket = None
self.endpoints = None
self.thread = None
Expand All @@ -97,6 +102,25 @@ def __init__(self, encode_method, decode_method, secret_method):
def start(self, min_port, max_port):
"""Start the server running.
Port range passed to socket creation in server thread.
Args:
min_port (int): minimum socket port number
max_port (int): maximum socket port number
"""
# TODO: this in asyncio?
# Requires the Cylc main loop in asyncio first
# And use of concurrent.futures.ThreadPoolExecutor?
self.thread = Thread(
target=self._start_server,
args=(min_port, max_port)
)
self.thread.start()

def _start_server(self, min_port, max_port):
"""Create the thread async loop, and run listener.
Will use a port range provided to select random ports.
Args:
Expand All @@ -119,12 +143,9 @@ def start(self, min_port, max_port):
raise CylcError('could not start Cylc ZMQ server: %s' % str(exc))

# start accepting requests
self.register_endpoints()

self.queue = Queue()
# TODO: this in asyncio? Requires the Cylc main loop in asyncio first
self.thread = Thread(target=self._listener)
self.thread.start()
self.register_endpoints()
self._listener()

def stop(self):
"""Finish serving the current request then stop the server."""
Expand All @@ -147,9 +168,9 @@ def _listener(self):
if self.queue.qsize():
command = self.queue.get()
if command == 'STOP':
self.stop()
break
else:
raise ValueError('Unknown command "%s"' % command)
raise ValueError('Unknown command "%s"' % command)

try:
# wait RECV_TIMEOUT for a message
Expand All @@ -166,7 +187,7 @@ def _listener(self):
except Exception as exc: # purposefully catch generic exception
# failed to decode message, possibly resulting from failed
# authentication
LOG.exception(f'failed to decode message: {str(exc)}')
LOG.exception('failed to decode message: "%s"' % str(exc))
else:
# success case - serve the request
res = self._receiver(message)
Expand Down Expand Up @@ -254,12 +275,13 @@ class SuiteRuntimeServer(ZMQServer):

API = 4 # cylc API version

def __init__(self, schd):
def __init__(self, schd, context=None):
ZMQServer.__init__(
self,
encrypt,
decrypt,
partial(get_secret, schd.suite)
partial(get_secret, schd.suite),
context=context
)
self.schd = schd
self.public_priv = None # update in get_public_priv()
Expand Down Expand Up @@ -418,7 +440,7 @@ def dry_run_tasks(self, task_globs, check_syntax=True):
"""
self.schd.command_queue.put(('dry_run_tasks', (task_globs,),
{'check_syntax': check_syntax}))
{'check_syntax': check_syntax}))
return (True, 'Command queued')

@authorise(Priv.CONTROL)
Expand Down Expand Up @@ -1252,11 +1274,27 @@ def pb_entire_workflow(self):
Returns:
bytes
Protobuf encoded message
Serialised Protobuf message
"""
pb_msg = self.schd.ws_data_mgr.get_entire_workflow()
# Use google.protobuf.json_format.MessageToJson
# to send response through JWT authorisation
# (Request still requires/uses JWT anyway).
return pb_msg.SerializeToString()

@authorise(Priv.READ)
@ZMQServer.expose
def pb_data_elements(self, element_type):
"""Send the specified data elements in delta form.
Args:
element_type (str):
Key from DELTAS_MAP dictionary.
Returns:
bytes
Serialised Protobuf message
"""
pb_msg = self.schd.ws_data_mgr.get_data_elements(element_type)
if pb_msg is None:
return f'No elements of type "{element_type}"'
return pb_msg.SerializeToString()
Loading

0 comments on commit 77607df

Please sign in to comment.