diff --git a/injection/contrib/pep690.py b/injection/contrib/pep690.py index de9be2c..68c0c33 100644 --- a/injection/contrib/pep690.py +++ b/injection/contrib/pep690.py @@ -3,30 +3,40 @@ from __future__ import annotations import sys -import types -from collections.abc import Callable -from contextlib import suppress +from collections.abc import Generator +from contextlib import contextmanager, suppress from contextvars import ContextVar +from copy import copy from dataclasses import dataclass from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Generic, Literal, TypedDict, overload +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload + +from injection.main import peek_or_inject if TYPE_CHECKING: - from _typeshed.importlib import MetaPathFinderProtocol, PathEntryFinderProtocol - from typing_extensions import Never, TypeVar + from typing_extensions import TypeAlias from injection.main import Injection -T = TypeVar("T", default=None) +T = TypeVar("T") Obj = TypeVar("Obj") +InjectedAttributeStash: TypeAlias = "dict[Injection[Obj], T]" -class SysActions(Enum): +class StateActionType(Enum): PERSIST = auto() + """Copy state visible now and expose it to the original thread on future request.""" + FUTURE = auto() + """ + Allow the state to evolve naturally at runtime. + + Rely on that future version of the state when it's requested. + """ + CONSTANT = auto() - SPECIFIED = auto() + """Define one state forever (like PERSIST, but with custom value).""" class StateAction(Generic[T]): @@ -35,61 +45,75 @@ class StateAction(Generic[T]): @overload def __init__( self, - action: Literal[SysActions.PERSIST, SysActions.FUTURE], + action_type: Literal[StateActionType.PERSIST, StateActionType.FUTURE], data: None = None, ) -> None: ... @overload def __init__( self, - action: Literal[SysActions.CONSTANT], + action_type: Literal[StateActionType.CONSTANT], data: T, ) -> None: ... - def __init__(self, action: SysActions, data: T | None = None) -> None: - self.action = action + def __init__( + self, + action_type: StateActionType, + data: T | None = None, + ) -> None: + self.action_type = action_type self.data = data -PERSIST: StateAction = StateAction(SysActions.PERSIST) -FUTURE: StateAction = StateAction(SysActions.FUTURE) +PERSIST: StateAction[None] = StateAction(StateActionType.PERSIST) +FUTURE: StateAction[None] = StateAction(StateActionType.FUTURE) injection_var: ContextVar[Injection[Any]] = ContextVar("injection") -class AttributeMappings(TypedDict, Generic[Obj]): - path: dict[Injection[Obj], list[str]] - path_hooks: dict[Injection[Obj], list[Callable[[str], PathEntryFinderProtocol]]] - meta_path: dict[Injection[Obj], list[MetaPathFinderProtocol]] - - @dataclass -class _LazyImportsSys(types.ModuleType, Generic[Obj]): - attribute_mappings: AttributeMappings[Obj] +class SysAttributeGetter: + attribute_name: str + mainstream_value: Any + stash: InjectedAttributeStash[Any, Any] - def __getattr__(self, name: str) -> Any: + def __call__(self) -> Any: with suppress(LookupError): injection = injection_var.get() - mapping = self.attribute_mappings[name] # type: ignore[literal-required] + mapping = self.stash[self.attribute_name] # type: ignore[literal-required] return mapping[injection] - return getattr(sys, name) - - -@dataclass -class LazyImportBuiltin: - def __call__(self, *args: Any, **kwds: Any) -> Any: - pass + return self.mainstream_value +@contextmanager def lazy_imports( *, - sys_path: StateAction = PERSIST, - sys_meta_path: StateAction = PERSIST, - sys_path_hooks: StateAction = PERSIST, -) -> None: - pass - - -def type_imports() -> Never: - raise NotImplementedError + sys_path: StateAction[Any] = PERSIST, + sys_meta_path: StateAction[Any] = PERSIST, + sys_path_hooks: StateAction[Any] = PERSIST, +) -> Generator[None]: + stash = {} + + for attribute_name, action in ( + ("path", sys_path), + ("meta_path", sys_meta_path), + ("path_hooks", sys_path_hooks), + ): + mainstream_value = getattr(sys, attribute_name) + if action.action_type is StateActionType.PERSIST: + action.data = copy(mainstream_value) + action.action_type = StateActionType.CONSTANT + + peek_or_inject( + vars(sys), + attribute_name, + SysAttributeGetter( + attribute_name=attribute_name, + mainstream_value=mainstream_value, + stash=stash, + ), + ) + vars(sys)[attribute_name] + + yield diff --git a/injection/main.py b/injection/main.py index ac8b67d..6314135 100644 --- a/injection/main.py +++ b/injection/main.py @@ -1,9 +1,20 @@ from __future__ import annotations from contextlib import suppress +from contextvars import ContextVar, copy_context from dataclasses import dataclass from threading import Lock, RLock, get_ident -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, overload +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Literal, + NamedTuple, + TypeVar, + overload, +) +from weakref import WeakSet from injection.compat import get_frame @@ -28,6 +39,10 @@ Object_co = TypeVar("Object_co", covariant=True) +PEEK_MUTEX = RLock() +peeking_var: ContextVar[bool] = ContextVar("peeking", default=False) +peeked_early_var: ContextVar[EarlyObject] = ContextVar("peeked_early") + class InjectionKey(str): __slots__ = ("origin", "hash", "reset", "early") @@ -49,11 +64,19 @@ def __eq__(self, other: object) -> bool: self.reset = False return True - caller_locals = get_frame(1).f_locals + try: + caller_locals = get_frame(1).f_locals + except ValueError: + # can happen if we patch sys + return True if caller_locals.get("__injection_recursive_guard__"): return True + if peeking_var.get(): + peeked_early_var.set(self.early) + return True + with self.early.__mutex__: __injection_recursive_guard__ = True # noqa: F841 self.early.__inject__() @@ -73,9 +96,19 @@ def strict_recursion_guard(early: EarlyObject[object]) -> Never: raise RecursionError(msg) +class InjectionFactoryWrapper(NamedTuple): + actual_factory: Any + pass_scope: bool + + def __call__(self, scope: Locals) -> Any: + if self.pass_scope: + return self.actual_factory(scope) + return self.actual_factory() + + @dataclass class Injection(Generic[Object_co]): - factory: Callable[..., Object_co] + actual_factory: Callable[..., Object_co] pass_scope: bool = False cache: bool = False cache_per_alias: bool = False @@ -84,32 +117,34 @@ class Injection(Generic[Object_co]): _reassignment_lock: ClassVar[Lock] = Lock() - def _call_factory(self, scope: Locals) -> Object_co: - if self.pass_scope: - return self.factory(scope) - return self.factory() + @property + def factory(self) -> InjectionFactoryWrapper: + return InjectionFactoryWrapper( + actual_factory=self.actual_factory, + pass_scope=self.pass_scope, + ) def __post_init__(self) -> None: if self.debug_info is None: - factory, cache, cache_per_alias = ( - self.factory, + actual_factory, cache, cache_per_alias = ( + self.actual_factory, self.cache, self.cache_per_alias, ) - init_opts = f"{factory=!r}, {cache=!r}, {cache_per_alias=!r}" + init_opts = f"{actual_factory=!r}, {cache=!r}, {cache_per_alias=!r}" include = "" if debug_info := self.debug_info: include = f", {debug_info}" self.debug_info = f"" - def assign_to(self, *aliases: str, scope: Locals) -> None: + def assign_to(self, *aliases: str, scope: Locals) -> WeakSet[EarlyObject]: if not aliases: msg = f"expected at least one alias in Injection.assign_to() ({self!r})" raise ValueError(msg) state = ObjectState( cache=self.cache, - factory=self._call_factory, + factory=self.factory, recursion_guard=self.recursion_guard, debug_info=self.debug_info, scope=scope, @@ -117,19 +152,24 @@ def assign_to(self, *aliases: str, scope: Locals) -> None: cache_per_alias = self.cache_per_alias + early_objects = WeakSet() + for alias in aliases: debug_info = f"{alias!r} from {self.debug_info}" - early = EarlyObject( + early_object = EarlyObject( alias=alias, state=state, cache_per_alias=cache_per_alias, debug_info=debug_info, ) - key = early.__key__ + early_objects.add(early_object) + key = early_object.__key__ with self._reassignment_lock: scope.pop(key, None) - scope[key] = early + scope[key] = early_object + + return early_objects SENTINEL = object() @@ -286,7 +326,7 @@ def inject( # noqa: PLR0913 """ inj = Injection( - factory=factory, + actual_factory=factory, pass_scope=pass_scope, cache_per_alias=cache_per_alias, cache=cache, @@ -295,3 +335,49 @@ def inject( # noqa: PLR0913 ) if into is not None and aliases: inj.assign_to(*aliases, scope=into) + + +def peek(scope: Locals, alias: str) -> EarlyObject | None: + """Safely get early object from a scope without triggering injection behavior.""" + peeking_context = copy_context() + peeking_context.run(peeking_var.set, True) # noqa: FBT003 + with suppress(KeyError): + peeking_context.run(scope.__getitem__, alias) + return peeking_context.get(peeked_early_var) + + +def peek_or_inject( # noqa: PLR0913 + scope: Locals, + alias: str, + *, + factory: Callable[[], Object_co] | Callable[[Locals], Object_co], + pass_scope: bool = False, + cache: bool = False, + cache_per_alias: bool = False, + recursion_guard: Callable[[EarlyObject[Any]], object] = strict_recursion_guard, + debug_info: str | None = None, +) -> EarlyObject: + """ + Peek or inject as necessary in a thread-safe manner. + + If an injection is present, return the existing early object. + If it is not present, create a new injection, inject it and return an early object. + + This function works only for one alias at a time. + """ + with PEEK_MUTEX: + metadata = peek(scope, alias) + if metadata is None: + return next( + iter( + Injection( + actual_factory=factory, + pass_scope=pass_scope, + cache=cache, + cache_per_alias=cache_per_alias, + recursion_guard=recursion_guard, + debug_info=debug_info, + ).assign_to(alias, scope=scope) + ) + ) + return metadata diff --git a/tests/unit_tests/test_main.py b/tests/unit_tests/test_main.py index 51ce1f5..7554a05 100644 --- a/tests/unit_tests/test_main.py +++ b/tests/unit_tests/test_main.py @@ -60,7 +60,7 @@ def factory() -> str: factory_called = True return "injected_object" - inj = Injection(factory=factory) + inj = Injection(actual_factory=factory) inj.assign_to("alias1", "alias2", scope=scope) obj1 = scope["alias1"] @@ -82,7 +82,7 @@ def factory() -> str: call_count += 1 return f"injected_object_{call_count}" - inj = Injection(factory=factory, cache=False, cache_per_alias=False) + inj = Injection(actual_factory=factory, cache=False, cache_per_alias=False) scope1: dict[str, str] = {} scope2: dict[str, str] = {} @@ -172,7 +172,7 @@ def factory() -> str: call_count += 1 return f"injected_object_{call_count}" - inj = Injection(factory=factory) + inj = Injection(actual_factory=factory) scope: dict[str, str] = {} @@ -230,6 +230,6 @@ def test_injection_with_no_aliases() -> None: def factory() -> str: return "injected_object" - inj = Injection(factory=factory) + inj = Injection(actual_factory=factory) with pytest.raises(ValueError, match="expected at least one alias"): inj.assign_to(scope=scope)