Skip to content

Commit

Permalink
add sync resolving
Browse files Browse the repository at this point in the history
  • Loading branch information
lesnik512 committed May 15, 2024
1 parent d2973c2 commit db14150
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 45 deletions.
8 changes: 6 additions & 2 deletions tests/test_context_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ async def test_context_resource_without_context_init(
context_resource: providers.AbstractResource[datetime.datetime],
) -> None:
with pytest.raises(RuntimeError, match="Context is not set. Use container_context"):
await context_resource()
await context_resource.async_resolve()

with pytest.raises(RuntimeError, match="Context is not set. Use container_context"):
context_resource.sync_resolve()


@container_context()
Expand Down Expand Up @@ -79,7 +82,8 @@ async def test_context_resources_overriding(context_resource: providers.Abstract
context_resource.override(context_resource_mock)

context_resource_result = await context_resource()
assert context_resource_result is context_resource_mock
context_resource_result2 = context_resource.sync_resolve()
assert context_resource_result is context_resource_result2 is context_resource_mock

DIContainer.reset_override()
with pytest.raises(RuntimeError, match="Context is not set. Use container_context"):
Expand Down
52 changes: 49 additions & 3 deletions tests/test_main_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from tests import container
from tests.container import DIContainer
from that_depends import inject, providers
from that_depends import providers


async def test_factory_providers() -> None:
Expand All @@ -15,11 +15,29 @@ async def test_factory_providers() -> None:
async_resource = await DIContainer.async_resource()

assert dependent_factory.simple_factory is not simple_factory
assert DIContainer.simple_factory.sync_resolve() is not simple_factory
assert dependent_factory.sync_resource == sync_resource
assert dependent_factory.async_resource == async_resource
assert isinstance(async_factory, datetime.datetime)


async def test_async_resource_provider() -> None:
async_resource = await DIContainer.async_resource()

assert DIContainer.async_resource.sync_resolve() is async_resource


def test_failed_sync_resolve() -> None:
with pytest.raises(RuntimeError, match="AsyncFactory cannot be resolved synchronously"):
DIContainer.async_factory.sync_resolve()

with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved synchronously"):
DIContainer.async_resource.sync_resolve()

with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved synchronously"):
DIContainer.sequence.sync_resolve()


async def test_list_provider() -> None:
sequence = await DIContainer.sequence()
sync_resource = await DIContainer.sync_resource()
Expand All @@ -31,11 +49,14 @@ async def test_list_provider() -> None:
async def test_singleton_provider() -> None:
singleton1 = await DIContainer.singleton()
singleton2 = await DIContainer.singleton()
singleton3 = DIContainer.singleton.sync_resolve()
await DIContainer.singleton.tear_down()
singleton4 = DIContainer.singleton.sync_resolve()

assert singleton1 is singleton2
assert singleton1 is singleton2 is singleton3
assert singleton4 is not singleton1


@inject
async def test_providers_overriding() -> None:
async_resource_mock = datetime.datetime.fromisoformat("2023-01-01")
sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01")
Expand Down Expand Up @@ -64,6 +85,31 @@ async def test_providers_overriding() -> None:
assert (await container.DIContainer.async_resource()) != async_resource_mock


async def test_providers_overriding_sync_resolve() -> None:
async_resource_mock = datetime.datetime.fromisoformat("2023-01-01")
sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01")
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)
singleton_mock = container.SingletonFactory(dep1=False)
container.DIContainer.async_resource.override(async_resource_mock)
container.DIContainer.sync_resource.override(sync_resource_mock)
container.DIContainer.simple_factory.override(simple_factory_mock)
container.DIContainer.singleton.override(singleton_mock)

container.DIContainer.simple_factory.sync_resolve()
await container.DIContainer.async_resource.async_resolve()
dependent_factory = container.DIContainer.dependent_factory.sync_resolve()
singleton = container.DIContainer.singleton.sync_resolve()

assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1
assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2
assert dependent_factory.sync_resource == sync_resource_mock
assert dependent_factory.async_resource == async_resource_mock
assert singleton is singleton_mock

container.DIContainer.reset_override()
assert container.DIContainer.sync_resource.sync_resolve() != sync_resource_mock


def test_wrong_providers_init() -> None:
with pytest.raises(RuntimeError, match="Resource must be generator function"):
providers.Resource(lambda: None) # type: ignore[arg-type,return-value]
Expand Down
2 changes: 1 addition & 1 deletion that_depends/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,6 @@ async def resolve(cls, object_to_resolve: type[T] | typing.Callable[P, T]) -> T:
msg = f"Provider is not found, {field_name=}"
raise RuntimeError(msg)

kwargs[field_name] = await providers[field_name].resolve()
kwargs[field_name] = await providers[field_name].async_resolve()

return object_to_resolve(**kwargs)
10 changes: 7 additions & 3 deletions that_depends/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@ class AbstractProvider(typing.Generic[T], abc.ABC):
"""Abstract Provider Class."""

@abc.abstractmethod
async def resolve(self) -> T:
"""Resolve dependency."""
async def async_resolve(self) -> T:
"""Resolve dependency asynchronously."""

@abc.abstractmethod
def sync_resolve(self) -> T:
"""Resolve dependency synchronously."""

async def __call__(self) -> T:
return await self.resolve()
return await self.async_resolve()

def override(self, mock: object) -> None:
self._override = mock
Expand Down
9 changes: 6 additions & 3 deletions that_depends/providers/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ class List(AbstractProvider[T]):
def __init__(self, *providers: AbstractProvider[T]) -> None:
self._providers = providers

async def resolve(self) -> list[T]: # type: ignore[override]
return [await x.resolve() for x in self._providers]
async def async_resolve(self) -> list[T]: # type: ignore[override]
return [await x.async_resolve() for x in self._providers]

def sync_resolve(self) -> list[T]: # type: ignore[override]
return [x.sync_resolve() for x in self._providers]

async def __call__(self) -> list[T]: # type: ignore[override]
return await self.resolve()
return await self.async_resolve()
56 changes: 36 additions & 20 deletions that_depends/providers/context_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return await self.app(scope, receive, send)


def _get_context() -> dict[str, AbstractResource[typing.Any]]:
try:
return context.get()
except LookupError as exc:
msg = "Context is not set. Use container_context"
raise RuntimeError(msg) from exc


class ContextResource(AbstractProvider[T]):
def __init__(
self,
Expand All @@ -64,21 +72,25 @@ def __init__(
self._override = None
self._internal_name = f"{type(self).__name__}-{uuid.uuid4()}"

async def resolve(self) -> T:
def _get_or_create_resource(self) -> AbstractResource[T]:
context_obj = _get_context()
if not (resource := context_obj.get(self._internal_name)):
resource = Resource(self._creator, *self._args, **self._kwargs)
context_obj[self._internal_name] = resource

return resource

async def async_resolve(self) -> T:
if self._override:
return typing.cast(T, self._override)

try:
context_obj = context.get()
except LookupError as exc:
msg = "Context is not set. Use container_context"
raise RuntimeError(msg) from exc
return await self._get_or_create_resource().async_resolve()

if not (_resource := context_obj.get(self._internal_name)):
_resource = Resource(self._creator, *self._args, **self._kwargs)
context_obj[self._internal_name] = _resource
def sync_resolve(self) -> T:
if self._override:
return typing.cast(T, self._override)

return typing.cast(T, await _resource.resolve())
return self._get_or_create_resource().sync_resolve()


class AsyncContextResource(AbstractProvider[T]):
Expand All @@ -98,18 +110,22 @@ def __init__(
self._override = None
self._internal_name = f"{type(self).__name__}-{uuid.uuid4()}"

async def resolve(self) -> T:
def _get_or_create_resource(self) -> AbstractResource[T]:
context_obj = _get_context()
if not (resource := context_obj.get(self._internal_name)):
resource = AsyncResource(self._creator, *self._args, **self._kwargs)
context_obj[self._internal_name] = resource

return resource

async def async_resolve(self) -> T:
if self._override:
return typing.cast(T, self._override)

try:
context_obj = context.get()
except LookupError as exc:
msg = "Context is not set. Use container_context"
raise RuntimeError(msg) from exc
return await self._get_or_create_resource().async_resolve()

if not (_resource := context_obj.get(self._internal_name)):
_resource = AsyncResource(self._creator, *self._args, **self._kwargs)
context_obj[self._internal_name] = _resource
def sync_resolve(self) -> T:
if self._override:
return typing.cast(T, self._override)

return typing.cast(T, await _resource.resolve())
return self._get_or_create_resource().sync_resolve()
25 changes: 19 additions & 6 deletions that_depends/providers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,22 @@ def __init__(self, factory: type[T] | typing.Callable[P, T], *args: P.args, **kw
self._kwargs = kwargs
self._override = None

async def resolve(self) -> T:
async def async_resolve(self) -> T:
if self._override:
return typing.cast(T, self._override)

return self._factory(
*[await x() if isinstance(x, AbstractProvider) else x for x in self._args],
**{k: await v() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
*[await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args],
**{k: await v.async_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
)

def sync_resolve(self) -> T:
if self._override:
return typing.cast(T, self._override)

return self._factory(
*[x.sync_resolve() if isinstance(x, AbstractProvider) else x for x in self._args],
**{k: v.sync_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
)


Expand All @@ -31,11 +40,15 @@ def __init__(self, factory: typing.Callable[P, typing.Awaitable[T]], *args: P.ar
self._kwargs = kwargs
self._override = None

async def resolve(self) -> T:
async def async_resolve(self) -> T:
if self._override:
return typing.cast(T, self._override)

return await self._factory(
*[await x() if isinstance(x, AbstractProvider) else x for x in self._args],
**{k: await v() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
*[await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args],
**{k: await v.async_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
)

def sync_resolve(self) -> T:
msg = "AsyncFactory cannot be resolved synchronously"
raise RuntimeError(msg)
40 changes: 36 additions & 4 deletions that_depends/providers/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def tear_down(self) -> None:
if self._instance is not None:
self._instance = None

async def resolve(self) -> T:
async def async_resolve(self) -> T:
if self._override:
return typing.cast(T, self._override)

Expand All @@ -42,8 +42,30 @@ async def resolve(self) -> T:
T,
self._context_stack.enter_context(
contextlib.contextmanager(self._creator)(
*[await x() if isinstance(x, AbstractProvider) else x for x in self._args],
**{k: await v() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
*[await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args],
**{
k: await v.async_resolve() if isinstance(v, AbstractProvider) else v
for k, v in self._kwargs.items()
},
),
),
)
return self._instance

def sync_resolve(self) -> T:
if self._override:
return typing.cast(T, self._override)

if self._instance is None:
self._instance = typing.cast(
T,
self._context_stack.enter_context(
contextlib.contextmanager(self._creator)(
*[x.sync_resolve() if isinstance(x, AbstractProvider) else x for x in self._args],
**{
k: v.sync_resolve() if isinstance(v, AbstractProvider) else v
for k, v in self._kwargs.items()
},
),
),
)
Expand Down Expand Up @@ -74,7 +96,7 @@ async def tear_down(self) -> None:
if self._instance is not None:
self._instance = None

async def resolve(self) -> T:
async def async_resolve(self) -> T:
if self._override:
return typing.cast(T, self._override)

Expand All @@ -89,3 +111,13 @@ async def resolve(self) -> T:
),
)
return self._instance

def sync_resolve(self) -> T:
if self._override:
return typing.cast(T, self._override)

if self._instance is None:
msg = "AsyncResource cannot be resolved synchronously"
raise RuntimeError(msg)

return self._instance
20 changes: 17 additions & 3 deletions that_depends/providers/singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,28 @@ def __init__(self, factory: type[T] | typing.Callable[P, T], *args: P.args, **kw
self._override = None
self._instance: T | None = None

async def resolve(self) -> T:
async def async_resolve(self) -> T:
if self._override is not None:
return typing.cast(T, self._override)

if self._instance is None:
self._instance = self._factory(
*[await x() if isinstance(x, AbstractProvider) else x for x in self._args],
**{k: await v() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
*[await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args],
**{
k: await v.async_resolve() if isinstance(v, AbstractProvider) else v
for k, v in self._kwargs.items()
},
)
return self._instance

def sync_resolve(self) -> T:
if self._override is not None:
return typing.cast(T, self._override)

if self._instance is None:
self._instance = self._factory(
*[x.sync_resolve() if isinstance(x, AbstractProvider) else x for x in self._args],
**{k: v.sync_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
)
return self._instance

Expand Down

0 comments on commit db14150

Please sign in to comment.