Skip to content

Commit

Permalink
PathItem.parameters only contains path params now, Operation.paramete…
Browse files Browse the repository at this point in the history
…rs will now get all the query parameters if the http method doesn't have a body, if it does have a body then the query parameters will be placed in the RequestBody
  • Loading branch information
Jaymon committed Jan 30, 2025
1 parent 64e4850 commit 8937717
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 57 deletions.
51 changes: 39 additions & 12 deletions endpoints/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def reflect_http_methods(self, http_verb=""):
There can be multiple of each http method name, so this can
yield N GET methods, etc.
"""
# method_names = self.value["method_names"]
method_names = self.get_http_method_names()
if http_verb:
items = []
Expand Down Expand Up @@ -185,6 +184,11 @@ class ReflectMethod(ReflectCallable):
these are methods on the controller like GET and POST
"""
@classmethod
def has_body(self, http_verb):
"""Returns True if http_verb accepts a body in the request"""
return http_verb in set(["PUT", "POST", "PATCH"])

def __init__(self, target, http_verb, reflect_controller, **kwargs):
"""
:param target: callable, the controller method
Expand All @@ -202,10 +206,6 @@ def __init__(self, target, http_verb, reflect_controller, **kwargs):
def reflect_class(self):
return self._reflect_controller

def has_body(self):
"""Returns True if this method accepts a body in the request"""
return self.name in set(["PUT", "POST", "PATCH"])

def reflect_params(self):
"""This will reflect all params defined with the @param decorator"""
unwrapped = self.get_unwrapped()
Expand All @@ -216,7 +216,7 @@ def reflect_params(self):
def reflect_body_params(self):
"""This will reflect all the params that are usually passed up using
the body on a POST request"""
if self.has_body():
if self.has_body(self.http_verb):
for rp in self.reflect_params():
if rp.target.is_kwarg:
yield rp
Expand All @@ -225,7 +225,14 @@ def reflect_url_params(self):
"""This will reflect params that need to be in the url path or the
query part of the url"""
for rp in self.reflect_params():
if not rp.target.is_kwarg or not self.has_body():
if not rp.target.is_kwarg or not self.has_body(self.http_verb):
yield rp

def reflect_query_params(self):
"""This will reflect params that need to be in the query part of the
url"""
for rp in self.reflect_params():
if rp.target.is_kwarg and not self.has_body(self.http_verb):
yield rp

def reflect_path_params(self):
Expand Down Expand Up @@ -1009,10 +1016,15 @@ def set_request_method(self, reflect_method):
"""This is called from RequestBody and is the main hook into
customizing and extending the request schema for child projects"""
self.reflect_method = reflect_method

for reflect_param in reflect_method.reflect_body_params():
self.set_object_keys()
self.add_param(reflect_param)

for reflect_param in reflect_method.reflect_url_params():
self.set_object_keys()
self.add_param(reflect_param)

def set_response_method(self, reflect_method):
"""Called from Response and is the main hook into customizing and
extending the response schema for child projects"""
Expand Down Expand Up @@ -1653,7 +1665,7 @@ class Operation(OpenABC):

_operationId = Field(str)

_parameters = Field(list[Parameter|Reference])
_parameters = Field(list[Parameter|Reference], todict_empty_value=None)

_requestBody = Field(RequestBody|Reference)

Expand Down Expand Up @@ -1682,6 +1694,21 @@ def init_instance(self, reflect_method, **kwargs):
self.reflect_method = reflect_method
self.set_docblock(reflect_method.get_docblock())

def get_parameters_value(self, **kwargs):
# if this operation has a body then we put the url params in the body
# instead
if not self.reflect_method.has_body(self.http_verb):
parameters = []

# this is a positional argument (part of path) or query param
# (after the ? in the url)
for reflect_param in self.reflect_method.reflect_query_params():
parameter = self.create_instance("parameter_class")
parameter.set_param(reflect_param)
parameters.append(parameter)

return parameters

def get_tags_value(self, **kwargs):
tags = list(self.reflect_method.reflect_class().value["module_keys"])
if not tags and self.root:
Expand Down Expand Up @@ -1726,7 +1753,7 @@ def get_operation_id_value(self, **kwargs):
return "".join(parts)

def get_request_body_value(self, **kwargs):
if self.reflect_method.has_body():
if self.reflect_method.has_body(self.http_verb):
rb = self.create_instance(
"request_body_class",
**kwargs
Expand Down Expand Up @@ -1886,6 +1913,8 @@ def set_put_operation(self, operation):

_description = Field(str)

_parameters = Field(list[Parameter|Reference], todict_empty_value=None)

_get = Field(Operation)

_put = Field(Operation)
Expand All @@ -1904,8 +1933,6 @@ def set_put_operation(self, operation):

_servers = Field(list[Server])

_parameters = Field(list[Parameter|Reference], todict_empty_value=None)

def add_method(self, reflect_method, **kwargs):
"""Add the method to this path
Expand All @@ -1926,7 +1953,7 @@ def add_parameters(self, reflect_method):

# this is a positional argument (part of path) or query param
# (after the ? in the url)
for reflect_param in reflect_method.reflect_url_params():
for reflect_param in reflect_method.reflect_path_params():
parameter = self.create_instance("parameter_class")
parameter.set_param(reflect_param)
parameters.append(parameter)
Expand Down
127 changes: 82 additions & 45 deletions tests/reflection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_reflect_url_modules(self):


class ReflectMethodTest(TestCase):
def test_reflect_params(self):
def test_reflect_params_post(self):
rc = self.create_reflect_controllers("""
class Foo(Controller):
@param(0)
Expand All @@ -114,6 +114,23 @@ def POST(self, *args, **kwargs):
self.assertEqual(3, len(list(rm.reflect_params())))
self.assertEqual(2, len(list(rm.reflect_body_params())))
self.assertEqual(1, len(list(rm.reflect_url_params())))
self.assertEqual(0, len(list(rm.reflect_query_params())))

def test_reflect_params_get(self):
rc = self.create_reflect_controllers("""
class Foo(Controller):
@param(0)
@param("bar", type=int, help="bar variable")
@param("che", type=str, required=False, help="che variable")
def GET(self, *args, **kwargs):
pass
""")[0]

rm = list(rc.reflect_http_methods("GET"))[0]
self.assertEqual(2, len(list(rm.reflect_query_params())))
self.assertEqual(3, len(list(rm.reflect_params())))
self.assertEqual(3, len(list(rm.reflect_url_params())))
self.assertEqual(1, len(list(rm.reflect_path_params())))

def test_get_url_path_1(self):
rm = self.create_reflect_methods("""
Expand Down Expand Up @@ -215,33 +232,6 @@ def GET(self):
self.assertTrue("/" in oa.paths)
self.assertTrue("get" in oa.paths["/"])

def test_params_positional_named(self):
oa = self.create_openapi("""
class Foo(Controller):
@param(0)
def GET(self, zero, *args, **kwargs):
pass
""")

self.assertTrue("/foo/{zero}" in oa.paths)

parameter = oa.paths["/foo/{zero}"].parameters[0]
self.assertEqual("zero", parameter["name"])
self.assertEqual("path", parameter["in"])

def test_params_query(self):
oa = self.create_openapi("""
class Foo(Controller):
@param("zero")
def GET(self, zero, *args, **kwargs):
pass
""")

pi = oa.paths["/foo"]

self.assertEqual("query", pi.parameters[0]["in"])
self.assertEqual(1, len(pi.parameters))

def test_params_body(self):
oa = self.create_openapi("""
class Foo(Controller):
Expand Down Expand Up @@ -443,23 +433,6 @@ def test_write_yaml(self):
fp = oa.write_yaml(dp)
self.assertTrue(fp.isfile())

def test_multiple_path_with_options(self):
oa = self.create_openapi("""
class Foo(Controller):
cors = True
@param(0)
@param(1)
def GET(self, bar, che):
pass
""")

pi = oa.paths["/foo/{bar}/{che}"]
self.assertEqual(3, len(pi))
self.assertTrue("options" in pi)
self.assertFalse("405" in pi["options"]["responses"])
self.assertTrue("get" in pi)
self.assertFalse("/foo" in oa.paths)

def test_url_path(self):
oa = self.create_openapi("""
class Default(Controller):
Expand Down Expand Up @@ -543,6 +516,70 @@ def GET(self, param, *args, **kwargs) -> str:
self.assertTrue("text/html" in content)


class PathItemTest(TestCase):
def test_multiple_path_with_options(self):
oa = self.create_openapi("""
class Foo(Controller):
cors = True
@param(0)
@param(1)
def GET(self, bar, che):
pass
""")

pi = oa.paths["/foo/{bar}/{che}"]
for http_verb in ["get", "options"]:
self.assertTrue(http_verb in pi)

self.assertFalse("405" in pi["options"]["responses"])
self.assertFalse("/foo" in oa.paths)

def test_params_positional_named(self):
oa = self.create_openapi("""
class Foo(Controller):
@param(0)
def GET(self, zero, *args, **kwargs):
pass
""")

self.assertTrue("/foo/{zero}" in oa.paths)

parameter = oa.paths["/foo/{zero}"].parameters[0]
self.assertEqual("zero", parameter["name"])
self.assertEqual("path", parameter["in"])


class OperationTest(TestCase):
def test_any_parameters(self):
"""Makes sure ANY sets a get operation with parameters and sets a post
with a requestBody with a foo property"""
oa = self.create_openapi("""
class Default(Controller):
@param("foo", type=int, default=1)
def ANY(self, foo) -> None:
pass
""")

pi = oa.paths["/"]

self.assertEqual("foo", pi["get"]["parameters"][0]["name"])
self.assertFalse("Paremeters"in pi["post"])
schema = pi["post"]["requestBody"]["content"]["*/*"]["schema"]
self.assertTrue("foo" in schema["properties"])

def test_params_query(self):
oa = self.create_openapi("""
class Foo(Controller):
@param("zero")
def GET(self, zero, *args, **kwargs):
pass
""")

pi = oa.paths["/foo"]
self.assertEqual(0, len(pi.parameters))
self.assertEqual("query", pi["get"]["parameters"][0]["in"])


class SchemaTest(TestCase):
def test_list_value_types(self):
rt = ReflectType(list[dict[str, int]|tuple[float, float]])
Expand Down

0 comments on commit 8937717

Please sign in to comment.