Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix query, form, header model extra not honored #201

Merged
merged 2 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions flask_openapi3/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ def _get_value(model: Type[BaseModel], args: MultiDict, model_field_key: str, mo
def _validate_header(header: Type[BaseModel], func_kwargs: dict):
request_headers = dict(request.headers)
header_dict = {}
model_properties = header.model_json_schema().get("properties", {})
for model_field_key, model_field_value in header.model_fields.items():
key_title = model_field_key.replace("_", "-").title()
model_field_schema = model_properties.get(model_field_value.alias or model_field_key)
if model_field_value.alias and header.model_config.get("populate_by_name"):
key = model_field_value.alias
key_alias_title = model_field_value.alias.replace("_", "-").title()
Expand All @@ -57,6 +59,12 @@ def _validate_header(header: Type[BaseModel], func_kwargs: dict):
value = request_headers[key_title]
if value is not None:
header_dict[key] = value
if model_field_schema.get("type") == "null":
header_dict[key] = value # type:ignore
# extra keys
for key, value in request_headers.items():
if key not in header_dict.keys():
header_dict[key] = value
func_kwargs["header"] = header.model_validate(obj=header_dict)


Expand All @@ -81,6 +89,12 @@ def _validate_query(query: Type[BaseModel], func_kwargs: dict):
key, value = _get_value(query, request_args, model_field_key, model_field_value)
if value is not None and value != []:
query_dict[key] = value
if model_field_schema.get("type") == "null":
query_dict[key] = value
# extra keys
for key, value in request_args.items():
if key not in query_dict.keys():
query_dict[key] = value
func_kwargs["query"] = query.model_validate(obj=query_dict)


Expand Down Expand Up @@ -114,6 +128,12 @@ def _validate_form(form: Type[BaseModel], func_kwargs: dict):
value = _value
if value is not None and value != []:
form_dict[key] = value
if model_field_schema.get("type") == "null":
form_dict[key] = value
# extra keys
for key, value in {**dict(request_form), **dict(request_files)}.items():
if key not in form_dict.keys():
form_dict[key] = value
func_kwargs["form"] = form.model_validate(obj=form_dict)


Expand Down
78 changes: 78 additions & 0 deletions tests/test_model_extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
# @Author : llc
# @Time : 2024/11/20 14:45
from typing import Optional

import pytest
from pydantic import BaseModel, Field, ConfigDict

from flask_openapi3 import OpenAPI

app = OpenAPI(__name__)
app.config["TESTING"] = True


class BookQuery(BaseModel):
age: Optional[int] = Field(None, description="Age")

model_config = ConfigDict(extra="allow")


class BookForm(BaseModel):
string: str

model_config = ConfigDict(extra="forbid")


class BookHeader(BaseModel):
api_key: str = Field(..., description="API Key")

model_config = ConfigDict(extra="forbid")


@pytest.fixture
def client():
client = app.test_client()

return client


@app.get("/book")
def get_books(query: BookQuery):
"""get books
to get all books
"""
assert query.age == 3
assert query.author == "joy"
return {"code": 0, "message": "ok"}


@app.post("/form")
def api_form(form: BookForm):
print(form)
return {"code": 0, "message": "ok"}


def test_query(client):
resp = client.get("/book?age=3&author=joy")
assert resp.status_code == 200


@app.get("/header")
def get_book(header: BookHeader):
return header.model_dump(by_alias=True)


def test_form(client):
data = {
"string": "a",
"string_list": ["a", "b", "c"]
}
r = client.post("/form", data=data, content_type="multipart/form-data")
assert r.status_code == 422


def test_header(client):
headers = {"Hello1": "111", "hello2": "222", "api_key": "333", "api_type": "A", "x-hello": "444"}
resp = client.get("/header", headers=headers)
assert resp.status_code == 422