@@ -3260,8 +3260,20 @@ def check_assignment(
32603260 lvalue .items , rvalue , rvalue , infer_lvalue_type
32613261 )
32623262 else :
3263- self .try_infer_partial_generic_type_from_assignment (lvalue , rvalue , "=" )
3263+ # Resolve partial generic targets before we analyze the lvalue, to mirror the
3264+ # legacy flow where this happened pre-check_lvalue.
3265+ self .try_resolve_partial_type_from_assignment (
3266+ lvalue , None , rvalue , "=" , pre_check = True
3267+ )
32643268 lvalue_type , index_lvalue , inferred = self .check_lvalue (lvalue , rvalue )
3269+ (
3270+ lvalue_type ,
3271+ override_rvalue_type ,
3272+ should_return ,
3273+ skip_simple_assignment ,
3274+ ) = self .try_resolve_partial_type_from_assignment (lvalue , lvalue_type , rvalue , "=" )
3275+ if should_return :
3276+ return
32653277 # If we're assigning to __getattr__ or similar methods, check that the signature is
32663278 # valid.
32673279 if isinstance (lvalue , NameExpr ) and lvalue .node :
@@ -3293,43 +3305,8 @@ def check_assignment(
32933305 self .fail (message_registry .CANNOT_MODIFY_MATCH_ARGS , lvalue )
32943306
32953307 if lvalue_type :
3296- if isinstance (lvalue_type , PartialType ) and lvalue_type .type is None :
3297- # Try to infer a proper type for a variable with a partial None type.
3298- rvalue_type = self .expr_checker .accept (rvalue )
3299- if isinstance (get_proper_type (rvalue_type ), NoneType ):
3300- # This doesn't actually provide any additional information -- multiple
3301- # None initializers preserve the partial None type.
3302- return
3303-
3304- var = lvalue_type .var
3305- if is_valid_inferred_type (
3306- rvalue_type , self .options , is_lvalue_final = var .is_final
3307- ):
3308- partial_types = self .find_partial_types (var )
3309- if partial_types is not None :
3310- if not self .current_node_deferred :
3311- # Partial type can't be final, so strip any literal values.
3312- rvalue_type = remove_instance_last_known_values (rvalue_type )
3313- inferred_type = make_simplified_union ([rvalue_type , NoneType ()])
3314- self .set_inferred_type (var , lvalue , inferred_type )
3315- else :
3316- var .type = None
3317- del partial_types [var ]
3318- lvalue_type = var .type
3319- else :
3320- # Try to infer a partial type.
3321- if not self .infer_partial_type (var , lvalue , rvalue_type ):
3322- # If that also failed, give up and let the caller know that we
3323- # cannot read their mind. The definition site will be reported later.
3324- # Calling .put() directly because the newly inferred type is
3325- # not a subtype of None - we are not looking for narrowing
3326- fallback = self .inference_error_fallback_type (rvalue_type )
3327- self .binder .put (lvalue , fallback )
3328- # Same as self.set_inference_error_fallback_type but inlined
3329- # to avoid computing fallback twice.
3330- # We are replacing partial<None> now, so the variable type
3331- # should remain optional.
3332- self .set_inferred_type (var , lvalue , make_optional_type (fallback ))
3308+ if skip_simple_assignment :
3309+ rvalue_type = override_rvalue_type
33333310 elif (
33343311 is_literal_none (rvalue )
33353312 and isinstance (lvalue , NameExpr )
@@ -3483,54 +3460,6 @@ def get_variable_type_context(self, inferred: Var, rvalue: Expression) -> Type |
34833460 return None
34843461 return candidate
34853462
3486- def try_infer_partial_generic_type_from_assignment (
3487- self , lvalue : Lvalue , rvalue : Expression , op : str
3488- ) -> None :
3489- """Try to infer a precise type for partial generic type from assignment.
3490-
3491- 'op' is '=' for normal assignment and a binary operator ('+', ...) for
3492- augmented assignment.
3493-
3494- Example where this happens:
3495-
3496- x = []
3497- if foo():
3498- x = [1] # Infer List[int] as type of 'x'
3499- """
3500- var = None
3501- if (
3502- isinstance (lvalue , NameExpr )
3503- and isinstance (lvalue .node , Var )
3504- and isinstance (lvalue .node .type , PartialType )
3505- ):
3506- var = lvalue .node
3507- elif isinstance (lvalue , MemberExpr ):
3508- var = self .expr_checker .get_partial_self_var (lvalue )
3509- if var is not None :
3510- typ = var .type
3511- assert isinstance (typ , PartialType )
3512- if typ .type is None :
3513- return
3514- # Return if this is an unsupported augmented assignment.
3515- if op != "=" and (typ .type .fullname , op ) not in self .partial_type_augmented_ops :
3516- return
3517- # TODO: some logic here duplicates the None partial type counterpart
3518- # inlined in check_assignment(), see #8043.
3519- partial_types = self .find_partial_types (var )
3520- if partial_types is None :
3521- return
3522- rvalue_type = self .expr_checker .accept (rvalue )
3523- rvalue_type = get_proper_type (rvalue_type )
3524- if isinstance (rvalue_type , Instance ):
3525- if rvalue_type .type == typ .type and is_valid_inferred_type (
3526- rvalue_type , self .options
3527- ):
3528- var .type = rvalue_type
3529- del partial_types [var ]
3530- elif isinstance (rvalue_type , AnyType ):
3531- var .type = fill_typevars_with_any (typ .type )
3532- del partial_types [var ]
3533-
35343463 def check_compatibility_all_supers (self , lvalue : RefExpr , rvalue : Expression ) -> None :
35353464 lvalue_node = lvalue .node
35363465 # Check if we are a class variable with at least one base class
@@ -4801,6 +4730,104 @@ def check_indexed_assignment(
48014730 if isinstance (res_type , UninhabitedType ) and not res_type .ambiguous :
48024731 self .binder .unreachable ()
48034732
4733+ def try_resolve_partial_type_from_assignment (
4734+ self ,
4735+ lvalue : Lvalue ,
4736+ lvalue_type : Type | None ,
4737+ rvalue : Expression ,
4738+ op : str ,
4739+ * ,
4740+ pre_check : bool = False ,
4741+ ) -> tuple [Type | None , Type | None , bool , bool ]:
4742+ """Refine partial types assigned to by this assignment.
4743+
4744+ Returns (new_lvalue_type, rvalue_type_override, should_return, skip_simple_assignment).
4745+ If skip_simple_assignment is True, the caller should avoid performing the regular
4746+ check_simple_assignment logic, since we've already handled the partial type case.
4747+ """
4748+ target = self ._get_partial_assignment_target (lvalue )
4749+ if target is None :
4750+ return lvalue_type , None , False , False
4751+
4752+ var , partial = target
4753+ if partial .type is None :
4754+ if pre_check :
4755+ return lvalue_type , None , False , False
4756+ rvalue_type = self .expr_checker .accept (rvalue )
4757+ proper_rvalue = get_proper_type (rvalue_type )
4758+ if isinstance (proper_rvalue , NoneType ):
4759+ # This doesn't actually provide any additional information -- multiple
4760+ # None initializers preserve the partial None type.
4761+ return lvalue_type , rvalue_type , True , True
4762+
4763+ if is_valid_inferred_type (
4764+ rvalue_type , self .options , is_lvalue_final = var .is_final
4765+ ):
4766+ partial_types = self .find_partial_types (var )
4767+ if partial_types is None :
4768+ return lvalue_type , rvalue_type , False , True
4769+ if not self .current_node_deferred :
4770+ # Partial type can't be final, so strip any literal values.
4771+ rvalue_type = remove_instance_last_known_values (rvalue_type )
4772+ inferred_type = make_simplified_union ([rvalue_type , NoneType ()])
4773+ self .set_inferred_type (var , lvalue , inferred_type )
4774+ else :
4775+ var .type = None
4776+ del partial_types [var ]
4777+ return var .type , rvalue_type , False , True
4778+
4779+ # Try to infer a partial type.
4780+ if not self .infer_partial_type (var , lvalue , rvalue_type ):
4781+ # If that also failed, give up and let the caller know that we
4782+ # cannot read their mind. The definition site will be reported later.
4783+ # Calling .put() directly because the newly inferred type is
4784+ # not a subtype of None - we are not looking for narrowing
4785+ fallback = self .inference_error_fallback_type (rvalue_type )
4786+ self .binder .put (lvalue , fallback )
4787+ # Same as self.set_inference_error_fallback_type but inlined
4788+ # to avoid computing fallback twice.
4789+ # We are replacing partial<None> now, so the variable type
4790+ # should remain optional.
4791+ self .set_inferred_type (var , lvalue , make_optional_type (fallback ))
4792+ return lvalue_type , fallback , False , True
4793+ return var .type , rvalue_type , False , True
4794+
4795+ type_info = partial .type
4796+ assert type_info is not None
4797+ if op != "=" and (type_info .fullname , op ) not in self .partial_type_augmented_ops :
4798+ return lvalue_type , None , False , False
4799+
4800+ partial_types = self .find_partial_types (var )
4801+ if partial_types is None :
4802+ return lvalue_type , None , False , False
4803+
4804+ rvalue_type = self .expr_checker .accept (rvalue )
4805+ rvalue_type = get_proper_type (rvalue_type )
4806+ if isinstance (rvalue_type , Instance ):
4807+ if rvalue_type .type == type_info and is_valid_inferred_type (rvalue_type , self .options ):
4808+ self .replace_partial_type (var , rvalue_type , partial_types )
4809+ return var .type , rvalue_type , False , False
4810+ elif isinstance (rvalue_type , AnyType ):
4811+ self .replace_partial_type (var , fill_typevars_with_any (type_info ), partial_types )
4812+ return var .type , rvalue_type , False , False
4813+
4814+ return lvalue_type , None , False , False
4815+
4816+ def _get_partial_assignment_target (self , expr : Expression ) -> tuple [Var , PartialType ] | None :
4817+ var : Var | None
4818+ if isinstance (expr , NameExpr ) and isinstance (expr .node , Var ):
4819+ var = expr .node
4820+ elif isinstance (expr , MemberExpr ):
4821+ var = self .expr_checker .get_partial_self_var (expr )
4822+ elif isinstance (expr , RefExpr ) and isinstance (expr .node , Var ):
4823+ var = expr .node
4824+ else :
4825+ var = None
4826+
4827+ if isinstance (var , Var ) and isinstance (var .type , PartialType ):
4828+ return var , var .type
4829+ return None
4830+
48044831 def replace_partial_type (
48054832 self , var : Var , new_type : Type , partial_types : dict [Var , Context ]
48064833 ) -> None :
@@ -4819,41 +4846,37 @@ def replace_partial_type(
48194846 def try_infer_partial_type_from_indexed_assignment (
48204847 self , lvalue : IndexExpr , rvalue : Expression
48214848 ) -> None :
4822- # TODO: Should we share some of this with try_infer_partial_type?
4823- var = None
4824- if isinstance (lvalue .base , RefExpr ) and isinstance (lvalue .base .node , Var ):
4825- var = lvalue .base .node
4826- elif isinstance (lvalue .base , MemberExpr ):
4827- var = self .expr_checker .get_partial_self_var (lvalue .base )
4828- if isinstance (var , Var ):
4829- if isinstance (var .type , PartialType ):
4830- type_type = var .type .type
4831- if type_type is None :
4832- return # The partial type is None.
4833- partial_types = self .find_partial_types (var )
4834- if partial_types is None :
4835- return
4836- typename = type_type .fullname
4837- if (
4838- typename == "builtins.dict"
4839- or typename == "collections.OrderedDict"
4840- or typename == "collections.defaultdict"
4841- ):
4842- # TODO: Don't infer things twice.
4843- key_type = self .expr_checker .accept (lvalue .index )
4844- value_type = self .expr_checker .accept (rvalue )
4845- if (
4846- is_valid_inferred_type (key_type , self .options )
4847- and is_valid_inferred_type (value_type , self .options )
4848- and not self .current_node_deferred
4849- and not (
4850- typename == "collections.defaultdict"
4851- and var .type .value_type is not None
4852- and not is_equivalent (value_type , var .type .value_type )
4853- )
4854- ):
4855- new_type = self .named_generic_type (typename , [key_type , value_type ])
4856- self .replace_partial_type (var , new_type , partial_types )
4849+ target = self ._get_partial_assignment_target (lvalue .base )
4850+ if target is None :
4851+ return
4852+ var , partial = target
4853+ type_info = partial .type
4854+ if type_info is None :
4855+ return # The partial type is None.
4856+ partial_types = self .find_partial_types (var )
4857+ if partial_types is None :
4858+ return
4859+ typename = type_info .fullname
4860+ if typename not in (
4861+ "builtins.dict" ,
4862+ "collections.OrderedDict" ,
4863+ "collections.defaultdict" ,
4864+ ):
4865+ return
4866+ # TODO: Don't infer things twice.
4867+ key_type = self .expr_checker .accept (lvalue .index )
4868+ value_type = self .expr_checker .accept (rvalue )
4869+ if not (
4870+ is_valid_inferred_type (key_type , self .options )
4871+ and is_valid_inferred_type (value_type , self .options )
4872+ and not self .current_node_deferred
4873+ ):
4874+ return
4875+ if typename == "collections.defaultdict" and partial .value_type is not None :
4876+ if not is_equivalent (value_type , partial .value_type ):
4877+ return
4878+ new_type = self .named_generic_type (typename , [key_type , value_type ])
4879+ self .replace_partial_type (var , new_type , partial_types )
48574880
48584881 def type_requires_usage (self , typ : Type ) -> tuple [str , ErrorCode ] | None :
48594882 """Some types require usage in all cases. The classic example is
@@ -5069,13 +5092,26 @@ def visit_while_stmt(self, s: WhileStmt) -> None:
50695092
50705093 def visit_operator_assignment_stmt (self , s : OperatorAssignmentStmt ) -> None :
50715094 """Type check an operator assignment statement, e.g. x += 1."""
5072- self .try_infer_partial_generic_type_from_assignment (s .lvalue , s .rvalue , s .op )
5095+ self .try_resolve_partial_type_from_assignment (
5096+ s .lvalue , None , s .rvalue , s .op , pre_check = True
5097+ )
50735098 if isinstance (s .lvalue , MemberExpr ):
50745099 # Special case, some additional errors may be given for
50755100 # assignments to read-only or final attributes.
50765101 lvalue_type = self .expr_checker .visit_member_expr (s .lvalue , True )
50775102 else :
50785103 lvalue_type = self .expr_checker .accept (s .lvalue )
5104+ (
5105+ resolved_lvalue_type ,
5106+ _ ,
5107+ should_return ,
5108+ _ ,
5109+ ) = self .try_resolve_partial_type_from_assignment (s .lvalue , lvalue_type , s .rvalue , s .op )
5110+ if resolved_lvalue_type is not None :
5111+ lvalue_type = resolved_lvalue_type
5112+ if should_return :
5113+ self .check_final (s )
5114+ return
50795115 inplace , method = infer_operator_assignment_method (lvalue_type , s .op )
50805116 if inplace :
50815117 # There is __ifoo__, treat as x = x.__ifoo__(y)
0 commit comments