@@ -698,22 +698,41 @@ def test_query_list_eq_numeric_comparison() -> None:
698698 ) # Should not be equal since values are different
699699
700700
701- def test_keygetter_nested_objects () -> None :
702- """Test keygetter function with nested objects."""
701+ @dataclasses .dataclass
702+ class Food (t .Mapping [str , t .Any ]):
703+ fruit : list [str ] = dataclasses .field (default_factory = list )
704+ breakfast : str | None = None
703705
704- @dataclasses .dataclass
705- class Food :
706- fruit : list [str ] = dataclasses .field (default_factory = list )
707- breakfast : str | None = None
706+ def __getitem__ (self , key : str ) -> t .Any :
707+ return getattr (self , key )
708708
709- @dataclasses .dataclass
710- class Restaurant :
711- place : str
712- city : str
713- state : str
714- food : Food = dataclasses .field (default_factory = Food )
709+ def __iter__ (self ) -> t .Iterator [str ]:
710+ return iter (self .__dataclass_fields__ )
711+
712+ def __len__ (self ) -> int :
713+ return len (self .__dataclass_fields__ )
714+
715+
716+ @dataclasses .dataclass
717+ class Restaurant (t .Mapping [str , t .Any ]):
718+ place : str
719+ city : str
720+ state : str
721+ food : Food = dataclasses .field (default_factory = Food )
722+
723+ def __getitem__ (self , key : str ) -> t .Any :
724+ return getattr (self , key )
715725
716- # Test with nested dataclass
726+ def __iter__ (self ) -> t .Iterator [str ]:
727+ return iter (self .__dataclass_fields__ )
728+
729+ def __len__ (self ) -> int :
730+ return len (self .__dataclass_fields__ )
731+
732+
733+ def test_keygetter_nested_objects () -> None :
734+ """Test keygetter function with nested objects."""
735+ # Test with nested dataclass that implements Mapping protocol
717736 restaurant = Restaurant (
718737 place = "Largo" ,
719738 city = "Tampa" ,
@@ -736,7 +755,9 @@ class Restaurant:
736755
737756 # Test with non-mapping object (returns the object itself)
738757 non_mapping = "not a mapping"
739- assert keygetter (non_mapping , "any_key" ) == non_mapping # type: ignore
758+ assert (
759+ keygetter (t .cast (t .Mapping [str , t .Any ], non_mapping ), "any_key" ) == non_mapping
760+ )
740761
741762
742763def test_query_list_slicing () -> None :
@@ -773,24 +794,33 @@ def test_query_list_attributes() -> None:
773794
774795 # Test pk_key attribute with objects
775796 @dataclasses .dataclass
776- class Item :
797+ class Item ( t . Mapping [ str , t . Any ]) :
777798 id : str
778799 value : int
779800
801+ def __getitem__ (self , key : str ) -> t .Any :
802+ return getattr (self , key )
803+
804+ def __iter__ (self ) -> t .Iterator [str ]:
805+ return iter (self .__dataclass_fields__ )
806+
807+ def __len__ (self ) -> int :
808+ return len (self .__dataclass_fields__ )
809+
780810 items = [Item ("1" , 1 ), Item ("2" , 2 )]
781- ql = QueryList (items )
782- ql .pk_key = "id"
783- assert ql .items () == [("1" , items [0 ]), ("2" , items [1 ])]
811+ ql_items : QueryList [ t . Any ] = QueryList (items )
812+ ql_items .pk_key = "id"
813+ assert list ( ql_items .items () ) == [("1" , items [0 ]), ("2" , items [1 ])]
784814
785815 # Test pk_key with non-existent attribute
786- ql .pk_key = "nonexistent"
816+ ql_items .pk_key = "nonexistent"
787817 with pytest .raises (AttributeError ):
788- ql .items ()
818+ ql_items .items ()
789819
790820 # Test pk_key with None
791- ql .pk_key = None
821+ ql_items .pk_key = None
792822 with pytest .raises (PKRequiredException ):
793- ql .items ()
823+ ql_items .items ()
794824
795825
796826def test_lookup_name_map () -> None :
0 commit comments