Skip to content

Commit dc75e68

Browse files
Refactor partial type inference to reuse shared logic
1 parent 45aa599 commit dc75e68

File tree

1 file changed

+158
-122
lines changed

1 file changed

+158
-122
lines changed

mypy/checker.py

Lines changed: 158 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -3260,8 +3260,20 @@ def check_assignment(
32603260
lvalue.items, rvalue, rvalue, infer_lvalue_type
32613261
)
32623262
else:
3263-
self.try_infer_partial_generic_type_from_assignment(lvalue, rvalue, "=")
3263+
# Resolve partial generic targets before we analyze the lvalue, to mirror the
3264+
# legacy flow where this happened pre-check_lvalue.
3265+
self.try_resolve_partial_type_from_assignment(
3266+
lvalue, None, rvalue, "=", pre_check=True
3267+
)
32643268
lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue, rvalue)
3269+
(
3270+
lvalue_type,
3271+
override_rvalue_type,
3272+
should_return,
3273+
skip_simple_assignment,
3274+
) = self.try_resolve_partial_type_from_assignment(lvalue, lvalue_type, rvalue, "=")
3275+
if should_return:
3276+
return
32653277
# If we're assigning to __getattr__ or similar methods, check that the signature is
32663278
# valid.
32673279
if isinstance(lvalue, NameExpr) and lvalue.node:
@@ -3293,43 +3305,8 @@ def check_assignment(
32933305
self.fail(message_registry.CANNOT_MODIFY_MATCH_ARGS, lvalue)
32943306

32953307
if lvalue_type:
3296-
if isinstance(lvalue_type, PartialType) and lvalue_type.type is None:
3297-
# Try to infer a proper type for a variable with a partial None type.
3298-
rvalue_type = self.expr_checker.accept(rvalue)
3299-
if isinstance(get_proper_type(rvalue_type), NoneType):
3300-
# This doesn't actually provide any additional information -- multiple
3301-
# None initializers preserve the partial None type.
3302-
return
3303-
3304-
var = lvalue_type.var
3305-
if is_valid_inferred_type(
3306-
rvalue_type, self.options, is_lvalue_final=var.is_final
3307-
):
3308-
partial_types = self.find_partial_types(var)
3309-
if partial_types is not None:
3310-
if not self.current_node_deferred:
3311-
# Partial type can't be final, so strip any literal values.
3312-
rvalue_type = remove_instance_last_known_values(rvalue_type)
3313-
inferred_type = make_simplified_union([rvalue_type, NoneType()])
3314-
self.set_inferred_type(var, lvalue, inferred_type)
3315-
else:
3316-
var.type = None
3317-
del partial_types[var]
3318-
lvalue_type = var.type
3319-
else:
3320-
# Try to infer a partial type.
3321-
if not self.infer_partial_type(var, lvalue, rvalue_type):
3322-
# If that also failed, give up and let the caller know that we
3323-
# cannot read their mind. The definition site will be reported later.
3324-
# Calling .put() directly because the newly inferred type is
3325-
# not a subtype of None - we are not looking for narrowing
3326-
fallback = self.inference_error_fallback_type(rvalue_type)
3327-
self.binder.put(lvalue, fallback)
3328-
# Same as self.set_inference_error_fallback_type but inlined
3329-
# to avoid computing fallback twice.
3330-
# We are replacing partial<None> now, so the variable type
3331-
# should remain optional.
3332-
self.set_inferred_type(var, lvalue, make_optional_type(fallback))
3308+
if skip_simple_assignment:
3309+
rvalue_type = override_rvalue_type
33333310
elif (
33343311
is_literal_none(rvalue)
33353312
and isinstance(lvalue, NameExpr)
@@ -3483,54 +3460,6 @@ def get_variable_type_context(self, inferred: Var, rvalue: Expression) -> Type |
34833460
return None
34843461
return candidate
34853462

3486-
def try_infer_partial_generic_type_from_assignment(
3487-
self, lvalue: Lvalue, rvalue: Expression, op: str
3488-
) -> None:
3489-
"""Try to infer a precise type for partial generic type from assignment.
3490-
3491-
'op' is '=' for normal assignment and a binary operator ('+', ...) for
3492-
augmented assignment.
3493-
3494-
Example where this happens:
3495-
3496-
x = []
3497-
if foo():
3498-
x = [1] # Infer List[int] as type of 'x'
3499-
"""
3500-
var = None
3501-
if (
3502-
isinstance(lvalue, NameExpr)
3503-
and isinstance(lvalue.node, Var)
3504-
and isinstance(lvalue.node.type, PartialType)
3505-
):
3506-
var = lvalue.node
3507-
elif isinstance(lvalue, MemberExpr):
3508-
var = self.expr_checker.get_partial_self_var(lvalue)
3509-
if var is not None:
3510-
typ = var.type
3511-
assert isinstance(typ, PartialType)
3512-
if typ.type is None:
3513-
return
3514-
# Return if this is an unsupported augmented assignment.
3515-
if op != "=" and (typ.type.fullname, op) not in self.partial_type_augmented_ops:
3516-
return
3517-
# TODO: some logic here duplicates the None partial type counterpart
3518-
# inlined in check_assignment(), see #8043.
3519-
partial_types = self.find_partial_types(var)
3520-
if partial_types is None:
3521-
return
3522-
rvalue_type = self.expr_checker.accept(rvalue)
3523-
rvalue_type = get_proper_type(rvalue_type)
3524-
if isinstance(rvalue_type, Instance):
3525-
if rvalue_type.type == typ.type and is_valid_inferred_type(
3526-
rvalue_type, self.options
3527-
):
3528-
var.type = rvalue_type
3529-
del partial_types[var]
3530-
elif isinstance(rvalue_type, AnyType):
3531-
var.type = fill_typevars_with_any(typ.type)
3532-
del partial_types[var]
3533-
35343463
def check_compatibility_all_supers(self, lvalue: RefExpr, rvalue: Expression) -> None:
35353464
lvalue_node = lvalue.node
35363465
# Check if we are a class variable with at least one base class
@@ -4801,6 +4730,104 @@ def check_indexed_assignment(
48014730
if isinstance(res_type, UninhabitedType) and not res_type.ambiguous:
48024731
self.binder.unreachable()
48034732

4733+
def try_resolve_partial_type_from_assignment(
4734+
self,
4735+
lvalue: Lvalue,
4736+
lvalue_type: Type | None,
4737+
rvalue: Expression,
4738+
op: str,
4739+
*,
4740+
pre_check: bool = False,
4741+
) -> tuple[Type | None, Type | None, bool, bool]:
4742+
"""Refine partial types assigned to by this assignment.
4743+
4744+
Returns (new_lvalue_type, rvalue_type_override, should_return, skip_simple_assignment).
4745+
If skip_simple_assignment is True, the caller should avoid performing the regular
4746+
check_simple_assignment logic, since we've already handled the partial type case.
4747+
"""
4748+
target = self._get_partial_assignment_target(lvalue)
4749+
if target is None:
4750+
return lvalue_type, None, False, False
4751+
4752+
var, partial = target
4753+
if partial.type is None:
4754+
if pre_check:
4755+
return lvalue_type, None, False, False
4756+
rvalue_type = self.expr_checker.accept(rvalue)
4757+
proper_rvalue = get_proper_type(rvalue_type)
4758+
if isinstance(proper_rvalue, NoneType):
4759+
# This doesn't actually provide any additional information -- multiple
4760+
# None initializers preserve the partial None type.
4761+
return lvalue_type, rvalue_type, True, True
4762+
4763+
if is_valid_inferred_type(
4764+
rvalue_type, self.options, is_lvalue_final=var.is_final
4765+
):
4766+
partial_types = self.find_partial_types(var)
4767+
if partial_types is None:
4768+
return lvalue_type, rvalue_type, False, True
4769+
if not self.current_node_deferred:
4770+
# Partial type can't be final, so strip any literal values.
4771+
rvalue_type = remove_instance_last_known_values(rvalue_type)
4772+
inferred_type = make_simplified_union([rvalue_type, NoneType()])
4773+
self.set_inferred_type(var, lvalue, inferred_type)
4774+
else:
4775+
var.type = None
4776+
del partial_types[var]
4777+
return var.type, rvalue_type, False, True
4778+
4779+
# Try to infer a partial type.
4780+
if not self.infer_partial_type(var, lvalue, rvalue_type):
4781+
# If that also failed, give up and let the caller know that we
4782+
# cannot read their mind. The definition site will be reported later.
4783+
# Calling .put() directly because the newly inferred type is
4784+
# not a subtype of None - we are not looking for narrowing
4785+
fallback = self.inference_error_fallback_type(rvalue_type)
4786+
self.binder.put(lvalue, fallback)
4787+
# Same as self.set_inference_error_fallback_type but inlined
4788+
# to avoid computing fallback twice.
4789+
# We are replacing partial<None> now, so the variable type
4790+
# should remain optional.
4791+
self.set_inferred_type(var, lvalue, make_optional_type(fallback))
4792+
return lvalue_type, fallback, False, True
4793+
return var.type, rvalue_type, False, True
4794+
4795+
type_info = partial.type
4796+
assert type_info is not None
4797+
if op != "=" and (type_info.fullname, op) not in self.partial_type_augmented_ops:
4798+
return lvalue_type, None, False, False
4799+
4800+
partial_types = self.find_partial_types(var)
4801+
if partial_types is None:
4802+
return lvalue_type, None, False, False
4803+
4804+
rvalue_type = self.expr_checker.accept(rvalue)
4805+
rvalue_type = get_proper_type(rvalue_type)
4806+
if isinstance(rvalue_type, Instance):
4807+
if rvalue_type.type == type_info and is_valid_inferred_type(rvalue_type, self.options):
4808+
self.replace_partial_type(var, rvalue_type, partial_types)
4809+
return var.type, rvalue_type, False, False
4810+
elif isinstance(rvalue_type, AnyType):
4811+
self.replace_partial_type(var, fill_typevars_with_any(type_info), partial_types)
4812+
return var.type, rvalue_type, False, False
4813+
4814+
return lvalue_type, None, False, False
4815+
4816+
def _get_partial_assignment_target(self, expr: Expression) -> tuple[Var, PartialType] | None:
4817+
var: Var | None
4818+
if isinstance(expr, NameExpr) and isinstance(expr.node, Var):
4819+
var = expr.node
4820+
elif isinstance(expr, MemberExpr):
4821+
var = self.expr_checker.get_partial_self_var(expr)
4822+
elif isinstance(expr, RefExpr) and isinstance(expr.node, Var):
4823+
var = expr.node
4824+
else:
4825+
var = None
4826+
4827+
if isinstance(var, Var) and isinstance(var.type, PartialType):
4828+
return var, var.type
4829+
return None
4830+
48044831
def replace_partial_type(
48054832
self, var: Var, new_type: Type, partial_types: dict[Var, Context]
48064833
) -> None:
@@ -4819,41 +4846,37 @@ def replace_partial_type(
48194846
def try_infer_partial_type_from_indexed_assignment(
48204847
self, lvalue: IndexExpr, rvalue: Expression
48214848
) -> None:
4822-
# TODO: Should we share some of this with try_infer_partial_type?
4823-
var = None
4824-
if isinstance(lvalue.base, RefExpr) and isinstance(lvalue.base.node, Var):
4825-
var = lvalue.base.node
4826-
elif isinstance(lvalue.base, MemberExpr):
4827-
var = self.expr_checker.get_partial_self_var(lvalue.base)
4828-
if isinstance(var, Var):
4829-
if isinstance(var.type, PartialType):
4830-
type_type = var.type.type
4831-
if type_type is None:
4832-
return # The partial type is None.
4833-
partial_types = self.find_partial_types(var)
4834-
if partial_types is None:
4835-
return
4836-
typename = type_type.fullname
4837-
if (
4838-
typename == "builtins.dict"
4839-
or typename == "collections.OrderedDict"
4840-
or typename == "collections.defaultdict"
4841-
):
4842-
# TODO: Don't infer things twice.
4843-
key_type = self.expr_checker.accept(lvalue.index)
4844-
value_type = self.expr_checker.accept(rvalue)
4845-
if (
4846-
is_valid_inferred_type(key_type, self.options)
4847-
and is_valid_inferred_type(value_type, self.options)
4848-
and not self.current_node_deferred
4849-
and not (
4850-
typename == "collections.defaultdict"
4851-
and var.type.value_type is not None
4852-
and not is_equivalent(value_type, var.type.value_type)
4853-
)
4854-
):
4855-
new_type = self.named_generic_type(typename, [key_type, value_type])
4856-
self.replace_partial_type(var, new_type, partial_types)
4849+
target = self._get_partial_assignment_target(lvalue.base)
4850+
if target is None:
4851+
return
4852+
var, partial = target
4853+
type_info = partial.type
4854+
if type_info is None:
4855+
return # The partial type is None.
4856+
partial_types = self.find_partial_types(var)
4857+
if partial_types is None:
4858+
return
4859+
typename = type_info.fullname
4860+
if typename not in (
4861+
"builtins.dict",
4862+
"collections.OrderedDict",
4863+
"collections.defaultdict",
4864+
):
4865+
return
4866+
# TODO: Don't infer things twice.
4867+
key_type = self.expr_checker.accept(lvalue.index)
4868+
value_type = self.expr_checker.accept(rvalue)
4869+
if not (
4870+
is_valid_inferred_type(key_type, self.options)
4871+
and is_valid_inferred_type(value_type, self.options)
4872+
and not self.current_node_deferred
4873+
):
4874+
return
4875+
if typename == "collections.defaultdict" and partial.value_type is not None:
4876+
if not is_equivalent(value_type, partial.value_type):
4877+
return
4878+
new_type = self.named_generic_type(typename, [key_type, value_type])
4879+
self.replace_partial_type(var, new_type, partial_types)
48574880

48584881
def type_requires_usage(self, typ: Type) -> tuple[str, ErrorCode] | None:
48594882
"""Some types require usage in all cases. The classic example is
@@ -5069,13 +5092,26 @@ def visit_while_stmt(self, s: WhileStmt) -> None:
50695092

50705093
def visit_operator_assignment_stmt(self, s: OperatorAssignmentStmt) -> None:
50715094
"""Type check an operator assignment statement, e.g. x += 1."""
5072-
self.try_infer_partial_generic_type_from_assignment(s.lvalue, s.rvalue, s.op)
5095+
self.try_resolve_partial_type_from_assignment(
5096+
s.lvalue, None, s.rvalue, s.op, pre_check=True
5097+
)
50735098
if isinstance(s.lvalue, MemberExpr):
50745099
# Special case, some additional errors may be given for
50755100
# assignments to read-only or final attributes.
50765101
lvalue_type = self.expr_checker.visit_member_expr(s.lvalue, True)
50775102
else:
50785103
lvalue_type = self.expr_checker.accept(s.lvalue)
5104+
(
5105+
resolved_lvalue_type,
5106+
_,
5107+
should_return,
5108+
_,
5109+
) = self.try_resolve_partial_type_from_assignment(s.lvalue, lvalue_type, s.rvalue, s.op)
5110+
if resolved_lvalue_type is not None:
5111+
lvalue_type = resolved_lvalue_type
5112+
if should_return:
5113+
self.check_final(s)
5114+
return
50795115
inplace, method = infer_operator_assignment_method(lvalue_type, s.op)
50805116
if inplace:
50815117
# There is __ifoo__, treat as x = x.__ifoo__(y)

0 commit comments

Comments
 (0)