diff --git a/mypy/checker.py b/mypy/checker.py index 63e128f78310..0089ffd035fb 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6609,6 +6609,28 @@ def equality_type_narrowing_helper( narrowable_operand_index_to_hash: dict[int, tuple[Key, ...]], ) -> tuple[TypeMap, TypeMap]: """Calculate type maps for '==', '!=', 'is' or 'is not' expression.""" + # If we haven't been able to narrow types yet, we might be dealing with a + # explicit type(x) == some_type check + if_map, else_map = self.narrow_type_by_equality( + operator, + operands, + operand_types, + expr_indices, + narrowable_operand_index_to_hash.keys(), + ) + if if_map == {} and else_map == {} and node is not None: + if_map, else_map = self.find_type_equals_check(node, expr_indices) + return if_map, else_map + + def narrow_type_by_equality( + self, + operator: str, + operands: list[Expression], + operand_types: list[Type], + expr_indices: list[int], + narrowable_indices: AbstractSet[int], + ) -> tuple[TypeMap, TypeMap]: + """Calculate type maps for '==', '!=', 'is' or 'is not' expression, ignoring `type(x)` checks.""" # is_valid_target: # Controls which types we're allowed to narrow exprs to. Note that # we cannot use 'is_literal_type_like' in both cases since doing @@ -6654,20 +6676,15 @@ def has_no_custom_eq_checks(t: Type) -> bool: operands, operand_types, expr_indices, - narrowable_operand_index_to_hash.keys(), + narrowable_indices, is_valid_target, coerce_only_in_literal_context, ) if if_map == {} and else_map == {}: if_map, else_map = self.refine_away_none_in_comparison( - operands, operand_types, expr_indices, narrowable_operand_index_to_hash.keys() + operands, operand_types, expr_indices, narrowable_indices ) - - # If we haven't been able to narrow types yet, we might be dealing with a - # explicit type(x) == some_type check - if if_map == {} and else_map == {}: - if_map, else_map = self.find_type_equals_check(node, expr_indices) return if_map, else_map def propagate_up_typemap_info(self, new_types: TypeMap) -> TypeMap: @@ -6902,6 +6919,11 @@ def should_coerce_inner(typ: Type) -> bool: for i in chain_indices: expr_type = operand_types[i] if should_coerce: + # TODO: doing this prevents narrowing a single-member Enum to literal + # of its member, because we expand it here and then refuse to add equal + # types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow + # `x` to `Literal[Foo.A]` iff `Foo` has exactly one member. + # See testMatchEnumSingleChoice expr_type = coerce_to_literal(expr_type) if not is_valid_target(get_proper_type(expr_type)): continue diff --git a/mypy/checker_shared.py b/mypy/checker_shared.py index 0014d2c6fc88..6aac66904bde 100644 --- a/mypy/checker_shared.py +++ b/mypy/checker_shared.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import abstractmethod -from collections.abc import Iterator, Sequence +from collections.abc import Iterator, Sequence, Set as AbstractSet from contextlib import contextmanager from typing import NamedTuple, overload @@ -245,6 +245,17 @@ def conditional_types_with_intersection( ) -> tuple[Type | None, Type | None]: raise NotImplementedError + @abstractmethod + def narrow_type_by_equality( + self, + operator: str, + operands: list[Expression], + operand_types: list[Type], + expr_indices: list[int], + narrowable_indices: AbstractSet[int], + ) -> tuple[dict[Expression, Type] | None, dict[Expression, Type] | None]: + raise NotImplementedError + @abstractmethod def check_deprecated(self, node: Node | None, context: Context) -> None: raise NotImplementedError diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 3c51c4106909..9b1bae1ef6d1 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -14,7 +14,7 @@ from mypy.maptype import map_instance_to_supertype from mypy.meet import narrow_declared_type from mypy.messages import MessageBuilder -from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, Var +from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TempNode, TypeAlias, Var from mypy.options import Options from mypy.patterns import ( AsPattern, @@ -39,7 +39,6 @@ AnyType, FunctionLike, Instance, - LiteralType, NoneType, ProperType, TupleType, @@ -205,12 +204,15 @@ def visit_value_pattern(self, o: ValuePattern) -> PatternType: current_type = self.type_context[-1] typ = self.chk.expr_checker.accept(o.expr) typ = coerce_to_literal(typ) - narrowed_type, rest_type = self.chk.conditional_types_with_intersection( - current_type, [get_type_range(typ)], o, default=get_proper_type(typ) + node = TempNode(current_type) + # Value patterns are essentially a syntactic sugar on top of `if x == Value`. + # They should be treated equivalently. + ok_map, rest_map = self.chk.narrow_type_by_equality( + "==", [node, TempNode(typ)], [current_type, typ], [0, 1], {0} ) - if not isinstance(get_proper_type(narrowed_type), (LiteralType, UninhabitedType)): - return PatternType(narrowed_type, UnionType.make_union([narrowed_type, rest_type]), {}) - return PatternType(narrowed_type, rest_type, {}) + ok_type = ok_map.get(node, current_type) if ok_map is not None else UninhabitedType() + rest_type = rest_map.get(node, current_type) if rest_map is not None else UninhabitedType() + return PatternType(ok_type, rest_type, {}) def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType: current_type = self.type_context[-1] diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 1e27e30d4b04..1bd2c82fb95f 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -30,9 +30,9 @@ m: Any match m: case 1: - reveal_type(m) # N: Revealed type is "Literal[1]" + reveal_type(m) # N: Revealed type is "Any" case 2: - reveal_type(m) # N: Revealed type is "Literal[2]" + reveal_type(m) # N: Revealed type is "Any" case other: reveal_type(other) # N: Revealed type is "Any" @@ -61,7 +61,7 @@ m: object match m: case b.b: - reveal_type(m) # N: Revealed type is "builtins.int" + reveal_type(m) # N: Revealed type is "builtins.object" [file b.py] b: int @@ -83,7 +83,7 @@ m: A match m: case b.b: - reveal_type(m) # N: Revealed type is "__main__." + reveal_type(m) # N: Revealed type is "__main__.A" [file b.py] class B: ... b: B @@ -96,7 +96,7 @@ m: int match m: case b.b: - reveal_type(m) + reveal_type(m) # N: Revealed type is "builtins.int" [file b.py] b: str [builtins fixtures/primitives.pyi] @@ -1742,14 +1742,15 @@ from typing import NoReturn def assert_never(x: NoReturn) -> None: ... class Medal(Enum): - gold = 1 + GOLD = 1 def f(m: Medal) -> None: always_assigned: int | None = None match m: - case Medal.gold: + case Medal.GOLD: always_assigned = 1 - reveal_type(m) # N: Revealed type is "Literal[__main__.Medal.gold]" + # This should narrow to literal, see TODO in checker::refine_identity_comparison_expression + reveal_type(m) # N: Revealed type is "__main__.Medal" case _: assert_never(m) @@ -1785,6 +1786,34 @@ def g(m: Medal) -> int: return 2 [builtins fixtures/enum.pyi] +[case testMatchLiteralOrValuePattern] +# flags: --warn-unreachable +from typing import Literal + +def test1(x: Literal[1,2,3]) -> None: + match x: + case 1: + reveal_type(x) # N: Revealed type is "Literal[1]" + case other: + reveal_type(x) # N: Revealed type is "Union[Literal[2], Literal[3]]" + +def test2(x: Literal[1,2,3]) -> None: + match x: + case 1: + reveal_type(x) # N: Revealed type is "Literal[1]" + case 2: + reveal_type(x) # N: Revealed type is "Literal[2]" + case 3: + reveal_type(x) # N: Revealed type is "Literal[3]" + case other: + 1 # E: Statement is unreachable + +def test3(x: Literal[1,2,3]) -> None: + match x: + case 1 | 3: + reveal_type(x) # N: Revealed type is "Union[Literal[1], Literal[3]]" + case other: + reveal_type(x) # N: Revealed type is "Literal[2]" [case testMatchLiteralPatternEnumWithTypedAttribute] from enum import Enum @@ -2813,7 +2842,7 @@ match A().foo: def int_literal() -> None: match 12: case 1 as s: - reveal_type(s) # N: Revealed type is "Literal[1]" + reveal_type(s) # E: Statement is unreachable case int(i): reveal_type(i) # N: Revealed type is "Literal[12]?" case other: @@ -2822,7 +2851,7 @@ def int_literal() -> None: def str_literal() -> None: match 'foo': case 'a' as s: - reveal_type(s) # N: Revealed type is "Literal['a']" + reveal_type(s) # E: Statement is unreachable case str(i): reveal_type(i) # N: Revealed type is "Literal['foo']?" case other: @@ -2909,9 +2938,9 @@ T_Choice = TypeVar("T_Choice", bound=b.One | b.Two) def switch(choice: type[T_Choice]) -> None: match choice: case b.One: - reveal_type(choice) # N: Revealed type is "def () -> b.One" + reveal_type(choice) # N: Revealed type is "type[T_Choice`-1]" case b.Two: - reveal_type(choice) # N: Revealed type is "def () -> b.Two" + reveal_type(choice) # N: Revealed type is "type[T_Choice`-1]" case _: reveal_type(choice) # N: Revealed type is "type[T_Choice`-1]"