-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfields.py
186 lines (159 loc) · 5.87 KB
/
fields.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from collections.abc import Callable, Generator, Mapping
from types import NoneType, UnionType
from typing import Annotated, Any, Union, get_args, get_origin
from mex.common.models import (
ADDITIVE_MODEL_CLASSES_BY_NAME,
EXTRACTED_MODEL_CLASSES_BY_NAME,
MERGED_MODEL_CLASSES_BY_NAME,
PREVENTIVE_MODEL_CLASSES_BY_NAME,
SUBTRACTIVE_MODEL_CLASSES_BY_NAME,
BaseModel,
GenericFieldInfo,
)
from mex.common.types import MERGED_IDENTIFIER_CLASSES, Link, LiteralStringType, Text
def _get_inner_types(annotation: Any) -> Generator[type, None, None]:
"""Yield all inner types from unions, lists and type annotations (except NoneType).
Args:
annotation: A valid python type annotation
Returns:
A generator for all (non-NoneType) types found in the annotation
"""
if get_origin(annotation) == Annotated:
yield from _get_inner_types(get_args(annotation)[0])
elif get_origin(annotation) in (Union, UnionType, list):
for arg in get_args(annotation):
yield from _get_inner_types(arg)
elif annotation not in (None, NoneType):
yield annotation
def _contains_only_types(field: GenericFieldInfo, *types: type) -> bool:
"""Return whether a `field` is annotated as one of the given `types`.
Unions, lists and type annotations are checked for their inner types and only the
non-`NoneType` types are considered for the type-check.
Args:
field: A `GenericFieldInfo` instance
types: Types to look for in the field's annotation
Returns:
Whether the field contains any of the given types
"""
if inner_types := list(_get_inner_types(field.annotation)):
return all(inner_type in types for inner_type in inner_types)
return False
def _group_fields_by_class_name(
model_classes_by_name: Mapping[str, type[BaseModel]],
predicate: Callable[[GenericFieldInfo], bool],
) -> dict[str, list[str]]:
"""Group the field names by model class and filter them by the given predicate.
Args:
model_classes_by_name: Map from class names to model classes
predicate: Function to filter the fields of the classes by
Returns:
Dictionary mapping class names to a list of field names filtered by `predicate`
"""
return {
name: sorted(
{
field_name
for field_name, field_info in cls.get_all_fields().items()
if predicate(field_info)
}
)
for name, cls in model_classes_by_name.items()
}
# all models classes
ALL_MODEL_CLASSES_BY_NAME = {
**ADDITIVE_MODEL_CLASSES_BY_NAME,
**EXTRACTED_MODEL_CLASSES_BY_NAME,
**MERGED_MODEL_CLASSES_BY_NAME,
**PREVENTIVE_MODEL_CLASSES_BY_NAME,
**SUBTRACTIVE_MODEL_CLASSES_BY_NAME,
}
# fields that are immutable and can only be set once
FROZEN_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name(
ALL_MODEL_CLASSES_BY_NAME,
lambda field_info: field_info.frozen is True,
)
# static fields that are set once on class-level to a literal type
LITERAL_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name(
ALL_MODEL_CLASSES_BY_NAME,
lambda field_info: isinstance(field_info.annotation, LiteralStringType),
)
# fields typed as merged identifiers containing references to merged items
REFERENCE_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name(
ALL_MODEL_CLASSES_BY_NAME,
lambda field_info: _contains_only_types(field_info, *MERGED_IDENTIFIER_CLASSES),
)
# nested fields that contain `Text` objects
TEXT_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name(
ALL_MODEL_CLASSES_BY_NAME,
lambda field_info: _contains_only_types(field_info, Text),
)
# nested fields that contain `Link` objects
LINK_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name(
ALL_MODEL_CLASSES_BY_NAME,
lambda field_info: _contains_only_types(field_info, Link),
)
# fields annotated as `str` type
STRING_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name(
ALL_MODEL_CLASSES_BY_NAME,
lambda field_info: _contains_only_types(field_info, str),
)
# fields that should be indexed as searchable fields
SEARCHABLE_FIELDS = sorted(
{
field_name
for field_names in STRING_FIELDS_BY_CLASS_NAME.values()
for field_name in field_names
}
)
# classes that have fields that should be searchable
SEARCHABLE_CLASSES = sorted(
{name for name, field_names in STRING_FIELDS_BY_CLASS_NAME.items() if field_names}
)
# fields with changeable values that are not nested objects or merged item references
MUTABLE_FIELDS_BY_CLASS_NAME = {
name: sorted(
{
field_name
for field_name in cls.get_all_fields()
if field_name
not in (
*FROZEN_FIELDS_BY_CLASS_NAME[name],
*REFERENCE_FIELDS_BY_CLASS_NAME[name],
*TEXT_FIELDS_BY_CLASS_NAME[name],
*LINK_FIELDS_BY_CLASS_NAME[name],
)
}
)
for name, cls in ALL_MODEL_CLASSES_BY_NAME.items()
}
# fields with mergeable values that are neither literal nor frozen
MERGEABLE_FIELDS_BY_CLASS_NAME = {
name: sorted(
{
field_name
for field_name in cls.model_fields
if field_name
not in (
*FROZEN_FIELDS_BY_CLASS_NAME[name],
*LITERAL_FIELDS_BY_CLASS_NAME[name],
)
}
)
for name, cls in MERGED_MODEL_CLASSES_BY_NAME.items()
}
# fields with values that should be set once but are neither literal nor references
FINAL_FIELDS_BY_CLASS_NAME = {
name: sorted(
{
field_name
for field_name in cls.get_all_fields()
if field_name in FROZEN_FIELDS_BY_CLASS_NAME[name]
and field_name
not in (
*LITERAL_FIELDS_BY_CLASS_NAME[name],
*REFERENCE_FIELDS_BY_CLASS_NAME[name],
)
}
)
for name, cls in ALL_MODEL_CLASSES_BY_NAME.items()
}