diff --git a/mypy/checker.py b/mypy/checker.py index 9f8299e6805d..7d8277b091c8 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3260,8 +3260,17 @@ def check_assignment( lvalue.items, rvalue, rvalue, infer_lvalue_type ) else: - self.try_infer_partial_generic_type_from_assignment(lvalue, rvalue, "=") + # Resolve partial generic targets before we analyze the lvalue, to mirror the + # legacy flow where this happened pre-check_lvalue. + self.try_resolve_partial_type_from_assignment( + lvalue, None, rvalue, "=", pre_check=True + ) lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue, rvalue) + (lvalue_type, override_rvalue_type, should_return, skip_simple_assignment) = ( + self.try_resolve_partial_type_from_assignment(lvalue, lvalue_type, rvalue, "=") + ) + if should_return: + return # If we're assigning to __getattr__ or similar methods, check that the signature is # valid. if isinstance(lvalue, NameExpr) and lvalue.node: @@ -3293,43 +3302,8 @@ def check_assignment( self.fail(message_registry.CANNOT_MODIFY_MATCH_ARGS, lvalue) if lvalue_type: - if isinstance(lvalue_type, PartialType) and lvalue_type.type is None: - # Try to infer a proper type for a variable with a partial None type. - rvalue_type = self.expr_checker.accept(rvalue) - if isinstance(get_proper_type(rvalue_type), NoneType): - # This doesn't actually provide any additional information -- multiple - # None initializers preserve the partial None type. - return - - var = lvalue_type.var - if is_valid_inferred_type( - rvalue_type, self.options, is_lvalue_final=var.is_final - ): - partial_types = self.find_partial_types(var) - if partial_types is not None: - if not self.current_node_deferred: - # Partial type can't be final, so strip any literal values. - rvalue_type = remove_instance_last_known_values(rvalue_type) - inferred_type = make_simplified_union([rvalue_type, NoneType()]) - self.set_inferred_type(var, lvalue, inferred_type) - else: - var.type = None - del partial_types[var] - lvalue_type = var.type - else: - # Try to infer a partial type. - if not self.infer_partial_type(var, lvalue, rvalue_type): - # If that also failed, give up and let the caller know that we - # cannot read their mind. The definition site will be reported later. - # Calling .put() directly because the newly inferred type is - # not a subtype of None - we are not looking for narrowing - fallback = self.inference_error_fallback_type(rvalue_type) - self.binder.put(lvalue, fallback) - # Same as self.set_inference_error_fallback_type but inlined - # to avoid computing fallback twice. - # We are replacing partial now, so the variable type - # should remain optional. - self.set_inferred_type(var, lvalue, make_optional_type(fallback)) + if skip_simple_assignment: + rvalue_type = override_rvalue_type elif ( is_literal_none(rvalue) and isinstance(lvalue, NameExpr) @@ -3483,54 +3457,6 @@ def get_variable_type_context(self, inferred: Var, rvalue: Expression) -> Type | return None return candidate - def try_infer_partial_generic_type_from_assignment( - self, lvalue: Lvalue, rvalue: Expression, op: str - ) -> None: - """Try to infer a precise type for partial generic type from assignment. - - 'op' is '=' for normal assignment and a binary operator ('+', ...) for - augmented assignment. - - Example where this happens: - - x = [] - if foo(): - x = [1] # Infer List[int] as type of 'x' - """ - var = None - if ( - isinstance(lvalue, NameExpr) - and isinstance(lvalue.node, Var) - and isinstance(lvalue.node.type, PartialType) - ): - var = lvalue.node - elif isinstance(lvalue, MemberExpr): - var = self.expr_checker.get_partial_self_var(lvalue) - if var is not None: - typ = var.type - assert isinstance(typ, PartialType) - if typ.type is None: - return - # Return if this is an unsupported augmented assignment. - if op != "=" and (typ.type.fullname, op) not in self.partial_type_augmented_ops: - return - # TODO: some logic here duplicates the None partial type counterpart - # inlined in check_assignment(), see #8043. - partial_types = self.find_partial_types(var) - if partial_types is None: - return - rvalue_type = self.expr_checker.accept(rvalue) - rvalue_type = get_proper_type(rvalue_type) - if isinstance(rvalue_type, Instance): - if rvalue_type.type == typ.type and is_valid_inferred_type( - rvalue_type, self.options - ): - var.type = rvalue_type - del partial_types[var] - elif isinstance(rvalue_type, AnyType): - var.type = fill_typevars_with_any(typ.type) - del partial_types[var] - def check_compatibility_all_supers(self, lvalue: RefExpr, rvalue: Expression) -> None: lvalue_node = lvalue.node # Check if we are a class variable with at least one base class @@ -4801,6 +4727,102 @@ def check_indexed_assignment( if isinstance(res_type, UninhabitedType) and not res_type.ambiguous: self.binder.unreachable() + def try_resolve_partial_type_from_assignment( + self, + lvalue: Lvalue, + lvalue_type: Type | None, + rvalue: Expression, + op: str, + *, + pre_check: bool = False, + ) -> tuple[Type | None, Type | None, bool, bool]: + """Refine partial types assigned to by this assignment. + + Returns (new_lvalue_type, rvalue_type_override, should_return, skip_simple_assignment). + If skip_simple_assignment is True, the caller should avoid performing the regular + check_simple_assignment logic, since we've already handled the partial type case. + """ + target = self._get_partial_assignment_target(lvalue) + if target is None: + return lvalue_type, None, False, False + + var, partial = target + if partial.type is None: + if pre_check: + return lvalue_type, None, False, False + rvalue_type = self.expr_checker.accept(rvalue) + proper_rvalue = get_proper_type(rvalue_type) + if isinstance(proper_rvalue, NoneType): + # This doesn't actually provide any additional information -- multiple + # None initializers preserve the partial None type. + return lvalue_type, rvalue_type, True, True + + if is_valid_inferred_type(rvalue_type, self.options, is_lvalue_final=var.is_final): + partial_types = self.find_partial_types(var) + if partial_types is None: + return lvalue_type, rvalue_type, False, True + if not self.current_node_deferred: + # Partial type can't be final, so strip any literal values. + rvalue_type = remove_instance_last_known_values(rvalue_type) + inferred_type = make_simplified_union([rvalue_type, NoneType()]) + self.set_inferred_type(var, lvalue, inferred_type) + else: + var.type = None + del partial_types[var] + return var.type, rvalue_type, False, True + + # Try to infer a partial type. + if not self.infer_partial_type(var, lvalue, rvalue_type): + # If that also failed, give up and let the caller know that we + # cannot read their mind. The definition site will be reported later. + # Calling .put() directly because the newly inferred type is + # not a subtype of None - we are not looking for narrowing + fallback = self.inference_error_fallback_type(rvalue_type) + self.binder.put(lvalue, fallback) + # Same as self.set_inference_error_fallback_type but inlined + # to avoid computing fallback twice. + # We are replacing partial now, so the variable type + # should remain optional. + self.set_inferred_type(var, lvalue, make_optional_type(fallback)) + return lvalue_type, fallback, False, True + return var.type, rvalue_type, False, True + + type_info = partial.type + assert type_info is not None + if op != "=" and (type_info.fullname, op) not in self.partial_type_augmented_ops: + return lvalue_type, None, False, False + + partial_types = self.find_partial_types(var) + if partial_types is None: + return lvalue_type, None, False, False + + rvalue_type = self.expr_checker.accept(rvalue) + rvalue_type = get_proper_type(rvalue_type) + if isinstance(rvalue_type, Instance): + if rvalue_type.type == type_info and is_valid_inferred_type(rvalue_type, self.options): + self.replace_partial_type(var, rvalue_type, partial_types) + return var.type, rvalue_type, False, False + elif isinstance(rvalue_type, AnyType): + self.replace_partial_type(var, fill_typevars_with_any(type_info), partial_types) + return var.type, rvalue_type, False, False + + return lvalue_type, None, False, False + + def _get_partial_assignment_target(self, expr: Expression) -> tuple[Var, PartialType] | None: + var: Var | None + if isinstance(expr, NameExpr) and isinstance(expr.node, Var): + var = expr.node + elif isinstance(expr, MemberExpr): + var = self.expr_checker.get_partial_self_var(expr) + elif isinstance(expr, RefExpr) and isinstance(expr.node, Var): + var = expr.node + else: + var = None + + if isinstance(var, Var) and isinstance(var.type, PartialType): + return var, var.type + return None + def replace_partial_type( self, var: Var, new_type: Type, partial_types: dict[Var, Context] ) -> None: @@ -4819,41 +4841,33 @@ def replace_partial_type( def try_infer_partial_type_from_indexed_assignment( self, lvalue: IndexExpr, rvalue: Expression ) -> None: - # TODO: Should we share some of this with try_infer_partial_type? - var = None - if isinstance(lvalue.base, RefExpr) and isinstance(lvalue.base.node, Var): - var = lvalue.base.node - elif isinstance(lvalue.base, MemberExpr): - var = self.expr_checker.get_partial_self_var(lvalue.base) - if isinstance(var, Var): - if isinstance(var.type, PartialType): - type_type = var.type.type - if type_type is None: - return # The partial type is None. - partial_types = self.find_partial_types(var) - if partial_types is None: - return - typename = type_type.fullname - if ( - typename == "builtins.dict" - or typename == "collections.OrderedDict" - or typename == "collections.defaultdict" - ): - # TODO: Don't infer things twice. - key_type = self.expr_checker.accept(lvalue.index) - value_type = self.expr_checker.accept(rvalue) - if ( - is_valid_inferred_type(key_type, self.options) - and is_valid_inferred_type(value_type, self.options) - and not self.current_node_deferred - and not ( - typename == "collections.defaultdict" - and var.type.value_type is not None - and not is_equivalent(value_type, var.type.value_type) - ) - ): - new_type = self.named_generic_type(typename, [key_type, value_type]) - self.replace_partial_type(var, new_type, partial_types) + target = self._get_partial_assignment_target(lvalue.base) + if target is None: + return + var, partial = target + type_info = partial.type + if type_info is None: + return # The partial type is None. + partial_types = self.find_partial_types(var) + if partial_types is None: + return + typename = type_info.fullname + if typename not in ("builtins.dict", "collections.OrderedDict", "collections.defaultdict"): + return + # TODO: Don't infer things twice. + key_type = self.expr_checker.accept(lvalue.index) + value_type = self.expr_checker.accept(rvalue) + if not ( + is_valid_inferred_type(key_type, self.options) + and is_valid_inferred_type(value_type, self.options) + and not self.current_node_deferred + ): + return + if typename == "collections.defaultdict" and partial.value_type is not None: + if not is_equivalent(value_type, partial.value_type): + return + new_type = self.named_generic_type(typename, [key_type, value_type]) + self.replace_partial_type(var, new_type, partial_types) def type_requires_usage(self, typ: Type) -> tuple[str, ErrorCode] | None: """Some types require usage in all cases. The classic example is @@ -5069,13 +5083,23 @@ def visit_while_stmt(self, s: WhileStmt) -> None: def visit_operator_assignment_stmt(self, s: OperatorAssignmentStmt) -> None: """Type check an operator assignment statement, e.g. x += 1.""" - self.try_infer_partial_generic_type_from_assignment(s.lvalue, s.rvalue, s.op) + self.try_resolve_partial_type_from_assignment( + s.lvalue, None, s.rvalue, s.op, pre_check=True + ) if isinstance(s.lvalue, MemberExpr): # Special case, some additional errors may be given for # assignments to read-only or final attributes. lvalue_type = self.expr_checker.visit_member_expr(s.lvalue, True) else: lvalue_type = self.expr_checker.accept(s.lvalue) + (resolved_lvalue_type, _, should_return, _) = ( + self.try_resolve_partial_type_from_assignment(s.lvalue, lvalue_type, s.rvalue, s.op) + ) + if resolved_lvalue_type is not None: + lvalue_type = resolved_lvalue_type + if should_return: + self.check_final(s) + return inplace, method = infer_operator_assignment_method(lvalue_type, s.op) if inplace: # There is __ifoo__, treat as x = x.__ifoo__(y)