diff --git a/requirements.txt b/requirements.txt index ed775ba..8d9b47e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,7 +39,7 @@ pytest-asyncio==0.21.1 pytest-cov==4.1.0 pytest-mock==3.12.0 python-dateutil==2.8.2 -readme-renderer==40.0 +readme-renderer==42.0 requests==2.31.0 requests-toolbelt==1.0.0 rfc3986==2.0.0 diff --git a/streamdal/__init__.py b/streamdal/__init__.py index 263e15e..838f73f 100644 --- a/streamdal/__init__.py +++ b/streamdal/__init__.py @@ -113,6 +113,7 @@ class StreamdalClient: workers: list audiences: dict tails: dict + paused_tails: dict host: str port: int schemas: dict @@ -160,6 +161,7 @@ def __init__(self, cfg: StreamdalConfig): self.paused_pipelines = {} self.audiences = {} self.tails = {} + self.paused_tails = {} self.schemas = {} self.log = log self.exit = cfg.exit @@ -899,6 +901,10 @@ def _tail_request(self, cmd: protos.Command): self._start_tail(cmd) elif cmd.tail.request.type == protos.TailRequestType.TAIL_REQUEST_TYPE_STOP: self._stop_tail(cmd) + elif cmd.tail.request.type == protos.TailRequestType.TAIL_REQUEST_TYPE_PAUSE: + self._pause_tail(cmd) + elif cmd.tail.request.type == protos.TailRequestType.TAIL_REQUEST_TYPE_RESUME: + self._resume_tail(cmd) def _send_tail( self, @@ -907,7 +913,7 @@ def _send_tail( original_data: bytes, new_data: bytes, ): - tails = self._get_tails(aud) + tails = self._get_active_tails_for_audience(aud) if len(tails) == 0: return @@ -964,9 +970,9 @@ def _start_tail(self, cmd: protos.Command): if aud_str in self.audiences: t.start_tail_workers() - self._set_tail(t) + self._set_active_tail(t) - def _set_tail(self, t: Tail): + def _set_active_tail(self, t: Tail): key = common.aud_to_str(t.request.audience) if key not in self.tails: @@ -974,30 +980,31 @@ def _set_tail(self, t: Tail): self.tails[key][t.request.id] = t + def _set_paused_tail(self, t: Tail): + key = common.aud_to_str(t.request.audience) + + if key not in self.paused_tails: + self.paused_tails[key] = {} + + self.paused_tails[key][t.request.id] = t + def _stop_tail(self, cmd: protos.Command): validation.tail_request(cmd) aud = cmd.tail.request.audience tail_id = cmd.tail.request.id - tails = self._get_tails(aud) - if len(tails) == 0: - self.log.debug( - f"received stop tail command for non-existent tail: {tail_id}" - ) - return - - if tail_id not in tails.keys(): - self.log.debug( - f"received stop tail command for non-existent tail: {tail_id}" - ) - return - - self.log.debug(f"Stopping tail: {tail_id}") + tails = self._get_active_tails_for_audience(aud) + if tail_id in tails.keys(): + self.log.debug(f"Stopping active tail: {tail_id}") + tails[tail_id].exit.set() + self._remove_active_tail(aud, tail_id) - tails[tail_id].exit.set() - - self._remove_tail(aud, tail_id) + paused_tails = self._get_paused_tails_for_audience(aud) + if tail_id in paused_tails.keys(): + self.log.debug(f"Stopping paused tail: {tail_id}") + paused_tails[tail_id].exit.set() + self._remove_paused_tail(aud, tail_id) def _stop_all_tails(self): """ @@ -1009,9 +1016,43 @@ def _stop_all_tails(self): for audience in audiences: for t in audience.values(): t.exit.set() - self._remove_tail(t.request.audience, t.request.id) + self._remove_active_tail(t.request.audience, t.request.id) + + audiences = self.paused_tails.values() + for audience in audiences: + for t in audience.values(): + t.exit.set() + self._remove_paused_tail(t.request.audience, t.request.id) + + def _pause_tail(self, cmd: protos.Command): + # Remove from active tails + t = self._remove_active_tail(cmd.tail.request.audience, cmd.tail.request.id) + if t is None: + self.log.debug( + f"Received paused tail for unknown tail request {cmd.tail.request.id}" + ) + return + + # Add to paused tails + self._set_paused_tail(t) + + self.log.debug(f"Pausing tail: {cmd.tail.request.id}") + + def _resume_tail(self, cmd: protos.Command): + # Remove from paused tails + t = self._remove_paused_tail(cmd.tail.request.audience, cmd.tail.request.id) + if t is None: + self.log.debug( + f"Received resumed tail for unknown tail request {cmd.tail.request.id}" + ) + return + + # Add to active tails + self._set_active_tail(t) - def _get_tails( + self.log.debug(f"Resuming tail: {cmd.tail.request.id}") + + def _get_active_tails_for_audience( self, aud: protos.Audience, ) -> dict: @@ -1021,19 +1062,46 @@ def _get_tails( return {} - def _remove_tail(self, aud: protos.Audience, tail_id: str): + def _get_paused_tails_for_audience( + self, + aud: protos.Audience, + ) -> dict: + key = common.aud_to_str(aud) + if key in self.paused_tails: + return self.paused_tails[key] + + return {} + + def _remove_active_tail(self, aud: protos.Audience, tail_id: str) -> Tail: key = common.aud_to_str(aud) if key not in self.tails: - return + return None if tail_id not in self.tails[key]: - return + return None - self.tails[key].pop(tail_id) + t = self.tails[key].pop(tail_id) if len(self.tails[key]) == 0: self.tails.pop(key) + return t + + def _remove_paused_tail(self, aud: protos.Audience, tail_id: str) -> Tail: + key = common.aud_to_str(aud) + if key not in self.paused_tails: + return None + + if tail_id not in self.paused_tails[key]: + return None + + t = self.paused_tails[key].pop(tail_id) + + if len(self.paused_tails[key]) == 0: + self.paused_tails.pop(key) + + return t + def _get_schema(self, aud: protos.Audience) -> bytes: schema = self.schemas.get(common.aud_to_str(aud)) if schema is None: diff --git a/test_streamdal.py b/test_streamdal.py index 14db79d..4255c02 100644 --- a/test_streamdal.py +++ b/test_streamdal.py @@ -7,6 +7,7 @@ import unittest.mock as mock import streamdal from streamdal import StreamdalClient, StreamdalConfig +from streamdal.tail import Tail class TestStreamdalClient: @@ -31,6 +32,7 @@ def before_each(self): client.paused_pipelines = {} client.audiences = {} client.tails = {} + client.paused_tails = {} client.schemas = {} self.client = client @@ -290,6 +292,74 @@ def test_tail_request_stop(self, mocker): self.client._tail_request(cmd) m.assert_called_once() + def test_tail_request_pause(self): + tail_id = uuid.uuid4().__str__() + + aud = protos.Audience( + component_name="kafka", + service_name="testing", + operation_name="test-topic", + operation_type=protos.OperationType.OPERATION_TYPE_PRODUCER, + ) + + t = Tail.__new__(Tail) + t.request = protos.TailRequest( + audience=aud, + id=tail_id, + type=protos.TailRequestType.TAIL_REQUEST_TYPE_START, + ) + + self.client.tails[common.aud_to_str(aud)] = {} + self.client.tails[common.aud_to_str(aud)][tail_id] = t + + pause_cmd = protos.Command( + tail=protos.TailCommand( + request=protos.TailRequest( + audience=aud, + id=tail_id, + type=protos.TailRequestType.TAIL_REQUEST_TYPE_PAUSE, + ) + ), + ) + + self.client._tail_request(pause_cmd) + assert len(self.client.paused_tails) == 1 + assert len(self.client.tails) == 0 + + def test_tail_request_resume(self): + tail_id = uuid.uuid4().__str__() + + aud = protos.Audience( + component_name="kafka", + service_name="testing", + operation_name="test-topic", + operation_type=protos.OperationType.OPERATION_TYPE_PRODUCER, + ) + + t = Tail.__new__(Tail) + t.request = protos.TailRequest( + audience=aud, + id=tail_id, + type=protos.TailRequestType.TAIL_REQUEST_TYPE_START, + ) + + self.client.paused_tails[common.aud_to_str(aud)] = {} + self.client.paused_tails[common.aud_to_str(aud)][tail_id] = t + + resume_cmd = protos.Command( + tail=protos.TailCommand( + request=protos.TailRequest( + audience=aud, + id=tail_id, + type=protos.TailRequestType.TAIL_REQUEST_TYPE_RESUME, + ) + ), + ) + + self.client._tail_request(resume_cmd) + assert len(self.client.paused_tails) == 0 + assert len(self.client.tails) == 1 + def test_set_tail(self): tail_id = uuid.uuid4().__str__() @@ -308,7 +378,7 @@ def test_set_tail(self): assert len(self.client.tails) == 0 - self.client._set_tail(tail) + self.client._set_active_tail(tail) assert len(self.client.tails) == 1 @@ -370,7 +440,7 @@ def test_stop_tail(self): tail=protos.TailCommand(request=req), ) - self.client._set_tail(tail) + self.client._set_active_tail(tail) assert len(self.client.tails) == 1 cmd.tail.request.type = (protos.TailRequestType.TAIL_REQUEST_TYPE_STOP,) @@ -390,7 +460,7 @@ def test_remove_tail(self): aud_str = common.aud_to_str(aud) self.client.tails = {aud_str: {tail_id: mock.Mock()}} - self.client._remove_tail(aud, tail_id) + self.client._remove_active_tail(aud, tail_id) assert len(self.client.tails) == 0 def test_set_schema(self):