Skip to content

Commit a5756b5

Browse files
fix shutdowner default none value, add more tests
1 parent 806bedb commit a5756b5

File tree

2 files changed

+77
-30
lines changed

2 files changed

+77
-30
lines changed

src/dependency_injector/providers.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3905,7 +3905,7 @@ cdef class ContextLocalResource(Resource):
39053905
if self._async_mode == ASYNC_MODE_ENABLED:
39063906
return NULL_AWAITABLE
39073907
return
3908-
if self._shutdowner_context_var.get():
3908+
if self._shutdowner_context_var.get() != self._none:
39093909
future = self._shutdowner_context_var.get()(None, None, None)
39103910
if __is_future_or_coroutine(future):
39113911
self._reset_all_contex_vars()
@@ -3977,7 +3977,7 @@ cdef class ContextLocalResource(Resource):
39773977
return resource
39783978
else:
39793979
self._resource_context_var.set(obj)
3980-
self._shutdowner_context_var.set(None)
3980+
self._shutdowner_context_var.set(self._none)
39813981

39823982
return self._resource_context_var.get()
39833983

tests/unit/providers/resource/test_context_local_resource_py38.py

Lines changed: 75 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import decimal
55
import sys
66
from contextlib import contextmanager
7-
from typing import Any
87

98
from pytest import mark, raises
109

1110
from dependency_injector import containers, errors, providers, resources
1211

12+
1313
def init_fn(*args, **kwargs):
1414
return args, kwargs
1515

@@ -76,30 +76,27 @@ class Container(containers.DeclarativeContainer):
7676
assert _init.counter == 1
7777

7878

79-
def test_injection_in_different_context():
79+
@mark.asyncio
80+
async def test_injection_in_different_context():
8081
def _init():
8182
return object()
8283

8384
async def _async_init():
8485
return object()
8586

86-
8787
class Container(containers.DeclarativeContainer):
8888
context_local_resource = providers.ContextLocalResource(_init)
8989
async_context_local_resource = providers.ContextLocalResource(_async_init)
9090

91-
loop = asyncio.get_event_loop()
9291
container = Container()
93-
obj1 = loop.run_until_complete(container.async_context_local_resource())
94-
obj2 = loop.run_until_complete(container.async_context_local_resource())
95-
assert obj1!=obj2
92+
obj1 = await container.async_context_local_resource()
93+
obj2 = await container.async_context_local_resource()
94+
assert obj1 != obj2
9695

9796
obj3 = container.context_local_resource()
9897
obj4 = container.context_local_resource()
9998

100-
assert obj3==obj4
101-
102-
99+
assert obj3 == obj4
103100

104101

105102
def test_init_function():
@@ -121,10 +118,10 @@ def _init():
121118
provider.shutdown()
122119

123120

124-
def test_init_generator():
121+
def test_init_generator_in_one_context():
125122
def _init():
126123
_init.init_counter += 1
127-
yield
124+
yield object()
128125
_init.shutdown_counter += 1
129126

130127
_init.init_counter = 0
@@ -133,33 +130,31 @@ def _init():
133130
provider = providers.ContextLocalResource(_init)
134131

135132
result1 = provider()
136-
assert result1 is None
133+
result2 = provider()
134+
135+
assert result1 == result2
136+
137137
assert _init.init_counter == 1
138138
assert _init.shutdown_counter == 0
139139

140140
provider.shutdown()
141141
assert _init.init_counter == 1
142142
assert _init.shutdown_counter == 1
143143

144-
result2 = provider()
145-
assert result2 is None
146-
assert _init.init_counter == 2
147-
assert _init.shutdown_counter == 1
148-
149144
provider.shutdown()
150-
assert _init.init_counter == 2
151-
assert _init.shutdown_counter == 2
145+
assert _init.init_counter == 1
146+
assert _init.shutdown_counter == 1
152147

153148

154-
def test_init_context_manager() -> None:
149+
def test_init_context_manager_in_one_context() -> None:
155150
init_counter, shutdown_counter = 0, 0
156151

157152
@contextmanager
158153
def _init():
159154
nonlocal init_counter, shutdown_counter
160155

161156
init_counter += 1
162-
yield
157+
yield object()
163158
shutdown_counter += 1
164159

165160
init_counter = 0
@@ -168,24 +163,77 @@ def _init():
168163
provider = providers.ContextLocalResource(_init)
169164

170165
result1 = provider()
171-
assert result1 is None
166+
result2 = provider()
167+
assert result1 == result2
168+
172169
assert init_counter == 1
173170
assert shutdown_counter == 0
174171

175172
provider.shutdown()
173+
176174
assert init_counter == 1
177175
assert shutdown_counter == 1
178176

179-
result2 = provider()
180-
assert result2 is None
181-
assert init_counter == 2
177+
provider.shutdown()
178+
assert init_counter == 1
182179
assert shutdown_counter == 1
183180

184-
provider.shutdown()
181+
182+
@mark.asyncio
183+
async def test_async_init_context_manager_in_different_contexts() -> None:
184+
init_counter, shutdown_counter = 0, 0
185+
186+
async def _init():
187+
nonlocal init_counter, shutdown_counter
188+
init_counter += 1
189+
yield object()
190+
shutdown_counter += 1
191+
192+
init_counter = 0
193+
shutdown_counter = 0
194+
195+
provider = providers.ContextLocalResource(_init)
196+
197+
async def run_in_context():
198+
resource = await provider()
199+
await provider.shutdown()
200+
return resource
201+
202+
result1, result2 = await asyncio.gather(run_in_context(), run_in_context())
203+
204+
assert result1 != result2
185205
assert init_counter == 2
186206
assert shutdown_counter == 2
187207

188208

209+
@mark.asyncio
210+
async def test_async_init_context_manager_in_one_context() -> None:
211+
init_counter, shutdown_counter = 0, 0
212+
213+
async def _init():
214+
nonlocal init_counter, shutdown_counter
215+
init_counter += 1
216+
yield object()
217+
shutdown_counter += 1
218+
219+
init_counter = 0
220+
shutdown_counter = 0
221+
222+
provider = providers.ContextLocalResource(_init)
223+
224+
async def run_in_context():
225+
resource_1 = await provider()
226+
resource_2 = await provider()
227+
await provider.shutdown()
228+
return resource_1, resource_2
229+
230+
result1, result2 = await run_in_context()
231+
232+
assert result1 == result2
233+
assert init_counter == 1
234+
assert shutdown_counter == 1
235+
236+
189237
def test_init_class():
190238
class TestResource(resources.Resource):
191239
init_counter = 0
@@ -218,7 +266,6 @@ def shutdown(self, _):
218266
assert TestResource.shutdown_counter == 2
219267

220268

221-
222269
def test_init_not_callable():
223270
provider = providers.ContextLocalResource(1)
224271
with raises(TypeError, match=r"object is not callable"):

0 commit comments

Comments
 (0)