Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 146 additions & 122 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3260,8 +3260,17 @@ def check_assignment(
lvalue.items, rvalue, rvalue, infer_lvalue_type
)
else:
self.try_infer_partial_generic_type_from_assignment(lvalue, rvalue, "=")
# Resolve partial generic targets before we analyze the lvalue, to mirror the
# legacy flow where this happened pre-check_lvalue.
self.try_resolve_partial_type_from_assignment(
lvalue, None, rvalue, "=", pre_check=True
)
lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue, rvalue)
(lvalue_type, override_rvalue_type, should_return, skip_simple_assignment) = (
self.try_resolve_partial_type_from_assignment(lvalue, lvalue_type, rvalue, "=")
)
if should_return:
return
# If we're assigning to __getattr__ or similar methods, check that the signature is
# valid.
if isinstance(lvalue, NameExpr) and lvalue.node:
Expand Down Expand Up @@ -3293,43 +3302,8 @@ def check_assignment(
self.fail(message_registry.CANNOT_MODIFY_MATCH_ARGS, lvalue)

if lvalue_type:
if isinstance(lvalue_type, PartialType) and lvalue_type.type is None:
# Try to infer a proper type for a variable with a partial None type.
rvalue_type = self.expr_checker.accept(rvalue)
if isinstance(get_proper_type(rvalue_type), NoneType):
# This doesn't actually provide any additional information -- multiple
# None initializers preserve the partial None type.
return

var = lvalue_type.var
if is_valid_inferred_type(
rvalue_type, self.options, is_lvalue_final=var.is_final
):
partial_types = self.find_partial_types(var)
if partial_types is not None:
if not self.current_node_deferred:
# Partial type can't be final, so strip any literal values.
rvalue_type = remove_instance_last_known_values(rvalue_type)
inferred_type = make_simplified_union([rvalue_type, NoneType()])
self.set_inferred_type(var, lvalue, inferred_type)
else:
var.type = None
del partial_types[var]
lvalue_type = var.type
else:
# Try to infer a partial type.
if not self.infer_partial_type(var, lvalue, rvalue_type):
# If that also failed, give up and let the caller know that we
# cannot read their mind. The definition site will be reported later.
# Calling .put() directly because the newly inferred type is
# not a subtype of None - we are not looking for narrowing
fallback = self.inference_error_fallback_type(rvalue_type)
self.binder.put(lvalue, fallback)
# Same as self.set_inference_error_fallback_type but inlined
# to avoid computing fallback twice.
# We are replacing partial<None> now, so the variable type
# should remain optional.
self.set_inferred_type(var, lvalue, make_optional_type(fallback))
if skip_simple_assignment:
rvalue_type = override_rvalue_type
elif (
is_literal_none(rvalue)
and isinstance(lvalue, NameExpr)
Expand Down Expand Up @@ -3483,54 +3457,6 @@ def get_variable_type_context(self, inferred: Var, rvalue: Expression) -> Type |
return None
return candidate

def try_infer_partial_generic_type_from_assignment(
self, lvalue: Lvalue, rvalue: Expression, op: str
) -> None:
"""Try to infer a precise type for partial generic type from assignment.

'op' is '=' for normal assignment and a binary operator ('+', ...) for
augmented assignment.

Example where this happens:

x = []
if foo():
x = [1] # Infer List[int] as type of 'x'
"""
var = None
if (
isinstance(lvalue, NameExpr)
and isinstance(lvalue.node, Var)
and isinstance(lvalue.node.type, PartialType)
):
var = lvalue.node
elif isinstance(lvalue, MemberExpr):
var = self.expr_checker.get_partial_self_var(lvalue)
if var is not None:
typ = var.type
assert isinstance(typ, PartialType)
if typ.type is None:
return
# Return if this is an unsupported augmented assignment.
if op != "=" and (typ.type.fullname, op) not in self.partial_type_augmented_ops:
return
# TODO: some logic here duplicates the None partial type counterpart
# inlined in check_assignment(), see #8043.
partial_types = self.find_partial_types(var)
if partial_types is None:
return
rvalue_type = self.expr_checker.accept(rvalue)
rvalue_type = get_proper_type(rvalue_type)
if isinstance(rvalue_type, Instance):
if rvalue_type.type == typ.type and is_valid_inferred_type(
rvalue_type, self.options
):
var.type = rvalue_type
del partial_types[var]
elif isinstance(rvalue_type, AnyType):
var.type = fill_typevars_with_any(typ.type)
del partial_types[var]

def check_compatibility_all_supers(self, lvalue: RefExpr, rvalue: Expression) -> None:
lvalue_node = lvalue.node
# Check if we are a class variable with at least one base class
Expand Down Expand Up @@ -4801,6 +4727,102 @@ def check_indexed_assignment(
if isinstance(res_type, UninhabitedType) and not res_type.ambiguous:
self.binder.unreachable()

def try_resolve_partial_type_from_assignment(
self,
lvalue: Lvalue,
lvalue_type: Type | None,
rvalue: Expression,
op: str,
*,
pre_check: bool = False,
) -> tuple[Type | None, Type | None, bool, bool]:
"""Refine partial types assigned to by this assignment.

Returns (new_lvalue_type, rvalue_type_override, should_return, skip_simple_assignment).
If skip_simple_assignment is True, the caller should avoid performing the regular
check_simple_assignment logic, since we've already handled the partial type case.
"""
target = self._get_partial_assignment_target(lvalue)
if target is None:
return lvalue_type, None, False, False

var, partial = target
if partial.type is None:
if pre_check:
return lvalue_type, None, False, False
rvalue_type = self.expr_checker.accept(rvalue)
proper_rvalue = get_proper_type(rvalue_type)
if isinstance(proper_rvalue, NoneType):
# This doesn't actually provide any additional information -- multiple
# None initializers preserve the partial None type.
return lvalue_type, rvalue_type, True, True

if is_valid_inferred_type(rvalue_type, self.options, is_lvalue_final=var.is_final):
partial_types = self.find_partial_types(var)
if partial_types is None:
return lvalue_type, rvalue_type, False, True
if not self.current_node_deferred:
# Partial type can't be final, so strip any literal values.
rvalue_type = remove_instance_last_known_values(rvalue_type)
inferred_type = make_simplified_union([rvalue_type, NoneType()])
self.set_inferred_type(var, lvalue, inferred_type)
else:
var.type = None
del partial_types[var]
return var.type, rvalue_type, False, True

# Try to infer a partial type.
if not self.infer_partial_type(var, lvalue, rvalue_type):
# If that also failed, give up and let the caller know that we
# cannot read their mind. The definition site will be reported later.
# Calling .put() directly because the newly inferred type is
# not a subtype of None - we are not looking for narrowing
fallback = self.inference_error_fallback_type(rvalue_type)
self.binder.put(lvalue, fallback)
# Same as self.set_inference_error_fallback_type but inlined
# to avoid computing fallback twice.
# We are replacing partial<None> now, so the variable type
# should remain optional.
self.set_inferred_type(var, lvalue, make_optional_type(fallback))
return lvalue_type, fallback, False, True
return var.type, rvalue_type, False, True

type_info = partial.type
assert type_info is not None
if op != "=" and (type_info.fullname, op) not in self.partial_type_augmented_ops:
return lvalue_type, None, False, False

partial_types = self.find_partial_types(var)
if partial_types is None:
return lvalue_type, None, False, False

rvalue_type = self.expr_checker.accept(rvalue)
rvalue_type = get_proper_type(rvalue_type)
if isinstance(rvalue_type, Instance):
if rvalue_type.type == type_info and is_valid_inferred_type(rvalue_type, self.options):
self.replace_partial_type(var, rvalue_type, partial_types)
return var.type, rvalue_type, False, False
elif isinstance(rvalue_type, AnyType):
self.replace_partial_type(var, fill_typevars_with_any(type_info), partial_types)
return var.type, rvalue_type, False, False

return lvalue_type, None, False, False

def _get_partial_assignment_target(self, expr: Expression) -> tuple[Var, PartialType] | None:
var: Var | None
if isinstance(expr, NameExpr) and isinstance(expr.node, Var):
var = expr.node
elif isinstance(expr, MemberExpr):
var = self.expr_checker.get_partial_self_var(expr)
elif isinstance(expr, RefExpr) and isinstance(expr.node, Var):
var = expr.node
else:
var = None

if isinstance(var, Var) and isinstance(var.type, PartialType):
return var, var.type
return None

def replace_partial_type(
self, var: Var, new_type: Type, partial_types: dict[Var, Context]
) -> None:
Expand All @@ -4819,41 +4841,33 @@ def replace_partial_type(
def try_infer_partial_type_from_indexed_assignment(
self, lvalue: IndexExpr, rvalue: Expression
) -> None:
# TODO: Should we share some of this with try_infer_partial_type?
var = None
if isinstance(lvalue.base, RefExpr) and isinstance(lvalue.base.node, Var):
var = lvalue.base.node
elif isinstance(lvalue.base, MemberExpr):
var = self.expr_checker.get_partial_self_var(lvalue.base)
if isinstance(var, Var):
if isinstance(var.type, PartialType):
type_type = var.type.type
if type_type is None:
return # The partial type is None.
partial_types = self.find_partial_types(var)
if partial_types is None:
return
typename = type_type.fullname
if (
typename == "builtins.dict"
or typename == "collections.OrderedDict"
or typename == "collections.defaultdict"
):
# TODO: Don't infer things twice.
key_type = self.expr_checker.accept(lvalue.index)
value_type = self.expr_checker.accept(rvalue)
if (
is_valid_inferred_type(key_type, self.options)
and is_valid_inferred_type(value_type, self.options)
and not self.current_node_deferred
and not (
typename == "collections.defaultdict"
and var.type.value_type is not None
and not is_equivalent(value_type, var.type.value_type)
)
):
new_type = self.named_generic_type(typename, [key_type, value_type])
self.replace_partial_type(var, new_type, partial_types)
target = self._get_partial_assignment_target(lvalue.base)
if target is None:
return
var, partial = target
type_info = partial.type
if type_info is None:
return # The partial type is None.
partial_types = self.find_partial_types(var)
if partial_types is None:
return
typename = type_info.fullname
if typename not in ("builtins.dict", "collections.OrderedDict", "collections.defaultdict"):
return
# TODO: Don't infer things twice.
key_type = self.expr_checker.accept(lvalue.index)
value_type = self.expr_checker.accept(rvalue)
if not (
is_valid_inferred_type(key_type, self.options)
and is_valid_inferred_type(value_type, self.options)
and not self.current_node_deferred
):
return
if typename == "collections.defaultdict" and partial.value_type is not None:
if not is_equivalent(value_type, partial.value_type):
return
new_type = self.named_generic_type(typename, [key_type, value_type])
self.replace_partial_type(var, new_type, partial_types)

def type_requires_usage(self, typ: Type) -> tuple[str, ErrorCode] | None:
"""Some types require usage in all cases. The classic example is
Expand Down Expand Up @@ -5069,13 +5083,23 @@ def visit_while_stmt(self, s: WhileStmt) -> None:

def visit_operator_assignment_stmt(self, s: OperatorAssignmentStmt) -> None:
"""Type check an operator assignment statement, e.g. x += 1."""
self.try_infer_partial_generic_type_from_assignment(s.lvalue, s.rvalue, s.op)
self.try_resolve_partial_type_from_assignment(
s.lvalue, None, s.rvalue, s.op, pre_check=True
)
if isinstance(s.lvalue, MemberExpr):
# Special case, some additional errors may be given for
# assignments to read-only or final attributes.
lvalue_type = self.expr_checker.visit_member_expr(s.lvalue, True)
else:
lvalue_type = self.expr_checker.accept(s.lvalue)
(resolved_lvalue_type, _, should_return, _) = (
self.try_resolve_partial_type_from_assignment(s.lvalue, lvalue_type, s.rvalue, s.op)
)
if resolved_lvalue_type is not None:
lvalue_type = resolved_lvalue_type
if should_return:
self.check_final(s)
return
inplace, method = infer_operator_assignment_method(lvalue_type, s.op)
if inplace:
# There is __ifoo__, treat as x = x.__ifoo__(y)
Expand Down