diff --git a/esd_services_api_client/nexus/README.md b/esd_services_api_client/nexus/README.md index c339a92..5db02ca 100644 --- a/esd_services_api_client/nexus/README.md +++ b/esd_services_api_client/nexus/README.md @@ -239,17 +239,10 @@ class MyAlgorithm(MinimalisticAlgorithm): pass @inject - def __init__( - self, - input_processor: MyInputProcessor, - metrics_provider: MetricsProvider, - logger_factory: LoggerFactory, - ): - super().__init__(input_processor, metrics_provider, logger_factory) + def __init__(self, metrics_provider: MetricsProvider, logger_factory: LoggerFactory, input_processor: MyInputProcessor): + super().__init__(metrics_provider, logger_factory, input_processor) - async def _run( - self, x_ready: PandasDataFrame, y_ready: PandasDataFrame, **kwargs - ) -> PandasDataFrame: + async def _run(self, x_ready: PandasDataFrame, y_ready: PandasDataFrame, **kwargs) -> PandasDataFrame: return pandas.concat([x_ready, y_ready]) diff --git a/esd_services_api_client/nexus/algorithms/_baseline_algorithm.py b/esd_services_api_client/nexus/algorithms/_baseline_algorithm.py index 361a53a..5218850 100644 --- a/esd_services_api_client/nexus/algorithms/_baseline_algorithm.py +++ b/esd_services_api_client/nexus/algorithms/_baseline_algorithm.py @@ -1,6 +1,7 @@ """ Base algorithm """ +import asyncio # Copyright (c) 2023. ECCO Sneaks & Data # @@ -19,6 +20,7 @@ from abc import abstractmethod +from functools import reduce from adapta.metrics import MetricsProvider from pandas import DataFrame as PandasDataFrame @@ -35,12 +37,12 @@ class BaselineAlgorithm(NexusObject): def __init__( self, - input_processor: InputProcessor, metrics_provider: MetricsProvider, logger_factory: LoggerFactory, + *input_processors: InputProcessor, ): super().__init__(metrics_provider, logger_factory) - self._input_processor = input_processor + self._input_processors = input_processors @abstractmethod async def _run(self, **kwargs) -> PandasDataFrame: @@ -52,5 +54,18 @@ async def run(self, **kwargs) -> PandasDataFrame: """ Coroutine that executes the algorithm logic. """ - async with self._input_processor as input_processor: - return await self._run(**(await input_processor.process_input(**kwargs))) + + async def _process(processor: InputProcessor) -> dict[str, PandasDataFrame]: + async with processor as instance: + return await instance.process_input(**kwargs) + + process_tasks: dict[str, asyncio.Task] = { + input_processor.__class__.__name__.lower(): asyncio.create_task( + _process(input_processor) + ) + for input_processor in self._input_processors + } + await asyncio.wait(fs=process_tasks.values()) + results = [task.result() for task in process_tasks.values()] + + return await self._run(**reduce(lambda a, b: a | b, results)) diff --git a/esd_services_api_client/nexus/algorithms/minimalistic.py b/esd_services_api_client/nexus/algorithms/minimalistic.py index c9dc087..47aadc3 100644 --- a/esd_services_api_client/nexus/algorithms/minimalistic.py +++ b/esd_services_api_client/nexus/algorithms/minimalistic.py @@ -37,8 +37,8 @@ class MinimalisticAlgorithm(BaselineAlgorithm, ABC): @inject def __init__( self, - input_processor: InputProcessor, metrics_provider: MetricsProvider, logger_factory: LoggerFactory, + *input_processors: InputProcessor, ): - super().__init__(input_processor, metrics_provider, logger_factory) + super().__init__(metrics_provider, logger_factory, *input_processors) diff --git a/esd_services_api_client/nexus/algorithms/recursive.py b/esd_services_api_client/nexus/algorithms/recursive.py index 4257e8a..601afe2 100644 --- a/esd_services_api_client/nexus/algorithms/recursive.py +++ b/esd_services_api_client/nexus/algorithms/recursive.py @@ -39,20 +39,18 @@ class RecursiveAlgorithm(BaselineAlgorithm): @inject def __init__( self, - input_processor: InputProcessor, metrics_provider: MetricsProvider, logger_factory: LoggerFactory, + *input_processors: InputProcessor, ): - super().__init__(input_processor, metrics_provider, logger_factory) + super().__init__(metrics_provider, logger_factory, *input_processors) @abstractmethod async def _is_finished(self, **kwargs) -> bool: """ """ async def run(self, **kwargs) -> PandasDataFrame: - result = await self._run( - **(await self._input_processor.process_input(**kwargs)) - ) + result = await self._run(**kwargs) if self._is_finished(**result): return result return await self.run(**result) diff --git a/esd_services_api_client/nexus/input/_functions.py b/esd_services_api_client/nexus/input/_functions.py index 5454c9e..08f120d 100644 --- a/esd_services_api_client/nexus/input/_functions.py +++ b/esd_services_api_client/nexus/input/_functions.py @@ -62,7 +62,7 @@ async def _read(input_reader: InputReader): return await instance.read() read_tasks: dict[str, asyncio.Task] = { - reader.socket.alias: asyncio.create_task(_read(reader)) for reader in readers + reader.alias: asyncio.create_task(_read(reader)) for reader in readers } await asyncio.wait(fs=read_tasks.values())