44import decimal
55import sys
66from contextlib import contextmanager
7- from typing import Any
87
98from pytest import mark , raises
109
1110from dependency_injector import containers , errors , providers , resources
1211
12+
1313def init_fn (* args , ** kwargs ):
1414 return args , kwargs
1515
@@ -76,30 +76,27 @@ class Container(containers.DeclarativeContainer):
7676 assert _init .counter == 1
7777
7878
79- def test_injection_in_different_context ():
79+ @mark .asyncio
80+ async def test_injection_in_different_context ():
8081 def _init ():
8182 return object ()
8283
8384 async def _async_init ():
8485 return object ()
8586
86-
8787 class Container (containers .DeclarativeContainer ):
8888 context_local_resource = providers .ContextLocalResource (_init )
8989 async_context_local_resource = providers .ContextLocalResource (_async_init )
9090
91- loop = asyncio .get_event_loop ()
9291 container = Container ()
93- obj1 = loop . run_until_complete ( container .async_context_local_resource () )
94- obj2 = loop . run_until_complete ( container .async_context_local_resource () )
95- assert obj1 != obj2
92+ obj1 = await container .async_context_local_resource ()
93+ obj2 = await container .async_context_local_resource ()
94+ assert obj1 != obj2
9695
9796 obj3 = container .context_local_resource ()
9897 obj4 = container .context_local_resource ()
9998
100- assert obj3 == obj4
101-
102-
99+ assert obj3 == obj4
103100
104101
105102def test_init_function ():
@@ -121,10 +118,10 @@ def _init():
121118 provider .shutdown ()
122119
123120
124- def test_init_generator ():
121+ def test_init_generator_in_one_context ():
125122 def _init ():
126123 _init .init_counter += 1
127- yield
124+ yield object ()
128125 _init .shutdown_counter += 1
129126
130127 _init .init_counter = 0
@@ -133,33 +130,31 @@ def _init():
133130 provider = providers .ContextLocalResource (_init )
134131
135132 result1 = provider ()
136- assert result1 is None
133+ result2 = provider ()
134+
135+ assert result1 == result2
136+
137137 assert _init .init_counter == 1
138138 assert _init .shutdown_counter == 0
139139
140140 provider .shutdown ()
141141 assert _init .init_counter == 1
142142 assert _init .shutdown_counter == 1
143143
144- result2 = provider ()
145- assert result2 is None
146- assert _init .init_counter == 2
147- assert _init .shutdown_counter == 1
148-
149144 provider .shutdown ()
150- assert _init .init_counter == 2
151- assert _init .shutdown_counter == 2
145+ assert _init .init_counter == 1
146+ assert _init .shutdown_counter == 1
152147
153148
154- def test_init_context_manager () -> None :
149+ def test_init_context_manager_in_one_context () -> None :
155150 init_counter , shutdown_counter = 0 , 0
156151
157152 @contextmanager
158153 def _init ():
159154 nonlocal init_counter , shutdown_counter
160155
161156 init_counter += 1
162- yield
157+ yield object ()
163158 shutdown_counter += 1
164159
165160 init_counter = 0
@@ -168,24 +163,77 @@ def _init():
168163 provider = providers .ContextLocalResource (_init )
169164
170165 result1 = provider ()
171- assert result1 is None
166+ result2 = provider ()
167+ assert result1 == result2
168+
172169 assert init_counter == 1
173170 assert shutdown_counter == 0
174171
175172 provider .shutdown ()
173+
176174 assert init_counter == 1
177175 assert shutdown_counter == 1
178176
179- result2 = provider ()
180- assert result2 is None
181- assert init_counter == 2
177+ provider .shutdown ()
178+ assert init_counter == 1
182179 assert shutdown_counter == 1
183180
184- provider .shutdown ()
181+
182+ @mark .asyncio
183+ async def test_async_init_context_manager_in_different_contexts () -> None :
184+ init_counter , shutdown_counter = 0 , 0
185+
186+ async def _init ():
187+ nonlocal init_counter , shutdown_counter
188+ init_counter += 1
189+ yield object ()
190+ shutdown_counter += 1
191+
192+ init_counter = 0
193+ shutdown_counter = 0
194+
195+ provider = providers .ContextLocalResource (_init )
196+
197+ async def run_in_context ():
198+ resource = await provider ()
199+ await provider .shutdown ()
200+ return resource
201+
202+ result1 , result2 = await asyncio .gather (run_in_context (), run_in_context ())
203+
204+ assert result1 != result2
185205 assert init_counter == 2
186206 assert shutdown_counter == 2
187207
188208
209+ @mark .asyncio
210+ async def test_async_init_context_manager_in_one_context () -> None :
211+ init_counter , shutdown_counter = 0 , 0
212+
213+ async def _init ():
214+ nonlocal init_counter , shutdown_counter
215+ init_counter += 1
216+ yield object ()
217+ shutdown_counter += 1
218+
219+ init_counter = 0
220+ shutdown_counter = 0
221+
222+ provider = providers .ContextLocalResource (_init )
223+
224+ async def run_in_context ():
225+ resource_1 = await provider ()
226+ resource_2 = await provider ()
227+ await provider .shutdown ()
228+ return resource_1 , resource_2
229+
230+ result1 , result2 = await run_in_context ()
231+
232+ assert result1 == result2
233+ assert init_counter == 1
234+ assert shutdown_counter == 1
235+
236+
189237def test_init_class ():
190238 class TestResource (resources .Resource ):
191239 init_counter = 0
@@ -218,7 +266,6 @@ def shutdown(self, _):
218266 assert TestResource .shutdown_counter == 2
219267
220268
221-
222269def test_init_not_callable ():
223270 provider = providers .ContextLocalResource (1 )
224271 with raises (TypeError , match = r"object is not callable" ):
0 commit comments