Skip to content
This repository has been archived by the owner on Feb 9, 2024. It is now read-only.

Commit

Permalink
Merge pull request #22 from streamdal/blinktag/pause_resume_tail
Browse files Browse the repository at this point in the history
Pause/Resume Tail
  • Loading branch information
blinktag authored Dec 6, 2023
2 parents 8c198ca + 1067286 commit 84fd0b3
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 30 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pytest-asyncio==0.21.1
pytest-cov==4.1.0
pytest-mock==3.12.0
python-dateutil==2.8.2
readme-renderer==40.0
readme-renderer==42.0
requests==2.31.0
requests-toolbelt==1.0.0
rfc3986==2.0.0
Expand Down
120 changes: 94 additions & 26 deletions streamdal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class StreamdalClient:
workers: list
audiences: dict
tails: dict
paused_tails: dict
host: str
port: int
schemas: dict
Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(self, cfg: StreamdalConfig):
self.paused_pipelines = {}
self.audiences = {}
self.tails = {}
self.paused_tails = {}
self.schemas = {}
self.log = log
self.exit = cfg.exit
Expand Down Expand Up @@ -899,6 +901,10 @@ def _tail_request(self, cmd: protos.Command):
self._start_tail(cmd)
elif cmd.tail.request.type == protos.TailRequestType.TAIL_REQUEST_TYPE_STOP:
self._stop_tail(cmd)
elif cmd.tail.request.type == protos.TailRequestType.TAIL_REQUEST_TYPE_PAUSE:
self._pause_tail(cmd)
elif cmd.tail.request.type == protos.TailRequestType.TAIL_REQUEST_TYPE_RESUME:
self._resume_tail(cmd)

def _send_tail(
self,
Expand All @@ -907,7 +913,7 @@ def _send_tail(
original_data: bytes,
new_data: bytes,
):
tails = self._get_tails(aud)
tails = self._get_active_tails_for_audience(aud)
if len(tails) == 0:
return

Expand Down Expand Up @@ -964,40 +970,41 @@ def _start_tail(self, cmd: protos.Command):
if aud_str in self.audiences:
t.start_tail_workers()

self._set_tail(t)
self._set_active_tail(t)

def _set_tail(self, t: Tail):
def _set_active_tail(self, t: Tail):
key = common.aud_to_str(t.request.audience)

if key not in self.tails:
self.tails[key] = {}

self.tails[key][t.request.id] = t

def _set_paused_tail(self, t: Tail):
key = common.aud_to_str(t.request.audience)

if key not in self.paused_tails:
self.paused_tails[key] = {}

self.paused_tails[key][t.request.id] = t

def _stop_tail(self, cmd: protos.Command):
validation.tail_request(cmd)

aud = cmd.tail.request.audience
tail_id = cmd.tail.request.id

tails = self._get_tails(aud)
if len(tails) == 0:
self.log.debug(
f"received stop tail command for non-existent tail: {tail_id}"
)
return

if tail_id not in tails.keys():
self.log.debug(
f"received stop tail command for non-existent tail: {tail_id}"
)
return

self.log.debug(f"Stopping tail: {tail_id}")
tails = self._get_active_tails_for_audience(aud)
if tail_id in tails.keys():
self.log.debug(f"Stopping active tail: {tail_id}")
tails[tail_id].exit.set()
self._remove_active_tail(aud, tail_id)

tails[tail_id].exit.set()

self._remove_tail(aud, tail_id)
paused_tails = self._get_paused_tails_for_audience(aud)
if tail_id in paused_tails.keys():
self.log.debug(f"Stopping paused tail: {tail_id}")
paused_tails[tail_id].exit.set()
self._remove_paused_tail(aud, tail_id)

def _stop_all_tails(self):
"""
Expand All @@ -1009,9 +1016,43 @@ def _stop_all_tails(self):
for audience in audiences:
for t in audience.values():
t.exit.set()
self._remove_tail(t.request.audience, t.request.id)
self._remove_active_tail(t.request.audience, t.request.id)

audiences = self.paused_tails.values()
for audience in audiences:
for t in audience.values():
t.exit.set()
self._remove_paused_tail(t.request.audience, t.request.id)

def _pause_tail(self, cmd: protos.Command):
# Remove from active tails
t = self._remove_active_tail(cmd.tail.request.audience, cmd.tail.request.id)
if t is None:
self.log.debug(
f"Received paused tail for unknown tail request {cmd.tail.request.id}"
)
return

# Add to paused tails
self._set_paused_tail(t)

self.log.debug(f"Pausing tail: {cmd.tail.request.id}")

def _resume_tail(self, cmd: protos.Command):
# Remove from paused tails
t = self._remove_paused_tail(cmd.tail.request.audience, cmd.tail.request.id)
if t is None:
self.log.debug(
f"Received resumed tail for unknown tail request {cmd.tail.request.id}"
)
return

# Add to active tails
self._set_active_tail(t)

def _get_tails(
self.log.debug(f"Resuming tail: {cmd.tail.request.id}")

def _get_active_tails_for_audience(
self,
aud: protos.Audience,
) -> dict:
Expand All @@ -1021,19 +1062,46 @@ def _get_tails(

return {}

def _remove_tail(self, aud: protos.Audience, tail_id: str):
def _get_paused_tails_for_audience(
self,
aud: protos.Audience,
) -> dict:
key = common.aud_to_str(aud)
if key in self.paused_tails:
return self.paused_tails[key]

return {}

def _remove_active_tail(self, aud: protos.Audience, tail_id: str) -> Tail:
key = common.aud_to_str(aud)
if key not in self.tails:
return
return None

if tail_id not in self.tails[key]:
return
return None

self.tails[key].pop(tail_id)
t = self.tails[key].pop(tail_id)

if len(self.tails[key]) == 0:
self.tails.pop(key)

return t

def _remove_paused_tail(self, aud: protos.Audience, tail_id: str) -> Tail:
key = common.aud_to_str(aud)
if key not in self.paused_tails:
return None

if tail_id not in self.paused_tails[key]:
return None

t = self.paused_tails[key].pop(tail_id)

if len(self.paused_tails[key]) == 0:
self.paused_tails.pop(key)

return t

def _get_schema(self, aud: protos.Audience) -> bytes:
schema = self.schemas.get(common.aud_to_str(aud))
if schema is None:
Expand Down
76 changes: 73 additions & 3 deletions test_streamdal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import unittest.mock as mock
import streamdal
from streamdal import StreamdalClient, StreamdalConfig
from streamdal.tail import Tail


class TestStreamdalClient:
Expand All @@ -31,6 +32,7 @@ def before_each(self):
client.paused_pipelines = {}
client.audiences = {}
client.tails = {}
client.paused_tails = {}
client.schemas = {}

self.client = client
Expand Down Expand Up @@ -290,6 +292,74 @@ def test_tail_request_stop(self, mocker):
self.client._tail_request(cmd)
m.assert_called_once()

def test_tail_request_pause(self):
tail_id = uuid.uuid4().__str__()

aud = protos.Audience(
component_name="kafka",
service_name="testing",
operation_name="test-topic",
operation_type=protos.OperationType.OPERATION_TYPE_PRODUCER,
)

t = Tail.__new__(Tail)
t.request = protos.TailRequest(
audience=aud,
id=tail_id,
type=protos.TailRequestType.TAIL_REQUEST_TYPE_START,
)

self.client.tails[common.aud_to_str(aud)] = {}
self.client.tails[common.aud_to_str(aud)][tail_id] = t

pause_cmd = protos.Command(
tail=protos.TailCommand(
request=protos.TailRequest(
audience=aud,
id=tail_id,
type=protos.TailRequestType.TAIL_REQUEST_TYPE_PAUSE,
)
),
)

self.client._tail_request(pause_cmd)
assert len(self.client.paused_tails) == 1
assert len(self.client.tails) == 0

def test_tail_request_resume(self):
tail_id = uuid.uuid4().__str__()

aud = protos.Audience(
component_name="kafka",
service_name="testing",
operation_name="test-topic",
operation_type=protos.OperationType.OPERATION_TYPE_PRODUCER,
)

t = Tail.__new__(Tail)
t.request = protos.TailRequest(
audience=aud,
id=tail_id,
type=protos.TailRequestType.TAIL_REQUEST_TYPE_START,
)

self.client.paused_tails[common.aud_to_str(aud)] = {}
self.client.paused_tails[common.aud_to_str(aud)][tail_id] = t

resume_cmd = protos.Command(
tail=protos.TailCommand(
request=protos.TailRequest(
audience=aud,
id=tail_id,
type=protos.TailRequestType.TAIL_REQUEST_TYPE_RESUME,
)
),
)

self.client._tail_request(resume_cmd)
assert len(self.client.paused_tails) == 0
assert len(self.client.tails) == 1

def test_set_tail(self):
tail_id = uuid.uuid4().__str__()

Expand All @@ -308,7 +378,7 @@ def test_set_tail(self):

assert len(self.client.tails) == 0

self.client._set_tail(tail)
self.client._set_active_tail(tail)

assert len(self.client.tails) == 1

Expand Down Expand Up @@ -370,7 +440,7 @@ def test_stop_tail(self):
tail=protos.TailCommand(request=req),
)

self.client._set_tail(tail)
self.client._set_active_tail(tail)
assert len(self.client.tails) == 1

cmd.tail.request.type = (protos.TailRequestType.TAIL_REQUEST_TYPE_STOP,)
Expand All @@ -390,7 +460,7 @@ def test_remove_tail(self):
aud_str = common.aud_to_str(aud)

self.client.tails = {aud_str: {tail_id: mock.Mock()}}
self.client._remove_tail(aud, tail_id)
self.client._remove_active_tail(aud, tail_id)
assert len(self.client.tails) == 0

def test_set_schema(self):
Expand Down

0 comments on commit 84fd0b3

Please sign in to comment.