@@ -9,23 +9,27 @@ import types
99from . import providers
1010from .wiring import _Marker
1111
12+ from .providers cimport Provider
13+
1214
1315def _get_sync_patched (fn ):
1416 @ functools.wraps (fn)
1517 def _patched (*args , **kwargs ):
1618 cdef object result
1719 cdef dict to_inject
20+ cdef object arg_key
21+ cdef Provider provider
1822
1923 to_inject = kwargs.copy()
20- for injection , provider in _patched.__injections__.items():
21- if injection not in kwargs or isinstance (kwargs[injection ], _Marker):
22- to_inject[injection ] = provider()
24+ for arg_key , provider in _patched.__injections__.items():
25+ if arg_key not in kwargs or isinstance (kwargs[arg_key ], _Marker):
26+ to_inject[arg_key ] = provider()
2327
2428 result = fn(* args, ** to_inject)
2529
2630 if _patched.__closing__:
27- for injection , provider in _patched.__closing__.items():
28- if injection in kwargs and not isinstance (kwargs[injection ], _Marker):
31+ for arg_key , provider in _patched.__closing__.items():
32+ if arg_key in kwargs and not isinstance (kwargs[arg_key ], _Marker):
2933 continue
3034 if not isinstance (provider, providers.Resource):
3135 continue
@@ -35,49 +39,45 @@ def _get_sync_patched(fn):
3539 return _patched
3640
3741
38- def _get_async_patched (fn ):
39- @ functools.wraps (fn)
40- async def _patched(* args, ** kwargs):
41- cdef object result
42- cdef dict to_inject
43- cdef list to_inject_await = []
44- cdef list to_close_await = []
45-
46- to_inject = kwargs.copy()
47- for injection, provider in _patched.__injections__.items():
48- if injection not in kwargs or isinstance (kwargs[injection], _Marker):
49- provide = provider()
50- if _isawaitable(provide):
51- to_inject_await.append((injection, provide))
52- else :
53- to_inject[injection] = provide
54-
55- if to_inject_await:
56- async_to_inject = await asyncio.gather(* (provide for _, provide in to_inject_await))
57- for provide, (injection, _) in zip (async_to_inject, to_inject_await):
58- to_inject[injection] = provide
59-
60- result = await fn(* args, ** to_inject)
61-
62- if _patched.__closing__:
63- for injection, provider in _patched.__closing__.items():
64- if injection in kwargs \
65- and isinstance (kwargs[injection], _Marker):
66- continue
67- if not isinstance (provider, providers.Resource):
68- continue
69- shutdown = provider.shutdown()
70- if _isawaitable(shutdown):
71- to_close_await.append(shutdown)
72-
73- await asyncio.gather(* to_close_await)
74-
75- return result
76-
77- # Hotfix for iscoroutinefunction() for Cython < 3.0.0; can be removed after migration to Cython 3.0.0+
78- _patched._is_coroutine = asyncio.coroutines._is_coroutine
79-
80- return _patched
42+ async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings):
43+ cdef object result
44+ cdef dict to_inject
45+ cdef list to_inject_await = []
46+ cdef list to_close_await = []
47+ cdef object arg_key
48+ cdef Provider provider
49+
50+ to_inject = kwargs.copy()
51+ for arg_key, provider in injections.items():
52+ if arg_key not in kwargs or isinstance (kwargs[arg_key], _Marker):
53+ provide = provider()
54+ if provider.is_async_mode_enabled():
55+ to_inject_await.append((arg_key, provide))
56+ elif _isawaitable(provide):
57+ to_inject_await.append((arg_key, provide))
58+ else :
59+ to_inject[arg_key] = provide
60+
61+ if to_inject_await:
62+ async_to_inject = await asyncio.gather(* (provide for _, provide in to_inject_await))
63+ for provide, (injection, _) in zip (async_to_inject, to_inject_await):
64+ to_inject[injection] = provide
65+
66+ result = await fn(* args, ** to_inject)
67+
68+ if closings:
69+ for arg_key, provider in closings.items():
70+ if arg_key in kwargs and isinstance (kwargs[arg_key], _Marker):
71+ continue
72+ if not isinstance (provider, providers.Resource):
73+ continue
74+ shutdown = provider.shutdown()
75+ if _isawaitable(shutdown):
76+ to_close_await.append(shutdown)
77+
78+ await asyncio.gather(* to_close_await)
79+
80+ return result
8181
8282
8383cdef bint _isawaitable(object instance):
0 commit comments