-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnexus_object.py
110 lines (90 loc) · 3.04 KB
/
nexus_object.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Base classes for all objects used by Nexus.
"""
# Copyright (c) 2023-2024. ECCO Sneaks & Data
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC, abstractmethod
import re
from typing import Generic, TypeVar, Union, Any, Dict
import pandas
import polars
from adapta.metrics import MetricsProvider
from dataclasses_json.stringcase import snakecase
from esd_services_api_client.nexus.abstractions.logger_factory import LoggerFactory
class AlgorithmResult(ABC):
"""
Interface for algorithm run result. You can store arbitrary data here, but `dataframe` method must be implemented.
"""
@abstractmethod
def result(self) -> Union[pandas.DataFrame, polars.DataFrame, Dict]:
"""
Returns the main result. This will be written to the linked output storage.
"""
@abstractmethod
def to_kwargs(self) -> dict[str, Any]:
"""
Convert result to kwargs for the next iteration (for recursive algorithms)
"""
TPayload = TypeVar("TPayload") # pylint: disable=C0103
TResult = TypeVar( # pylint: disable=C0103
"TResult", pandas.DataFrame, polars.DataFrame
)
class NexusCoreObject(ABC):
"""
Base class for all Nexus objects.
"""
def __init__(
self,
metrics_provider: MetricsProvider,
logger_factory: LoggerFactory,
):
self._metrics_provider = metrics_provider
self._logger = logger_factory.create_logger(logger_type=self.__class__)
async def __aenter__(self):
self._logger.start()
await self._context_open()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self._logger.stop()
await self._context_close()
@abstractmethod
async def _context_open(self):
"""
Optional actions to perform on context activation.
"""
@abstractmethod
async def _context_close(self):
"""
Optional actions to perform on context closure.
"""
class NexusObject(Generic[TPayload, TResult], NexusCoreObject, ABC):
"""
Base class for all Nexus objects that perform operations on the algorithm payload.
"""
@classmethod
def alias(cls) -> str:
"""
Alias to identify this class instances when passed through kwargs.
"""
return snakecase(
re.sub(
r"(?<!^)(?=[A-Z])",
"_",
cls.__name__.lower()
.replace("reader", "")
.replace("processor", "")
.replace("algorithm", ""),
)
)