Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6609,6 +6609,28 @@ def equality_type_narrowing_helper(
narrowable_operand_index_to_hash: dict[int, tuple[Key, ...]],
) -> tuple[TypeMap, TypeMap]:
"""Calculate type maps for '==', '!=', 'is' or 'is not' expression."""
# If we haven't been able to narrow types yet, we might be dealing with a
# explicit type(x) == some_type check
if_map, else_map = self.narrow_type_by_equality(
operator,
operands,
operand_types,
expr_indices,
narrowable_operand_index_to_hash.keys(),
)
if if_map == {} and else_map == {} and node is not None:
if_map, else_map = self.find_type_equals_check(node, expr_indices)
return if_map, else_map

def narrow_type_by_equality(
self,
operator: str,
operands: list[Expression],
operand_types: list[Type],
expr_indices: list[int],
narrowable_indices: AbstractSet[int],
) -> tuple[TypeMap, TypeMap]:
"""Calculate type maps for '==', '!=', 'is' or 'is not' expression, ignoring `type(x)` checks."""
# is_valid_target:
# Controls which types we're allowed to narrow exprs to. Note that
# we cannot use 'is_literal_type_like' in both cases since doing
Expand Down Expand Up @@ -6654,20 +6676,15 @@ def has_no_custom_eq_checks(t: Type) -> bool:
operands,
operand_types,
expr_indices,
narrowable_operand_index_to_hash.keys(),
narrowable_indices,
is_valid_target,
coerce_only_in_literal_context,
)

if if_map == {} and else_map == {}:
if_map, else_map = self.refine_away_none_in_comparison(
operands, operand_types, expr_indices, narrowable_operand_index_to_hash.keys()
operands, operand_types, expr_indices, narrowable_indices
)

# If we haven't been able to narrow types yet, we might be dealing with a
# explicit type(x) == some_type check
if if_map == {} and else_map == {}:
if_map, else_map = self.find_type_equals_check(node, expr_indices)
return if_map, else_map

def propagate_up_typemap_info(self, new_types: TypeMap) -> TypeMap:
Expand Down Expand Up @@ -6902,6 +6919,11 @@ def should_coerce_inner(typ: Type) -> bool:
for i in chain_indices:
expr_type = operand_types[i]
if should_coerce:
# TODO: doing this prevents narrowing a single-member Enum to literal
# of its member, because we expand it here and then refuse to add equal
# types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
# `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
# See testMatchEnumSingleChoice
expr_type = coerce_to_literal(expr_type)
if not is_valid_target(get_proper_type(expr_type)):
continue
Expand Down
13 changes: 12 additions & 1 deletion mypy/checker_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from abc import abstractmethod
from collections.abc import Iterator, Sequence
from collections.abc import Iterator, Sequence, Set as AbstractSet
from contextlib import contextmanager
from typing import NamedTuple, overload

Expand Down Expand Up @@ -245,6 +245,17 @@ def conditional_types_with_intersection(
) -> tuple[Type | None, Type | None]:
raise NotImplementedError

@abstractmethod
def narrow_type_by_equality(
self,
operator: str,
operands: list[Expression],
operand_types: list[Type],
expr_indices: list[int],
narrowable_indices: AbstractSet[int],
) -> tuple[dict[Expression, Type] | None, dict[Expression, Type] | None]:
raise NotImplementedError

@abstractmethod
def check_deprecated(self, node: Node | None, context: Context) -> None:
raise NotImplementedError
Expand Down
16 changes: 9 additions & 7 deletions mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mypy.maptype import map_instance_to_supertype
from mypy.meet import narrow_declared_type
from mypy.messages import MessageBuilder
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, Var
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TempNode, TypeAlias, Var
from mypy.options import Options
from mypy.patterns import (
AsPattern,
Expand All @@ -39,7 +39,6 @@
AnyType,
FunctionLike,
Instance,
LiteralType,
NoneType,
ProperType,
TupleType,
Expand Down Expand Up @@ -205,12 +204,15 @@ def visit_value_pattern(self, o: ValuePattern) -> PatternType:
current_type = self.type_context[-1]
typ = self.chk.expr_checker.accept(o.expr)
typ = coerce_to_literal(typ)
narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
current_type, [get_type_range(typ)], o, default=get_proper_type(typ)
node = TempNode(current_type)
# Value patterns are essentially a syntactic sugar on top of `if x == Value`.
# They should be treated equivalently.
ok_map, rest_map = self.chk.narrow_type_by_equality(
"==", [node, TempNode(typ)], [current_type, typ], [0, 1], {0}
)
if not isinstance(get_proper_type(narrowed_type), (LiteralType, UninhabitedType)):
return PatternType(narrowed_type, UnionType.make_union([narrowed_type, rest_type]), {})
return PatternType(narrowed_type, rest_type, {})
ok_type = ok_map.get(node, current_type) if ok_map is not None else UninhabitedType()
rest_type = rest_map.get(node, current_type) if rest_map is not None else UninhabitedType()
return PatternType(ok_type, rest_type, {})

def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType:
current_type = self.type_context[-1]
Expand Down
53 changes: 41 additions & 12 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ m: Any

match m:
case 1:
reveal_type(m) # N: Revealed type is "Literal[1]"
reveal_type(m) # N: Revealed type is "Any"
case 2:
reveal_type(m) # N: Revealed type is "Literal[2]"
reveal_type(m) # N: Revealed type is "Any"
case other:
reveal_type(other) # N: Revealed type is "Any"

Expand Down Expand Up @@ -61,7 +61,7 @@ m: object

match m:
case b.b:
reveal_type(m) # N: Revealed type is "builtins.int"
reveal_type(m) # N: Revealed type is "builtins.object"
[file b.py]
b: int

Expand All @@ -83,7 +83,7 @@ m: A

match m:
case b.b:
reveal_type(m) # N: Revealed type is "__main__.<subclass of "__main__.A" and "b.B">"
reveal_type(m) # N: Revealed type is "__main__.A"
[file b.py]
class B: ...
b: B
Expand All @@ -96,7 +96,7 @@ m: int

match m:
case b.b:
reveal_type(m)
reveal_type(m) # N: Revealed type is "builtins.int"
[file b.py]
b: str
[builtins fixtures/primitives.pyi]
Expand Down Expand Up @@ -1742,14 +1742,15 @@ from typing import NoReturn
def assert_never(x: NoReturn) -> None: ...

class Medal(Enum):
gold = 1
GOLD = 1

def f(m: Medal) -> None:
always_assigned: int | None = None
match m:
case Medal.gold:
case Medal.GOLD:
always_assigned = 1
reveal_type(m) # N: Revealed type is "Literal[__main__.Medal.gold]"
# This should narrow to literal, see TODO in checker::refine_identity_comparison_expression
reveal_type(m) # N: Revealed type is "__main__.Medal"
case _:
assert_never(m)

Expand Down Expand Up @@ -1785,6 +1786,34 @@ def g(m: Medal) -> int:
return 2
[builtins fixtures/enum.pyi]

[case testMatchLiteralOrValuePattern]
# flags: --warn-unreachable
from typing import Literal

def test1(x: Literal[1,2,3]) -> None:
match x:
case 1:
reveal_type(x) # N: Revealed type is "Literal[1]"
case other:
reveal_type(x) # N: Revealed type is "Union[Literal[2], Literal[3]]"

def test2(x: Literal[1,2,3]) -> None:
match x:
case 1:
reveal_type(x) # N: Revealed type is "Literal[1]"
case 2:
reveal_type(x) # N: Revealed type is "Literal[2]"
case 3:
reveal_type(x) # N: Revealed type is "Literal[3]"
case other:
1 # E: Statement is unreachable

def test3(x: Literal[1,2,3]) -> None:
match x:
case 1 | 3:
reveal_type(x) # N: Revealed type is "Union[Literal[1], Literal[3]]"
case other:
reveal_type(x) # N: Revealed type is "Literal[2]"

[case testMatchLiteralPatternEnumWithTypedAttribute]
from enum import Enum
Expand Down Expand Up @@ -2813,7 +2842,7 @@ match A().foo:
def int_literal() -> None:
match 12:
case 1 as s:
reveal_type(s) # N: Revealed type is "Literal[1]"
reveal_type(s) # E: Statement is unreachable
case int(i):
reveal_type(i) # N: Revealed type is "Literal[12]?"
case other:
Expand All @@ -2822,7 +2851,7 @@ def int_literal() -> None:
def str_literal() -> None:
match 'foo':
case 'a' as s:
reveal_type(s) # N: Revealed type is "Literal['a']"
reveal_type(s) # E: Statement is unreachable
case str(i):
reveal_type(i) # N: Revealed type is "Literal['foo']?"
case other:
Expand Down Expand Up @@ -2909,9 +2938,9 @@ T_Choice = TypeVar("T_Choice", bound=b.One | b.Two)
def switch(choice: type[T_Choice]) -> None:
match choice:
case b.One:
reveal_type(choice) # N: Revealed type is "def () -> b.One"
reveal_type(choice) # N: Revealed type is "type[T_Choice`-1]"
case b.Two:
reveal_type(choice) # N: Revealed type is "def () -> b.Two"
reveal_type(choice) # N: Revealed type is "type[T_Choice`-1]"
case _:
reveal_type(choice) # N: Revealed type is "type[T_Choice`-1]"

Expand Down
Loading