Skip to content

Commit d6e0ea4

Browse files
authored
Support injecting annotated types with Inject[] (#279)
* Support injecting annotated types with `Inject[]` This was missing in #263. * Extend the tests with two annotated types To avoid the test from passing by just binding the origin type. * Disable branch coverage `__metadata__` is never empty.
1 parent 9e68690 commit d6e0ea4

File tree

2 files changed

+148
-2
lines changed

2 files changed

+148
-2
lines changed

injector/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,16 @@ def _is_injection_annotation(annotation: Any) -> bool:
12131213
_inject_marker in annotation.__metadata__ or _noinject_marker in annotation.__metadata__
12141214
)
12151215

1216+
def _recreate_annotated_origin(annotated_type: Any) -> Any:
1217+
# Creates `Annotated[type, annotation]` from `Inject[Annotated[type, annotation]]`,
1218+
# to support the injection of annotated types with the `Inject[]` annotation.
1219+
origin = annotated_type.__origin__
1220+
for metadata in annotated_type.__metadata__: # pragma: no branch
1221+
if metadata in (_inject_marker, _noinject_marker):
1222+
break
1223+
origin = Annotated[origin, metadata]
1224+
return origin
1225+
12161226
spec = inspect.getfullargspec(callable)
12171227

12181228
try:
@@ -1245,7 +1255,7 @@ def _is_injection_annotation(annotation: Any) -> bool:
12451255
for k, v in list(bindings.items()):
12461256
# extract metadata only from Inject and NonInject
12471257
if _is_injection_annotation(v):
1248-
v, metadata = v.__origin__, v.__metadata__
1258+
v, metadata = _recreate_annotated_origin(v), v.__metadata__
12491259
bindings[k] = v
12501260
else:
12511261
metadata = tuple()

injector_test.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1700,6 +1700,67 @@ def function(a: int) -> 'InvalidForwardReference':
17001700
assert get_bindings(function) == {'a': int}
17011701

17021702

1703+
def test_gets_bindings_for_annotated_type_with_inject_decorator() -> None:
1704+
UserID = Annotated[int, 'user_id']
1705+
1706+
@inject
1707+
def function(a: UserID, b: str) -> None:
1708+
pass
1709+
1710+
assert get_bindings(function) == {'a': UserID, 'b': str}
1711+
1712+
1713+
def test_gets_bindings_of_annotated_type_with_inject_annotation() -> None:
1714+
UserID = Annotated[int, 'user_id']
1715+
1716+
def function(a: Inject[UserID], b: Inject[str]) -> None:
1717+
pass
1718+
1719+
assert get_bindings(function) == {'a': UserID, 'b': str}
1720+
1721+
1722+
def test_gets_bindings_of_new_type_with_inject_annotation() -> None:
1723+
Name = NewType('Name', str)
1724+
1725+
@inject
1726+
def function(a: Name, b: str) -> None:
1727+
pass
1728+
1729+
assert get_bindings(function) == {'a': Name, 'b': str}
1730+
1731+
1732+
def test_gets_bindings_of_inject_annotation_with_new_type() -> None:
1733+
def function(a: Inject[Name], b: str) -> None:
1734+
pass
1735+
1736+
assert get_bindings(function) == {'a': Name}
1737+
1738+
1739+
def test_get_bindings_of_nested_noinject_inject_annotation() -> None:
1740+
# This is not how this is intended to be used
1741+
def function(a: Inject[NoInject[int]], b: NoInject[Inject[str]]) -> None:
1742+
pass
1743+
1744+
assert get_bindings(function) == {}
1745+
1746+
1747+
def test_get_bindings_of_nested_noinject_inject_annotation_and_inject_decorator() -> None:
1748+
# This is not how this is intended to be used
1749+
@inject
1750+
def function(a: Inject[NoInject[int]], b: NoInject[Inject[str]]) -> None:
1751+
pass
1752+
1753+
assert get_bindings(function) == {}
1754+
1755+
1756+
def test_get_bindings_of_nested_inject_annotations() -> None:
1757+
# This is not how this is intended to be used
1758+
def function(a: Inject[Inject[int]]) -> None:
1759+
pass
1760+
1761+
assert get_bindings(function) == {'a': int}
1762+
1763+
17031764
# Tests https://github.com/alecthomas/injector/issues/202
17041765
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10+")
17051766
def test_get_bindings_for_pep_604():
@@ -1785,21 +1846,80 @@ def configure(binder):
17851846

17861847
def test_annotated_integration_with_annotated():
17871848
UserID = Annotated[int, 'user_id']
1849+
UserAge = Annotated[int, 'user_age']
17881850

17891851
@inject
17901852
class TestClass:
1791-
def __init__(self, user_id: UserID):
1853+
def __init__(self, user_id: UserID, user_age: UserAge):
1854+
self.user_id = user_id
1855+
self.user_age = user_age
1856+
1857+
def configure(binder):
1858+
binder.bind(UserID, to=123)
1859+
binder.bind(UserAge, to=32)
1860+
1861+
injector = Injector([configure])
1862+
1863+
test_class = injector.get(TestClass)
1864+
assert test_class.user_id == 123
1865+
assert test_class.user_age == 32
1866+
1867+
1868+
def test_inject_annotation_with_annotated_type():
1869+
UserID = Annotated[int, 'user_id']
1870+
UserAge = Annotated[int, 'user_age']
1871+
1872+
class TestClass:
1873+
def __init__(self, user_id: Inject[UserID], user_age: Inject[UserAge]):
17921874
self.user_id = user_id
1875+
self.user_age = user_age
17931876

17941877
def configure(binder):
17951878
binder.bind(UserID, to=123)
1879+
binder.bind(UserAge, to=32)
1880+
binder.bind(int, to=456)
1881+
1882+
injector = Injector([configure])
1883+
1884+
test_class = injector.get(TestClass)
1885+
assert test_class.user_id == 123
1886+
assert test_class.user_age == 32
1887+
1888+
1889+
def test_inject_annotation_with_nested_annotated_type():
1890+
UserID = Annotated[int, 'user_id']
1891+
SpecialUserID = Annotated[UserID, 'special_user_id']
1892+
1893+
class TestClass:
1894+
def __init__(self, user_id: Inject[SpecialUserID]):
1895+
self.user_id = user_id
1896+
1897+
def configure(binder):
1898+
binder.bind(SpecialUserID, to=123)
17961899

17971900
injector = Injector([configure])
17981901

17991902
test_class = injector.get(TestClass)
18001903
assert test_class.user_id == 123
18011904

18021905

1906+
def test_noinject_annotation_with_annotated_type():
1907+
UserID = Annotated[int, 'user_id']
1908+
1909+
@inject
1910+
class TestClass:
1911+
def __init__(self, user_id: NoInject[UserID] = None):
1912+
self.user_id = user_id
1913+
1914+
def configure(binder):
1915+
binder.bind(UserID, to=123)
1916+
1917+
injector = Injector([configure])
1918+
1919+
test_class = injector.get(TestClass)
1920+
assert test_class.user_id is None
1921+
1922+
18031923
def test_newtype_integration_with_annotated():
18041924
UserID = NewType('UserID', int)
18051925

@@ -1817,6 +1937,22 @@ def configure(binder):
18171937
assert test_class.user_id == 123
18181938

18191939

1940+
def test_newtype_with_injection_annotation():
1941+
UserID = NewType('UserID', int)
1942+
1943+
class TestClass:
1944+
def __init__(self, user_id: Inject[UserID]):
1945+
self.user_id = user_id
1946+
1947+
def configure(binder):
1948+
binder.bind(UserID, to=123)
1949+
1950+
injector = Injector([configure])
1951+
1952+
test_class = injector.get(TestClass)
1953+
assert test_class.user_id == 123
1954+
1955+
18201956
def test_dataclass_annotated_parameter():
18211957
Foo = Annotated[int, object()]
18221958

0 commit comments

Comments
 (0)