Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for cached queries and validation #17

Merged
merged 2 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions cannula/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from graphql import (
DocumentNode,
GraphQLError,
GraphQLObjectType,
execute,
ExecutionResult,
Expand All @@ -37,6 +38,11 @@
LOG = logging.getLogger(__name__)


class ParseResults(typing.NamedTuple):
document_ast: DocumentNode
errors: typing.List[GraphQLError] = []


class Resolver:
"""Resolver Registry

Expand Down Expand Up @@ -130,7 +136,7 @@ def load_query(self, query_name: str) -> DocumentNode:
path = os.path.join(self.query_directory, f"{query_name}.graphql")
assert os.path.isfile(path), f"No query found for {query_name}"

with open(path) as query:
with open(path, "r") as query:
return parse(query.read())

def resolver(self, type_name: str = "Query") -> typing.Any:
Expand Down Expand Up @@ -246,9 +252,23 @@ def _merge_registry(self, registry: dict):
for type_name, value in registry.items():
self.registry[type_name].update(value)

@functools.lru_cache(maxsize=128)
def validate(self, document: DocumentNode) -> typing.List[GraphQLError]:
"""Validate the document against the schema and store results in lru_cache."""
return validate(self.schema, document)

@functools.lru_cache(maxsize=128)
def parse_document(self, document: str) -> ParseResults:
"""Parse and store the document in lru_cache."""
try:
document_ast = parse(document)
return ParseResults(document_ast, [])
except GraphQLError as err:
return ParseResults(DocumentNode(), [err])

async def call(
self,
document: DocumentNode,
document: typing.Union[DocumentNode, str],
request: typing.Any = None,
variables: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> ExecutionResult:
Expand All @@ -257,8 +277,12 @@ async def call(
This is meant to be called in an asyncio.loop, if you are using a
web framework that is synchronous use the `call_sync` method.
"""
validation_errors = validate(self.schema, document)
if validation_errors:
if isinstance(document, str):
document, errors = self.parse_document(document)
if errors:
return ExecutionResult(data=None, errors=errors)

if validation_errors := self.validate(document):
return ExecutionResult(data=None, errors=validation_errors)

context = self.get_context(request)
Expand Down
Empty file added cannula/contrib/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions cannula/contrib/asgi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import typing

import pydantic


class GraphQLPayload(pydantic.BaseModel):
query: str
variables: typing.Optional[typing.Dict[str, typing.Any]] = None
operation: typing.Optional[str] = None
2 changes: 1 addition & 1 deletion performance/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ setup: $(VIRTUAL_ENV) $(VIRTUAL_ENV)/.requirements-installed ## Setup local envi
clean: ## Clean your local workspace
rm -rf $(VIRTUAL_ENV)

test: ## run performance test
test: setup ## run performance test
$(VIRTUAL_ENV)/bin/pytest --no-cov -s test_performance.py

#% Available Commands:
Expand Down
159 changes: 148 additions & 11 deletions performance/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
import ariadne
import ariadne.asgi
import cannula
import cannula.contrib.asgi
import httpx
import fastapi
import pydantic

NUM_RUNS = 1000


class Widget(pydantic.BaseModel):
name: str
Expand Down Expand Up @@ -48,6 +51,18 @@ class Widget(pydantic.BaseModel):
}
"""

invalid_document = """
query blah ( { }
"""

invalid_query = """
query widgets ($use: String) {
get_nonexistent(use: $use) {
foo
}
}
"""


query = ariadne.QueryType()
api = fastapi.FastAPI()
Expand Down Expand Up @@ -88,47 +103,169 @@ async def get_ariadne_app(request: fastapi.Request) -> typing.Any:
return await ariadne_app.handle_request(request)


@api.get("/api/cannula")
async def get_cannula_app(request: fastapi.Request) -> typing.Any:
@api.post("/api/cannula")
async def get_cannula_app(
request: fastapi.Request, payload: cannula.contrib.asgi.GraphQLPayload
) -> typing.Any:
results = await cannula_app.call(
cannula.gql(document), request, variables={"use": "tighten"}
payload.query, request, variables=payload.variables
)
return {"data": results.data, "errors": results.errors}
errors = [e.formatted for e in results.errors] if results.errors else None
return {"data": results.data, "errors": errors}


async def test_performance():
client = httpx.AsyncClient(app=api, base_url="http://localhost")

start = time.perf_counter()
for x in range(1000):
for x in range(NUM_RUNS):
resp = await client.get("/api/fastapi?use=tighten")
assert resp.status_code == 200
assert resp.status_code == 200, resp.text
assert resp.json() == [
{"name": "screw driver", "quantity": 10, "use": "tighten"},
{"name": "wrench", "quantity": 20, "use": "tighten"},
]

stop = time.perf_counter()
fast_results = stop - start

print("\nperformance test results:")
print(f"fastapi: {fast_results}")

start = time.perf_counter()
for _x in range(1000):
for _x in range(NUM_RUNS):
resp = await client.post(
"/api/ariadne",
json={"query": document, "variables": {"use": "tighten"}},
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 200
assert resp.status_code == 200, resp.text
assert resp.json()["data"]["get_widgets"] == [
{"name": "screw driver", "quantity": 10, "use": "tighten"},
{"name": "wrench", "quantity": 20, "use": "tighten"},
]

stop = time.perf_counter()
ariadne_results = stop - start

print(f"ariadne results: {ariadne_results}")

start = time.perf_counter()
for _x in range(NUM_RUNS):
resp = await client.post(
"/api/cannula",
json={"query": document, "variables": {"use": "tighten"}},
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 200, resp.text
assert resp.json()["data"]["get_widgets"] == [
{"name": "screw driver", "quantity": 10, "use": "tighten"},
{"name": "wrench", "quantity": 20, "use": "tighten"},
]

stop = time.perf_counter()
cannula_results = stop - start

print(f"cannula results: {cannula_results}")


async def test_performance_invalid_request():
client = httpx.AsyncClient(app=api, base_url="http://localhost")

start = time.perf_counter()
for x in range(NUM_RUNS):
resp = await client.get("/api/fastapi")
assert resp.status_code == 422, resp.text

stop = time.perf_counter()
fast_results = stop - start

print("\nperformance test results:")
print(f"fastapi: {fast_results}")

start = time.perf_counter()
for _x in range(NUM_RUNS):
resp = await client.post(
"/api/ariadne",
json={"query": invalid_document, "variables": {"use": "tighten"}},
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 400, resp.text
errors = resp.json()["errors"]
assert len(errors) == 1
assert errors[0]["message"] == "Syntax Error: Expected '$', found '{'."

stop = time.perf_counter()
ariadne_results = stop - start

print(f"ariadne results: {ariadne_results}")

start = time.perf_counter()
for _x in range(NUM_RUNS):
resp = await client.post(
"/api/cannula",
json={"query": invalid_document, "variables": {"use": "tighten"}},
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 200, resp.text
errors = resp.json()["errors"]
assert len(errors) == 1
assert errors[0]["message"] == "Syntax Error: Expected '$', found '{'."

stop = time.perf_counter()
cannula_results = stop - start

print(f"cannula results: {cannula_results}")


async def test_performance_invalid_query():
client = httpx.AsyncClient(app=api, base_url="http://localhost")

start = time.perf_counter()
for x in range(NUM_RUNS):
resp = await client.get("/api/fastapi")
assert resp.status_code == 422, resp.text

stop = time.perf_counter()
fast_results = stop - start

print("\nperformance test results:")
print(f"fastapi: {fast_results}")

start = time.perf_counter()
for _x in range(NUM_RUNS):
resp = await client.post(
"/api/ariadne",
json={"query": invalid_query, "variables": {"use": "tighten"}},
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 400, resp.text
errors = resp.json()["errors"]
assert len(errors) == 1
assert (
errors[0]["message"]
== "Cannot query field 'get_nonexistent' on type 'Query'."
)

stop = time.perf_counter()
ariadne_results = stop - start

print(f"ariadne results: {ariadne_results}")

start = time.perf_counter()
for _x in range(1000):
resp = await client.get(
for _x in range(NUM_RUNS):
resp = await client.post(
"/api/cannula",
json={"query": invalid_query, "variables": {"use": "tighten"}},
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 200, resp.text
errors = resp.json()["errors"]
assert len(errors) == 1
assert (
errors[0]["message"]
== "Cannot query field 'get_nonexistent' on type 'Query'."
)
assert resp.status_code == 200

stop = time.perf_counter()
cannula_results = stop - start
Expand Down