Skip to content

Commit

Permalink
Implement ref properties
Browse files Browse the repository at this point in the history
  • Loading branch information
jcoelho93 committed Jun 10, 2024
1 parent 9b35049 commit 9011ede
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 63 deletions.
18 changes: 9 additions & 9 deletions examples/openapi.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@
}
},
"definitions": {
"schemesList": {
"type": "array",
"description": "The transfer protocol of the API.",
"items": {
"type": "string",
"enum": ["http", "https", "ws", "wss"]
},
"uniqueItems": true
},
"info": {
"type": "object",
"description": "General information about the API.",
Expand Down Expand Up @@ -1381,15 +1390,6 @@
},
"uniqueItems": true
},
"schemesList": {
"type": "array",
"description": "The transfer protocol of the API.",
"items": {
"type": "string",
"enum": ["http", "https", "ws", "wss"]
},
"uniqueItems": true
},
"collectionFormat": {
"type": "string",
"enum": ["csv", "ssv", "tsv", "pipes"],
Expand Down
115 changes: 61 additions & 54 deletions steer/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from questionary import prompt
from pydantic import BaseModel
from steer.models import OutputType
from typing import Optional, List, Any, Dict
from typing import Optional, List, Any, Dict, Union


class Property(BaseModel):
Expand Down Expand Up @@ -121,7 +121,7 @@ def _get_prompt_args(self):
'name': self.name,
'message': f"{self.name} ({self.type}): ",
'choices': self.enum,
'default': str(self.default)
'default': self.default
}
return {k: v for k, v in args.items() if v is not None}

Expand Down Expand Up @@ -207,34 +207,28 @@ def prompt(self):

class ArrayProperty(Property):
type: str = 'array'
items: Optional[Property] = None
items: Optional[Union[StringProperty, IntegerProperty, NumberProperty]] = None
uniqueItems: Optional[bool] = False

@classmethod
def from_dict(cls, key, obj, parent):
return cls(
name=key,
type=obj.get('type'),
items=obj.get('items'),
items=PropertyFactory.get_property(key, obj.get('items'), parent),
uniqueItems=obj.get('uniqueItems'),
path=parent + key
)

def prompt(self):
elements = []
while True:
if self.items.type == 'integer':
property = IntegerProperty.from_dict('Array element:', self.items.model_dump(), '.')
elif self.items.type == 'string':
property = StringProperty.from_dict('Array element:', self.items.model_dump(), '.')
elif self.items.type == 'number':
property = NumberProperty.from_dict('Array element:', self.items.model_dump(), '.')
elif self.items.type == 'boolean':
property = BooleanProperty.from_dict('Array element:', self.items.model_dump(), '.')
else:
raise NotImplementedError()
property.prompt()
elements.append(property)
try:
prop = PropertyFactory.get_property('Array element:', self.items.model_dump(), '.')
prop.prompt()
except NotImplementedError:
continue
elements.append(prop)
if not self._add_more_elements():
break
else:
Expand Down Expand Up @@ -269,16 +263,11 @@ def add_property(self, property: Property):

def with_properties(self, properties, parent):
for key, value in properties.items():
if value.get('type') == 'string':
self.add_property(StringProperty.from_dict(key, value, parent + '.'))
elif value.get('type') == 'integer':
self.add_property(IntegerProperty.from_dict(key, value, parent + '.'))
elif value.get('type') == 'object':
self.add_property(ObjectProperty.from_dict(key, value, parent + '.'))
elif value.get('type') == 'boolean':
self.add_property(BooleanProperty.from_dict(key, value, parent + '.'))
elif value.get('type') == 'array':
self.add_property(ArrayProperty.from_dict(key, value, parent + '.'))
try:
prop = PropertyFactory.get_property(key, value, parent + '.')
self.add_property(prop)
except NotImplementedError:
continue
return self

def save(self, data: Any):
Expand All @@ -297,17 +286,17 @@ def prompt(self):
p.prompt()


class RefProperty(Property):
type: str = 'ref'
class Reference:
reference: str

@classmethod
def from_dict(cls, key, obj, parent):
return cls(
name=key,
type=obj.get('type'),
ref=obj.get('$ref'),
path=parent + key
)
def __init__(self, ref: str):
self.ref = ref

def get_reference(self, definitions: List[Property]) -> Property:
reference_name = self.ref.split('/')[-1]
for definition in definitions:
if definition.name == reference_name:
return definition


class Schema(BaseModel):
Expand Down Expand Up @@ -370,25 +359,43 @@ def from_dict(cls, obj):
)

for key, value in obj.get('definitions', {}).items():
schema.add_definition(ObjectProperty.from_dict(key, value, '$.'))
try:
definition = PropertyFactory.get_property(key, value, '$.', schema)
schema.add_definition(definition)
except NotImplementedError:
continue

for key, value in obj['properties'].items():
match value.get('type'):
case 'string':
schema.add_property(StringProperty.from_dict(key, value, '$.'))
case 'integer':
schema.add_property(IntegerProperty.from_dict(key, value, '$.'))
case 'number':
schema.add_property(NumberProperty.from_dict(key, value, '$.'))
case 'object':
schema.add_property(ObjectProperty.from_dict(key, value, '$.'))
case 'boolean':
schema.add_property(BooleanProperty.from_dict(key, value, '$.'))
case 'array':
schema.add_property(ArrayProperty.from_dict(key, value, '$.'))
case 'ref':
schema.add_property(RefProperty.from_dict(key, value, '$.'))
case _:
logging.warn(f"Type {value.get('type')} not supported")
try:
prop = PropertyFactory.get_property(key, value, '$.', schema)
schema.add_property(prop)
except NotImplementedError:
continue

return schema


class PropertyFactory:
@classmethod
def get_property(self, key: str, obj: Dict, parent: str, schema: Schema = None):
if obj.get('$ref') and schema is not None:
ref = Reference(obj.get('$ref'))
prop = ref.get_reference(schema.definitions)
return prop

match obj.get('type'):
case 'string':
return StringProperty.from_dict(key, obj, parent)
case 'integer':
return IntegerProperty.from_dict(key, obj, parent)
case 'number':
return NumberProperty.from_dict(key, obj, parent)
case 'object':
return ObjectProperty.from_dict(key, obj, parent)
case 'boolean':
return BooleanProperty.from_dict(key, obj, parent)
case 'array':
return ArrayProperty.from_dict(key, obj, parent)
case _:
logging.warn(f"Type {type} not supported")
raise NotImplementedError(f"Type {type} not supported yet")
46 changes: 46 additions & 0 deletions tests/test_array_property.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest
from steer.schema import (
Schema, ArrayProperty,
IntegerProperty, StringProperty
)


json_schema = {
"$schema": "http://example.com",
"type": "object",
"properties": {
"phones": {
"type": "array",
"items": {
"type": "integer"
}
},
"emails": {
"type": "array",
"items": {
"type": "string"
}
}
}
}


class TestArrayProperty(unittest.TestCase):
def setUp(self):
self.schema = Schema.from_dict(json_schema)

def test_integer_array_property(self):
self.assertIsInstance(self.schema.properties[0], ArrayProperty)

self.assertEqual(self.schema.properties[0].name, 'phones')
self.assertIsInstance(self.schema.properties[0].items, IntegerProperty)

def test_string_array_property(self):
self.assertIsInstance(self.schema.properties[1], ArrayProperty)

self.assertEqual(self.schema.properties[1].name, 'emails')
self.assertIsInstance(self.schema.properties[1].items, StringProperty)


if __name__ == '__main__':
unittest.main()
82 changes: 82 additions & 0 deletions tests/test_ref_property.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import unittest
from steer.schema import (
Schema, ObjectProperty, StringProperty,
IntegerProperty, NumberProperty, BooleanProperty
)


json_schema = {
"$schema": "http://example.com",
"type": "object",
"properties": {
"name": {
"$ref": "#/definitions/name"
},
"age": {
"$ref": "#/definitions/age"
},
"net_worth": {
"$ref": "#/definitions/net_worth"
},
"married": {
"$ref": "#/definitions/married"
},
"address": {
"$ref": "#/definitions/address"
}
},
"definitions": {
"name": {
"type": "string",
"pattern": "^[A-Z]+$"
},
"age": {
"type": "integer"
},
"net_worth": {
"type": "number"
},
"married": {
"type": "boolean"
},
"address": {
"type": "object",
"properties": {
"street": {
"type": "string"
},
"city": {
"type": "string"
},
"postal_code": {
"type": "string",
"pattern": "^[0-9]{4}-[0-9]{3}$"
}
}
}
}
}


class TestRefProperty(unittest.TestCase):
def setUp(self):
self.schema = Schema.from_dict(json_schema)

def test_loading_definitions(self):
self.assertIsInstance(self.schema.definitions[0], StringProperty)
self.assertIsInstance(self.schema.definitions[1], IntegerProperty)
self.assertIsInstance(self.schema.definitions[2], NumberProperty)
self.assertIsInstance(self.schema.definitions[3], BooleanProperty)
self.assertIsInstance(self.schema.definitions[4], ObjectProperty)

def test_referenced_object_fields(self):
prop = self.schema.properties[4]

self.assertEqual(prop.name, 'address')
self.assertEqual(prop.type, 'object')
for p in prop.properties:
self.assertIsInstance(p, StringProperty)


if __name__ == '__main__':
unittest.main()

0 comments on commit 9011ede

Please sign in to comment.