Skip to content

Commit

Permalink
Refactor concurrent read from multiple readers so it can used outside…
Browse files Browse the repository at this point in the history
… InputProcessor (#90)
  • Loading branch information
george-zubrienko authored Jan 25, 2024
1 parent e0cda8f commit d66ff59
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 54 deletions.
30 changes: 18 additions & 12 deletions esd_services_api_client/nexus/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ Example usage:
```python
import asyncio
import json
import os
import socketserver
import threading
import os
from dataclasses import dataclass
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
from typing import Dict, Optional
Expand All @@ -32,10 +32,12 @@ from esd_services_api_client.nexus.abstractions.logger_factory import LoggerFact
from esd_services_api_client.nexus.abstractions.socket_provider import (
ExternalSocketProvider,
)
from esd_services_api_client.nexus.configurations.algorithm_configuration import (
NexusConfiguration,
)
from esd_services_api_client.nexus.core.app_core import Nexus
from esd_services_api_client.nexus.algorithms import MinimalisticAlgorithm
from esd_services_api_client.nexus.input import InputReader, InputProcessor
from esd_services_api_client.nexus.configurations.algorithm_configuration import NexusConfiguration
from pandas import DataFrame as PandasDataFrame

from esd_services_api_client.nexus.input.payload_reader import AlgorithmPayload
Expand Down Expand Up @@ -142,11 +144,11 @@ class XReader(InputReader[MyAlgorithmPayload]):
*readers: "InputReader"
):
super().__init__(
socket_provider.socket("x"),
store,
metrics_provider,
logger_factory,
payload,
socket=socket_provider.socket("x"),
store=store,
metrics_provider=metrics_provider,
logger_factory=logger_factory,
payload=payload,
*readers
)

Expand Down Expand Up @@ -177,11 +179,11 @@ class YReader(InputReader[MyAlgorithmPayload2]):
*readers: "InputReader"
):
super().__init__(
socket_provider.socket("y"),
store,
metrics_provider,
logger_factory,
payload,
socket=socket_provider.socket("y"),
store=store,
metrics_provider=metrics_provider,
logger_factory=logger_factory,
payload=payload,
*readers
)

Expand All @@ -208,6 +210,7 @@ class MyInputProcessor(InputProcessor):
y: YReader,
metrics_provider: MetricsProvider,
logger_factory: LoggerFactory,
my_conf: MyAlgorithmConfiguration,
):
super().__init__(
x,
Expand All @@ -217,7 +220,10 @@ class MyInputProcessor(InputProcessor):
payload=None,
)

self.conf = my_conf

async def process_input(self, **_) -> Dict[str, PandasDataFrame]:
self._logger.info("Config: {config}", config=self.conf.to_json())
inputs = await self._read_input()
return {
"x_ready": inputs["x"].assign(c=[-1, 1]),
Expand Down
1 change: 1 addition & 0 deletions esd_services_api_client/nexus/input/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@

from esd_services_api_client.nexus.input.input_processor import *
from esd_services_api_client.nexus.input.input_reader import *
from esd_services_api_client.nexus.input._functions import *
69 changes: 69 additions & 0 deletions esd_services_api_client/nexus/input/_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
Utility functions to handle input processing.
"""

# Copyright (c) 2023. ECCO Sneaks & Data
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import asyncio
from typing import Dict, Union, Type
import azure.core.exceptions
import deltalake
from pandas import DataFrame as PandasDataFrame

from esd_services_api_client.nexus.exceptions.input_reader_error import (
FatalInputReaderError,
TransientInputReaderError,
)
from esd_services_api_client.nexus.input.input_reader import InputReader


def resolve_reader_exc_type(
ex: BaseException,
) -> Union[Type[FatalInputReaderError], Type[TransientInputReaderError]]:
"""
Resolve base exception into a specific Nexus exception.
"""
match type(ex):
case azure.core.exceptions.HttpResponseError, deltalake.PyDeltaTableError:
return TransientInputReaderError
case azure.core.exceptions.AzureError, azure.core.exceptions.ClientAuthenticationError:
return FatalInputReaderError
case _:
return FatalInputReaderError


async def resolve_readers(*readers: InputReader) -> Dict[str, PandasDataFrame]:
"""
Concurrently resolve `data` property of all readers by invoking their `read` method.
"""

def get_result(alias: str, completed_task: asyncio.Task) -> PandasDataFrame:
reader_exc = completed_task.exception()
if reader_exc:
raise resolve_reader_exc_type(reader_exc)(alias, reader_exc)

return completed_task.result()

async def _read(input_reader: InputReader):
async with input_reader as instance:
return await instance.read()

read_tasks: dict[str, asyncio.Task] = {
reader.socket.alias: asyncio.create_task(_read(reader)) for reader in readers
}
await asyncio.wait(fs=read_tasks.values())

return {alias: get_result(alias, task) for alias, task in read_tasks.items()}
41 changes: 3 additions & 38 deletions esd_services_api_client/nexus/input/input_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,19 @@
# limitations under the License.
#

import asyncio
from abc import abstractmethod
from typing import Dict, Union, Type
from typing import Dict

import deltalake
from adapta.metrics import MetricsProvider

import azure.core.exceptions

from pandas import DataFrame as PandasDataFrame

from esd_services_api_client.nexus.abstractions.nexus_object import (
NexusObject,
TPayload,
)
from esd_services_api_client.nexus.abstractions.logger_factory import LoggerFactory
from esd_services_api_client.nexus.exceptions.input_reader_error import (
FatalInputReaderError,
TransientInputReaderError,
)
from esd_services_api_client.nexus.input._functions import resolve_readers
from esd_services_api_client.nexus.input.input_reader import InputReader


Expand All @@ -56,36 +49,8 @@ def __init__(
self._readers = readers
self._payload = payload

def _get_exc_type(
self, ex: BaseException
) -> Union[Type[FatalInputReaderError], Type[TransientInputReaderError]]:
match type(ex):
case azure.core.exceptions.HttpResponseError, deltalake.PyDeltaTableError:
return TransientInputReaderError
case azure.core.exceptions.AzureError, azure.core.exceptions.ClientAuthenticationError:
return FatalInputReaderError
case _:
return FatalInputReaderError

async def _read_input(self) -> Dict[str, PandasDataFrame]:
def get_result(alias: str, completed_task: asyncio.Task) -> PandasDataFrame:
reader_exc = completed_task.exception()
if reader_exc:
raise self._get_exc_type(reader_exc)(alias, reader_exc)

return completed_task.result()

async def _read(input_reader: InputReader):
async with input_reader as instance:
return await instance.read()

read_tasks: dict[str, asyncio.Task] = {
reader.socket.alias: asyncio.create_task(_read(reader))
for reader in self._readers
}
await asyncio.wait(fs=read_tasks.values())

return {alias: get_result(alias, task) for alias, task in read_tasks.items()}
return await resolve_readers(*self._readers)

@abstractmethod
async def process_input(self, **kwargs) -> Dict[str, PandasDataFrame]:
Expand Down
21 changes: 17 additions & 4 deletions esd_services_api_client/nexus/input/input_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Input reader.
"""
import functools

# Copyright (c) 2023. ECCO Sneaks & Data
#
Expand Down Expand Up @@ -43,12 +44,12 @@ class InputReader(NexusObject[TPayload]):

def __init__(
self,
socket: DataSocket,
store: QueryEnabledStore,
metrics_provider: MetricsProvider,
logger_factory: LoggerFactory,
payload: TPayload,
*readers: "InputReader"
*readers: "InputReader",
socket: Optional[DataSocket] = None,
):
super().__init__(metrics_provider, logger_factory)
self.socket = socket
Expand All @@ -58,6 +59,16 @@ def __init__(
self._payload = payload

@property
def alias(self) -> str:
"""
Alias to identify this reader's output
"""
if self.socket:
return self.socket.alias

return self._metric_name

@functools.cached_property
def data(self) -> Optional[PandasDataFrame]:
"""
Data read by this reader.
Expand Down Expand Up @@ -92,8 +103,10 @@ async def read(self) -> PandasDataFrame:
on_finish_message_template="Finished reading {entity} from path {data_path} in {elapsed:.2f}s seconds",
template_args={
"entity": self._metric_name.upper(),
"data_path": self.socket.data_path,
},
}
| {"data_path": self.socket.data_path}
if self.socket
else {},
)
async def _read(**_) -> PandasDataFrame:
if not self._data:
Expand Down

0 comments on commit d66ff59

Please sign in to comment.