diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index ff71bc7e..d68074a8 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -39,6 +39,7 @@ from strawberry_django.ordering import ORDER_ARG, StrawberryDjangoFieldOrdering from strawberry_django.pagination import StrawberryDjangoPagination from strawberry_django.permissions import filter_with_perms +from strawberry_django.queryset import run_type_get_queryset from strawberry_django.relay import resolve_model_nodes from strawberry_django.resolvers import ( default_qs_hook, @@ -278,12 +279,7 @@ def qs_hook(qs: models.QuerySet): return qs_hook def get_queryset(self, queryset, info, **kwargs): - type_ = self.django_type - - get_queryset = getattr(type_, "get_queryset", None) - if get_queryset: - queryset = get_queryset(queryset, info, **kwargs) - + queryset = run_type_get_queryset(queryset, self.django_type, info) queryset = super().get_queryset( filter_with_perms(queryset, info), info, **kwargs ) diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 27b31cbc..7497aad6 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -42,6 +42,7 @@ from typing_extensions import assert_never, assert_type, get_args from strawberry_django.fields.types import resolve_model_field_name +from strawberry_django.queryset import get_queryset_config, run_type_get_queryset from strawberry_django.resolvers import django_fetch from .descriptors import ModelProperty @@ -355,10 +356,12 @@ def _get_prefetch_queryset( else: remote_type = remote_type_defs[0] - if get_queryset := getattr(remote_type, "get_queryset", None): - return get_queryset(qs, info) - - return qs + return run_type_get_queryset( + qs, + remote_type, + # FIXME: Find out if the fact that info can be a GraphQLResolveInfo is a problem + info=info, # type: ignore + ) def _get_model_hints( @@ -693,7 +696,7 @@ def optimize( # Avoid optimizing twice and also modify an already resolved queryset if ( - getattr(qs, "_gql_optimized", False) or qs._result_cache is not None # type: ignore + get_queryset_config(qs).optimized or qs._result_cache is not None # type: ignore ): return qs @@ -749,7 +752,8 @@ def optimize( if store: qs = store.apply(qs, info=info, config=config) - qs._gql_optimized = True # type: ignore + qs_config = get_queryset_config(qs) + qs_config.optimized = True return qs diff --git a/strawberry_django/queryset.py b/strawberry_django/queryset.py new file mode 100644 index 00000000..d2f5b6ac --- /dev/null +++ b/strawberry_django/queryset.py @@ -0,0 +1,51 @@ +import dataclasses +from typing import Any, Optional, TypeVar + +from django.db.models import Model +from django.db.models.query import QuerySet +from strawberry import Info + +_M = TypeVar("_M", bound=Model) + +CONFIG_KEY = "_strawberry_django_config" + + +@dataclasses.dataclass +class StrawberryDjangoQuerySetConfig: + optimized: bool = False + type_get_queryset_did_run: bool = False + + +def get_queryset_config(queryset: QuerySet) -> StrawberryDjangoQuerySetConfig: + return getattr(queryset, CONFIG_KEY, None) or StrawberryDjangoQuerySetConfig() + + +def run_type_get_queryset( + qs: QuerySet[_M], + origin: Any, + info: Optional[Info] = None, +) -> QuerySet[_M]: + config = get_queryset_config(qs) + get_queryset = getattr(origin, "get_queryset", None) + + if get_queryset and not config.type_get_queryset_did_run: + qs = get_queryset(qs, info) + new_config = get_queryset_config(qs) + new_config.type_get_queryset_did_run = True + + return qs + + +_original_clone = QuerySet._clone # type: ignore + + +def _qs_clone(self): + config = get_queryset_config(self) + cloned = _original_clone(self) + setattr(cloned, CONFIG_KEY, dataclasses.replace(config)) + return cloned + + +# Monkey patch the QuerySet._clone method to make sure our config is copied +# to the new QuerySet instance once it is cloned. +QuerySet._clone = _qs_clone # type: ignore diff --git a/strawberry_django/relay.py b/strawberry_django/relay.py index b84edfc7..c2260b1b 100644 --- a/strawberry_django/relay.py +++ b/strawberry_django/relay.py @@ -24,6 +24,7 @@ from strawberry.utils.inspect import in_async_context from typing_extensions import Literal, Self +from strawberry_django.queryset import run_type_get_queryset from strawberry_django.resolvers import django_getattr, django_resolver from strawberry_django.utils.typing import ( WithStrawberryDjangoObjectDefinition, @@ -242,10 +243,7 @@ def resolve_model_nodes( source = cast(Type[_M], django_type.model) qs = cast(models.QuerySet[_M], source._default_manager.all()) - - get_queryset = getattr(origin, "get_queryset", None) - if get_queryset: - qs = get_queryset(qs, info) + qs = run_type_get_queryset(qs, origin, info) id_attr = cast(relay.Node, origin).resolve_id_attr() if node_ids is not None: @@ -376,10 +374,7 @@ def resolve_model_node( id_attr = cast(relay.Node, origin).resolve_id_attr() qs = source._default_manager.all() - - get_queryset = getattr(origin, "get_queryset", None) - if get_queryset: - qs = get_queryset(qs, info) + qs = run_type_get_queryset(qs, origin, info) qs = qs.filter(**{id_attr: node_id}) diff --git a/strawberry_django/utils/patches.py b/strawberry_django/utils/patches.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index dab1890e..28ed4056 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -5,11 +5,13 @@ import strawberry from django.db.models import Prefetch from django.utils import timezone +from pytest_mock import MockerFixture from strawberry.relay import to_base64 from strawberry.types import ExecutionResult import strawberry_django from strawberry_django.optimizer import DjangoOptimizerExtension +from tests.projects.schema import StaffType from . import utils from .projects.faker import ( @@ -56,7 +58,14 @@ def test_user_query(db, gql_client: GraphQLTestClient): @pytest.mark.django_db(transaction=True) -def test_staff_query(db, gql_client: GraphQLTestClient): +def test_staff_query(db, gql_client: GraphQLTestClient, mocker: MockerFixture): + staff_type_get_queryset = StaffType.get_queryset + mock_staff_type_get_queryset = mocker.patch( + "tests.projects.schema.StaffType.get_queryset", + autospec=True, + side_effect=staff_type_get_queryset, + ) + query = """ query TestQuery { staffConn { @@ -81,6 +90,7 @@ def test_staff_query(db, gql_client: GraphQLTestClient): ], }, } + mock_staff_type_get_queryset.assert_called_once() @pytest.mark.django_db(transaction=True)