diff --git a/docs/providers/context_local_resource.rst b/docs/providers/context_local_resource.rst new file mode 100644 index 00000000..c169b6a1 --- /dev/null +++ b/docs/providers/context_local_resource.rst @@ -0,0 +1,32 @@ +.. _context-local-resource-provider: + +Context Local Resource provider +================================ + +.. meta:: + :keywords: Python,DI,Dependency injection,IoC,Inversion of Control,Resource,Context Local, + Context Variables,Singleton,Per-context + :description: Context Local Resource provider provides a component with initialization and shutdown + that is scoped to execution context using contextvars. This page demonstrates how to + use context local resource provider. + +.. currentmodule:: dependency_injector.providers + +``ContextLocalResource`` inherits from :ref:`resource-provider` and uses the same initialization and shutdown logic +as the standard ``Resource`` provider. +It extends it with context-local storage using Python's ``contextvars`` module. +This means that objects are context local singletons - the same context will +receive the same instance, but different execution contexts will have their own separate instances. + +This is particularly useful in asynchronous applications where you need per-request resource instances +(such as database sessions) that are automatically cleaned up when the request context ends. +Example: + +.. literalinclude:: ../../examples/providers/context_local_resource.py + :language: python + :lines: 3- + + + +.. disqus:: + diff --git a/docs/providers/index.rst b/docs/providers/index.rst index 3edbf127..0dacb826 100644 --- a/docs/providers/index.rst +++ b/docs/providers/index.rst @@ -46,6 +46,7 @@ Providers module API docs - :py:mod:`dependency_injector.providers` dict configuration resource + context_local_resource aggregate selector dependency diff --git a/docs/providers/resource.rst b/docs/providers/resource.rst index b07c2db0..02863a47 100644 --- a/docs/providers/resource.rst +++ b/docs/providers/resource.rst @@ -21,6 +21,9 @@ Resource provider Resource providers help to initialize and configure logging, event loop, thread or process pool, etc. Resource provider is similar to ``Singleton``. Resource initialization happens only once. +If you need a context local singleton (where each execution context has its own instance), +see :ref:`context-local-resource-provider`. + You can make injections and use provided instance the same way like you do with any other provider. .. code-block:: python diff --git a/examples/providers/context_local_resource.py b/examples/providers/context_local_resource.py new file mode 100644 index 00000000..87af2f9d --- /dev/null +++ b/examples/providers/context_local_resource.py @@ -0,0 +1,50 @@ +from uuid import uuid4 + +from fastapi import Depends, FastAPI + +from dependency_injector import containers, providers +from dependency_injector.wiring import Closing, Provide, inject + +global_list = [] + + +class AsyncSessionLocal: + def __init__(self): + self.id = uuid4() + + async def __aenter__(self): + print("Entering session !") + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + print("Closing session !") + + async def execute(self, user_input): + return f"Executing {user_input} in session {self.id}" + + +app = FastAPI() + + +class Container(containers.DeclarativeContainer): + db_session = providers.ContextLocalResource(AsyncSessionLocal) + + +@app.get("/") +@inject +async def index(db: AsyncSessionLocal = Depends(Closing[Provide["db_session"]])): + global global_list + if db.id in global_list: + raise Exception("The db session is already used") # never reaches here + global_list.append(db.id) + res = await db.execute("SELECT 1") + return str(res) + + +if __name__ == "__main__": + import uvicorn + + container = Container() + container.wire(modules=["__main__"]) + uvicorn.run(app, host="localhost", port=8000) + container.unwire() diff --git a/src/dependency_injector/providers.pxd b/src/dependency_injector/providers.pxd index 21ed7f22..beb35718 100644 --- a/src/dependency_injector/providers.pxd +++ b/src/dependency_injector/providers.pxd @@ -226,9 +226,9 @@ cdef class Dict(Provider): cdef class Resource(Provider): cdef object _provides - cdef bint _initialized - cdef object _shutdowner - cdef object _resource + cdef bint __initialized + cdef object __shutdowner + cdef object __resource cdef tuple _args cdef int _args_len @@ -239,6 +239,12 @@ cdef class Resource(Provider): cpdef object _provide(self, tuple args, dict kwargs) +cdef class ContextLocalResource(Resource): + cdef object _resource_context_var + cdef object _initialized_context_var + cdef object _shutdowner_context_var + + cdef class Container(Provider): cdef object _container_cls cdef dict _overriding_providers diff --git a/src/dependency_injector/providers.pyi b/src/dependency_injector/providers.pyi index 8f9b525a..d6168d64 100644 --- a/src/dependency_injector/providers.pyi +++ b/src/dependency_injector/providers.pyi @@ -525,6 +525,8 @@ class Resource(Provider[T]): def init(self) -> Optional[Awaitable[T]]: ... def shutdown(self) -> Optional[Awaitable]: ... +class ContextLocalResource(Resource[T]):... + class Container(Provider[T]): def __init__( self, diff --git a/src/dependency_injector/providers.pyx b/src/dependency_injector/providers.pyx index d8a8ab35..ac1a4804 100644 --- a/src/dependency_injector/providers.pyx +++ b/src/dependency_injector/providers.pyx @@ -3186,7 +3186,7 @@ cdef class ThreadLocalSingleton(BaseSingleton): return future_result self._storage.instance = instance - + return instance def _async_init_instance(self, future_result, result): @@ -3620,9 +3620,9 @@ cdef class Resource(Provider): self._provides = None self.set_provides(provides) - self._initialized = False - self._resource = None - self._shutdowner = None + self.__initialized = False + self.__resource = None + self.__shutdowner = None self._args = tuple() self._args_len = 0 @@ -3760,6 +3760,36 @@ cdef class Resource(Provider): self._kwargs_len = len(self._kwargs) return self + @property + def _initialized(self): + """Get initialized state.""" + return self.__initialized + + @_initialized.setter + def _initialized(self, value): + """Set initialized state.""" + self.__initialized = value + + @property + def _resource(self): + """Get resource.""" + return self.__resource + + @_resource.setter + def _resource(self, value): + """Set resource.""" + self.__resource = value + + @property + def _shutdowner(self): + """Get shutdowner.""" + return self.__shutdowner + + @_shutdowner.setter + def _shutdowner(self, value): + """Set shutdowner.""" + self.__shutdowner = value + @property def initialized(self): """Check if resource is initialized.""" @@ -3771,24 +3801,27 @@ cdef class Resource(Provider): def shutdown(self): """Shutdown resource.""" - if not self._initialized: + if not self._initialized : + self._reset_all_contex_vars() if self._async_mode == ASYNC_MODE_ENABLED: return NULL_AWAITABLE return if self._shutdowner: future = self._shutdowner(None, None, None) - if __is_future_or_coroutine(future): - return ensure_future(self._shutdown_async(future)) - - self._resource = None - self._initialized = False - self._shutdowner = None + self._reset_all_contex_vars() + return ensure_future(future) + self._reset_all_contex_vars() if self._async_mode == ASYNC_MODE_ENABLED: return NULL_AWAITABLE + def _reset_all_contex_vars(self): + self._initialized = False + self._resource = None + self._shutdowner = None + @property def related(self): """Return related providers generator.""" @@ -3797,41 +3830,28 @@ cdef class Resource(Provider): yield from filter(is_provider, self.kwargs.values()) yield from super().related - async def _shutdown_async(self, future) -> None: - try: - await future - finally: - self._resource = None - self._initialized = False - self._shutdowner = None - async def _handle_async_cm(self, obj) -> None: try: - self._resource = resource = await obj.__aenter__() - self._shutdowner = obj.__aexit__ + resource = await obj.__aenter__() return resource except: self._initialized = False raise - async def _provide_async(self, future) -> None: - try: - obj = await future - - if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): - self._resource = await obj.__aenter__() - self._shutdowner = obj.__aexit__ - elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): - self._resource = obj.__enter__() - self._shutdowner = obj.__exit__ - else: - self._resource = obj - self._shutdowner = None + async def _provide_async(self, future): + obj = await future - return self._resource - except: - self._initialized = False - raise + if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): + resource = await obj.__aenter__() + shutdowner = obj.__aexit__ + elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): + resource = obj.__enter__() + shutdowner = obj.__exit__ + else: + resource = obj + shutdowner = None + + return resource, shutdowner cpdef object _provide(self, tuple args, dict kwargs): if self._initialized: @@ -3850,14 +3870,18 @@ cdef class Resource(Provider): if __is_future_or_coroutine(obj): self._initialized = True - self._resource = resource = ensure_future(self._provide_async(obj)) - return resource + future_result = asyncio.Future() + future = ensure_future(self._provide_async(obj)) + future.add_done_callback(functools.partial(self._async_init_instance, future_result)) + self._resource = future_result + return self._resource elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): self._resource = obj.__enter__() self._shutdowner = obj.__exit__ elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): self._initialized = True self._resource = resource = ensure_future(self._handle_async_cm(obj)) + self._shutdowner = obj.__aexit__ return resource else: self._resource = obj @@ -3866,6 +3890,57 @@ cdef class Resource(Provider): self._initialized = True return self._resource + def _async_init_instance(self, future_result, result): + try: + resource, shutdowner = result.result() + except Exception as exception: + self._resource = None + self._shutdowner = None + self._initialized = False + future_result.set_exception(exception) + else: + self._resource = resource + self._shutdowner = shutdowner + future_result.set_result(resource) + + +cdef class ContextLocalResource(Resource): + def __init__(self, provides=None, *args, **kwargs): + self._initialized_context_var = ContextVar("_initialized_context_var", default=False) + self._resource_context_var = ContextVar("_resource_context_var", default=None) + self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=None) + super().__init__(provides, *args, **kwargs) + + @property + def _initialized(self): + """Get initialized state.""" + return self._initialized_context_var.get() + + @_initialized.setter + def _initialized(self, value): + """Set initialized state.""" + self._initialized_context_var.set(value) + + @property + def _resource(self): + """Get resource.""" + return self._resource_context_var.get() + + @_resource.setter + def _resource(self, value): + """Set resource.""" + self._resource_context_var.set(value) + + @property + def _shutdowner(self): + """Get shutdowner.""" + return self._shutdowner_context_var.get() + + @_shutdowner.setter + def _shutdowner(self, value): + """Set shutdowner.""" + self._shutdowner_context_var.set(value) + cdef class Container(Provider): """Container provider provides an instance of declarative container. diff --git a/tests/unit/providers/resource/test_context_local_resource_py38.py b/tests/unit/providers/resource/test_context_local_resource_py38.py new file mode 100644 index 00000000..3a0452b9 --- /dev/null +++ b/tests/unit/providers/resource/test_context_local_resource_py38.py @@ -0,0 +1,492 @@ +"""Resource provider tests.""" + +import asyncio +import decimal +import sys +from contextlib import contextmanager + +from pytest import mark, raises + +from dependency_injector import containers, errors, providers, resources + + +def init_fn(*args, **kwargs): + return args, kwargs + + +def test_is_provider(): + assert providers.is_provider(providers.ContextLocalResource(init_fn)) is True + + +def test_init_optional_provides(): + provider = providers.ContextLocalResource() + provider.set_provides(init_fn) + assert provider.provides is init_fn + assert provider() == (tuple(), dict()) + + +def test_set_provides_returns_(): + provider = providers.ContextLocalResource() + assert provider.set_provides(init_fn) is provider + + +@mark.parametrize( + "str_name,cls", + [ + ("dependency_injector.providers.Factory", providers.Factory), + ("decimal.Decimal", decimal.Decimal), + ("list", list), + (".test_context_local_resource_py38.test_is_provider", test_is_provider), + ("test_is_provider", test_is_provider), + ], +) +def test_set_provides_string_imports(str_name, cls): + assert providers.ContextLocalResource(str_name).provides is cls + + +def test_provided_instance_provider(): + provider = providers.ContextLocalResource(init_fn) + assert isinstance(provider.provided, providers.ProvidedInstance) + + +def test_injection(): + resource = object() + + def _init(): + _init.counter += 1 + return resource + + _init.counter = 0 + + class Container(containers.DeclarativeContainer): + context_local_resource = providers.ContextLocalResource(_init) + dependency1 = providers.List(context_local_resource) + dependency2 = providers.List(context_local_resource) + + container = Container() + list1 = container.dependency1() + list2 = container.dependency2() + + assert list1 == [resource] + assert list1[0] is resource + + assert list2 == [resource] + assert list2[0] is resource + + assert _init.counter == 1 + + +@mark.asyncio +async def test_injection_in_different_context(): + def _init(): + return object() + + async def _async_init(): + return object() + + class Container(containers.DeclarativeContainer): + context_local_resource = providers.ContextLocalResource(_init) + async_context_local_resource = providers.ContextLocalResource(_async_init) + + async def run_in_context(): + obj = await container.async_context_local_resource() + return obj + + container = Container() + + obj1, obj2 = await asyncio.gather(run_in_context(), run_in_context()) + assert obj1 != obj2 + + obj3 = await container.async_context_local_resource() + obj4 = await container.async_context_local_resource() + assert obj3 == obj4 + + obj5, obj6 = await asyncio.gather(run_in_context(), run_in_context()) + assert obj5 == obj6 # as context is copied from the current one where async_context_local_resource was initialized + + obj7 = container.context_local_resource() + obj8 = container.context_local_resource() + + assert obj7 == obj8 + + +def test_init_function(): + def _init(): + _init.counter += 1 + + _init.counter = 0 + + provider = providers.ContextLocalResource(_init) + + result1 = provider() + assert result1 is None + assert _init.counter == 1 + + result2 = provider() + assert result2 is None + assert _init.counter == 1 + + provider.shutdown() + + +def test_init_generator_in_one_context(): + def _init(): + _init.init_counter += 1 + yield object() + _init.shutdown_counter += 1 + + _init.init_counter = 0 + _init.shutdown_counter = 0 + + provider = providers.ContextLocalResource(_init) + + result1 = provider() + result2 = provider() + + assert result1 == result2 + + assert _init.init_counter == 1 + assert _init.shutdown_counter == 0 + + provider.shutdown() + assert _init.init_counter == 1 + assert _init.shutdown_counter == 1 + + provider.shutdown() + assert _init.init_counter == 1 + assert _init.shutdown_counter == 1 + + +def test_init_context_manager_in_one_context() -> None: + init_counter, shutdown_counter = 0, 0 + + @contextmanager + def _init(): + nonlocal init_counter, shutdown_counter + + init_counter += 1 + yield object() + shutdown_counter += 1 + + init_counter = 0 + shutdown_counter = 0 + + provider = providers.ContextLocalResource(_init) + + result1 = provider() + result2 = provider() + assert result1 == result2 + + assert init_counter == 1 + assert shutdown_counter == 0 + + provider.shutdown() + + assert init_counter == 1 + assert shutdown_counter == 1 + + provider.shutdown() + assert init_counter == 1 + assert shutdown_counter == 1 + + +@mark.asyncio +async def test_async_init_context_manager_in_different_contexts() -> None: + init_counter, shutdown_counter = 0, 0 + + async def _init(): + nonlocal init_counter, shutdown_counter + init_counter += 1 + yield object() + shutdown_counter += 1 + + init_counter = 0 + shutdown_counter = 0 + + provider = providers.ContextLocalResource(_init) + + async def run_in_context(): + resource = await provider() + await provider.shutdown() + return resource + + result1, result2 = await asyncio.gather(run_in_context(), run_in_context()) + + assert result1 != result2 + assert init_counter == 2 + assert shutdown_counter == 2 + + +@mark.asyncio +async def test_async_init_context_manager_in_one_context() -> None: + init_counter, shutdown_counter = 0, 0 + + async def _init(): + nonlocal init_counter, shutdown_counter + init_counter += 1 + yield object() + shutdown_counter += 1 + + init_counter = 0 + shutdown_counter = 0 + + provider = providers.ContextLocalResource(_init) + + async def run_in_context(): + resource_1 = await provider() + resource_2 = await provider() + await provider.shutdown() + return resource_1, resource_2 + + result1, result2 = await run_in_context() + + assert result1 == result2 + assert init_counter == 1 + assert shutdown_counter == 1 + + +def test_init_class(): + class TestResource(resources.Resource): + init_counter = 0 + shutdown_counter = 0 + + def init(self): + self.__class__.init_counter += 1 + + def shutdown(self, _): + self.__class__.shutdown_counter += 1 + + provider = providers.ContextLocalResource(TestResource) + + result1 = provider() + assert result1 is None + assert TestResource.init_counter == 1 + assert TestResource.shutdown_counter == 0 + + provider.shutdown() + assert TestResource.init_counter == 1 + assert TestResource.shutdown_counter == 1 + + result2 = provider() + assert result2 is None + assert TestResource.init_counter == 2 + assert TestResource.shutdown_counter == 1 + + provider.shutdown() + assert TestResource.init_counter == 2 + assert TestResource.shutdown_counter == 2 + + +def test_init_not_callable(): + provider = providers.ContextLocalResource(1) + with raises(TypeError, match=r"object is not callable"): + provider.init() + + +def test_init_and_shutdown(): + def _init(): + _init.init_counter += 1 + yield + _init.shutdown_counter += 1 + + _init.init_counter = 0 + _init.shutdown_counter = 0 + + provider = providers.ContextLocalResource(_init) + + result1 = provider.init() + assert result1 is None + assert _init.init_counter == 1 + assert _init.shutdown_counter == 0 + + provider.shutdown() + assert _init.init_counter == 1 + assert _init.shutdown_counter == 1 + + result2 = provider.init() + assert result2 is None + assert _init.init_counter == 2 + assert _init.shutdown_counter == 1 + + provider.shutdown() + assert _init.init_counter == 2 + assert _init.shutdown_counter == 2 + + +def test_shutdown_of_not_initialized(): + def _init(): + yield + + provider = providers.ContextLocalResource(_init) + + result = provider.shutdown() + assert result is None + + +def test_initialized(): + provider = providers.ContextLocalResource(init_fn) + assert provider.initialized is False + + provider.init() + assert provider.initialized is True + + provider.shutdown() + assert provider.initialized is False + + +def test_call_with_context_args(): + provider = providers.ContextLocalResource(init_fn, "i1", "i2") + assert provider("i3", i4=4) == (("i1", "i2", "i3"), {"i4": 4}) + + +def test_fluent_interface(): + provider = providers.ContextLocalResource(init_fn).add_args(1, 2).add_kwargs(a3=3, a4=4) + assert provider() == ((1, 2), {"a3": 3, "a4": 4}) + + +def test_set_args(): + provider = providers.ContextLocalResource(init_fn).add_args(1, 2).set_args(3, 4) + assert provider.args == (3, 4) + + +def test_clear_args(): + provider = providers.ContextLocalResource(init_fn).add_args(1, 2).clear_args() + assert provider.args == tuple() + + +def test_set_kwargs(): + provider = providers.ContextLocalResource(init_fn).add_kwargs(a1="i1", a2="i2").set_kwargs(a3="i3", a4="i4") + assert provider.kwargs == {"a3": "i3", "a4": "i4"} + + +def test_clear_kwargs(): + provider = providers.ContextLocalResource(init_fn).add_kwargs(a1="i1", a2="i2").clear_kwargs() + assert provider.kwargs == {} + + +def test_call_overridden(): + provider = providers.ContextLocalResource(init_fn, 1) + overriding_provider1 = providers.ContextLocalResource(init_fn, 2) + overriding_provider2 = providers.ContextLocalResource(init_fn, 3) + + provider.override(overriding_provider1) + provider.override(overriding_provider2) + + instance1 = provider() + instance2 = provider() + + assert instance1 is instance2 + assert instance1 == ((3,), {}) + assert instance2 == ((3,), {}) + + +def test_deepcopy(): + provider = providers.ContextLocalResource(init_fn, 1, 2, a3=3, a4=4) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert provider.args == provider_copy.args + assert provider.kwargs == provider_copy.kwargs + assert isinstance(provider, providers.ContextLocalResource) + + +def test_deepcopy_initialized(): + provider = providers.ContextLocalResource(init_fn) + provider.init() + + with raises(errors.Error): + providers.deepcopy(provider) + + +def test_deepcopy_from_memo(): + provider = providers.ContextLocalResource(init_fn) + provider_copy_memo = providers.ContextLocalResource(init_fn) + + provider_copy = providers.deepcopy( + provider, + memo={id(provider): provider_copy_memo}, + ) + + assert provider_copy is provider_copy_memo + + +def test_deepcopy_args(): + provider = providers.ContextLocalResource(init_fn) + dependent_provider1 = providers.Factory(list) + dependent_provider2 = providers.Factory(dict) + + provider.add_args(dependent_provider1, dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.args[0] + dependent_provider_copy2 = provider_copy.args[1] + + assert provider.args != provider_copy.args + + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_kwargs(): + provider = providers.ContextLocalResource(init_fn) + dependent_provider1 = providers.Factory(list) + dependent_provider2 = providers.Factory(dict) + + provider.add_kwargs(d1=dependent_provider1, d2=dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.kwargs["d1"] + dependent_provider_copy2 = provider_copy.kwargs["d2"] + + assert provider.kwargs != provider_copy.kwargs + + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_overridden(): + provider = providers.ContextLocalResource(init_fn) + object_provider = providers.Object(object()) + + provider.override(object_provider) + + provider_copy = providers.deepcopy(provider) + object_provider_copy = provider_copy.overridden[0] + + assert provider is not provider_copy + assert provider.args == provider_copy.args + assert isinstance(provider, providers.ContextLocalResource) + + assert object_provider is not object_provider_copy + assert isinstance(object_provider_copy, providers.Object) + + +def test_deepcopy_with_sys_streams(): + provider = providers.ContextLocalResource(init_fn) + provider.add_args(sys.stdin, sys.stdout, sys.stderr) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert isinstance(provider_copy, providers.ContextLocalResource) + assert provider.args[0] is sys.stdin + assert provider.args[1] is sys.stdout + assert provider.args[2] is sys.stderr + + +def test_repr(): + provider = providers.ContextLocalResource(init_fn) + + assert repr(provider) == ( + "".format( + repr(init_fn), + hex(id(provider)), + ) + ) diff --git a/tests/unit/samples/wiring/asyncinjections.py b/tests/unit/samples/wiring/asyncinjections.py index e0861017..befd59b0 100644 --- a/tests/unit/samples/wiring/asyncinjections.py +++ b/tests/unit/samples/wiring/asyncinjections.py @@ -18,6 +18,7 @@ def reset_counters(self): resource1 = TestResource() resource2 = TestResource() +resource3 = TestResource() async def async_resource(resource): @@ -34,6 +35,8 @@ class Container(containers.DeclarativeContainer): resource1 = providers.Resource(async_resource, providers.Object(resource1)) resource2 = providers.Resource(async_resource, providers.Object(resource2)) + context_local_resource = providers.ContextLocalResource(async_resource, providers.Object(resource3)) + context_local_resource_with_factory_object = providers.ContextLocalResource(async_resource, providers.Factory(TestResource)) @inject @@ -57,5 +60,13 @@ async def async_generator_injection( async def async_injection_with_closing( resource1: object = Closing[Provide[Container.resource1]], resource2: object = Closing[Provide[Container.resource2]], + context_local_resource: object = Closing[Provide[Container.context_local_resource]], ): - return resource1, resource2 + return resource1, resource2, context_local_resource + + +@inject +async def async_injection_with_closing_context_local_resources( + context_local_resource1: object = Closing[Provide[Container.context_local_resource_with_factory_object]], +): + return context_local_resource1 diff --git a/tests/unit/samples/wiringstringids/asyncinjections.py b/tests/unit/samples/wiringstringids/asyncinjections.py index 41529379..514b455a 100644 --- a/tests/unit/samples/wiringstringids/asyncinjections.py +++ b/tests/unit/samples/wiringstringids/asyncinjections.py @@ -16,6 +16,7 @@ def reset_counters(self): resource1 = TestResource() resource2 = TestResource() +resource3 = TestResource() async def async_resource(resource): @@ -32,6 +33,8 @@ class Container(containers.DeclarativeContainer): resource1 = providers.Resource(async_resource, providers.Object(resource1)) resource2 = providers.Resource(async_resource, providers.Object(resource2)) + context_local_resource = providers.ContextLocalResource(async_resource, providers.Object(resource3)) + context_local_resource_with_factory_object = providers.ContextLocalResource(async_resource, providers.Factory(TestResource)) @inject @@ -46,5 +49,13 @@ async def async_injection( async def async_injection_with_closing( resource1: object = Closing[Provide["resource1"]], resource2: object = Closing[Provide["resource2"]], + context_local_resource: object = Closing[Provide["context_local_resource"]], ): - return resource1, resource2 + return resource1, resource2, context_local_resource + + +@inject +async def async_injection_with_closing_context_local_resources( + context_local_resource1: object = Closing[Provide["context_local_resource_with_factory_object"]] +): + return context_local_resource1 diff --git a/tests/unit/wiring/provider_ids/test_async_injections_py36.py b/tests/unit/wiring/provider_ids/test_async_injections_py36.py index 70f9eb17..4c5ec12f 100644 --- a/tests/unit/wiring/provider_ids/test_async_injections_py36.py +++ b/tests/unit/wiring/provider_ids/test_async_injections_py36.py @@ -1,7 +1,8 @@ """Async injection tests.""" -from pytest import fixture, mark +import asyncio +from pytest import fixture, mark from samples.wiring import asyncinjections @@ -51,7 +52,7 @@ async def test_async_generator_injections() -> None: @mark.asyncio async def test_async_injections_with_closing(): - resource1, resource2 = await asyncinjections.async_injection_with_closing() + resource1, resource2, context_local_resource = await asyncinjections.async_injection_with_closing() assert resource1 is asyncinjections.resource1 assert asyncinjections.resource1.init_counter == 1 @@ -61,7 +62,11 @@ async def test_async_injections_with_closing(): assert asyncinjections.resource2.init_counter == 1 assert asyncinjections.resource2.shutdown_counter == 1 - resource1, resource2 = await asyncinjections.async_injection_with_closing() + assert context_local_resource is asyncinjections.resource3 + assert asyncinjections.resource3.init_counter == 1 + assert asyncinjections.resource3.shutdown_counter == 1 + + resource1, resource2, context_local_resource = await asyncinjections.async_injection_with_closing() assert resource1 is asyncinjections.resource1 assert asyncinjections.resource1.init_counter == 2 @@ -70,3 +75,19 @@ async def test_async_injections_with_closing(): assert resource2 is asyncinjections.resource2 assert asyncinjections.resource2.init_counter == 2 assert asyncinjections.resource2.shutdown_counter == 2 + + assert context_local_resource is asyncinjections.resource3 + assert asyncinjections.resource3.init_counter == 2 + assert asyncinjections.resource3.shutdown_counter == 2 + + +@mark.asyncio +async def test_async_injections_with_closing_concurrently(): + resource1, resource2 = await asyncio.gather(asyncinjections.async_injection_with_closing_context_local_resources(), + asyncinjections.async_injection_with_closing_context_local_resources()) + assert resource1 != resource2 + + resource1 = await asyncinjections.Container.context_local_resource_with_factory_object() + resource2 = await asyncinjections.Container.context_local_resource_with_factory_object() + + assert resource1 == resource2 diff --git a/tests/unit/wiring/string_ids/test_async_injections_py36.py b/tests/unit/wiring/string_ids/test_async_injections_py36.py index cff13ce5..bdf6a2ab 100644 --- a/tests/unit/wiring/string_ids/test_async_injections_py36.py +++ b/tests/unit/wiring/string_ids/test_async_injections_py36.py @@ -1,7 +1,8 @@ """Async injection tests.""" -from pytest import fixture, mark +import asyncio +from pytest import fixture, mark from samples.wiringstringids import asyncinjections @@ -34,7 +35,7 @@ async def test_async_injections(): @mark.asyncio async def test_async_injections_with_closing(): - resource1, resource2 = await asyncinjections.async_injection_with_closing() + resource1, resource2, context_local_resource = await asyncinjections.async_injection_with_closing() assert resource1 is asyncinjections.resource1 assert asyncinjections.resource1.init_counter == 1 @@ -44,7 +45,11 @@ async def test_async_injections_with_closing(): assert asyncinjections.resource2.init_counter == 1 assert asyncinjections.resource2.shutdown_counter == 1 - resource1, resource2 = await asyncinjections.async_injection_with_closing() + assert context_local_resource is asyncinjections.resource3 + assert asyncinjections.resource3.init_counter == 1 + assert asyncinjections.resource3.shutdown_counter == 1 + + resource1, resource2, context_local_resource = await asyncinjections.async_injection_with_closing() assert resource1 is asyncinjections.resource1 assert asyncinjections.resource1.init_counter == 2 @@ -53,3 +58,19 @@ async def test_async_injections_with_closing(): assert resource2 is asyncinjections.resource2 assert asyncinjections.resource2.init_counter == 2 assert asyncinjections.resource2.shutdown_counter == 2 + + assert context_local_resource is asyncinjections.resource3 + assert asyncinjections.resource3.init_counter == 2 + assert asyncinjections.resource3.shutdown_counter == 2 + + +@mark.asyncio +async def test_async_injections_with_closing_concurrently(): + resource1, resource2 = await asyncio.gather(asyncinjections.async_injection_with_closing_context_local_resources(), + asyncinjections.async_injection_with_closing_context_local_resources()) + assert resource1 != resource2 + + resource1 = await asyncinjections.Container.context_local_resource_with_factory_object() + resource2 = await asyncinjections.Container.context_local_resource_with_factory_object() + + assert resource1 == resource2