Skip to content

Commit

Permalink
Merge pull request #319 from saltstack/issue/3003.5/61865
Browse files Browse the repository at this point in the history
[3003.5] Fix bug in tcp transport
  • Loading branch information
garethgreenaway authored May 24, 2022
2 parents 3d53101 + 40478bc commit 88e9b86
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 33 deletions.
1 change: 1 addition & 0 deletions changelog/61865.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug in tcp transport
4 changes: 2 additions & 2 deletions salt/transport/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,6 @@ def handle_stream(self, stream, address):
@salt.ext.tornado.gen.coroutine
def publish_payload(self, package, _):
package = self.pack_publish(package)
log.debug("TCP PubServer sending payload: %s", package)
payload = salt.transport.frame.frame_msg(package["payload"])

to_remove = []
Expand Down Expand Up @@ -1725,7 +1724,8 @@ def _publish_daemon(self, **kwargs):

# Check if io_loop was set outside
if self.io_loop is None:
self.io_loop = salt.ext.tornado.ioloop.IOLoop.current()
self.io_loop = salt.ext.tornado.ioloop.IOLoop()
self.io_loop.make_current()

# Spin up the publisher
pub_server = PubServer(
Expand Down
153 changes: 122 additions & 31 deletions tests/pytests/functional/transport/zeromq/test_pub_server_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import multiprocessing
import signal
import socket
import time
from concurrent.futures.thread import ThreadPoolExecutor

Expand All @@ -14,7 +15,9 @@
import salt.master
import salt.transport.client
import salt.transport.server
import salt.transport.tcp
import salt.transport.zeromq
import salt.utils.msgpack
import salt.utils.platform
import salt.utils.process
import salt.utils.stringutils
Expand All @@ -25,13 +28,21 @@
log = logging.getLogger(__name__)


class RecvError(Exception):
"""
Raised by the Collector's _recv method when there is a problem
getting publishes from to the publisher.
"""


class Collector(salt.utils.process.SignalHandlingProcess):
def __init__(
self, minion_config, pub_uri, aes_key, timeout=30, zmq_filtering=False
self, minion_config, interface, port, aes_key, timeout=300, zmq_filtering=False
):
super().__init__()
self.minion_config = minion_config
self.pub_uri = pub_uri
self.interface = interface
self.port = port
self.aes_key = aes_key
self.timeout = timeout
self.hard_timeout = time.time() + timeout + 30
Expand All @@ -41,6 +52,16 @@ def __init__(
self.stopped = multiprocessing.Event()
self.started = multiprocessing.Event()
self.running = multiprocessing.Event()
if salt.utils.msgpack.version >= (0, 5, 2):
# Under Py2 we still want raw to be set to True
msgpack_kwargs = {"raw": False}
else:
msgpack_kwargs = {"encoding": "utf-8"}
self.unpacker = salt.utils.msgpack.Unpacker(**msgpack_kwargs)

@property
def transport(self):
return self.minion_config["transport"]

def _rotate_secrets(self, now=None):
salt.master.SMaster.secrets["aes"] = {
Expand All @@ -57,47 +78,104 @@ def _rotate_secrets(self, now=None):
"rotate_master_key": self._rotate_secrets,
}

def run(self):
"""
Gather results until then number of seconds specified by timeout passes
without receiving a message
"""
ctx = zmq.Context()
sock = ctx.socket(zmq.SUB)
sock.setsockopt(zmq.LINGER, -1)
sock.setsockopt(zmq.SUBSCRIBE, b"")
sock.connect(self.pub_uri)
def _setup_listener(self):
if self.transport == "zeromq":
ctx = zmq.Context()
self.sock = ctx.socket(zmq.SUB)
self.sock.setsockopt(zmq.LINGER, -1)
self.sock.setsockopt(zmq.SUBSCRIBE, b"")
pub_uri = "tcp://{}:{}".format(self.interface, self.port)
self.sock.connect(pub_uri)
else:
end = time.time() + 300
while True:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.connect((self.interface, self.port))
except ConnectionRefusedError:
if time.time() >= end:
raise
time.sleep(1)
else:
break
self.sock = salt.ext.tornado.iostream.IOStream(sock)

@salt.ext.tornado.gen.coroutine
def _recv(self):
exc = None
if self.transport == "zeromq":
# test_zeromq_filtering requires catching the
# SaltDeserializationError in order to pass.
try:
payload = self.sock.recv(zmq.NOBLOCK)
serial_payload = salt.payload.Serial({}).loads(payload)
raise salt.ext.tornado.gen.Return(serial_payload)
except (zmq.ZMQError, salt.exceptions.SaltDeserializationError):
exc = RecvError("ZMQ Error")
else:
for msg in self.unpacker:
serial_payload = salt.payload.Serial({}).loads(msg["body"])
raise salt.ext.tornado.gen.Return(serial_payload)
byts = yield self.sock.read_bytes(8096, partial=True)
self.unpacker.feed(byts)
for msg in self.unpacker:
serial_payload = salt.payload.Serial({}).loads(msg["body"])
raise salt.ext.tornado.gen.Return(serial_payload)
exc = RecvError("TCP Error")
raise exc

@salt.ext.tornado.gen.coroutine
def _run(self, loop):
try:
self._setup_listener()
except Exception: # pylint: disable=broad-except
self.started.set()
log.exception("Failed to start listening")
return
self.started.set()
last_msg = time.time()
serial = salt.payload.Serial(self.minion_config)
crypticle = salt.crypt.Crypticle(self.minion_config, self.aes_key)
self.started.set()
while True:
curr_time = time.time()
if time.time() > self.hard_timeout:
log.error("Hard timeout reaced in test collector!")
break
if curr_time - last_msg >= self.timeout:
log.error("Receive timeout reaced in test collector!")
break
try:
payload = sock.recv(zmq.NOBLOCK)
except zmq.ZMQError:
payload = yield self._recv()
except RecvError:
time.sleep(0.01)
else:
try:
serial_payload = serial.loads(payload)
payload = crypticle.loads(serial_payload["load"])
payload = crypticle.loads(payload["load"])
if not payload:
continue
if "start" in payload:
log.info("Collector started")
self.running.set()
continue
if "stop" in payload:
log.info("Collector stopped")
break
last_msg = time.time()
self.results.append(payload["jid"])
except salt.exceptions.SaltDeserializationError:
log.error("Deserializer Error")
if not self.zmq_filtering:
log.exception("Failed to deserialize...")
break
loop.stop()

def run(self):
"""
Gather results until then number of seconds specified by timeout passes
without receiving a message
"""
loop = salt.ext.tornado.ioloop.IOLoop()
loop.add_callback(self._run, loop)
loop.start()

def __enter__(self):
self.manager.__enter__()
Expand Down Expand Up @@ -140,18 +218,21 @@ def __init__(self, master_config, minion_config, **collector_kwargs):
self.process_manager = salt.utils.process.ProcessManager(
name="ZMQ-PubServer-ProcessManager"
)
self.pub_server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(
self.pub_server_channel = salt.transport.server.PubServerChannel.factory(
self.master_config
)
self.pub_server_channel.pre_fork(
self.process_manager,
kwargs={"log_queue": salt.log.setup.get_multiprocessing_logging_queue()},
)
self.pub_uri = "tcp://{interface}:{publish_port}".format(**self.master_config)
self.queue = multiprocessing.Queue()
self.stopped = multiprocessing.Event()
self.collector = Collector(
self.minion_config, self.pub_uri, self.aes_key, **self.collector_kwargs
self.minion_config,
self.master_config["interface"],
self.master_config["publish_port"],
self.aes_key,
**self.collector_kwargs
)

def run(self):
Expand Down Expand Up @@ -179,8 +260,8 @@ def close(self):
return
self.process_manager.stop_restarting()
self.process_manager.send_signal_to_processes(signal.SIGTERM)
self.pub_server_channel.pub_close()
self.process_manager.kill_children()
if hasattr(self.pub_server_channel, "pub_close"):
self.pub_server_channel.pub_close()
# Really terminate any process still left behind
for pid in self.process_manager._process_map:
terminate_process(pid=pid, kill_children=True, slow_stop=False)
Expand All @@ -192,7 +273,7 @@ def publish(self, payload):
def __enter__(self):
self.start()
self.collector.__enter__()
attempts = 30
attempts = 300
while attempts > 0:
self.publish({"tgt_type": "glob", "tgt": "*", "jid": -1, "start": True})
if self.collector.running.wait(1) is True:
Expand All @@ -210,33 +291,43 @@ def __exit__(self, *args):
# We can safely wait here without a timeout because the Collector instance has a
# hard timeout set, so eventually Collector.stopped will be set
self.collector.stopped.wait()
self.collector.join()
# Stop our own processing
self.queue.put(None)
# Wait at most 10 secs for the above `None` in the queue to be processed
self.stopped.wait(10)
self.close()
self.terminate()
self.join()
log.info("The PubServerChannelProcess has terminated")


@pytest.fixture(params=["tcp", "zeromq"])
def transport(request):
yield request.param


@pytest.mark.skip_on_windows
@pytest.mark.slow_test
def test_publish_to_pubserv_ipc(salt_master, salt_minion):
def test_publish_to_pubserv_ipc(salt_master, salt_minion, transport):
"""
Test sending 10K messags to ZeroMQPubServerChannel using IPC transport
ZMQ's ipc transport not supported on Windows
"""
opts = dict(salt_master.config.copy(), ipc_mode="ipc", pub_hwm=0)
with PubServerChannelProcess(opts, salt_minion.config.copy()) as server_channel:
opts = dict(
salt_master.config.copy(), ipc_mode="ipc", pub_hwm=0, transport=transport
)
minion_opts = dict(salt_minion.config.copy(), transport=transport)
with PubServerChannelProcess(opts, minion_opts) as server_channel:
send_num = 10000
expect = []
for idx in range(send_num):
expect.append(idx)
load = {"tgt_type": "glob", "tgt": "*", "jid": idx}
server_channel.publish(load)
results = server_channel.collector.results
assert len(results) == send_num, "{} != {}, difference: {}".format(
assert len(results) == send_num, "{} != {}, difference: {:.40}".format(
len(results), send_num, set(expect).difference(results)
)

Expand All @@ -252,15 +343,15 @@ def test_issue_36469_tcp(salt_master, salt_minion):
"""

def _send_small(opts, sid, num=10):
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
server_channel = salt.transport.server.PubServerChannel.factory(opts)
for idx in range(num):
load = {"tgt_type": "glob", "tgt": "*", "jid": "{}-s{}".format(sid, idx)}
server_channel.publish(load)
time.sleep(0.3)
server_channel.pub_close()

def _send_large(opts, sid, num=10, size=250000 * 3):
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
server_channel = salt.transport.server.PubServerChannel.factory(opts)
for idx in range(num):
load = {
"tgt_type": "glob",
Expand All @@ -269,7 +360,7 @@ def _send_large(opts, sid, num=10, size=250000 * 3):
"xdata": "0" * size,
}
server_channel.publish(load)
time.sleep(0.3)
time.sleep(3)
server_channel.pub_close()

opts = dict(salt_master.config.copy(), ipc_mode="tcp", pub_hwm=0)
Expand Down

0 comments on commit 88e9b86

Please sign in to comment.