Skip to content

Commit

Permalink
feat: Avoid calling Type.get_queryset method more than once
Browse files Browse the repository at this point in the history
This change avoids calling Type.get_queryset method more than once
when the type defines one. This used to happen specially for connections

This change also paves the way for the upcoming nested list/connection
optimizations that we want to do.

Fix #528
  • Loading branch information
bellini666 committed May 11, 2024
1 parent 0efbc1d commit 6905513
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 21 deletions.
8 changes: 2 additions & 6 deletions strawberry_django/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from strawberry_django.ordering import ORDER_ARG, StrawberryDjangoFieldOrdering
from strawberry_django.pagination import StrawberryDjangoPagination
from strawberry_django.permissions import filter_with_perms
from strawberry_django.queryset import run_type_get_queryset
from strawberry_django.relay import resolve_model_nodes
from strawberry_django.resolvers import (
default_qs_hook,
Expand Down Expand Up @@ -278,12 +279,7 @@ def qs_hook(qs: models.QuerySet):
return qs_hook

def get_queryset(self, queryset, info, **kwargs):
type_ = self.django_type

get_queryset = getattr(type_, "get_queryset", None)
if get_queryset:
queryset = get_queryset(queryset, info, **kwargs)

queryset = run_type_get_queryset(queryset, self.django_type, info)
queryset = super().get_queryset(
filter_with_perms(queryset, info), info, **kwargs
)
Expand Down
16 changes: 10 additions & 6 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from typing_extensions import assert_never, assert_type, get_args

from strawberry_django.fields.types import resolve_model_field_name
from strawberry_django.queryset import get_queryset_config, run_type_get_queryset
from strawberry_django.resolvers import django_fetch

from .descriptors import ModelProperty
Expand Down Expand Up @@ -355,10 +356,12 @@ def _get_prefetch_queryset(
else:
remote_type = remote_type_defs[0]

if get_queryset := getattr(remote_type, "get_queryset", None):
return get_queryset(qs, info)

return qs
return run_type_get_queryset(
qs,
remote_type,
# FIXME: Find out if the fact that info can be a GraphQLResolveInfo is a problem
info=info, # type: ignore
)


def _get_model_hints(
Expand Down Expand Up @@ -693,7 +696,7 @@ def optimize(

# Avoid optimizing twice and also modify an already resolved queryset
if (
getattr(qs, "_gql_optimized", False) or qs._result_cache is not None # type: ignore
get_queryset_config(qs).optimized or qs._result_cache is not None # type: ignore
):
return qs

Expand Down Expand Up @@ -749,7 +752,8 @@ def optimize(

if store:
qs = store.apply(qs, info=info, config=config)
qs._gql_optimized = True # type: ignore
qs_config = get_queryset_config(qs)
qs_config.optimized = True

return qs

Expand Down
51 changes: 51 additions & 0 deletions strawberry_django/queryset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import dataclasses
from typing import Any, Optional, TypeVar

from django.db.models import Model
from django.db.models.query import QuerySet
from strawberry import Info

_M = TypeVar("_M", bound=Model)

CONFIG_KEY = "_strawberry_django_config"


@dataclasses.dataclass
class StrawberryDjangoQuerySetConfig:
optimized: bool = False
type_get_queryset_did_run: bool = False


def get_queryset_config(queryset: QuerySet) -> StrawberryDjangoQuerySetConfig:
return getattr(queryset, CONFIG_KEY, None) or StrawberryDjangoQuerySetConfig()


def run_type_get_queryset(
qs: QuerySet[_M],
origin: Any,
info: Optional[Info] = None,
) -> QuerySet[_M]:
config = get_queryset_config(qs)
get_queryset = getattr(origin, "get_queryset", None)

if get_queryset and not config.type_get_queryset_did_run:
qs = get_queryset(qs, info)
new_config = get_queryset_config(qs)
new_config.type_get_queryset_did_run = True

return qs


_original_clone = QuerySet._clone # type: ignore


def _qs_clone(self):
config = get_queryset_config(self)
cloned = _original_clone(self)
setattr(cloned, CONFIG_KEY, dataclasses.replace(config))
return cloned


# Monkey patch the QuerySet._clone method to make sure our config is copied
# to the new QuerySet instance once it is cloned.
QuerySet._clone = _qs_clone # type: ignore
11 changes: 3 additions & 8 deletions strawberry_django/relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from strawberry.utils.inspect import in_async_context
from typing_extensions import Literal, Self

from strawberry_django.queryset import run_type_get_queryset
from strawberry_django.resolvers import django_getattr, django_resolver
from strawberry_django.utils.typing import (
WithStrawberryDjangoObjectDefinition,
Expand Down Expand Up @@ -242,10 +243,7 @@ def resolve_model_nodes(
source = cast(Type[_M], django_type.model)

qs = cast(models.QuerySet[_M], source._default_manager.all())

get_queryset = getattr(origin, "get_queryset", None)
if get_queryset:
qs = get_queryset(qs, info)
qs = run_type_get_queryset(qs, origin, info)

id_attr = cast(relay.Node, origin).resolve_id_attr()
if node_ids is not None:
Expand Down Expand Up @@ -376,10 +374,7 @@ def resolve_model_node(

id_attr = cast(relay.Node, origin).resolve_id_attr()
qs = source._default_manager.all()

get_queryset = getattr(origin, "get_queryset", None)
if get_queryset:
qs = get_queryset(qs, info)
qs = run_type_get_queryset(qs, origin, info)

qs = qs.filter(**{id_attr: node_id})

Expand Down
Empty file.
12 changes: 11 additions & 1 deletion tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import strawberry
from django.db.models import Prefetch
from django.utils import timezone
from pytest_mock import MockerFixture
from strawberry.relay import to_base64
from strawberry.types import ExecutionResult

import strawberry_django
from strawberry_django.optimizer import DjangoOptimizerExtension
from tests.projects.schema import StaffType

from . import utils
from .projects.faker import (
Expand Down Expand Up @@ -56,7 +58,14 @@ def test_user_query(db, gql_client: GraphQLTestClient):


@pytest.mark.django_db(transaction=True)
def test_staff_query(db, gql_client: GraphQLTestClient):
def test_staff_query(db, gql_client: GraphQLTestClient, mocker: MockerFixture):
staff_type_get_queryset = StaffType.get_queryset
mock_staff_type_get_queryset = mocker.patch(
"tests.projects.schema.StaffType.get_queryset",
autospec=True,
side_effect=staff_type_get_queryset,
)

query = """
query TestQuery {
staffConn {
Expand All @@ -81,6 +90,7 @@ def test_staff_query(db, gql_client: GraphQLTestClient):
],
},
}
mock_staff_type_get_queryset.assert_called_once()


@pytest.mark.django_db(transaction=True)
Expand Down

0 comments on commit 6905513

Please sign in to comment.