Skip to content

Commit

Permalink
Enable separation between forked and non-forked runs (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
george-zubrienko authored Nov 25, 2024
1 parent 0114a02 commit adfbb1e
Showing 1 changed file with 63 additions and 14 deletions.
77 changes: 63 additions & 14 deletions esd_services_api_client/nexus/algorithms/forked_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,61 @@ def __init__(
self,
metrics_provider: MetricsProvider,
logger_factory: LoggerFactory,
forks: list[RemoteAlgorithm],
*input_processors: InputProcessor,
cache: InputCache,
):
super().__init__(metrics_provider, logger_factory)
self._input_processors = input_processors
self._forks = forks
self._cache = cache
self._inputs: dict = {}

@property
def inputs(self) -> dict:
"""
Inputs generated for this algorithm run.
"""
return self._inputs

@abstractmethod
async def _get_forks(self, **kwargs) -> list[RemoteAlgorithm]:
"""
Resolve forks to be used in this run, if any
"""

@abstractmethod
async def _main_run(self, **kwargs) -> AlgorithmResult:
"""
Logic to use for the main run - if this node is the root node - **this result** will be returned to the client.
"""

@abstractmethod
async def _fork_run(self, **kwargs) -> AlgorithmResult:
"""
Logic to use for the fork - if this node is **NOT** the root node - **result will be ignored by the client**.
"""

@abstractmethod
async def _run(self, **kwargs) -> AlgorithmResult:
async def _is_forked(self, **kwargs) -> bool:
"""
Core logic for this algorithm. Implementing this method is mandatory.
Determine if this is the main run or a fork run.
"""

async def _default_inputs(self, **kwargs) -> dict:
"""
Generate inputs by invoking all processors.
"""
return await self._cache.resolve(*self._input_processors, **kwargs)

@abstractmethod
async def _main_inputs(self, **kwargs) -> dict:
"""
Sets inputs for the main run - if this node is the root node
"""

@abstractmethod
async def _fork_inputs(self, **kwargs) -> dict:
"""
Sets inputs for the forked run - if this node is **NOT** the root node
"""

@property
Expand All @@ -96,29 +138,36 @@ async def run(self, **kwargs) -> AlgorithmResult:
},
)
async def _measured_run(**run_args) -> AlgorithmResult:
return await self._run(**run_args)
if await self._is_forked(**run_args):
return await self._fork_run(**run_args)

return await self._main_run(**run_args)

if len(self._forks) > 0:
# evaluate if additional forks will be spawned
forks = await self._get_forks(**kwargs)

if len(forks) > 0:
self._logger.info(
"This algorithm has forks attached: {forks}. They will be executed after the main run",
forks=",".join([fork.alias() for fork in self._forks]),
"Forking node with: {forks}, after the node run",
forks=",".join([fork.alias() for fork in forks]),
)
else:
self._logger.info(
"This algorithm supports forks but none were injected. Proceeding with a main run only"
)
self._logger.info("Leaf algorithm node: proceeding with this node run only")

results = await self._cache.resolve(*self._input_processors, **kwargs)
if self._is_forked(**kwargs):
self._inputs = await self._fork_inputs(**kwargs)
else:
self._inputs = await self._main_inputs(**kwargs)

run_result = await partial(
_measured_run,
**results,
**self._inputs,
metric_tags=self._metric_tags,
metrics_provider=self._metrics_provider,
logger=self._logger,
)()

# now await callback scheduling
await asyncio.wait([fork.run(**kwargs) for fork in self._forks])
await asyncio.wait([fork.run(**kwargs) for fork in forks])

return run_result

0 comments on commit adfbb1e

Please sign in to comment.