Skip to content

Commit

Permalink
Allow using multiple input processors (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
george-zubrienko authored Jan 30, 2024
1 parent d66ff59 commit 898a9d8
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 22 deletions.
13 changes: 3 additions & 10 deletions esd_services_api_client/nexus/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,17 +239,10 @@ class MyAlgorithm(MinimalisticAlgorithm):
pass

@inject
def __init__(
self,
input_processor: MyInputProcessor,
metrics_provider: MetricsProvider,
logger_factory: LoggerFactory,
):
super().__init__(input_processor, metrics_provider, logger_factory)
def __init__(self, metrics_provider: MetricsProvider, logger_factory: LoggerFactory, input_processor: MyInputProcessor):
super().__init__(metrics_provider, logger_factory, input_processor)

async def _run(
self, x_ready: PandasDataFrame, y_ready: PandasDataFrame, **kwargs
) -> PandasDataFrame:
async def _run(self, x_ready: PandasDataFrame, y_ready: PandasDataFrame, **kwargs) -> PandasDataFrame:
return pandas.concat([x_ready, y_ready])


Expand Down
23 changes: 19 additions & 4 deletions esd_services_api_client/nexus/algorithms/_baseline_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Base algorithm
"""
import asyncio

# Copyright (c) 2023. ECCO Sneaks & Data
#
Expand All @@ -19,6 +20,7 @@


from abc import abstractmethod
from functools import reduce

from adapta.metrics import MetricsProvider
from pandas import DataFrame as PandasDataFrame
Expand All @@ -35,12 +37,12 @@ class BaselineAlgorithm(NexusObject):

def __init__(
self,
input_processor: InputProcessor,
metrics_provider: MetricsProvider,
logger_factory: LoggerFactory,
*input_processors: InputProcessor,
):
super().__init__(metrics_provider, logger_factory)
self._input_processor = input_processor
self._input_processors = input_processors

@abstractmethod
async def _run(self, **kwargs) -> PandasDataFrame:
Expand All @@ -52,5 +54,18 @@ async def run(self, **kwargs) -> PandasDataFrame:
"""
Coroutine that executes the algorithm logic.
"""
async with self._input_processor as input_processor:
return await self._run(**(await input_processor.process_input(**kwargs)))

async def _process(processor: InputProcessor) -> dict[str, PandasDataFrame]:
async with processor as instance:
return await instance.process_input(**kwargs)

process_tasks: dict[str, asyncio.Task] = {
input_processor.__class__.__name__.lower(): asyncio.create_task(
_process(input_processor)
)
for input_processor in self._input_processors
}
await asyncio.wait(fs=process_tasks.values())
results = [task.result() for task in process_tasks.values()]

return await self._run(**reduce(lambda a, b: a | b, results))
4 changes: 2 additions & 2 deletions esd_services_api_client/nexus/algorithms/minimalistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class MinimalisticAlgorithm(BaselineAlgorithm, ABC):
@inject
def __init__(
self,
input_processor: InputProcessor,
metrics_provider: MetricsProvider,
logger_factory: LoggerFactory,
*input_processors: InputProcessor,
):
super().__init__(input_processor, metrics_provider, logger_factory)
super().__init__(metrics_provider, logger_factory, *input_processors)
8 changes: 3 additions & 5 deletions esd_services_api_client/nexus/algorithms/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,18 @@ class RecursiveAlgorithm(BaselineAlgorithm):
@inject
def __init__(
self,
input_processor: InputProcessor,
metrics_provider: MetricsProvider,
logger_factory: LoggerFactory,
*input_processors: InputProcessor,
):
super().__init__(input_processor, metrics_provider, logger_factory)
super().__init__(metrics_provider, logger_factory, *input_processors)

@abstractmethod
async def _is_finished(self, **kwargs) -> bool:
""" """

async def run(self, **kwargs) -> PandasDataFrame:
result = await self._run(
**(await self._input_processor.process_input(**kwargs))
)
result = await self._run(**kwargs)
if self._is_finished(**result):
return result
return await self.run(**result)
2 changes: 1 addition & 1 deletion esd_services_api_client/nexus/input/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def _read(input_reader: InputReader):
return await instance.read()

read_tasks: dict[str, asyncio.Task] = {
reader.socket.alias: asyncio.create_task(_read(reader)) for reader in readers
reader.alias: asyncio.create_task(_read(reader)) for reader in readers
}
await asyncio.wait(fs=read_tasks.values())

Expand Down

0 comments on commit 898a9d8

Please sign in to comment.