Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor out OneOfSchema from OneOfStringSchema and OneOfIntSchema #131

Merged
merged 31 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
335 changes: 128 additions & 207 deletions src/arcaflow_plugin_sdk/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2578,7 +2578,114 @@ def _to_openapi_fragment(


@dataclass
class OneOfStringSchema(_JSONSchemaGenerator, _OpenAPIGenerator):
class OneOfSchema(_JSONSchemaGenerator, _OpenAPIGenerator):
types: typing.Union[
Dict[str, typing.Annotated[_OBJECT_LIKE, discriminator("type_id")]],
Dict[int, typing.Annotated[_OBJECT_LIKE, discriminator("type_id")]],
]
discriminator_inlined: typing.Annotated[
bool,
_name("Discriminator field inlined"),
_description(
"Whether or not the discriminator is inlined in the underlying"
" objects' schema"
),
]
oneof_type: typing.Annotated[str, _name("One Of Type Schema Name")] = None
discriminator_type: typing.Annotated[str, _name("Discriminator Type")] = (
None
)
discriminator_field_name: typing.Annotated[
str,
_name("Discriminator field name"),
_description(
"Name of the field used to discriminate between possible values."
),
] = "_type"

def _insert_discriminator(
self,
discriminated_object: typing.Dict[str, typing.Any],
discriminator_val: str,
) -> typing.Dict[str, typing.Any]:
"""Add a discriminator field as a property of a member type.

This function adds a member type's discriminator field as a property
with a constant value equal to its discriminated value. The
discriminator field is moved to the zeroth index of the list of
required fields in a data packet.

:param discriminated_object: A Python dict which represents the
relevant fragment of the scope's JSON definition.
:param discriminator_val: The value that represents the given object in
its discriminated union.
"""
if self.discriminator_inlined:
# update the object's schema to show the only valid value
# for this object's discriminator
discriminated_object["properties"][
self.discriminator_field_name
] = {
"type": self.discriminator_type,
"const": discriminator_val,
}
# discriminator field is already present in the required
# list when the discriminator is inlined
discriminated_object["required"].remove(
self.discriminator_field_name
)
# discriminator must have the first position
discriminated_object["required"].insert(
0, self.discriminator_field_name
)

def _to_jsonschema_fragment(
self, scope: typing.ForwardRef("ScopeSchema"), defs: _JSONSchemaDefs
) -> any:
one_of = []
for k, v in self.types.items():
# noinspection PyProtectedMember
_ = scope.objects[v.id]._to_jsonschema_fragment(scope, defs)
self._insert_discriminator(defs.defs[v.id], str(k))
if v.display is not None:
if v.display.name is not None:
defs.defs[v.id]["title"] = v.display.name
if v.display.description is not None:
defs.defs[v.id]["description"] = v.display.description
name = v.id + self.oneof_type + str(k)
defs.defs[name] = defs.defs[v.id]
one_of.append({"$ref": "#/$defs/" + name})
return {"oneOf": one_of}

def _to_openapi_fragment(
self, scope: typing.ForwardRef("ScopeSchema"), defs: _OpenAPIComponents
) -> any:
one_of = []
discriminator_mapping = {}
for k, v in self.types.items():
# noinspection PyProtectedMember
_ = scope.objects[v.id]._to_openapi_fragment(scope, defs)
name = v.id + self.oneof_type + str(k)
discriminator_mapping[k] = "#/components/schemas/" + name
self._insert_discriminator(defs.defs[v.id], str(k))
if v.display is not None:
if v.display.name is not None:
defs.defs[v.id]["title"] = v.display.name
if v.display.description is not None:
defs.defs[v.id]["description"] = v.display.description
defs.components[name] = defs.defs[v.id]
one_of.append({"$ref": "#/components/schemas/" + name})
return {
"oneOf": one_of,
"discriminator": {
"propertyName": self.discriminator_field_name,
"mapping": discriminator_mapping,
},
}


@dataclass
class OneOfStringSchema(OneOfSchema):
"""This class holds the definition of variable types with a string
discriminator. This type acts as a split for a case where multiple possible
object types can be present in a field. This type requires that there be a
Expand Down Expand Up @@ -2701,112 +2808,14 @@ class OneOfStringSchema(_JSONSchemaGenerator, _OpenAPIGenerator):
""" # noqa: E501

types: Dict[str, typing.Annotated[_OBJECT_LIKE, discriminator("type_id")]]
discriminator_inlined: typing.Annotated[
bool,
_name("Discriminator field inlined"),
_description(
"True if the discriminator is a field in each schema of the"
" underlying objects"
),
]
discriminator_field_name: typing.Annotated[
str,
_name("Discriminator field name"),
_description(
"Name of the field whose value is used to discriminate between"
" possible subobject types. If this field is present in any of the"
" subobjects it must have a type of string."
),
] = "_type"

def _insert_discriminator(
self,
discriminated_object: typing.Dict[str, typing.Any],
discriminator_val: str,
) -> typing.Dict[str, typing.Any]:
"""Add a discriminator field as a property of a member type.

This function adds a member type's discriminator field as a property
with a constant value equal to its discriminated value. The
discriminator field is moved to the zeroth index of the list of
required fields in a data packet.

:param discriminated_object: A Python dict which represents the
relevant fragment of the scope's JSON definition.
:param discriminator_val: The value that represents the given object in
its discriminated union.
"""
if self.discriminator_inlined:
# update the object's schema to show the only valid value
# for this object's discriminator
discriminated_object["properties"][
self.discriminator_field_name
] = {
"type": "string",
"const": discriminator_val,
}
# discriminator field is already present in the required
# list when the discriminator is inlined
discriminated_object["required"].remove(
self.discriminator_field_name
)
# discriminator must have the first position
discriminated_object["required"].insert(
0, self.discriminator_field_name
)

def _to_jsonschema_fragment(
self, scope: typing.ForwardRef("ScopeSchema"), defs: _JSONSchemaDefs
) -> any:
one_of = []
for k, v in self.types.items():
# noinspection PyProtectedMember
scope.objects[v.id]._to_jsonschema_fragment(scope, defs)

self._insert_discriminator(defs.defs[v.id], k)

if v.display is not None:
if v.display.name is not None:
defs.defs[v.id]["title"] = v.display.name
if v.display.description is not None:
defs.defs[v.id]["description"] = v.display.description

name = v.id + "_discriminated_string_" + _id_typeize(k)
defs.defs[name] = defs.defs[v.id]
one_of.append({"$ref": "#/$defs/" + name})
return {"oneOf": one_of}

def _to_openapi_fragment(
self, scope: typing.ForwardRef("ScopeSchema"), defs: _OpenAPIComponents
) -> any:
one_of = []
discriminator_mapping = {}
for k, v in self.types.items():
# noinspection PyProtectedMember
scope.objects[v.id]._to_openapi_fragment(scope, defs)

name = v.id + "_discriminated_string_" + _id_typeize(k)
discriminator_mapping[k] = "#/components/schemas/" + name
self._insert_discriminator(defs.defs[v.id], k)
if v.display is not None:
if v.display.name is not None:
defs.defs[v.id]["title"] = v.display.name
if v.display.description is not None:
defs.defs[v.id]["description"] = v.display.description

defs.components[name] = defs.defs[v.id]
one_of.append({"$ref": "#/components/schemas/" + name})
return {
"oneOf": one_of,
"discriminator": {
"propertyName": self.discriminator_field_name,
"mapping": discriminator_mapping,
},
}
def __post_init__(self):
self.oneof_type = "_discriminated_string_"
self.discriminator_type = "string"


@dataclass
class OneOfIntSchema(_JSONSchemaGenerator, _OpenAPIGenerator):
class OneOfIntSchema(OneOfSchema):
"""This class holds the definition of variable types with an integer
discriminator. This type acts as a split for a case where multiple possible
object types can be present in a field. This type requires that there be a
Expand Down Expand Up @@ -2912,106 +2921,10 @@ class OneOfIntSchema(_JSONSchemaGenerator, _OpenAPIGenerator):
""" # noqa: E501

types: Dict[int, typing.Annotated[_OBJECT_LIKE, discriminator("type_id")]]
discriminator_inlined: typing.Annotated[
bool,
_name("Discriminator field inlined"),
_description(
"Whether or not the discriminator is inlined in the underlying"
" objects' schema"
),
]
discriminator_field_name: typing.Annotated[
str,
_name("Discriminator field name"),
_description(
"Name of the field used to discriminate between possible values."
" If this field ispresent on any of the component objects it must"
" also be an int."
),
] = "_type"

def _insert_discriminator(
self,
discriminated_object: typing.Dict[str, typing.Any],
discriminator_val: str,
) -> typing.Dict[str, typing.Any]:
"""Add a discriminator field as a property of a member type.

This function adds a member type's discriminator field as a property
with a constant value equal to its discriminated value. The
discriminator field is moved to the zeroth index of the list of
required fields in a data packet.

:param discriminated_object: A Python dict which represents the
relevant fragment of the scope's JSON definition.
:param discriminator_val: The value that represents the given object in
its discriminated union.
"""
if self.discriminator_inlined:
# update the object's schema to show the only valid value
# for this object's discriminator
discriminated_object["properties"][
self.discriminator_field_name
] = {
"type": "string",
"const": discriminator_val,
}
# discriminator field is already present in the required
# list when the discriminator is inlined
discriminated_object["required"].remove(
self.discriminator_field_name
)
# discriminator must have the first position
discriminated_object["required"].insert(
0, self.discriminator_field_name
)

def _to_jsonschema_fragment(
self, scope: typing.ForwardRef("ScopeSchema"), defs: _JSONSchemaDefs
) -> any:
one_of = []
for k, v in self.types.items():
# noinspection PyProtectedMember
scope.objects[v.id]._to_jsonschema_fragment(scope, defs)

self._insert_discriminator(defs.defs[v.id], str(k))
if v.display is not None:
if v.display.name is not None:
defs.defs[v.id]["title"] = v.display.name
if v.display.description is not None:
defs.defs[v.id]["description"] = v.display.description
name = v.id + "_discriminated_int_" + str(k)
defs.defs[name] = defs.defs[v.id]
one_of.append({"$ref": "#/$defs/" + name})
return {"oneOf": one_of}

def _to_openapi_fragment(
self, scope: typing.ForwardRef("ScopeSchema"), defs: _OpenAPIComponents
) -> any:
one_of = []
discriminator_mapping = {}
for k, v in self.types.items():
# noinspection PyProtectedMember
scope.objects[v.id]._to_openapi_fragment(scope, defs)
name = v.id + "_discriminated_int_" + str(k)
discriminator_mapping[k] = "#/components/schemas/" + name

self._insert_discriminator(defs.defs[v.id], str(k))
if v.display is not None:
if v.display.name is not None:
defs.defs[v.id]["title"] = v.display.name
if v.display.description is not None:
defs.defs[v.id]["description"] = v.display.description

defs.components[name] = defs.defs[v.id]
one_of.append({"$ref": "#/components/schemas/" + name})
return {
"oneOf": one_of,
"discriminator": {
"propertyName": self.discriminator_field_name,
"mapping": discriminator_mapping,
},
}
def __post_init__(self):
self.oneof_type = "_discriminated_int_"
self.discriminator_type = "integer"


@dataclass
Expand Down Expand Up @@ -5621,7 +5534,10 @@ def __init__(
):
# noinspection PyArgumentList
OneOfStringSchema.__init__(
self, types, discriminator_inlined, discriminator_field_name
self,
types=types,
discriminator_inlined=discriminator_inlined,
discriminator_field_name=discriminator_field_name,
)
_OneOfType.__init__(
self,
Expand Down Expand Up @@ -5659,7 +5575,10 @@ def __init__(
):
# noinspection PyArgumentList
OneOfIntSchema.__init__(
self, types, discriminator_inlined, discriminator_field_name
self,
types=types,
discriminator_inlined=discriminator_inlined,
discriminator_field_name=discriminator_field_name,
)
_OneOfType.__init__(
self,
Expand Down Expand Up @@ -7019,12 +6938,14 @@ def _resolve_union(
types[discriminator_value] = f.type
if discriminator_type is str:
return OneOfStringType(
types,
scope,
types=types,
scope=scope,
discriminator_inlined=False,
)
else:
return OneOfIntType(types, scope, discriminator_inlined=False)
return OneOfIntType(
types=types, scope=scope, discriminator_inlined=False
)

@classmethod
def _resolve_pattern(cls, t, type_hints: type, path, scope: ScopeType):
Expand Down
Loading