diff --git a/flask_pydantic_spec/flask_backend.py b/flask_pydantic_spec/flask_backend.py index b7bc179..e0f26f2 100644 --- a/flask_pydantic_spec/flask_backend.py +++ b/flask_pydantic_spec/flask_backend.py @@ -148,9 +148,16 @@ def request_validation( else: parsed_body = request.get_json(silent=True) or {} elif request.content_type and "multipart/form-data" in request.content_type: - parsed_body = parse_multi_dict(request.form) if request.form else {} + # It's possible there is a binary json object in the files - iterate through and find it + parsed_body = {} + for key, value in request.files.items(): + if value.mimetype == "application/json": + parsed_body[key] = json.loads(value.stream.read().decode(encoding="utf-8")) + # Finally, find any JSON objects in the form and add them to the body + parsed_body.update(parse_multi_dict(request.form) or {}) else: parsed_body = request.get_data() or {} + req_headers: Optional[Headers] = request.headers or None req_cookies: Optional[Mapping[str, str]] = request.cookies or None setattr( @@ -158,9 +165,11 @@ def request_validation( "context", Context( query=query.parse_obj(req_query) if query else None, - body=getattr(body, "model").parse_obj(parsed_body) - if body and getattr(body, "model") - else None, + body=( + getattr(body, "model").parse_obj(parsed_body) + if body and getattr(body, "model") + else None + ), headers=headers.parse_obj(req_headers or {}) if headers else None, cookies=cookies.parse_obj(req_cookies or {}) if cookies else None, ), diff --git a/tests/test_plugin_flask.py b/tests/test_plugin_flask.py index db8485c..471d2f3 100644 --- a/tests/test_plugin_flask.py +++ b/tests/test_plugin_flask.py @@ -2,10 +2,13 @@ from io import BytesIO from random import randint import gzip +from typing import Union + import pytest import json from flask import Flask, jsonify, request from werkzeug.datastructures import FileStorage +from werkzeug.test import Client from flask_pydantic_spec.types import Response, MultipartFormRequest from flask_pydantic_spec import FlaskPydanticSpec @@ -111,7 +114,7 @@ def client(request): @pytest.mark.parametrize("client", [422], indirect=True) -def test_flask_validate(client): +def test_flask_validate(client: Client): resp = client.get("/ping") assert resp.status_code == 422 assert resp.headers.get("X-Error") == "Validation Error" @@ -158,23 +161,31 @@ def test_flask_validate(client): @pytest.mark.parametrize("client", [422], indirect=True) -def test_sending_file(client): +@pytest.mark.parametrize( + "data", + [ + FileStorage( + BytesIO(json.dumps({"type": "foo", "created_at": str(datetime.now().date())}).encode()), + ), + json.dumps({"type": "foo", "created_at": str(datetime.now().date())}), + ], +) +def test_sending_file(client: Client, data: Union[FileStorage, str]): file = FileStorage(BytesIO(b"abcde"), filename="test.jpg", name="test.jpg") resp = client.post( "/api/file", data={ "file": file, "file_name": "another_test.jpg", - "data": json.dumps({"type": "foo", "created_at": str(datetime.now().date())}), + "data": data, }, - content_type="multipart/form-data", ) assert resp.status_code == 200 assert resp.json["name"] == "another_test.jpg" @pytest.mark.parametrize("client", [422], indirect=True) -def test_query_params(client): +def test_query_params(client: Client): resp = client.get("api/user?name=james&name=bethany&name=claire") assert resp.status_code == 200 assert len(resp.json["data"]) == 2 @@ -189,7 +200,7 @@ def test_query_params(client): @pytest.mark.parametrize("client", [200], indirect=True) -def test_flask_skip_validation(client): +def test_flask_skip_validation(client: Client): resp = client.get("api/group/test") assert resp.status_code == 200 assert resp.json["name"] == "test" @@ -197,7 +208,7 @@ def test_flask_skip_validation(client): @pytest.mark.parametrize("client", [422], indirect=True) -def test_flask_doc(client): +def test_flask_doc(client: Client): resp = client.get("/apidoc/openapi.json") assert resp.json == api.spec @@ -211,7 +222,7 @@ def test_flask_doc(client): @pytest.mark.parametrize("client", [400], indirect=True) -def test_flask_validate_with_alternative_code(client): +def test_flask_validate_with_alternative_code(client: Client): resp = client.get("/ping") assert resp.status_code == 400 assert resp.headers.get("X-Error") == "Validation Error" @@ -222,7 +233,7 @@ def test_flask_validate_with_alternative_code(client): @pytest.mark.parametrize("client", [400], indirect=True) -def test_flask_post_gzip(client): +def test_flask_post_gzip(client: Client): body = dict(name="flask", limit=10) compressed = gzip.compress(bytes(json.dumps(body), encoding="utf-8")) @@ -240,7 +251,7 @@ def test_flask_post_gzip(client): @pytest.mark.parametrize("client", [400], indirect=True) -def test_flask_post_gzip_failure(client): +def test_flask_post_gzip_failure(client: Client): body = dict(name="flask") compressed = gzip.compress(bytes(json.dumps(body), encoding="utf-8"))