88import sys
99from types import ModuleType
1010from typing import (
11+ TYPE_CHECKING ,
1112 Any ,
1213 Callable ,
1314 Dict ,
14- Generic ,
1515 Iterable ,
1616 Iterator ,
1717 Optional ,
18+ Protocol ,
1819 Set ,
1920 Tuple ,
2021 Type ,
2324 cast ,
2425)
2526
27+ from typing_extensions import Self
2628
2729# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362
2830if sys .version_info >= (3 , 9 ):
@@ -66,7 +68,6 @@ def get_origin(tp):
6668
6769from . import providers
6870
69-
7071__all__ = (
7172 "wire" ,
7273 "unwire" ,
@@ -89,7 +90,11 @@ def get_origin(tp):
8990
9091T = TypeVar ("T" )
9192F = TypeVar ("F" , bound = Callable [..., Any ])
92- Container = Any
93+
94+ if TYPE_CHECKING :
95+ from .containers import Container
96+ else :
97+ Container = Any
9398
9499
95100class PatchedRegistry :
@@ -777,15 +782,15 @@ class RequiredModifier(Modifier):
777782 def __init__ (self ) -> None :
778783 self .type_modifier = None
779784
780- def as_int (self ) -> "RequiredModifier" :
785+ def as_int (self ) -> Self :
781786 self .type_modifier = TypeModifier (int )
782787 return self
783788
784- def as_float (self ) -> "RequiredModifier" :
789+ def as_float (self ) -> Self :
785790 self .type_modifier = TypeModifier (float )
786791 return self
787792
788- def as_ (self , type_ : Type ) -> "RequiredModifier" :
793+ def as_ (self , type_ : Type ) -> Self :
789794 self .type_modifier = TypeModifier (type_ )
790795 return self
791796
@@ -833,15 +838,15 @@ class ProvidedInstance(Modifier):
833838 def __init__ (self ) -> None :
834839 self .segments = []
835840
836- def __getattr__ (self , item ) :
841+ def __getattr__ (self , item : str ) -> Self :
837842 self .segments .append ((self .TYPE_ATTRIBUTE , item ))
838843 return self
839844
840- def __getitem__ (self , item ):
845+ def __getitem__ (self , item ) -> Self :
841846 self .segments .append ((self .TYPE_ITEM , item ))
842847 return self
843848
844- def call (self ):
849+ def call (self ) -> Self :
845850 self .segments .append ((self .TYPE_CALL , None ))
846851 return self
847852
@@ -866,36 +871,56 @@ def provided() -> ProvidedInstance:
866871 return ProvidedInstance ()
867872
868873
869- class _Marker (Generic [T ]):
874+ MarkerItem = Union [
875+ str ,
876+ providers .Provider [Any ],
877+ Tuple [str , TypeModifier ],
878+ Type [Container ],
879+ "_Marker" ,
880+ ]
870881
871- __IS_MARKER__ = True
872882
873- def __init__ (
874- self ,
875- provider : Union [providers .Provider , Container , str ],
876- modifier : Optional [Modifier ] = None ,
877- ) -> None :
878- if _is_declarative_container (provider ):
879- provider = provider .__self__
880- self .provider = provider
881- self .modifier = modifier
883+ if TYPE_CHECKING :
882884
883- def __class_getitem__ (cls , item ) -> T :
884- if isinstance (item , tuple ):
885- return cls (* item )
886- return cls (item )
885+ class _Marker (Protocol ):
886+ __IS_MARKER__ : bool
887887
888- def __call__ (self ) -> T :
889- return self
888+ def __call__ (self ) -> Self : ...
889+ def __getattr__ (self , item : str ) -> Self : ...
890+ def __getitem__ (self , item : Any ) -> Any : ...
891+
892+ Provide : _Marker
893+ Provider : _Marker
894+ Closing : _Marker
895+ else :
890896
897+ class _Marker :
891898
892- class Provide ( _Marker ): ...
899+ __IS_MARKER__ = True
893900
901+ def __init__ (
902+ self ,
903+ provider : Union [providers .Provider , Container , str ],
904+ modifier : Optional [Modifier ] = None ,
905+ ) -> None :
906+ if _is_declarative_container (provider ):
907+ provider = provider .__self__
908+ self .provider = provider
909+ self .modifier = modifier
894910
895- class Provider (_Marker ): ...
911+ def __class_getitem__ (cls , item : MarkerItem ) -> Self :
912+ if isinstance (item , tuple ):
913+ return cls (* item )
914+ return cls (item )
896915
916+ def __call__ (self ) -> Self :
917+ return self
897918
898- class Closing (_Marker ): ...
919+ class Provide (_Marker ): ...
920+
921+ class Provider (_Marker ): ...
922+
923+ class Closing (_Marker ): ...
899924
900925
901926class AutoLoader :
@@ -998,8 +1023,8 @@ def is_loader_installed() -> bool:
9981023_loader = AutoLoader ()
9991024
10001025# Optimizations
1001- from ._cwiring import _sync_inject # noqa
10021026from ._cwiring import _async_inject # noqa
1027+ from ._cwiring import _sync_inject # noqa
10031028
10041029
10051030# Wiring uses the following Python wrapper because there is
@@ -1028,13 +1053,17 @@ def _patched(*args, **kwargs):
10281053 patched .injections ,
10291054 patched .closing ,
10301055 )
1056+
10311057 return cast (F , _patched )
10321058
10331059
10341060if sys .version_info >= (3 , 10 ):
1061+
10351062 def _get_annotations (obj : Any ) -> Dict [str , Any ]:
10361063 return inspect .get_annotations (obj )
1064+
10371065else :
1066+
10381067 def _get_annotations (obj : Any ) -> Dict [str , Any ]:
10391068 return getattr (obj , "__annotations__" , {})
10401069
0 commit comments