From a4bd218eb5f5b635c7b7253da3c921f191ec6d8c Mon Sep 17 00:00:00 2001 From: Robert Myers Date: Wed, 27 Dec 2023 17:12:40 -0600 Subject: [PATCH 1/2] Adding lru cache for validation --- cannula/api.py | 8 ++++++-- performance/test_performance.py | 24 ++++++++++++++++++++---- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/cannula/api.py b/cannula/api.py index 21bd796..669c750 100644 --- a/cannula/api.py +++ b/cannula/api.py @@ -16,6 +16,7 @@ from graphql import ( DocumentNode, + GraphQLError, GraphQLObjectType, execute, ExecutionResult, @@ -246,6 +247,10 @@ def _merge_registry(self, registry: dict): for type_name, value in registry.items(): self.registry[type_name].update(value) + @functools.lru_cache(maxsize=128) + def validate(self, document: DocumentNode) -> typing.List[GraphQLError]: + return validate(self.schema, document) + async def call( self, document: DocumentNode, @@ -257,8 +262,7 @@ async def call( This is meant to be called in an asyncio.loop, if you are using a web framework that is synchronous use the `call_sync` method. """ - validation_errors = validate(self.schema, document) - if validation_errors: + if validation_errors := self.validate(document): return ExecutionResult(data=None, errors=validation_errors) context = self.get_context(request) diff --git a/performance/test_performance.py b/performance/test_performance.py index f131da5..c43d599 100755 --- a/performance/test_performance.py +++ b/performance/test_performance.py @@ -11,6 +11,8 @@ import fastapi import pydantic +NUM_RUNS = 1000 + class Widget(pydantic.BaseModel): name: str @@ -76,6 +78,7 @@ def resolve_get_widgets(_, _info, use: str) -> typing.List[dict]: exe_schema = ariadne.make_executable_schema(schema, query) ariadne_app = ariadne.asgi.GraphQL(exe_schema) cannula_app = cannula.API(__name__, schema=[schema]) +document_ast = cannula.gql(document) @cannula_app.resolver("Query") @@ -91,7 +94,7 @@ async def get_ariadne_app(request: fastapi.Request) -> typing.Any: @api.get("/api/cannula") async def get_cannula_app(request: fastapi.Request) -> typing.Any: results = await cannula_app.call( - cannula.gql(document), request, variables={"use": "tighten"} + document_ast, request, variables={"use": "tighten"} ) return {"data": results.data, "errors": results.errors} @@ -100,9 +103,14 @@ async def test_performance(): client = httpx.AsyncClient(app=api, base_url="http://localhost") start = time.perf_counter() - for x in range(1000): + for x in range(NUM_RUNS): resp = await client.get("/api/fastapi?use=tighten") assert resp.status_code == 200 + assert resp.json() == [ + {"name": "screw driver", "quantity": 10, "use": "tighten"}, + {"name": "wrench", "quantity": 20, "use": "tighten"}, + ] + stop = time.perf_counter() fast_results = stop - start @@ -110,13 +118,17 @@ async def test_performance(): print(f"fastapi: {fast_results}") start = time.perf_counter() - for _x in range(1000): + for _x in range(NUM_RUNS): resp = await client.post( "/api/ariadne", json={"query": document, "variables": {"use": "tighten"}}, headers={"Content-Type": "application/json"}, ) assert resp.status_code == 200 + assert resp.json()["data"]["get_widgets"] == [ + {"name": "screw driver", "quantity": 10, "use": "tighten"}, + {"name": "wrench", "quantity": 20, "use": "tighten"}, + ] stop = time.perf_counter() ariadne_results = stop - start @@ -124,11 +136,15 @@ async def test_performance(): print(f"ariadne results: {ariadne_results}") start = time.perf_counter() - for _x in range(1000): + for _x in range(NUM_RUNS): resp = await client.get( "/api/cannula", ) assert resp.status_code == 200 + assert resp.json()["data"]["get_widgets"] == [ + {"name": "screw driver", "quantity": 10, "use": "tighten"}, + {"name": "wrench", "quantity": 20, "use": "tighten"}, + ] stop = time.perf_counter() cannula_results = stop - start From 358492be4a1c7f7b59e154a1c00a072102b32e72 Mon Sep 17 00:00:00 2001 From: Robert Myers Date: Thu, 28 Dec 2023 11:24:32 -0600 Subject: [PATCH 2/2] Adding better performance for invalid queries --- cannula/api.py | 24 +++++- cannula/contrib/__init__.py | 0 cannula/contrib/asgi.py | 9 +++ performance/Makefile | 2 +- performance/test_performance.py | 139 +++++++++++++++++++++++++++++--- 5 files changed, 162 insertions(+), 12 deletions(-) create mode 100644 cannula/contrib/__init__.py create mode 100644 cannula/contrib/asgi.py diff --git a/cannula/api.py b/cannula/api.py index 669c750..871762e 100644 --- a/cannula/api.py +++ b/cannula/api.py @@ -38,6 +38,11 @@ LOG = logging.getLogger(__name__) +class ParseResults(typing.NamedTuple): + document_ast: DocumentNode + errors: typing.List[GraphQLError] = [] + + class Resolver: """Resolver Registry @@ -131,7 +136,7 @@ def load_query(self, query_name: str) -> DocumentNode: path = os.path.join(self.query_directory, f"{query_name}.graphql") assert os.path.isfile(path), f"No query found for {query_name}" - with open(path) as query: + with open(path, "r") as query: return parse(query.read()) def resolver(self, type_name: str = "Query") -> typing.Any: @@ -249,11 +254,21 @@ def _merge_registry(self, registry: dict): @functools.lru_cache(maxsize=128) def validate(self, document: DocumentNode) -> typing.List[GraphQLError]: + """Validate the document against the schema and store results in lru_cache.""" return validate(self.schema, document) + @functools.lru_cache(maxsize=128) + def parse_document(self, document: str) -> ParseResults: + """Parse and store the document in lru_cache.""" + try: + document_ast = parse(document) + return ParseResults(document_ast, []) + except GraphQLError as err: + return ParseResults(DocumentNode(), [err]) + async def call( self, - document: DocumentNode, + document: typing.Union[DocumentNode, str], request: typing.Any = None, variables: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> ExecutionResult: @@ -262,6 +277,11 @@ async def call( This is meant to be called in an asyncio.loop, if you are using a web framework that is synchronous use the `call_sync` method. """ + if isinstance(document, str): + document, errors = self.parse_document(document) + if errors: + return ExecutionResult(data=None, errors=errors) + if validation_errors := self.validate(document): return ExecutionResult(data=None, errors=validation_errors) diff --git a/cannula/contrib/__init__.py b/cannula/contrib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cannula/contrib/asgi.py b/cannula/contrib/asgi.py new file mode 100644 index 0000000..c04782b --- /dev/null +++ b/cannula/contrib/asgi.py @@ -0,0 +1,9 @@ +import typing + +import pydantic + + +class GraphQLPayload(pydantic.BaseModel): + query: str + variables: typing.Optional[typing.Dict[str, typing.Any]] = None + operation: typing.Optional[str] = None diff --git a/performance/Makefile b/performance/Makefile index f4d7ba2..8a19def 100644 --- a/performance/Makefile +++ b/performance/Makefile @@ -25,7 +25,7 @@ setup: $(VIRTUAL_ENV) $(VIRTUAL_ENV)/.requirements-installed ## Setup local envi clean: ## Clean your local workspace rm -rf $(VIRTUAL_ENV) -test: ## run performance test +test: setup ## run performance test $(VIRTUAL_ENV)/bin/pytest --no-cov -s test_performance.py #% Available Commands: diff --git a/performance/test_performance.py b/performance/test_performance.py index c43d599..ab33010 100755 --- a/performance/test_performance.py +++ b/performance/test_performance.py @@ -7,6 +7,7 @@ import ariadne import ariadne.asgi import cannula +import cannula.contrib.asgi import httpx import fastapi import pydantic @@ -50,6 +51,18 @@ class Widget(pydantic.BaseModel): } """ +invalid_document = """ + query blah ( { } +""" + +invalid_query = """ + query widgets ($use: String) { + get_nonexistent(use: $use) { + foo + } + } +""" + query = ariadne.QueryType() api = fastapi.FastAPI() @@ -78,7 +91,6 @@ def resolve_get_widgets(_, _info, use: str) -> typing.List[dict]: exe_schema = ariadne.make_executable_schema(schema, query) ariadne_app = ariadne.asgi.GraphQL(exe_schema) cannula_app = cannula.API(__name__, schema=[schema]) -document_ast = cannula.gql(document) @cannula_app.resolver("Query") @@ -91,12 +103,15 @@ async def get_ariadne_app(request: fastapi.Request) -> typing.Any: return await ariadne_app.handle_request(request) -@api.get("/api/cannula") -async def get_cannula_app(request: fastapi.Request) -> typing.Any: +@api.post("/api/cannula") +async def get_cannula_app( + request: fastapi.Request, payload: cannula.contrib.asgi.GraphQLPayload +) -> typing.Any: results = await cannula_app.call( - document_ast, request, variables={"use": "tighten"} + payload.query, request, variables=payload.variables ) - return {"data": results.data, "errors": results.errors} + errors = [e.formatted for e in results.errors] if results.errors else None + return {"data": results.data, "errors": errors} async def test_performance(): @@ -105,7 +120,7 @@ async def test_performance(): start = time.perf_counter() for x in range(NUM_RUNS): resp = await client.get("/api/fastapi?use=tighten") - assert resp.status_code == 200 + assert resp.status_code == 200, resp.text assert resp.json() == [ {"name": "screw driver", "quantity": 10, "use": "tighten"}, {"name": "wrench", "quantity": 20, "use": "tighten"}, @@ -124,7 +139,7 @@ async def test_performance(): json={"query": document, "variables": {"use": "tighten"}}, headers={"Content-Type": "application/json"}, ) - assert resp.status_code == 200 + assert resp.status_code == 200, resp.text assert resp.json()["data"]["get_widgets"] == [ {"name": "screw driver", "quantity": 10, "use": "tighten"}, {"name": "wrench", "quantity": 20, "use": "tighten"}, @@ -137,10 +152,12 @@ async def test_performance(): start = time.perf_counter() for _x in range(NUM_RUNS): - resp = await client.get( + resp = await client.post( "/api/cannula", + json={"query": document, "variables": {"use": "tighten"}}, + headers={"Content-Type": "application/json"}, ) - assert resp.status_code == 200 + assert resp.status_code == 200, resp.text assert resp.json()["data"]["get_widgets"] == [ {"name": "screw driver", "quantity": 10, "use": "tighten"}, {"name": "wrench", "quantity": 20, "use": "tighten"}, @@ -150,3 +167,107 @@ async def test_performance(): cannula_results = stop - start print(f"cannula results: {cannula_results}") + + +async def test_performance_invalid_request(): + client = httpx.AsyncClient(app=api, base_url="http://localhost") + + start = time.perf_counter() + for x in range(NUM_RUNS): + resp = await client.get("/api/fastapi") + assert resp.status_code == 422, resp.text + + stop = time.perf_counter() + fast_results = stop - start + + print("\nperformance test results:") + print(f"fastapi: {fast_results}") + + start = time.perf_counter() + for _x in range(NUM_RUNS): + resp = await client.post( + "/api/ariadne", + json={"query": invalid_document, "variables": {"use": "tighten"}}, + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400, resp.text + errors = resp.json()["errors"] + assert len(errors) == 1 + assert errors[0]["message"] == "Syntax Error: Expected '$', found '{'." + + stop = time.perf_counter() + ariadne_results = stop - start + + print(f"ariadne results: {ariadne_results}") + + start = time.perf_counter() + for _x in range(NUM_RUNS): + resp = await client.post( + "/api/cannula", + json={"query": invalid_document, "variables": {"use": "tighten"}}, + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200, resp.text + errors = resp.json()["errors"] + assert len(errors) == 1 + assert errors[0]["message"] == "Syntax Error: Expected '$', found '{'." + + stop = time.perf_counter() + cannula_results = stop - start + + print(f"cannula results: {cannula_results}") + + +async def test_performance_invalid_query(): + client = httpx.AsyncClient(app=api, base_url="http://localhost") + + start = time.perf_counter() + for x in range(NUM_RUNS): + resp = await client.get("/api/fastapi") + assert resp.status_code == 422, resp.text + + stop = time.perf_counter() + fast_results = stop - start + + print("\nperformance test results:") + print(f"fastapi: {fast_results}") + + start = time.perf_counter() + for _x in range(NUM_RUNS): + resp = await client.post( + "/api/ariadne", + json={"query": invalid_query, "variables": {"use": "tighten"}}, + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400, resp.text + errors = resp.json()["errors"] + assert len(errors) == 1 + assert ( + errors[0]["message"] + == "Cannot query field 'get_nonexistent' on type 'Query'." + ) + + stop = time.perf_counter() + ariadne_results = stop - start + + print(f"ariadne results: {ariadne_results}") + + start = time.perf_counter() + for _x in range(NUM_RUNS): + resp = await client.post( + "/api/cannula", + json={"query": invalid_query, "variables": {"use": "tighten"}}, + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200, resp.text + errors = resp.json()["errors"] + assert len(errors) == 1 + assert ( + errors[0]["message"] + == "Cannot query field 'get_nonexistent' on type 'Query'." + ) + + stop = time.perf_counter() + cannula_results = stop - start + + print(f"cannula results: {cannula_results}")