Skip to content

Commit

Permalink
feat: make get_origin_and_cast
Browse files Browse the repository at this point in the history
This fixes the type-checker error raised while accessing
RequestBuilder[T].__origin__
  • Loading branch information
anand2312 committed Sep 19, 2023
1 parent 2231d2d commit ab2256f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 27 deletions.
20 changes: 7 additions & 13 deletions postgrest/_async/request_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from ..exceptions import APIError, generate_default_error_message
from ..types import ReturnMethod
from ..utils import AsyncClient
from ..utils import AsyncClient, get_origin_and_cast

_ReturnT = TypeVar("_ReturnT")

Expand Down Expand Up @@ -154,13 +154,10 @@ def __init__(
params: QueryParams,
json: dict,
) -> None:
# Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__
# tries to call _GenericAlias.__init__ - which is the wrong method
# The __origin__ attribute of the _GenericAlias is the actual class
BaseFilterRequestBuilder[_ReturnT].__origin__.__init__(
get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__(
self, session, headers, params
)
AsyncQueryRequestBuilder[_ReturnT].__origin__.__init__(
get_origin_and_cast(AsyncQueryRequestBuilder[_ReturnT]).__init__(
self, session, path, http_method, headers, params, json
)

Expand All @@ -178,10 +175,10 @@ def __init__(
params: QueryParams,
json: dict,
) -> None:
BaseFilterRequestBuilder[_ReturnT].__origin__.__init__(
get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__(
self, session, headers, params
)
AsyncSingleRequestBuilder[_ReturnT].__origin__.__init__(
get_origin_and_cast(AsyncSingleRequestBuilder[_ReturnT]).__init__(
self, session, path, http_method, headers, params, json
)

Expand All @@ -197,13 +194,10 @@ def __init__(
params: QueryParams,
json: dict,
) -> None:
# Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__
# tries to call _GenericAlias.__init__ - which is the wrong method
# The __origin__ attribute of the _GenericAlias is the actual class
BaseSelectRequestBuilder[_ReturnT].__origin__.__init__(
get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__(
self, session, headers, params
)
AsyncQueryRequestBuilder[_ReturnT].__origin__.__init__(
get_origin_and_cast(AsyncQueryRequestBuilder[_ReturnT]).__init__(
self, session, path, http_method, headers, params, json
)

Expand Down
20 changes: 7 additions & 13 deletions postgrest/_sync/request_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from ..exceptions import APIError, generate_default_error_message
from ..types import ReturnMethod
from ..utils import SyncClient
from ..utils import SyncClient, get_origin_and_cast

_ReturnT = TypeVar("_ReturnT")

Expand Down Expand Up @@ -154,13 +154,10 @@ def __init__(
params: QueryParams,
json: dict,
) -> None:
# Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__
# tries to call _GenericAlias.__init__ - which is the wrong method
# The __origin__ attribute of the _GenericAlias is the actual class
BaseFilterRequestBuilder[_ReturnT].__origin__.__init__(
get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__(
self, session, headers, params
)
SyncQueryRequestBuilder[_ReturnT].__origin__.__init__(
get_origin_and_cast(SyncQueryRequestBuilder[_ReturnT]).__init__(
self, session, path, http_method, headers, params, json
)

Expand All @@ -178,10 +175,10 @@ def __init__(
params: QueryParams,
json: dict,
) -> None:
BaseFilterRequestBuilder[_ReturnT].__origin__.__init__(
get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__(
self, session, headers, params
)
SyncSingleRequestBuilder[_ReturnT].__origin__.__init__(
get_origin_and_cast(SyncSingleRequestBuilder[_ReturnT]).__init__(
self, session, path, http_method, headers, params, json
)

Expand All @@ -197,13 +194,10 @@ def __init__(
params: QueryParams,
json: dict,
) -> None:
# Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__
# tries to call _GenericAlias.__init__ - which is the wrong method
# The __origin__ attribute of the _GenericAlias is the actual class
BaseSelectRequestBuilder[_ReturnT].__origin__.__init__(
get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__(
self, session, headers, params
)
SyncQueryRequestBuilder[_ReturnT].__origin__.__init__(
get_origin_and_cast(SyncQueryRequestBuilder[_ReturnT]).__init__(
self, session, path, http_method, headers, params, json
)

Expand Down
16 changes: 15 additions & 1 deletion postgrest/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any
from typing import Any, TypeVar, cast, get_origin

from httpx import AsyncClient # noqa: F401
from httpx import Client as BaseClient # noqa: F401
Expand All @@ -21,3 +21,17 @@ def sanitize_param(param: Any) -> str:

def sanitize_pattern_param(pattern: str) -> str:
return sanitize_param(pattern.replace("%", "*"))


_T = TypeVar("_T")


def get_origin_and_cast(typ: type[type[_T]]) -> type[_T]:
# Base[T] is an instance of typing._GenericAlias, so doing Base[T].__init__
# tries to call _GenericAlias.__init__ - which is the wrong method
# get_origin(Base[T]) returns Base
# This function casts Base back to Base[T] to maintain type-safety
# while still allowing us to access the methods of `Base` at runtime
# See: definitions of request builders that use multiple-inheritance
# like AsyncFilterRequestBuilder
return cast(type[_T], get_origin(typ))

0 comments on commit ab2256f

Please sign in to comment.