diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..9d36db038a --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,36 @@ +Release type: minor + +This release adds support for schema-extensions in subscriptions. + +Here's a small example of how to use them (they work the same way as query and +mutation extensions): + +```python +import asyncio +from typing import AsyncIterator + +import strawberry +from strawberry.extensions.base_extension import SchemaExtension + + +@strawberry.type +class Subscription: + @strawberry.subscription + async def notifications(self, info: strawberry.Info) -> AsyncIterator[str]: + for _ in range(3): + yield "Hello" + + +class MyExtension(SchemaExtension): + async def on_operation(self): + # This would run when the subscription starts + print("Subscription started") + yield + # The subscription has ended + print("Subscription ended") + + +schema = strawberry.Schema( + query=Query, subscription=Subscription, extensions=[MyExtension] +) +``` diff --git a/TWEET.md b/TWEET.md new file mode 100644 index 0000000000..d3c718c4cd --- /dev/null +++ b/TWEET.md @@ -0,0 +1,5 @@ +🆕 Release $version is out! Thanks to $contributor for the PR 👏 + +This release adds supports for schema extensions to subscriptions! + +Get it here 👉 $release_url diff --git a/docs/breaking-changes.md b/docs/breaking-changes.md index 772c9ef8c5..9f0ed20a80 100644 --- a/docs/breaking-changes.md +++ b/docs/breaking-changes.md @@ -4,6 +4,7 @@ title: List of breaking changes and deprecations # List of breaking changes and deprecations +- [Version 0.240.0 - 10 September 2024](./breaking-changes/0.240.0.md) - [Version 0.236.0 - 17 July 2024](./breaking-changes/0.236.0.md) - [Version 0.233.0 - 29 May 2024](./breaking-changes/0.233.0.md) - [Version 0.217.0 - 18 December 2023](./breaking-changes/0.217.0.md) diff --git a/docs/breaking-changes/0.240.0.md b/docs/breaking-changes/0.240.0.md new file mode 100644 index 0000000000..886ce28873 --- /dev/null +++ b/docs/breaking-changes/0.240.0.md @@ -0,0 +1,36 @@ +--- +title: 0.240.0 Breaking Changes +slug: breaking-changes/0.240.0 +--- + +# v0.240.0 updates `Schema.subscribe`'s signature + +In order to support schema extensions in subscriptions and errors that can be +raised before the execution of the subscription, we had to update the signature +of `Schema.subscribe`. + +Previously it was: + +```python +async def subscribe( + self, + query: str, + variable_values: Optional[Dict[str, Any]] = None, + context_value: Optional[Any] = None, + root_value: Optional[Any] = None, + operation_name: Optional[str] = None, +) -> Union[AsyncIterator[GraphQLExecutionResult], GraphQLExecutionResult]: +``` + +Now it is: + +```python +async def subscribe( + self, + query: Optional[str], + variable_values: Optional[Dict[str, Any]] = None, + context_value: Optional[Any] = None, + root_value: Optional[Any] = None, + operation_name: Optional[str] = None, +) -> Union[AsyncGenerator[ExecutionResult, None], PreExecutionError]: +``` diff --git a/noxfile.py b/noxfile.py index 66f8c200fa..b638508d77 100644 --- a/noxfile.py +++ b/noxfile.py @@ -11,7 +11,7 @@ PYTHON_VERSIONS = ["3.12", "3.11", "3.10", "3.9", "3.8"] GQL_CORE_VERSIONS = [ "3.2.3", - "3.3.0", + "3.3.0a6", ] COMMON_PYTEST_OPTIONS = [ @@ -44,12 +44,7 @@ def _install_gql_core(session: Session, version: str) -> None: - # hack for better workflow names # noqa: FIX004 - if version == "3.2.3": - session._session.install(f"graphql-core=={version}") # type: ignore - session._session.install( - "https://github.com/graphql-python/graphql-core/archive/876aef67b6f1e1f21b3b5db94c7ff03726cb6bdf.zip" - ) # type: ignore + session._session.install(f"graphql-core=={version}") gql_core_parametrize = nox.parametrize( diff --git a/strawberry/channels/testing.py b/strawberry/channels/testing.py index db3e78e097..890c7147d2 100644 --- a/strawberry/channels/testing.py +++ b/strawberry/channels/testing.py @@ -144,9 +144,9 @@ async def subscribe( message_type = response["type"] if message_type == NextMessage.type: payload = NextMessage(**response).payload - ret = ExecutionResult(payload["data"], None) + ret = ExecutionResult(payload.get("data"), None) if "errors" in payload: - ret.errors = self.process_errors(payload["errors"]) + ret.errors = self.process_errors(payload.get("errors") or []) ret.extensions = payload.get("extensions", None) yield ret elif message_type == ErrorMessage.type: diff --git a/strawberry/extensions/base_extension.py b/strawberry/extensions/base_extension.py index 92160315f0..ff8d75d7ce 100644 --- a/strawberry/extensions/base_extension.py +++ b/strawberry/extensions/base_extension.py @@ -21,9 +21,11 @@ class LifecycleStep(Enum): class SchemaExtension: execution_context: ExecutionContext - def __init__(self, *, execution_context: ExecutionContext) -> None: - self.execution_context = execution_context - + # to support extensions that still use the old signature + # we have an optional argument here for ease of initialization. + def __init__( + self, *, execution_context: ExecutionContext | None = None + ) -> None: ... def on_operation( # type: ignore self, ) -> AsyncIteratorOrIterator[None]: # pragma: no cover @@ -61,6 +63,11 @@ def resolve( def get_results(self) -> AwaitableOrValue[Dict[str, Any]]: return {} + @classmethod + def _implements_resolve(cls) -> bool: + """Whether the extension implements the resolve method.""" + return cls.resolve is not SchemaExtension.resolve + Hook = Callable[[SchemaExtension], AsyncIteratorOrIterator[None]] diff --git a/strawberry/extensions/runner.py b/strawberry/extensions/runner.py index 44d12ac009..1e249fc1e8 100644 --- a/strawberry/extensions/runner.py +++ b/strawberry/extensions/runner.py @@ -1,9 +1,7 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union - -from graphql import MiddlewareManager +from typing import TYPE_CHECKING, Any, Dict, List, Optional from strawberry.extensions.context import ( ExecutingContextManager, @@ -13,11 +11,11 @@ ) from strawberry.utils.await_maybe import await_maybe -from . import SchemaExtension - if TYPE_CHECKING: from strawberry.types import ExecutionContext + from . import SchemaExtension + class SchemaExtensionsRunner: extensions: List[SchemaExtension] @@ -25,27 +23,10 @@ class SchemaExtensionsRunner: def __init__( self, execution_context: ExecutionContext, - extensions: Optional[ - List[Union[Type[SchemaExtension], SchemaExtension]] - ] = None, + extensions: Optional[List[SchemaExtension]] = None, ) -> None: self.execution_context = execution_context - - if not extensions: - extensions = [] - - init_extensions: List[SchemaExtension] = [] - - for extension in extensions: - # If the extension has already been instantiated then set the - # `execution_context` attribute - if isinstance(extension, SchemaExtension): - extension.execution_context = execution_context - init_extensions.append(extension) - else: - init_extensions.append(extension(execution_context=execution_context)) - - self.extensions = init_extensions + self.extensions = extensions or [] def operation(self) -> OperationContextManager: return OperationContextManager(self.extensions) @@ -61,29 +42,19 @@ def executing(self) -> ExecutingContextManager: def get_extensions_results_sync(self) -> Dict[str, Any]: data: Dict[str, Any] = {} - for extension in self.extensions: if inspect.iscoroutinefunction(extension.get_results): msg = "Cannot use async extension hook during sync execution" raise RuntimeError(msg) - data.update(extension.get_results()) # type: ignore return data - async def get_extensions_results(self) -> Dict[str, Any]: + async def get_extensions_results(self, ctx: ExecutionContext) -> Dict[str, Any]: data: Dict[str, Any] = {} for extension in self.extensions: - results = await await_maybe(extension.get_results()) - data.update(results) + data.update(await await_maybe(extension.get_results())) + data.update(ctx.extensions_results) return data - - def as_middleware_manager(self, *additional_middlewares: Any) -> MiddlewareManager: - middlewares = tuple(self.extensions) + additional_middlewares - - return MiddlewareManager(*middlewares) - - -__all__ = ["SchemaExtensionsRunner"] diff --git a/strawberry/http/__init__.py b/strawberry/http/__init__.py index b5fb6f5066..722f82560f 100644 --- a/strawberry/http/__init__.py +++ b/strawberry/http/__init__.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional -from typing_extensions import TypedDict +from typing_extensions import Literal, TypedDict if TYPE_CHECKING: from strawberry.types import ExecutionResult @@ -33,6 +33,7 @@ class GraphQLRequestData: query: Optional[str] variables: Optional[Dict[str, Any]] operation_name: Optional[str] + protocol: Literal["http", "multipart-subscription"] = "http" def parse_query_params(params: Dict[str, str]) -> Dict[str, Any]: diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 86993e6f59..a9997bcc7f 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -14,6 +14,7 @@ Tuple, Union, ) +from typing_extensions import Literal from graphql import GraphQLError @@ -121,6 +122,15 @@ async def execute_operation( assert self.schema + if request_data.protocol == "multipart-subscription": + return await self.schema.subscribe( + request_data.query, # type: ignore + variable_values=request_data.variables, + context_value=context, + root_value=root_value, + operation_name=request_data.operation_name, + ) + return await self.schema.execute( request_data.query, root_value=root_value, @@ -312,14 +322,19 @@ async def parse_http_body( ) -> GraphQLRequestData: content_type, params = parse_content_type(request.content_type or "") + protocol: Literal["http", "multipart-subscription"] = "http" + if request.method == "GET": data = self.parse_query_params(request.query_params) + if self._is_multipart_subscriptions(content_type, params): + protocol = "multipart-subscription" elif "application/json" in content_type: data = self.parse_json(await request.get_body()) elif content_type == "multipart/form-data": data = await self.parse_multipart(request) elif self._is_multipart_subscriptions(content_type, params): data = await self.parse_multipart_subscriptions(request) + protocol = "multipart-subscription" else: raise HTTPException(400, "Unsupported content type") @@ -327,6 +342,7 @@ async def parse_http_body( query=data.get("query"), variables=data.get("variables"), operation_name=data.get("operationName"), + protocol=protocol, ) async def process_result( diff --git a/strawberry/schema/base.py b/strawberry/schema/base.py index 9e040d3856..f34fc98632 100644 --- a/strawberry/schema/base.py +++ b/strawberry/schema/base.py @@ -15,7 +15,6 @@ from strawberry.types import ( ExecutionContext, ExecutionResult, - SubscriptionExecutionResult, ) from strawberry.types.base import StrawberryObjectDefinition from strawberry.types.enum import EnumDefinition @@ -24,6 +23,7 @@ from strawberry.types.union import StrawberryUnion from .config import StrawberryConfig + from .subscribe import SubscriptionResult class BaseSchema(Protocol): @@ -43,7 +43,7 @@ async def execute( root_value: Optional[Any] = None, operation_name: Optional[str] = None, allowed_operation_types: Optional[Iterable[OperationType]] = None, - ) -> Union[ExecutionResult, SubscriptionExecutionResult]: + ) -> ExecutionResult: raise NotImplementedError @abstractmethod @@ -66,7 +66,7 @@ async def subscribe( context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, - ) -> Any: + ) -> SubscriptionResult: raise NotImplementedError @abstractmethod diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index 036f50c540..4090b8e3b6 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -4,40 +4,43 @@ from inspect import isawaitable from typing import ( TYPE_CHECKING, + Awaitable, Callable, Iterable, List, Optional, - Sequence, Tuple, Type, TypedDict, Union, + cast, ) -from graphql import GraphQLError, parse, subscribe +from graphql import ExecutionResult as GraphQLExecutionResult +from graphql import GraphQLError, parse from graphql import execute as original_execute from graphql.validation import validate from strawberry.exceptions import MissingQueryError -from strawberry.extensions.runner import SchemaExtensionsRunner from strawberry.schema.validation_rules.one_of import OneOfInputValidationRule from strawberry.types import ExecutionResult -from strawberry.types.graphql import OperationType +from strawberry.types.execution import PreExecutionError +from strawberry.utils.await_maybe import await_maybe from .exceptions import InvalidOperationTypeError if TYPE_CHECKING: - from typing_extensions import NotRequired, Unpack + from typing_extensions import NotRequired, TypeAlias, Unpack from graphql import ExecutionContext as GraphQLExecutionContext from graphql import GraphQLSchema + from graphql.execution.middleware import MiddlewareManager from graphql.language import DocumentNode from graphql.validation import ASTValidationRule - from strawberry.extensions import SchemaExtension + from strawberry.extensions.runner import SchemaExtensionsRunner from strawberry.types import ExecutionContext - from strawberry.types.execution import SubscriptionExecutionResult + from strawberry.types.graphql import OperationType # duplicated because of https://github.com/mkdocstrings/griffe-typingdoc/issues/7 @@ -45,6 +48,11 @@ class ParseOptions(TypedDict): max_tokens: NotRequired[int] +ProcessErrors: TypeAlias = ( + "Callable[[List[GraphQLError], Optional[ExecutionContext]], None]" +) + + def parse_document(query: str, **kwargs: Unpack[ParseOptions]) -> DocumentNode: return parse(query, **kwargs) @@ -77,112 +85,120 @@ def _run_validation(execution_context: ExecutionContext) -> None: ) +async def _parse_and_validate_async( + context: ExecutionContext, extensions_runner: SchemaExtensionsRunner +) -> Optional[PreExecutionError]: + if not context.query: + raise MissingQueryError() + + async with extensions_runner.parsing(): + try: + if not context.graphql_document: + context.graphql_document = parse_document(context.query) + + except GraphQLError as error: + context.errors = [error] + return PreExecutionError(data=None, errors=[error]) + + except Exception as error: + error = GraphQLError(str(error), original_error=error) + context.errors = [error] + return PreExecutionError(data=None, errors=[error]) + + if context.operation_type not in context.allowed_operations: + raise InvalidOperationTypeError(context.operation_type) + + async with extensions_runner.validation(): + _run_validation(context) + if context.errors: + return PreExecutionError( + data=None, + errors=context.errors, + ) + + return None + + +async def _handle_execution_result( + context: ExecutionContext, + result: Union[GraphQLExecutionResult, ExecutionResult], + extensions_runner: SchemaExtensionsRunner, + process_errors: ProcessErrors, +) -> ExecutionResult: + # Set errors on the context so that it's easier + # to access in extensions + if result.errors: + context.errors = result.errors + + # Run the `Schema.process_errors` function here before + # extensions have a chance to modify them (see the MaskErrors + # extension). That way we can log the original errors but + # only return a sanitised version to the client. + process_errors(result.errors, context) + if isinstance(result, GraphQLExecutionResult): + result = ExecutionResult(data=result.data, errors=result.errors) + result.extensions = await extensions_runner.get_extensions_results(context) + context.result = result # type: ignore # mypy failed to deduce correct type. + return result + + +def _coerce_error(error: Union[GraphQLError, Exception]) -> GraphQLError: + if isinstance(error, GraphQLError): + return error + return GraphQLError(str(error), original_error=error) + + async def execute( schema: GraphQLSchema, - *, - allowed_operation_types: Iterable[OperationType], - extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]], execution_context: ExecutionContext, + extensions_runner: SchemaExtensionsRunner, + process_errors: ProcessErrors, + middleware_manager: MiddlewareManager, execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, - process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], -) -> Union[ExecutionResult, SubscriptionExecutionResult]: - extensions_runner = SchemaExtensionsRunner( - execution_context=execution_context, - extensions=list(extensions), - ) - +) -> ExecutionResult | PreExecutionError: try: async with extensions_runner.operation(): # Note: In graphql-core the schema would be validated here but in # Strawberry we are validating it at initialisation time instead - if not execution_context.query: - raise MissingQueryError() - async with extensions_runner.parsing(): - try: - if not execution_context.graphql_document: - execution_context.graphql_document = parse_document( - execution_context.query, **execution_context.parse_options - ) - - except GraphQLError as exc: - execution_context.errors = [exc] - process_errors([exc], execution_context) - return ExecutionResult( - data=None, - errors=[exc], - extensions=await extensions_runner.get_extensions_results(), - ) - - if execution_context.operation_type not in allowed_operation_types: - raise InvalidOperationTypeError(execution_context.operation_type) - - async with extensions_runner.validation(): - _run_validation(execution_context) - if execution_context.errors: - process_errors(execution_context.errors, execution_context) - return ExecutionResult(data=None, errors=execution_context.errors) + if errors := await _parse_and_validate_async( + execution_context, extensions_runner + ): + return await _handle_execution_result( + execution_context, errors, extensions_runner, process_errors + ) + assert execution_context.graphql_document async with extensions_runner.executing(): if not execution_context.result: - if execution_context.operation_type == OperationType.SUBSCRIPTION: - # TODO: should we process errors here? - # TODO: make our own wrapper? - return await subscribe( # type: ignore - schema, - execution_context.graphql_document, - root_value=execution_context.root_value, - context_value=execution_context.context, - variable_values=execution_context.variables, - operation_name=execution_context.operation_name, - ) - else: - result = original_execute( + res = await await_maybe( + original_execute( schema, execution_context.graphql_document, root_value=execution_context.root_value, - middleware=extensions_runner.as_middleware_manager(), + middleware=middleware_manager, variable_values=execution_context.variables, operation_name=execution_context.operation_name, context_value=execution_context.context, execution_context_class=execution_context_class, ) + ) - if isawaitable(result): - result = await result - - execution_context.result = result - # Also set errors on the execution_context so that it's easier - # to access in extensions - if result.errors: - execution_context.errors = result.errors - - # Run the `Schema.process_errors` function here before - # extensions have a chance to modify them (see the MaskErrors - # extension). That way we can log the original errors but - # only return a sanitised version to the client. - process_errors(result.errors, execution_context) - + else: + res = execution_context.result except (MissingQueryError, InvalidOperationTypeError) as e: raise e except Exception as exc: - error = ( - exc - if isinstance(exc, GraphQLError) - else GraphQLError(str(exc), original_error=exc) - ) - execution_context.errors = [error] - process_errors([error], execution_context) - return ExecutionResult( - data=None, - errors=[error], - extensions=await extensions_runner.get_extensions_results(), + return await _handle_execution_result( + execution_context, + PreExecutionError(data=None, errors=[_coerce_error(exc)]), + extensions_runner, + process_errors, ) - return ExecutionResult( - data=execution_context.result.data, - errors=execution_context.result.errors, - extensions=await extensions_runner.get_extensions_results(), + # return results after all the operation completed. + return await _handle_execution_result( + execution_context, res, extensions_runner, process_errors ) @@ -190,16 +206,12 @@ def execute_sync( schema: GraphQLSchema, *, allowed_operation_types: Iterable[OperationType], - extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]], + extensions_runner: SchemaExtensionsRunner, execution_context: ExecutionContext, execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, - process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], + process_errors: ProcessErrors, + middleware_manager: MiddlewareManager, ) -> ExecutionResult: - extensions_runner = SchemaExtensionsRunner( - execution_context=execution_context, - extensions=list(extensions), - ) - try: with extensions_runner.operation(): # Note: In graphql-core the schema would be validated here but in @@ -214,12 +226,12 @@ def execute_sync( execution_context.query, **execution_context.parse_options ) - except GraphQLError as exc: - execution_context.errors = [exc] - process_errors([exc], execution_context) + except GraphQLError as error: + execution_context.errors = [error] + process_errors([error], execution_context) return ExecutionResult( data=None, - errors=[exc], + errors=[error], extensions=extensions_runner.get_extensions_results_sync(), ) @@ -230,7 +242,11 @@ def execute_sync( _run_validation(execution_context) if execution_context.errors: process_errors(execution_context.errors, execution_context) - return ExecutionResult(data=None, errors=execution_context.errors) + return ExecutionResult( + data=None, + errors=execution_context.errors, + extensions=extensions_runner.get_extensions_results_sync(), + ) with extensions_runner.executing(): if not execution_context.result: @@ -238,7 +254,7 @@ def execute_sync( schema, execution_context.graphql_document, root_value=execution_context.root_value, - middleware=extensions_runner.as_middleware_manager(), + middleware=middleware_manager, variable_values=execution_context.variables, operation_name=execution_context.operation_name, context_value=execution_context.context, @@ -246,13 +262,15 @@ def execute_sync( ) if isawaitable(result): + result = cast(Awaitable[GraphQLExecutionResult], result) # type: ignore[redundant-cast] ensure_future(result).cancel() raise RuntimeError( "GraphQL execution failed to complete synchronously." ) + result = cast(GraphQLExecutionResult, result) # type: ignore[redundant-cast] execution_context.result = result - # Also set errors on the execution_context so that it's easier + # Also set errors on the context so that it's easier # to access in extensions if result.errors: execution_context.errors = result.errors @@ -262,23 +280,17 @@ def execute_sync( # extension). That way we can log the original errors but # only return a sanitised version to the client. process_errors(result.errors, execution_context) - except (MissingQueryError, InvalidOperationTypeError) as e: raise e except Exception as exc: - error = ( - exc - if isinstance(exc, GraphQLError) - else GraphQLError(str(exc), original_error=exc) - ) - execution_context.errors = [error] - process_errors([error], execution_context) + errors = [_coerce_error(exc)] + execution_context.errors = errors + process_errors(errors, execution_context) return ExecutionResult( data=None, - errors=[error], + errors=errors, extensions=extensions_runner.get_extensions_results_sync(), ) - return ExecutionResult( data=execution_context.result.data, errors=execution_context.result.errors, diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index ec4bf0d64a..a8fce095b5 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -1,11 +1,10 @@ from __future__ import annotations import warnings -from functools import lru_cache +from functools import cached_property, lru_cache from typing import ( TYPE_CHECKING, Any, - AsyncIterator, Dict, Iterable, List, @@ -22,18 +21,19 @@ GraphQLNonNull, GraphQLSchema, get_introspection_query, - parse, validate_schema, ) -from graphql.execution import subscribe +from graphql.execution.middleware import MiddlewareManager from graphql.type.directives import specified_directives from strawberry import relay from strawberry.annotation import StrawberryAnnotation +from strawberry.extensions import SchemaExtension from strawberry.extensions.directives import ( DirectivesExtension, DirectivesExtensionSync, ) +from strawberry.extensions.runner import SchemaExtensionsRunner from strawberry.schema.schema_converter import GraphQLCoreConverter from strawberry.schema.types.scalar import DEFAULT_SCALAR_REGISTRY from strawberry.types import ExecutionContext @@ -41,19 +41,17 @@ from strawberry.types.graphql import OperationType from ..printer import print_schema -from ..utils.await_maybe import await_maybe from . import compat from .base import BaseSchema from .config import StrawberryConfig from .execute import execute, execute_sync +from .subscribe import SubscriptionResult, subscribe if TYPE_CHECKING: from graphql import ExecutionContext as GraphQLExecutionContext - from graphql import ExecutionResult as GraphQLExecutionResult from strawberry.directive import StrawberryDirective - from strawberry.extensions import SchemaExtension - from strawberry.types import ExecutionResult, SubscriptionExecutionResult + from strawberry.types import ExecutionResult from strawberry.types.base import StrawberryType from strawberry.types.enum import EnumDefinition from strawberry.types.field import StrawberryField @@ -85,7 +83,7 @@ def __init__( ] = None, schema_directives: Iterable[object] = (), ) -> None: - """Default Schema to be to be used in a Strawberry application. + """Default Schema to be used in a Strawberry application. A GraphQL Schema class used to define the structure and configuration of GraphQL queries, mutations, and subscriptions. @@ -125,6 +123,7 @@ class Query: self.subscription = subscription self.extensions = extensions + self._cached_middleware_manager: MiddlewareManager | None = None self.execution_context_class = execution_context_class self.config = config or StrawberryConfig() @@ -214,15 +213,63 @@ class Query: formatted_errors = "\n\n".join(f"❌ {error.message}" for error in errors) raise ValueError(f"Invalid Schema. Errors:\n\n{formatted_errors}") - def get_extensions( - self, sync: bool = False - ) -> List[Union[Type[SchemaExtension], SchemaExtension]]: - extensions = list(self.extensions) - + def get_extensions(self, sync: bool = False) -> List[SchemaExtension]: + extensions = [] if self.directives: - extensions.append(DirectivesExtensionSync if sync else DirectivesExtension) + extensions = [ + *self.extensions, + DirectivesExtensionSync if sync else DirectivesExtension, + ] + extensions.extend(self.extensions) + return [ + ext if isinstance(ext, SchemaExtension) else ext(execution_context=None) + for ext in extensions + ] + + @cached_property + def _sync_extensions(self) -> List[SchemaExtension]: + return self.get_extensions(sync=True) + + @cached_property + def _async_extensions(self) -> List[SchemaExtension]: + return self.get_extensions(sync=False) + + def create_extensions_runner( + self, execution_context: ExecutionContext, extensions: list[SchemaExtension] + ) -> SchemaExtensionsRunner: + return SchemaExtensionsRunner( + execution_context=execution_context, + extensions=extensions, + ) - return extensions + def _get_middleware_manager( + self, extensions: list[SchemaExtension] + ) -> MiddlewareManager: + # create a middleware manager with all the extensions that implement resolve + if not self._cached_middleware_manager: + self._cached_middleware_manager = MiddlewareManager( + *(ext for ext in extensions if ext._implements_resolve()) + ) + return self._cached_middleware_manager + + def _create_execution_context( + self, + query: Optional[str], + allowed_operation_types: Iterable[OperationType], + variable_values: Optional[Dict[str, Any]] = None, + context_value: Optional[Any] = None, + root_value: Optional[Any] = None, + operation_name: Optional[str] = None, + ) -> ExecutionContext: + return ExecutionContext( + query=query, + schema=self, + allowed_operations=allowed_operation_types, + context=context_value, + root_value=root_value, + variables=variable_values, + provided_operation_name=operation_name, + ) @lru_cache def get_type_by_name( @@ -284,31 +331,33 @@ async def execute( root_value: Optional[Any] = None, operation_name: Optional[str] = None, allowed_operation_types: Optional[Iterable[OperationType]] = None, - ) -> Union[ExecutionResult, SubscriptionExecutionResult]: + ) -> ExecutionResult: if allowed_operation_types is None: allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES - # Create execution context - execution_context = ExecutionContext( + execution_context = self._create_execution_context( query=query, - schema=self, - context=context_value, + allowed_operation_types=allowed_operation_types, + variable_values=variable_values, + context_value=context_value, root_value=root_value, - variables=variable_values, - provided_operation_name=operation_name, + operation_name=operation_name, ) - - result = await execute( + extensions = self.get_extensions() + # TODO (#3571): remove this when we implement execution context as parameter. + for extension in extensions: + extension.execution_context = execution_context + return await execute( self._schema, - extensions=self.get_extensions(), - execution_context_class=self.execution_context_class, execution_context=execution_context, - allowed_operation_types=allowed_operation_types, + extensions_runner=self.create_extensions_runner( + execution_context, extensions + ), process_errors=self._process_errors, + middleware_manager=self._get_middleware_manager(extensions), + execution_context_class=self.execution_context_class, ) - return result - def execute_sync( self, query: Optional[str], @@ -321,44 +370,59 @@ def execute_sync( if allowed_operation_types is None: allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES - execution_context = ExecutionContext( + execution_context = self._create_execution_context( query=query, - schema=self, - context=context_value, + allowed_operation_types=allowed_operation_types, + variable_values=variable_values, + context_value=context_value, root_value=root_value, - variables=variable_values, - provided_operation_name=operation_name, + operation_name=operation_name, ) - - result = execute_sync( + extensions = self._sync_extensions + # TODO (#3571): remove this when we implement execution context as parameter. + for extension in extensions: + extension.execution_context = execution_context + return execute_sync( self._schema, - extensions=self.get_extensions(sync=True), - execution_context_class=self.execution_context_class, execution_context=execution_context, + extensions_runner=self.create_extensions_runner( + execution_context, extensions + ), + execution_context_class=self.execution_context_class, allowed_operation_types=allowed_operation_types, process_errors=self._process_errors, + middleware_manager=self._get_middleware_manager(extensions), ) - return result - async def subscribe( self, - # TODO: make this optional when we support extensions - query: str, + query: Optional[str], variable_values: Optional[Dict[str, Any]] = None, context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, - ) -> Union[AsyncIterator[GraphQLExecutionResult], GraphQLExecutionResult]: - return await await_maybe( - subscribe( - self._schema, - parse(query), - root_value=root_value, - context_value=context_value, - variable_values=variable_values, - operation_name=operation_name, - ) + ) -> SubscriptionResult: + execution_context = self._create_execution_context( + query=query, + allowed_operation_types=(OperationType.SUBSCRIPTION,), + variable_values=variable_values, + context_value=context_value, + root_value=root_value, + operation_name=operation_name, + ) + extensions = self._async_extensions + # TODO (#3571): remove this when we implement execution context as parameter. + for extension in extensions: + extension.execution_context = execution_context + return await subscribe( + self._schema, + execution_context=execution_context, + extensions_runner=self.create_extensions_runner( + execution_context, extensions + ), + process_errors=self._process_errors, + middleware_manager=self._get_middleware_manager(extensions), + execution_context_class=self.execution_context_class, ) def _resolve_node_ids(self) -> None: diff --git a/strawberry/schema/subscribe.py b/strawberry/schema/subscribe.py new file mode 100644 index 0000000000..be8b783cb2 --- /dev/null +++ b/strawberry/schema/subscribe.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Optional, Type, Union + +from graphql import ( + ExecutionResult as OriginalExecutionResult, +) +from graphql.execution import ExecutionContext as GraphQLExecutionContext +from graphql.execution import subscribe as original_subscribe + +from strawberry.types import ExecutionResult +from strawberry.types.execution import ExecutionContext, PreExecutionError +from strawberry.utils import IS_GQL_32 +from strawberry.utils.await_maybe import await_maybe + +from .execute import ( + ProcessErrors, + _coerce_error, + _handle_execution_result, + _parse_and_validate_async, +) + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from graphql.execution.middleware import MiddlewareManager + from graphql.type.schema import GraphQLSchema + + from ..extensions.runner import SchemaExtensionsRunner + +SubscriptionResult: TypeAlias = Union[ + PreExecutionError, AsyncGenerator[ExecutionResult, None] +] + +OriginSubscriptionResult = Union[ + OriginalExecutionResult, + AsyncIterator[OriginalExecutionResult], +] + + +async def _subscribe( + schema: GraphQLSchema, + execution_context: ExecutionContext, + extensions_runner: SchemaExtensionsRunner, + process_errors: ProcessErrors, + middleware_manager: MiddlewareManager, + execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, +) -> AsyncGenerator[Union[PreExecutionError, ExecutionResult], None]: + async with extensions_runner.operation(): + if initial_error := await _parse_and_validate_async( + context=execution_context, + extensions_runner=extensions_runner, + ): + initial_error.extensions = await extensions_runner.get_extensions_results( + execution_context + ) + yield await _handle_execution_result( + execution_context, initial_error, extensions_runner, process_errors + ) + try: + async with extensions_runner.executing(): + assert execution_context.graphql_document is not None + gql_33_kwargs = { + "middleware": middleware_manager, + "execution_context_class": execution_context_class, + } + try: + # Might not be awaitable for pre-execution errors. + aiter_or_result: OriginSubscriptionResult = await await_maybe( + original_subscribe( + schema, + execution_context.graphql_document, + root_value=execution_context.root_value, + variable_values=execution_context.variables, + operation_name=execution_context.operation_name, + context_value=execution_context.context, + **{} if IS_GQL_32 else gql_33_kwargs, # type: ignore[arg-type] + ) + ) + # graphql-core 3.2 doesn't handle some of the pre-execution errors. + # see `test_subscription_immediate_error` + except Exception as exc: + aiter_or_result = OriginalExecutionResult( + data=None, errors=[_coerce_error(exc)] + ) + + # Handle pre-execution errors. + if isinstance(aiter_or_result, OriginalExecutionResult): + yield await _handle_execution_result( + execution_context, + PreExecutionError(data=None, errors=aiter_or_result.errors), + extensions_runner, + process_errors, + ) + else: + try: + async for result in aiter_or_result: + yield await _handle_execution_result( + execution_context, + result, + extensions_runner, + process_errors, + ) + # graphql-core doesn't handle exceptions raised while executing. + except Exception as exc: + yield await _handle_execution_result( + execution_context, + OriginalExecutionResult(data=None, errors=[_coerce_error(exc)]), + extensions_runner, + process_errors, + ) + # catch exceptions raised in `on_execute` hook. + except Exception as exc: + origin_result = OriginalExecutionResult( + data=None, errors=[_coerce_error(exc)] + ) + yield await _handle_execution_result( + execution_context, + origin_result, + extensions_runner, + process_errors, + ) + + +async def subscribe( + schema: GraphQLSchema, + execution_context: ExecutionContext, + extensions_runner: SchemaExtensionsRunner, + process_errors: ProcessErrors, + middleware_manager: MiddlewareManager, + execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, +) -> SubscriptionResult: + asyncgen = _subscribe( + schema, + execution_context, + extensions_runner, + process_errors, + middleware_manager, + execution_context_class, + ) + # GrapQL-core might return an initial error result instead of an async iterator. + # This happens when "there was an immediate error" i.e resolver is not an async iterator. + # To overcome this while maintaining the extension contexts we do this trick. + first = await asyncgen.__anext__() + if isinstance(first, PreExecutionError): + await asyncgen.aclose() + return first + + async def _wrapper() -> AsyncGenerator[ExecutionResult, None]: + yield first + async for result in asyncgen: + yield result + + return _wrapper() diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 1275ecf304..7d19db8e98 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -7,15 +7,13 @@ from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, - AsyncIterator, + Awaitable, Callable, Dict, List, Optional, ) -from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLError, GraphQLSyntaxError, parse from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( @@ -24,11 +22,14 @@ ConnectionInitMessage, ErrorMessage, NextMessage, + NextPayload, PingMessage, PongMessage, SubscribeMessage, SubscribeMessagePayload, ) +from strawberry.types import ExecutionResult +from strawberry.types.execution import PreExecutionError from strawberry.types.graphql import OperationType from strawberry.types.unset import UNSET from strawberry.utils.debug import pretty_print_graphql_operation @@ -38,10 +39,10 @@ from datetime import timedelta from strawberry.schema import BaseSchema + from strawberry.schema.subscribe import SubscriptionResult from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( GraphQLTransportMessage, ) - from strawberry.types import ExecutionResult class BaseGraphQLTransportWSHandler(ABC): @@ -243,10 +244,10 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: if isinstance(context, dict): context["connection_params"] = self.connection_params root_value = await self.get_root_value() - + result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult] # Get an AsyncGenerator yielding the results if operation_type == OperationType.SUBSCRIPTION: - result_source = await self.schema.subscribe( + result_source = self.schema.subscribe( query=message.payload.query, variable_values=message.payload.variables, operation_name=message.payload.operationName, @@ -254,29 +255,16 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: root_value=root_value, ) else: - # create AsyncGenerator returning a single result - async def get_result_source() -> AsyncIterator[ExecutionResult]: - yield await self.schema.execute( # type: ignore - query=message.payload.query, - variable_values=message.payload.variables, - context_value=context, - root_value=root_value, - operation_name=message.payload.operationName, - ) - - result_source = get_result_source() + result_source = self.schema.execute( + query=message.payload.query, + variable_values=message.payload.variables, + context_value=context, + root_value=root_value, + operation_name=message.payload.operationName, + ) operation = Operation(self, message.id, operation_type) - # Handle initial validation errors - if isinstance(result_source, GraphQLExecutionResult): - assert operation_type == OperationType.SUBSCRIPTION - assert result_source.errors - payload = [err.formatted for err in result_source.errors] - await self.send_message(ErrorMessage(id=message.id, payload=payload)) - self.schema.process_errors(result_source.errors) - return - # Create task to handle this subscription, reserve the operation ID operation.task = asyncio.create_task( self.operation_task(result_source, operation) @@ -284,65 +272,37 @@ async def get_result_source() -> AsyncIterator[ExecutionResult]: self.operations[message.id] = operation async def operation_task( - self, result_source: AsyncGenerator, operation: Operation + self, + result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult], + operation: Operation, ) -> None: """The operation task's top level method. Cleans-up and de-registers the operation once it is done.""" # TODO: Handle errors in this method using self.handle_task_exception() try: - await self.handle_async_results(result_source, operation) - except BaseException: # pragma: no cover - # cleanup in case of something really unexpected - # wait for generator to be closed to ensure that any existing - # 'finally' statement is called - with suppress(RuntimeError): - await result_source.aclose() - if operation.id in self.operations: - del self.operations[operation.id] + first_res_or_agen = await result_source + # that's an immediate error we should end the operation + # without a COMPLETE message + if isinstance(first_res_or_agen, PreExecutionError): + assert first_res_or_agen.errors + await operation.send_initial_errors(first_res_or_agen.errors) + # that's a mutation / query result + elif isinstance(first_res_or_agen, ExecutionResult): + await operation.send_next(first_res_or_agen) + await operation.send_message(CompleteMessage(id=operation.id)) + else: + async for result in first_res_or_agen: + await operation.send_next(result) + await operation.send_message(CompleteMessage(id=operation.id)) + + except BaseException as e: # pragma: no cover + self.operations.pop(operation.id, None) raise - else: - await operation.send_message(CompleteMessage(id=operation.id)) finally: # add this task to a list to be reaped later task = asyncio.current_task() assert task is not None self.completed_tasks.append(task) - async def handle_async_results( - self, - result_source: AsyncGenerator, - operation: Operation, - ) -> None: - try: - async for result in result_source: - if ( - result.errors - and operation.operation_type != OperationType.SUBSCRIPTION - ): - error_payload = [err.formatted for err in result.errors] - error_message = ErrorMessage(id=operation.id, payload=error_payload) - await operation.send_message(error_message) - # don't need to call schema.process_errors() here because - # it was already done by schema.execute() - return - else: - next_payload = {"data": result.data} - if result.errors: - self.schema.process_errors(result.errors) - next_payload["errors"] = [ - err.formatted for err in result.errors - ] - next_message = NextMessage(id=operation.id, payload=next_payload) - await operation.send_message(next_message) - except Exception as error: - # GraphQLErrors are handled by graphql-core and included in the - # ExecutionResult - error = GraphQLError(str(error), original_error=error) - error_payload = [error.formatted] - error_message = ErrorMessage(id=operation.id, payload=error_payload) - await operation.send_message(error_message) - self.schema.process_errors([error]) - return - def forget_id(self, id: str) -> None: # de-register the operation id making it immediately available # for re-use @@ -401,5 +361,21 @@ async def send_message(self, message: GraphQLTransportMessage) -> None: self.handler.forget_id(self.id) await self.handler.send_message(message) + async def send_initial_errors(self, errors: list[GraphQLError]) -> None: + # Initial errors see https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#error + # "This can occur before execution starts, + # usually due to validation errors, or during the execution of the request" + await self.send_message( + ErrorMessage(id=self.id, payload=[err.formatted for err in errors]) + ) + + async def send_next(self, execution_result: ExecutionResult) -> None: + next_payload: NextPayload = {"data": execution_result.data} + if execution_result.errors: + next_payload["errors"] = [err.formatted for err in execution_result.errors] + if execution_result.extensions: + next_payload["extensions"] = execution_result.extensions + await self.send_message(NextMessage(id=self.id, payload=next_payload)) + __all__ = ["BaseGraphQLTransportWSHandler", "Operation"] diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/types.py b/strawberry/subscriptions/protocols/graphql_transport_ws/types.py index 76260fabee..300f9204a7 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/types.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/types.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import asdict, dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict from strawberry.types.unset import UNSET @@ -68,12 +68,20 @@ class SubscribeMessage(GraphQLTransportMessage): type: str = "subscribe" +class NextPayload(TypedDict, total=False): + data: Any + + # Optional list of formatted graphql.GraphQLError objects + errors: Optional[List[GraphQLFormattedError]] + extensions: Optional[Dict[str, Any]] + + @dataclass class NextMessage(GraphQLTransportMessage): """Direction: Server -> Client.""" id: str - payload: Dict[str, Any] # TODO: shape like FormattedExecutionResult + payload: NextPayload type: str = "next" def as_dict(self) -> dict: diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 382167304f..0451e0934b 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -3,10 +3,15 @@ import asyncio from abc import ABC, abstractmethod from contextlib import suppress -from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, cast - -from graphql import ExecutionResult as GraphQLExecutionResult -from graphql import GraphQLError +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Awaitable, + Dict, + Optional, + cast, +) from strawberry.subscriptions.protocols.graphql_ws import ( GQL_COMPLETE, @@ -20,12 +25,15 @@ GQL_START, GQL_STOP, ) +from strawberry.types.execution import ExecutionResult, PreExecutionError from strawberry.utils.debug import pretty_print_graphql_operation if TYPE_CHECKING: from strawberry.schema import BaseSchema + from strawberry.schema.subscribe import SubscriptionResult from strawberry.subscriptions.protocols.graphql_ws.types import ( ConnectionInitPayload, + DataPayload, OperationMessage, OperationMessagePayload, StartPayload, @@ -123,28 +131,14 @@ async def handle_start(self, message: OperationMessage) -> None: if self.debug: pretty_print_graphql_operation(operation_name, query, variables) - try: - result_source = await self.schema.subscribe( - query=query, - variable_values=variables, - operation_name=operation_name, - context_value=context, - root_value=root_value, - ) - except GraphQLError as error: - error_payload = error.formatted - await self.send_message(GQL_ERROR, operation_id, error_payload) - self.schema.process_errors([error]) - return - - if isinstance(result_source, GraphQLExecutionResult): - assert result_source.errors - error_payload = result_source.errors[0].formatted - await self.send_message(GQL_ERROR, operation_id, error_payload) - self.schema.process_errors(result_source.errors) - return + result_source = self.schema.subscribe( + query=query, + variable_values=variables, + operation_name=operation_name, + context_value=context, + root_value=root_value, + ) - self.subscriptions[operation_id] = result_source result_handler = self.handle_async_results(result_source, operation_id) self.tasks[operation_id] = asyncio.create_task(result_handler) @@ -160,39 +154,28 @@ async def handle_keep_alive(self) -> None: async def handle_async_results( self, - result_source: AsyncGenerator, + result_source: Awaitable[SubscriptionResult], operation_id: str, ) -> None: try: - async for result in result_source: - payload = {"data": result.data} - if result.errors: - payload["errors"] = [err.formatted for err in result.errors] - await self.send_message(GQL_DATA, operation_id, payload) - # log errors after send_message to prevent potential - # slowdown of sending result - if result.errors: - self.schema.process_errors(result.errors) + agen_or_err = await result_source + if isinstance(agen_or_err, PreExecutionError): + assert agen_or_err.errors + error_payload = agen_or_err.errors[0].formatted + await self.send_message(GQL_ERROR, operation_id, error_payload) + else: + self.subscriptions[operation_id] = agen_or_err + async for result in agen_or_err: + await self.send_data(result, operation_id) + await self.send_message(GQL_COMPLETE, operation_id, None) except asyncio.CancelledError: - # CancelledErrors are expected during task cleanup. - pass - except Exception as error: - # GraphQLErrors are handled by graphql-core and included in the - # ExecutionResult - error = GraphQLError(str(error), original_error=error) - await self.send_message( - GQL_DATA, - operation_id, - {"data": None, "errors": [error.formatted]}, - ) - self.schema.process_errors([error]) - - await self.send_message(GQL_COMPLETE, operation_id, None) + await self.send_message(GQL_COMPLETE, operation_id, None) async def cleanup_operation(self, operation_id: str) -> None: - with suppress(RuntimeError): - await self.subscriptions[operation_id].aclose() - del self.subscriptions[operation_id] + if operation_id in self.subscriptions: + with suppress(RuntimeError): + await self.subscriptions[operation_id].aclose() + del self.subscriptions[operation_id] self.tasks[operation_id].cancel() with suppress(BaseException): @@ -210,5 +193,15 @@ async def send_message( data["payload"] = payload await self.send_json(data) + async def send_data( + self, execution_result: ExecutionResult, operation_id: str + ) -> None: + payload: DataPayload = {"data": execution_result.data} + if execution_result.errors: + payload["errors"] = [err.formatted for err in execution_result.errors] + if execution_result.extensions: + payload["extensions"] = execution_result.extensions + await self.send_message(GQL_DATA, operation_id, payload) + __all__ = ["BaseGraphQLWSHandler"] diff --git a/strawberry/subscriptions/protocols/graphql_ws/types.py b/strawberry/subscriptions/protocols/graphql_ws/types.py index 4b7d35d278..5ada0b100a 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/types.py +++ b/strawberry/subscriptions/protocols/graphql_ws/types.py @@ -20,6 +20,7 @@ class DataPayload(TypedDict, total=False): # Optional list of formatted graphql.GraphQLError objects errors: Optional[List[GraphQLFormattedError]] + extensions: Optional[Dict[str, Any]] ErrorPayload = GraphQLFormattedError diff --git a/strawberry/types/execution.py b/strawberry/types/execution.py index e1f88dbf21..94f3982a6f 100644 --- a/strawberry/types/execution.py +++ b/strawberry/types/execution.py @@ -5,6 +5,7 @@ TYPE_CHECKING, Any, Dict, + Iterable, List, Optional, Tuple, @@ -34,6 +35,7 @@ class ExecutionContext: query: Optional[str] schema: Schema + allowed_operations: Iterable[OperationType] context: Any = None variables: Optional[Dict[str, Any]] = None parse_options: ParseOptions = dataclasses.field( @@ -52,6 +54,7 @@ class ExecutionContext: graphql_document: Optional[DocumentNode] = None errors: Optional[List[GraphQLError]] = None result: Optional[GraphQLExecutionResult] = None + extensions_results: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self, provided_operation_name: str | None) -> None: self._provided_operation_name = provided_operation_name @@ -93,6 +96,18 @@ class ExecutionResult: extensions: Optional[Dict[str, Any]] = None +@dataclasses.dataclass +class PreExecutionError(ExecutionResult): + """Differentiate between a normal execution result and an immediate error. + + Immediate errors are errors that occur before the execution phase i.e validation errors, + or any other error that occur before we interact with resolvers. + + These errors are required by `graphql-ws-transport` protocol in order to close the operation + right away once the error is encountered. + """ + + class ParseOptions(TypedDict): max_tokens: NotRequired[int] diff --git a/strawberry/utils/__init__.py b/strawberry/utils/__init__.py index e69de29bb2..fa67e9654e 100644 --- a/strawberry/utils/__init__.py +++ b/strawberry/utils/__init__.py @@ -0,0 +1,5 @@ +import graphql +from packaging.version import Version + +IS_GQL_33 = Version(graphql.__version__) >= Version("3.3.0a") +IS_GQL_32 = not IS_GQL_33 diff --git a/tests/channels/test_layers.py b/tests/channels/test_layers.py index b868a0826b..40a3eed7b6 100644 --- a/tests/channels/test_layers.py +++ b/tests/channels/test_layers.py @@ -94,7 +94,11 @@ async def test_channel_listen(ws: WebsocketCommunicator): assert ( response == NextMessage( - id="sub1", payload={"data": {"listener": "Hello there!"}} + id="sub1", + payload={ + "data": {"listener": "Hello there!"}, + "extensions": {"example": "example"}, + }, ).as_dict() ) @@ -140,7 +144,11 @@ async def test_channel_listen_with_confirmation(ws: WebsocketCommunicator): assert ( response == NextMessage( - id="sub1", payload={"data": {"listenerWithConfirmation": "Hello there!"}} + id="sub1", + payload={ + "data": {"listenerWithConfirmation": "Hello there!"}, + "extensions": {"example": "example"}, + }, ).as_dict() ) @@ -318,7 +326,11 @@ async def test_channel_listen_group(ws: WebsocketCommunicator): assert ( response == NextMessage( - id="sub1", payload={"data": {"listener": "Hello there!"}} + id="sub1", + payload={ + "data": {"listener": "Hello there!"}, + "extensions": {"example": "example"}, + }, ).as_dict() ) @@ -334,7 +346,11 @@ async def test_channel_listen_group(ws: WebsocketCommunicator): assert ( response == NextMessage( - id="sub1", payload={"data": {"listener": "Hello there!"}} + id="sub1", + payload={ + "data": {"listener": "Hello there!"}, + "extensions": {"example": "example"}, + }, ).as_dict() ) @@ -380,7 +396,11 @@ async def test_channel_listen_group_cm(ws: WebsocketCommunicator): assert ( response == NextMessage( - id="sub1", payload={"data": {"listenerWithConfirmation": "Hello there!"}} + id="sub1", + payload={ + "data": {"listenerWithConfirmation": "Hello there!"}, + "extensions": {"example": "example"}, + }, ).as_dict() ) @@ -396,7 +416,11 @@ async def test_channel_listen_group_cm(ws: WebsocketCommunicator): assert ( response == NextMessage( - id="sub1", payload={"data": {"listenerWithConfirmation": "Hello there!"}} + id="sub1", + payload={ + "data": {"listenerWithConfirmation": "Hello there!"}, + "extensions": {"example": "example"}, + }, ).as_dict() ) diff --git a/tests/conftest.py b/tests/conftest.py index ab49ccbec5..6e1187b042 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,9 +2,10 @@ import sys from typing import Any, List, Tuple -import graphql import pytest +from strawberry.utils import IS_GQL_32 + def pytest_emoji_xfailed(config: pytest.Config) -> Tuple[str, str]: return "🤷‍♂️ ", "XFAIL 🤷‍♂️ " @@ -51,9 +52,6 @@ def pytest_ignore_collect( return True -IS_GQL_32 = "3.3" not in graphql.__version__ - - def skip_if_gql_32(reason: str) -> pytest.MarkDecorator: return pytest.mark.skipif( IS_GQL_32, diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index 52e08c6f2d..9f6f5f18a7 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -70,6 +70,13 @@ async def test_multipart_subscription( data = [d async for d in response.streaming_json()] - assert data == [{"payload": {"data": {"echo": "Hello world"}}}] + assert data == [ + { + "payload": { + "data": {"echo": "Hello world"}, + "extensions": {"example": "example"}, + } + } + ] assert response.status_code == 200 diff --git a/tests/http/test_query.py b/tests/http/test_query.py index 85a9f46889..68b7f0e87f 100644 --- a/tests/http/test_query.py +++ b/tests/http/test_query.py @@ -54,8 +54,9 @@ async def test_calls_handle_errors( { "message": "Cannot query field 'hey' on type 'Query'.", "locations": [{"line": 1, "column": 3}], - } + }, ], + "extensions": {"example": "example"}, } call_args = async_mock.call_args[0] if async_mock.called else sync_mock.call_args[0] diff --git a/tests/schema/extensions/schema_extensions/__init__.py b/tests/schema/extensions/schema_extensions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/schema/extensions/schema_extensions/conftest.py b/tests/schema/extensions/schema_extensions/conftest.py new file mode 100644 index 0000000000..15023cea80 --- /dev/null +++ b/tests/schema/extensions/schema_extensions/conftest.py @@ -0,0 +1,123 @@ +import contextlib +import dataclasses +import enum +from typing import Any, AsyncGenerator, List, Type + +import pytest + +import strawberry +from strawberry.extensions import SchemaExtension + + +@dataclasses.dataclass +class SchemaHelper: + query_type: type + subscription_type: type + query: str + subscription: str + + +class ExampleExtension(SchemaExtension): + def __init_subclass__(cls, **kwargs: Any): + super().__init_subclass__(**kwargs) + cls.called_hooks = [] + + expected = [ + "on_operation Entered", + "on_parse Entered", + "on_parse Exited", + "on_validate Entered", + "on_validate Exited", + "on_execute Entered", + "resolve", + "resolve", + "on_execute Exited", + "on_operation Exited", + "get_results", + ] + called_hooks: List[str] + + @classmethod + def assert_expected(cls) -> None: + assert cls.called_hooks == cls.expected + + +@pytest.fixture() +def default_query_types_and_query() -> SchemaHelper: + @strawberry.type + class Person: + name: str = "Jess" + + @strawberry.type + class Query: + @strawberry.field + def person(self) -> Person: + return Person() + + @strawberry.type + class Subscription: + @strawberry.subscription + async def count(self) -> AsyncGenerator[int, None]: + for i in range(5): + yield i + + subscription = "subscription TestSubscribe { count }" + query = "query TestQuery { person { name } }" + return SchemaHelper( + query_type=Query, + query=query, + subscription_type=Subscription, + subscription=subscription, + ) + + +class ExecType(enum.Enum): + SYNC = enum.auto() + ASYNC = enum.auto() + + def is_async(self) -> bool: + return self == ExecType.ASYNC + + +@pytest.fixture(params=[ExecType.ASYNC, ExecType.SYNC]) +def exec_type(request: pytest.FixtureRequest) -> ExecType: + return request.param + + +@contextlib.contextmanager +def hook_wrap(list_: List[str], hook_name: str): + list_.append(f"{hook_name} Entered") + try: + yield + finally: + list_.append(f"{hook_name} Exited") + + +@pytest.fixture() +def async_extension() -> Type[ExampleExtension]: + class MyExtension(ExampleExtension): + async def on_operation(self): + with hook_wrap(self.called_hooks, SchemaExtension.on_operation.__name__): + yield + + async def on_validate(self): + with hook_wrap(self.called_hooks, SchemaExtension.on_validate.__name__): + yield + + async def on_parse(self): + with hook_wrap(self.called_hooks, SchemaExtension.on_parse.__name__): + yield + + async def on_execute(self): + with hook_wrap(self.called_hooks, SchemaExtension.on_execute.__name__): + yield + + async def get_results(self): + self.called_hooks.append("get_results") + return {"example": "example"} + + async def resolve(self, _next, root, info, *args: str, **kwargs: Any): + self.called_hooks.append("resolve") + return _next(root, info, *args, **kwargs) + + return MyExtension diff --git a/tests/schema/extensions/test_extensions.py b/tests/schema/extensions/schema_extensions/test_extensions.py similarity index 85% rename from tests/schema/extensions/test_extensions.py rename to tests/schema/extensions/schema_extensions/test_extensions.py index fba3d400b4..648fa016fc 100644 --- a/tests/schema/extensions/test_extensions.py +++ b/tests/schema/extensions/schema_extensions/test_extensions.py @@ -1,8 +1,7 @@ import contextlib -import dataclasses import json import warnings -from typing import Any, List, Optional, Set, Type +from typing import Any, List, Optional, Type from unittest.mock import patch import pytest @@ -14,6 +13,8 @@ from strawberry.exceptions import StrawberryGraphQLError from strawberry.extensions import SchemaExtension +from .conftest import ExampleExtension, ExecType, SchemaHelper, hook_wrap + def test_base_extension(): @strawberry.type @@ -184,41 +185,6 @@ def hi(self) -> str: assert root_value == "ROOT" -@dataclasses.dataclass -class DefaultSchemaQuery: - query_type: type - query: str - - -class ExampleExtension(SchemaExtension): - def __init_subclass__(cls, **kwargs: Any): - super().__init_subclass__(**kwargs) - cls.called_hooks = set() - - expected = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - called_hooks: Set[int] - - @classmethod - def perform_test(cls) -> None: - assert cls.called_hooks == cls.expected - - -@pytest.fixture() -def default_query_types_and_query() -> DefaultSchemaQuery: - @strawberry.type - class Person: - name: str = "Jess" - - @strawberry.type - class Query: - @strawberry.field - def person(self) -> Person: - return Person() - - query = "query TestQuery { person { name } }" - return DefaultSchemaQuery(query_type=Query, query=query) - - def test_can_initialize_extension(default_query_types_and_query): class CustomizableExtension(SchemaExtension): def __init__(self, arg: int): @@ -239,76 +205,40 @@ def on_operation(self): assert res.data == {"override": 20} -@pytest.fixture() -def async_extension() -> Type[ExampleExtension]: - class MyExtension(ExampleExtension): - async def on_operation(self): - self.called_hooks.add(1) - yield - self.called_hooks.add(2) - - async def on_validate(self): - self.called_hooks.add(3) - yield - self.called_hooks.add(4) - - async def on_parse(self): - self.called_hooks.add(5) - yield - self.called_hooks.add(6) - - async def on_execute(self): - self.called_hooks.add(7) - yield - self.called_hooks.add(8) - - async def get_results(self): - self.called_hooks.add(9) - return {"example": "example"} - - async def resolve(self, _next, root, info, *args: str, **kwargs: Any): - self.called_hooks.add(10) - return _next(root, info, *args, **kwargs) - - return MyExtension - - @pytest.fixture() def sync_extension() -> Type[ExampleExtension]: class MyExtension(ExampleExtension): def on_operation(self): - self.called_hooks.add(1) - yield - self.called_hooks.add(2) + with hook_wrap(self.called_hooks, SchemaExtension.on_operation.__name__): + yield def on_validate(self): - self.called_hooks.add(3) - yield - self.called_hooks.add(4) + with hook_wrap(self.called_hooks, SchemaExtension.on_validate.__name__): + yield def on_parse(self): - self.called_hooks.add(5) - yield - self.called_hooks.add(6) + with hook_wrap(self.called_hooks, SchemaExtension.on_parse.__name__): + yield def on_execute(self): - self.called_hooks.add(7) - yield - self.called_hooks.add(8) + with hook_wrap(self.called_hooks, SchemaExtension.on_execute.__name__): + yield def get_results(self): - self.called_hooks.add(9) + self.called_hooks.append("get_results") return {"example": "example"} def resolve(self, _next, root, info, *args: str, **kwargs: Any): - self.called_hooks.add(10) + self.called_hooks.append("resolve") return _next(root, info, *args, **kwargs) return MyExtension @pytest.mark.asyncio -async def test_async_extension_hooks(default_query_types_and_query, async_extension): +async def test_async_extension_hooks( + default_query_types_and_query: SchemaHelper, async_extension: Type[ExampleExtension] +): schema = strawberry.Schema( query=default_query_types_and_query.query_type, extensions=[async_extension] ) @@ -316,7 +246,7 @@ async def test_async_extension_hooks(default_query_types_and_query, async_extens result = await schema.execute(default_query_types_and_query.query) assert result.errors is None - async_extension.perform_test() + async_extension.assert_expected() @pytest.mark.asyncio @@ -325,14 +255,12 @@ async def test_mixed_sync_and_async_extension_hooks( ): class MyExtension(sync_extension): async def on_operation(self): - self.called_hooks.add(1) - yield - self.called_hooks.add(2) + with hook_wrap(self.called_hooks, SchemaExtension.on_operation.__name__): + yield async def on_parse(self): - self.called_hooks.add(5) - yield - self.called_hooks.add(6) + with hook_wrap(self.called_hooks, SchemaExtension.on_parse.__name__): + yield @strawberry.type class Person: @@ -349,7 +277,7 @@ def person(self) -> Person: ) result = await schema.execute(default_query_types_and_query.query) assert result.errors is None - MyExtension.perform_test() + MyExtension.assert_expected() async def test_execution_order(default_query_types_and_query): @@ -432,18 +360,21 @@ async def test_sync_extension_hooks(default_query_types_and_query, sync_extensio result = schema.execute_sync(default_query_types_and_query.query) assert result.errors is None - sync_extension.perform_test() + sync_extension.assert_expected() async def test_extension_no_yield(default_query_types_and_query): class SyncExt(ExampleExtension): - expected = {1, 2} + expected = [ + f"{SchemaExtension.on_operation.__name__} Entered", + f"{SchemaExtension.on_parse.__name__} Entered", + ] def on_operation(self): - self.called_hooks.add(1) + self.called_hooks.append(self.__class__.expected[0]) async def on_parse(self): - self.called_hooks.add(2) + self.called_hooks.append(self.__class__.expected[1]) schema = strawberry.Schema( query=default_query_types_and_query.query_type, extensions=[SyncExt] @@ -452,7 +383,7 @@ async def on_parse(self): result = await schema.execute(default_query_types_and_query.query) assert result.errors is None - SyncExt.perform_test() + SyncExt.assert_expected() def test_raise_if_defined_both_legacy_and_new_style(default_query_types_and_query): @@ -465,7 +396,6 @@ def on_executing_start(self): ... schema = strawberry.Schema( query=default_query_types_and_query.query_type, extensions=[WrongUsageExtension] ) - result = schema.execute_sync(default_query_types_and_query.query) assert len(result.errors) == 1 assert isinstance(result.errors[0].original_error, ValueError) @@ -481,28 +411,40 @@ async def test_legacy_extension_supported(): class CompatExtension(ExampleExtension): async def on_request_start(self): - self.called_hooks.add(1) + self.called_hooks.append( + f"{SchemaExtension.on_operation.__name__} Entered" + ) async def on_request_end(self): - self.called_hooks.add(2) + self.called_hooks.append( + f"{SchemaExtension.on_operation.__name__} Exited" + ) async def on_validation_start(self): - self.called_hooks.add(3) + self.called_hooks.append( + f"{SchemaExtension.on_validate.__name__} Entered" + ) async def on_validation_end(self): - self.called_hooks.add(4) + self.called_hooks.append( + f"{SchemaExtension.on_validate.__name__} Exited" + ) async def on_parsing_start(self): - self.called_hooks.add(5) + self.called_hooks.append(f"{SchemaExtension.on_parse.__name__} Entered") async def on_parsing_end(self): - self.called_hooks.add(6) + self.called_hooks.append(f"{SchemaExtension.on_parse.__name__} Exited") def on_executing_start(self): - self.called_hooks.add(7) + self.called_hooks.append( + f"{SchemaExtension.on_execute.__name__} Entered" + ) def on_executing_end(self): - self.called_hooks.add(8) + self.called_hooks.append( + f"{SchemaExtension.on_execute.__name__} Exited" + ) @strawberry.type class Person: @@ -520,7 +462,9 @@ def person(self) -> Person: result = await schema.execute(query) assert result.errors is None - assert CompatExtension.called_hooks == {1, 2, 3, 4, 5, 6, 7, 8} + assert CompatExtension.called_hooks == list( + filter(lambda x: x.startswith("on_"), ExampleExtension.expected) + ) assert "Event driven styled extensions for" in w[0].message.args[0] @@ -533,19 +477,27 @@ async def test_legacy_only_start(): ) class CompatExtension(ExampleExtension): - expected = {1, 2, 3, 4} + expected = list( + filter(lambda x: x.endswith(" Entered"), ExampleExtension.expected) + ) async def on_request_start(self): - self.called_hooks.add(1) + self.called_hooks.append( + f"{SchemaExtension.on_operation.__name__} Entered" + ) async def on_validation_start(self): - self.called_hooks.add(2) + self.called_hooks.append( + f"{SchemaExtension.on_validate.__name__} Entered" + ) async def on_parsing_start(self): - self.called_hooks.add(3) + self.called_hooks.append(f"{SchemaExtension.on_parse.__name__} Entered") def on_executing_start(self): - self.called_hooks.add(4) + self.called_hooks.append( + f"{SchemaExtension.on_execute.__name__} Entered" + ) @strawberry.type class Person: @@ -563,7 +515,7 @@ def person(self) -> Person: result = await schema.execute(query) assert result.errors is None - assert CompatExtension.called_hooks == {1, 2, 3, 4} + CompatExtension.assert_expected() assert "Event driven styled extensions for" in w[0].message.args[0] @@ -576,17 +528,27 @@ async def test_legacy_only_end(): ) class CompatExtension(ExampleExtension): + expected = list( + filter(lambda x: x.endswith(" Exited"), ExampleExtension.expected) + ) + async def on_request_end(self): - self.called_hooks.add(1) + self.called_hooks.append( + f"{SchemaExtension.on_operation.__name__} Exited" + ) async def on_validation_end(self): - self.called_hooks.add(2) + self.called_hooks.append( + f"{SchemaExtension.on_validate.__name__} Exited" + ) async def on_parsing_end(self): - self.called_hooks.add(3) + self.called_hooks.append(f"{SchemaExtension.on_parse.__name__} Exited") def on_executing_end(self): - self.called_hooks.add(4) + self.called_hooks.append( + f"{SchemaExtension.on_execute.__name__} Exited" + ) @strawberry.type class Person: @@ -604,7 +566,7 @@ def person(self) -> Person: result = await schema.execute(query) assert result.errors is None - assert CompatExtension.called_hooks == {1, 2, 3, 4} + CompatExtension.assert_expected() assert "Event driven styled extensions for" in w[0].message.args[0] @@ -754,7 +716,9 @@ def ping(self) -> str: assert extension.called_hooks == expected_hooks -async def test_generic_exceptions_get_wrapped_in_a_graphql_error(): +async def test_generic_exceptions_get_wrapped_in_a_graphql_error( + exec_type: ExecType, +) -> None: exception = Exception("This should be wrapped in a GraphQL error") class MyExtension(SchemaExtension): @@ -767,19 +731,20 @@ class Query: schema = strawberry.Schema(query=Query, extensions=[MyExtension]) query = "query { ping }" + if exec_type.is_async(): + res = await schema.execute(query) + else: + res = schema.execute_sync(query) - sync_result = schema.execute_sync(query) - assert len(sync_result.errors) == 1 - assert isinstance(sync_result.errors[0], GraphQLError) - assert sync_result.errors[0].original_error == exception - - async_result = await schema.execute(query) - assert len(async_result.errors) == 1 - assert isinstance(async_result.errors[0], GraphQLError) - assert async_result.errors[0].original_error == exception + res = await schema.execute(query) + assert len(res.errors) == 1 + assert isinstance(res.errors[0], GraphQLError) + assert res.errors[0].original_error == exception -async def test_graphql_errors_get_not_wrapped_in_a_graphql_error(): +async def test_graphql_errors_get_not_wrapped_in_a_graphql_error( + exec_type: ExecType, +) -> None: exception = GraphQLError("This should not be wrapped in a GraphQL error") class MyExtension(SchemaExtension): @@ -792,20 +757,17 @@ class Query: schema = strawberry.Schema(query=Query, extensions=[MyExtension]) query = "query { ping }" - - sync_result = schema.execute_sync(query) - assert len(sync_result.errors) == 1 - assert sync_result.errors[0] == exception - assert sync_result.errors[0].original_error is None - - async_result = await schema.execute(query) - assert len(async_result.errors) == 1 - assert async_result.errors[0] == exception - assert async_result.errors[0].original_error is None + if exec_type.is_async(): + res = await schema.execute(query) + else: + res = schema.execute_sync(query) + assert len(res.errors) == 1 + assert res.errors[0] == exception + assert res.errors[0].original_error is None @pytest.mark.asyncio -async def test_non_parsing_errors_are_not_swallowed_by_parsing_hooks(): +async def test_dont_swallow_errors_in_parsing_hooks(): class MyExtension(SchemaExtension): def on_parse(self): raise Exception("This shouldn't be swallowed") @@ -1185,22 +1147,14 @@ async def hi(self) -> str: assert result.data == {"hi": "👋"} -def test_raise_if_hook_is_not_callable(): +def test_raise_if_hook_is_not_callable(default_query_types_and_query: SchemaHelper): class MyExtension(SchemaExtension): on_operation = "ABC" # type: ignore - @strawberry.type - class Query: - @strawberry.field - def hi(self) -> str: - return "👋" - - schema = strawberry.Schema(query=Query, extensions=[MyExtension]) - - # Query not set on input - query = "{ hi }" - - result = schema.execute_sync(query) + schema = strawberry.Schema( + query=default_query_types_and_query.query_type, extensions=[MyExtension] + ) + result = schema.execute_sync(default_query_types_and_query.query) assert len(result.errors) == 1 assert isinstance(result.errors[0].original_error, ValueError) assert result.errors[0].message.startswith("Hook on_operation on <") diff --git a/tests/schema/extensions/schema_extensions/test_subscription.py b/tests/schema/extensions/schema_extensions/test_subscription.py new file mode 100644 index 0000000000..bd20dc6c83 --- /dev/null +++ b/tests/schema/extensions/schema_extensions/test_subscription.py @@ -0,0 +1,174 @@ +from typing import AsyncGenerator, Type + +import pytest + +import strawberry +from strawberry.extensions import SchemaExtension +from strawberry.types.execution import ExecutionResult +from tests.conftest import skip_if_gql_32 + +from .conftest import ExampleExtension, SchemaHelper + +pytestmark = skip_if_gql_32( + "We only fully support schema extensions in graphql-core 3.3+" +) + + +def assert_agen(obj) -> AsyncGenerator[ExecutionResult, None]: + assert isinstance(obj, AsyncGenerator) + return obj + + +async def test_subscription_success_many_fields( + default_query_types_and_query: SchemaHelper, async_extension: Type[ExampleExtension] +) -> None: + schema = strawberry.Schema( + query=default_query_types_and_query.query_type, + subscription=default_query_types_and_query.subscription_type, + extensions=[async_extension], + ) + subscription_per_yield_hooks_exp = [] + for _ in range(5): # number of yields in the subscription + subscription_per_yield_hooks_exp.extend(["resolve", "get_results"]) + + async_extension.expected = [ + "on_operation Entered", + "on_parse Entered", + "on_parse Exited", + "on_validate Entered", + "on_validate Exited", + "on_execute Entered", + "on_execute Exited", + *subscription_per_yield_hooks_exp, + # last one doesn't call the "resolve" / "get_results" hooks because + # the subscription is done + "on_operation Exited", + ] + async for res in assert_agen( + await schema.subscribe(default_query_types_and_query.subscription) + ): + assert res.data + assert not res.errors + + async_extension.assert_expected() + + +async def test_subscription_extension_handles_immediate_errors( + default_query_types_and_query: SchemaHelper, async_extension: Type[ExampleExtension] +) -> None: + @strawberry.type() + class Subscription: + @strawberry.subscription() + async def count(self) -> AsyncGenerator[int, None]: + raise ValueError("This is an error") + + schema = strawberry.Schema( + query=default_query_types_and_query.query_type, + subscription=Subscription, + extensions=[async_extension], + ) + async_extension.expected = [ + "on_operation Entered", + "on_parse Entered", + "on_parse Exited", + "on_validate Entered", + "on_validate Exited", + "on_execute Entered", + "on_execute Exited", + "get_results", + "on_operation Exited", + ] + + res = await schema.subscribe(default_query_types_and_query.subscription) + assert res.errors + + async_extension.assert_expected() + + +async def test_error_after_first_yield_in_subscription( + default_query_types_and_query: SchemaHelper, async_extension: Type[ExampleExtension] +) -> None: + @strawberry.type() + class Subscription: + @strawberry.subscription() + async def count(self) -> AsyncGenerator[int, None]: + yield 1 + raise ValueError("This is an error") + + schema = strawberry.Schema( + query=default_query_types_and_query.query_type, + subscription=Subscription, + extensions=[async_extension], + ) + + agen = await schema.subscribe(default_query_types_and_query.subscription) + assert isinstance(agen, AsyncGenerator) + res1 = await agen.__anext__() + assert res1.data + assert not res1.errors + res2 = await agen.__anext__() + assert not res2.data + assert res2.errors + # close the generator + with pytest.raises(StopAsyncIteration): + await agen.__anext__() + async_extension.expected = [ + "on_operation Entered", + "on_parse Entered", + "on_parse Exited", + "on_validate Entered", + "on_validate Exited", + "on_execute Entered", + "on_execute Exited", + "resolve", + "get_results", + "get_results", + "on_operation Exited", + ] + async_extension.assert_expected() + + +async def test_extensions_results_are_cleared_between_subscription_yields( + default_query_types_and_query: SchemaHelper, +) -> None: + class MyExtension(SchemaExtension): + execution_number = 0 + + def get_results(self): + self.execution_number += 1 + return {str(self.execution_number): self.execution_number} + + schema = strawberry.Schema( + query=default_query_types_and_query.query_type, + subscription=default_query_types_and_query.subscription_type, + extensions=[MyExtension], + ) + + res_num = 1 + + async for res in assert_agen( + await schema.subscribe(default_query_types_and_query.subscription) + ): + assert res.extensions == {str(res_num): res_num} + assert not res.errors + res_num += 1 + + +async def test_subscription_catches_extension_errors( + default_query_types_and_query: SchemaHelper, +) -> None: + class MyExtension(SchemaExtension): + def on_execute(self): + raise ValueError("This is an error") + + schema = strawberry.Schema( + query=default_query_types_and_query.query_type, + subscription=default_query_types_and_query.subscription_type, + extensions=[MyExtension], + ) + async for res in assert_agen( + await schema.subscribe(default_query_types_and_query.subscription) + ): + assert res.errors + assert not res.data + assert res.errors[0].message == "This is an error" diff --git a/tests/schema/test_get_extensions.py b/tests/schema/test_get_extensions.py index 4a9d2ced67..aeabab3ec7 100644 --- a/tests/schema/test_get_extensions.py +++ b/tests/schema/test_get_extensions.py @@ -29,32 +29,39 @@ def test_returns_empty_list_when_no_custom_directives(): def test_returns_extension_passed_by_user(): schema = strawberry.Schema(query=Query, extensions=[MyExtension]) - assert schema.get_extensions() == [MyExtension] + assert len(schema.get_extensions()) == 1 + assert isinstance(schema.get_extensions()[0], MyExtension) def test_returns_directives_extension_when_passing_directives(): schema = strawberry.Schema(query=Query, directives=[uppercase]) - assert schema.get_extensions() == [DirectivesExtension] + assert len(schema.get_extensions()) == 1 + assert isinstance(schema.get_extensions()[0], DirectivesExtension) def test_returns_extension_passed_by_user_and_directives_extension(): schema = strawberry.Schema( query=Query, extensions=[MyExtension], directives=[uppercase] ) - - assert schema.get_extensions() == [MyExtension, DirectivesExtension] + for ext, ext_cls in zip( + schema.get_extensions(), [MyExtension, DirectivesExtension] + ): + assert isinstance(ext, ext_cls) def test_returns_directives_extension_when_passing_directives_sync(): schema = strawberry.Schema(query=Query, directives=[uppercase]) - assert schema.get_extensions(sync=True) == [DirectivesExtensionSync] + assert len(schema.get_extensions(sync=True)) == 1 + assert isinstance(schema.get_extensions(sync=True)[0], DirectivesExtensionSync) def test_returns_extension_passed_by_user_and_directives_extension_sync(): schema = strawberry.Schema( query=Query, extensions=[MyExtension], directives=[uppercase] ) - - assert schema.get_extensions(sync=True) == [MyExtension, DirectivesExtensionSync] + for ext, ext_cls in zip( + schema.get_extensions(sync=True), [MyExtension, DirectivesExtensionSync] + ): + assert isinstance(ext, ext_cls) diff --git a/tests/schema/test_subscription.py b/tests/schema/test_subscription.py index b63af92dcc..14a2781042 100644 --- a/tests/schema/test_subscription.py +++ b/tests/schema/test_subscription.py @@ -1,6 +1,7 @@ # ruff: noqa: F821 from __future__ import annotations +import inspect import sys from collections import abc # noqa: F401 from typing import ( # noqa: F401 @@ -15,6 +16,7 @@ import pytest import strawberry +from strawberry.types.execution import PreExecutionError @pytest.mark.asyncio @@ -236,3 +238,51 @@ async def example( assert not result.errors assert result.data["example"] == "Hi" + + +async def test_subscription_immediate_error(): + @strawberry.type + class Query: + x: str = "Hello" + + @strawberry.type + class Subscription: + @strawberry.subscription() + async def example(self) -> AsyncGenerator[str, None]: + return "fds" + + schema = strawberry.Schema(query=Query, subscription=Subscription) + + query = """#graphql + subscription { example } + """ + res_or_agen = await schema.subscribe(query) + assert isinstance(res_or_agen, PreExecutionError) + assert res_or_agen.errors + + +async def test_worng_opeartion_variables(): + @strawberry.type + class Query: + x: str = "Hello" + + @strawberry.type + class Subscription: + @strawberry.subscription + async def example(self, name: str) -> AsyncGenerator[str, None]: + yield f"Hi {name}" # pragma: no cover + + schema = strawberry.Schema(query=Query, subscription=Subscription) + + query = """#graphql + subscription subOp($opVar: String!){ example(name: $opVar) } + """ + + result = await schema.subscribe(query) + assert not inspect.isasyncgen(result) + + assert result.errors + assert ( + result.errors[0].message + == "Variable '$opVar' of required type 'String!' was not provided." + ) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index f3fd4b74b8..8062f25a5c 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -5,13 +5,8 @@ import json import time from datetime import timedelta -from typing import TYPE_CHECKING, Any, AsyncGenerator, Type -from unittest.mock import Mock, patch - -try: - from unittest.mock import AsyncMock -except ImportError: - AsyncMock = None +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, Type +from unittest.mock import AsyncMock, Mock, patch import pytest import pytest_asyncio @@ -30,7 +25,7 @@ SubscribeMessagePayload, ) from tests.http.clients.base import DebuggableGraphQLTransportWSMixin -from tests.views.schema import Schema +from tests.views.schema import MyExtension, Schema if TYPE_CHECKING: from ..http.clients.base import HttpClient, WebSocketClient @@ -54,6 +49,25 @@ async def ws(ws_raw: WebSocketClient) -> WebSocketClient: return ws_raw +def assert_next( + response: dict[str, Any], + id: str, + data: Dict[str, Any], + extensions: Optional[Dict[str, Any]] = None, +): + """ + Assert that the NextMessage payload contains the provided data. + If extensions is provided, it will also assert that the + extensions are present + """ + assert response["type"] == "next" + assert response["id"] == id + assert set(response["payload"].keys()) <= {"data", "errors", "extensions"} + assert response["payload"]["data"] == data + if extensions is not None: + assert response["payload"]["extensions"] == extensions + + async def test_unknown_message_type(ws_raw: WebSocketClient): ws = ws_raw @@ -158,13 +172,7 @@ async def test_connection_init_timeout_cancellation( ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", - payload={"data": {"debug": {"isConnectionInitTimeoutTaskDone": True}}}, - ).as_dict() - ) + assert_next(response, "sub1", {"debug": {"isConnectionInitTimeoutTaskDone": True}}) @pytest.mark.xfail(reason="This test is flaky") @@ -242,7 +250,7 @@ async def test_can_send_payload_with_additional_things(ws_raw: WebSocketClient): assert json.loads(data.data) == { "type": "next", "id": "1", - "payload": {"data": {"echo": "Hi"}}, + "payload": {"data": {"echo": "Hi"}, "extensions": {"example": "example"}}, } @@ -260,11 +268,7 @@ async def test_server_sent_ping(ws: WebSocketClient): await ws.send_json(PongMessage().as_dict()) response = await ws.receive_json() - assert ( - response - == NextMessage(id="sub1", payload={"data": {"requestPing": True}}).as_dict() - ) - + assert_next(response, "sub1", {"requestPing": True}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -326,9 +330,7 @@ async def test_reused_operation_ids(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response == NextMessage(id="sub1", payload={"data": {"echo": "Hi"}}).as_dict() - ) + assert_next(response, "sub1", {"echo": "Hi"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -345,9 +347,7 @@ async def test_reused_operation_ids(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response == NextMessage(id="sub1", payload={"data": {"echo": "Hi"}}).as_dict() - ) + assert_next(response, "sub1", {"echo": "Hi"}) async def test_simple_subscription(ws: WebSocketClient): @@ -361,10 +361,7 @@ async def test_simple_subscription(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response == NextMessage(id="sub1", payload={"data": {"echo": "Hi"}}).as_dict() - ) - + assert_next(response, "sub1", {"echo": "Hi"}) await ws.send_json(CompleteMessage(id="sub1").as_dict()) @@ -402,7 +399,7 @@ async def test_subscription_field_errors(ws: WebSocketClient): assert response["payload"][0]["locations"] == [{"line": 1, "column": 16}] assert ( response["payload"][0]["message"] - == "The subscription field 'notASubscriptionField' is not defined." + == "Cannot query field 'notASubscriptionField' on type 'Subscription'." ) process_errors.assert_called_once() @@ -427,13 +424,7 @@ async def test_subscription_cancellation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub2", payload={"data": {"debug": {"numActiveResultHandlers": 2}}} - ).as_dict() - ) - + assert_next(response, "sub2", {"debug": {"numActiveResultHandlers": 2}}) response = await ws.receive_json() assert response == CompleteMessage(id="sub2").as_dict() @@ -449,12 +440,7 @@ async def test_subscription_cancellation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub3", payload={"data": {"debug": {"numActiveResultHandlers": 1}}} - ).as_dict() - ) + assert_next(response, "sub3", {"debug": {"numActiveResultHandlers": 1}}) response = await ws.receive_json() assert response == CompleteMessage(id="sub3").as_dict() @@ -480,12 +466,14 @@ async def test_subscription_errors(ws: WebSocketClient): async def test_operation_error_no_complete(ws: WebSocketClient): """Test that an "error" message is not followed by "complete".""" - # get an "error" message + # Since we don't include the operation variables, + # the subscription will fail immediately. + # see https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#error await ws.send_json( SubscribeMessage( id="sub1", payload=SubscribeMessagePayload( - query='query { error(message: "TEST ERR") }', + query="subscription Foo($bar: String!){ exception(message: $bar) }", ), ).as_dict() ) @@ -496,17 +484,9 @@ async def test_operation_error_no_complete(ws: WebSocketClient): # after an "error" message, there should be nothing more # sent regarding "sub1", not even a "complete". - await ws.send_json( - SubscribeMessage( - id="sub2", - payload=SubscribeMessagePayload( - query='query { error(message: "TEST ERR") }', - ), - ).as_dict() - ) - response = await ws.receive_json() - assert response["type"] == ErrorMessage.type - assert response["id"] == "sub2" + await ws.send_json(PingMessage().as_dict()) + data = await ws.receive_json(timeout=1) + assert data == PongMessage().as_dict() async def test_subscription_exceptions(ws: WebSocketClient): @@ -522,12 +502,9 @@ async def test_subscription_exceptions(ws: WebSocketClient): ) response = await ws.receive_json() - assert response["type"] == ErrorMessage.type + assert response["type"] == NextMessage.type assert response["id"] == "sub1" - assert len(response["payload"]) == 1 - assert response["payload"][0].get("path") is None - assert response["payload"][0].get("locations") is None - assert response["payload"][0]["message"] == "TEST EXC" + assert response["payload"]["errors"] == [{"message": "TEST EXC"}] process_errors.assert_called_once() @@ -540,11 +517,7 @@ async def test_single_result_query_operation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage(id="sub1", payload={"data": {"hello": "Hello world"}}).as_dict() - ) - + assert_next(response, "sub1", {"hello": "Hello world"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -564,12 +537,7 @@ async def test_single_result_query_operation_async(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"asyncHello": "Hello Dolly"}} - ).as_dict() - ) + assert_next(response, "sub1", {"asyncHello": "Hello Dolly"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -602,12 +570,7 @@ async def test_single_result_query_operation_overlapped(ws: WebSocketClient): # we expect the response to the second query to arrive first response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub2", payload={"data": {"asyncHello": "Hello Dolly"}} - ).as_dict() - ) + assert_next(response, "sub2", {"asyncHello": "Hello Dolly"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub2").as_dict() @@ -621,10 +584,7 @@ async def test_single_result_mutation_operation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage(id="sub1", payload={"data": {"hello": "strawberry"}}).as_dict() - ) + assert_next(response, "sub1", {"hello": "strawberry"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -648,13 +608,7 @@ async def test_single_result_operation_selection(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"hello": "Hello Strawberry"}} - ).as_dict() - ) - + assert_next(response, "sub1", {"hello": "Hello Strawberry"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -679,7 +633,7 @@ async def test_single_result_invalid_operation_selection(ws: WebSocketClient): ws.assert_reason("Can't get GraphQL operation type") -async def test_single_result_operation_error(ws: WebSocketClient): +async def test_single_result_execution_error(ws: WebSocketClient): process_errors = Mock() with patch.object(Schema, "process_errors", process_errors): await ws.send_json( @@ -692,14 +646,17 @@ async def test_single_result_operation_error(ws: WebSocketClient): ) response = await ws.receive_json() - assert response["type"] == ErrorMessage.type + assert response["type"] == NextMessage.type assert response["id"] == "sub1" - assert len(response["payload"]) == 1 - assert response["payload"][0]["message"] == "You are not authorized" + errs = response["payload"]["errors"] + assert len(errs) == 1 + assert errs[0]["path"] == ["alwaysFail"] + assert errs[0]["message"] == "You are not authorized" + process_errors.assert_called_once() -async def test_single_result_operation_exception(ws: WebSocketClient): +async def test_single_result_pre_execution_error(ws: WebSocketClient): """Test that single-result-operations which raise exceptions behave in the same way as streaming operations. """ @@ -709,7 +666,7 @@ async def test_single_result_operation_exception(ws: WebSocketClient): SubscribeMessage( id="sub1", payload=SubscribeMessagePayload( - query='query { exception(message: "bummer") }', + query="query { IDontExist }", ), ).as_dict() ) @@ -718,8 +675,10 @@ async def test_single_result_operation_exception(ws: WebSocketClient): assert response["type"] == ErrorMessage.type assert response["id"] == "sub1" assert len(response["payload"]) == 1 - assert response["payload"][0].get("path") == ["exception"] - assert response["payload"][0]["message"] == "bummer" + assert ( + response["payload"][0]["message"] + == "Cannot query field 'IDontExist' on type 'Query'." + ) process_errors.assert_called_once() @@ -800,12 +759,7 @@ async def test_injects_connection_params(ws_raw: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"connectionParams": "rocks"}} - ).as_dict() - ) + assert_next(response, "sub1", {"connectionParams": "rocks"}) await ws.send_json(CompleteMessage(id="sub1").as_dict()) @@ -854,12 +808,7 @@ async def test_subsciption_cancel_finalization_delay(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"longFinalizer": "hello"}} - ).as_dict() - ) + assert_next(response, "sub1", {"longFinalizer": "hello"}) # now cancel the stubscription and send a new query. We expect the response # to the new query to arrive immediately, without waiting for the finalizer @@ -962,3 +911,22 @@ async def test_subscription_errors_continue(ws: WebSocketClient): response = await ws.receive_json() assert response["type"] == CompleteMessage.type assert response["id"] == "sub1" + + +@patch.object(MyExtension, MyExtension.get_results.__name__, return_value={}) +async def test_no_extensions_results_wont_send_extensions_in_payload( + mock: Mock, ws: WebSocketClient +): + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query='subscription { echo(message: "Hi") }' + ), + ).as_dict() + ) + + response = await ws.receive_json() + mock.assert_called_once() + assert_next(response, "sub1", {"echo": "Hi"}) + assert "extensions" not in response["payload"] diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index c872c3361c..b16a8d12db 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -2,6 +2,7 @@ import asyncio from typing import TYPE_CHECKING, AsyncGenerator +from unittest import mock import pytest import pytest_asyncio @@ -19,6 +20,7 @@ GQL_START, GQL_STOP, ) +from tests.views.schema import MyExtension if TYPE_CHECKING: from ..http.clients.aiohttp import HttpClient, WebSocketClient @@ -254,7 +256,9 @@ async def test_subscription_field_error(ws: WebSocketClient): assert response["id"] == "invalid-field" assert response["payload"] == { "locations": [{"line": 1, "column": 16}], - "message": ("The subscription field 'notASubscriptionField' is not defined."), + "message": ( + "Cannot query field 'notASubscriptionField' on type 'Subscription'." + ), } @@ -557,3 +561,34 @@ async def test_rejects_connection_params(aiohttp_app_client: HttpClient): # make sure the WebSocket is disconnected now await ws.receive(timeout=2) # receive close assert ws.closed + + +@mock.patch.object(MyExtension, MyExtension.get_results.__name__, return_value={}) +async def test_no_extensions_results_wont_send_extensions_in_payload( + mock: mock.MagicMock, aiohttp_app_client: HttpClient +): + async with aiohttp_app_client.ws_connect( + "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] + ) as ws: + await ws.send_json({"type": GQL_CONNECTION_INIT}) + await ws.send_json( + { + "type": GQL_START, + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } + ) + + response = await ws.receive_json() + assert response["type"] == GQL_CONNECTION_ACK + + response = await ws.receive_json() + mock.assert_called_once() + assert response["type"] == GQL_DATA + assert response["id"] == "demo" + assert "extensions" not in response["payload"] + + await ws.send_json({"type": GQL_STOP, "id": "demo"}) + response = await ws.receive_json()