Skip to content

Commit 31a915a

Browse files
authored
Use dummy concrete type instead of Any when checking protocol variance (#20110)
Fixes #20108. Variance checks for protocols follow a procedure roughly equivalent to that described in [typing.python.org - Variance Inference](https://typing.python.org/en/latest/spec/generics.html#variance-inference). A major difference in mypy's current implementation is in Step 3: > Create two specialized versions of the class. We’ll refer to these as `upper` and `lower` specializations. In both of these specializations, replace all type parameters other than the one being inferred by a dummy type instance (a concrete anonymous class that is assumed to meet the bounds or constraints of the type parameter). Mypy currently uses `Any` rather than a concrete dummy type. This causes issues during overload subtype checks in the example reported in the original issue, as the specialisations when checking variance suitability of `_T2_contra` look like: ```python from typing import TypeVar, Protocol, overload _T1_contra = TypeVar("_T1_contra", contravariant=True) _T2_contra = TypeVar("_T2_contra", contravariant=True) class A(Protocol[<_T1_contra=Any>, _T2_contra]): @overload def method(self, a: <_T1_contra=Any>) -> None: ... @overload def method(self, a: _T2_contra) -> None: ... ``` This PR replaces the use of `Any` with a dummy concrete type in the entire protocol variance check to more closely follow the variance inference algorithm in the spec and fixes this overload issue.
1 parent 28536b5 commit 31a915a

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

mypy/checker.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi):
395395

396396
# A helper state to produce unique temporary names on demand.
397397
_unique_id: int
398+
# Fake concrete type used when checking variance
399+
_variance_dummy_type: Instance | None
398400

399401
def __init__(
400402
self,
@@ -469,6 +471,7 @@ def __init__(
469471

470472
self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options)
471473
self._unique_id = 0
474+
self._variance_dummy_type = None
472475

473476
@property
474477
def expr_checker(self) -> mypy.checkexpr.ExpressionChecker:
@@ -2918,17 +2921,19 @@ def check_protocol_variance(self, defn: ClassDef) -> None:
29182921
info = defn.info
29192922
object_type = Instance(info.mro[-1], [])
29202923
tvars = info.defn.type_vars
2924+
if self._variance_dummy_type is None:
2925+
_, dummy_info = self.make_fake_typeinfo("<dummy>", "Dummy", "Dummy", [])
2926+
self._variance_dummy_type = Instance(dummy_info, [])
2927+
dummy = self._variance_dummy_type
29212928
for i, tvar in enumerate(tvars):
29222929
if not isinstance(tvar, TypeVarType):
29232930
# Variance of TypeVarTuple and ParamSpec is underspecified by PEPs.
29242931
continue
29252932
up_args: list[Type] = [
2926-
object_type if i == j else AnyType(TypeOfAny.special_form)
2927-
for j, _ in enumerate(tvars)
2933+
object_type if i == j else dummy.copy_modified() for j, _ in enumerate(tvars)
29282934
]
29292935
down_args: list[Type] = [
2930-
UninhabitedType() if i == j else AnyType(TypeOfAny.special_form)
2931-
for j, _ in enumerate(tvars)
2936+
UninhabitedType() if i == j else dummy.copy_modified() for j, _ in enumerate(tvars)
29322937
]
29332938
up, down = Instance(info, up_args), Instance(info, down_args)
29342939
# TODO: add advanced variance checks for recursive protocols

test-data/unit/check-protocols.test

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,19 @@ main:16: note: def meth(self, x: int) -> int
13731373
main:16: note: @overload
13741374
main:16: note: def meth(self, x: bytes) -> str
13751375

1376+
[case testProtocolWithMultiContravariantTypeVarOverloads]
1377+
from typing import overload, Protocol, TypeVar
1378+
1379+
T1 = TypeVar("T1", contravariant=True)
1380+
T2 = TypeVar("T2", contravariant=True)
1381+
1382+
class A(Protocol[T1, T2]):
1383+
@overload
1384+
def method(self, a: T1) -> None: ...
1385+
@overload
1386+
def method(self, a: T2) -> None: ...
1387+
1388+
13761389
-- Join and meet with protocol types
13771390
-- ---------------------------------
13781391

0 commit comments

Comments
 (0)