11from enum import Enum
2- from typing import Any , Optional
2+ from typing import Any , Optional , Union
33
44from openai .types .chat import ChatCompletionMessageParam
5- from pydantic import BaseModel
5+ from pydantic import BaseModel , Field
6+ from pydantic_ai .messages import ModelRequest , ModelResponse
67
78
89class AIChatRoles (str , Enum ):
@@ -40,6 +41,30 @@ class ChatRequest(BaseModel):
4041 context : ChatRequestContext
4142 sessionState : Optional [Any ] = None
4243
44+
45+ class ItemPublic (BaseModel ):
46+ id : int
47+ name : str
48+ location : str
49+ cuisine : str
50+ rating : int
51+ price_level : int
52+ review_count : int
53+ hours : int
54+ tags : str
55+ description : str
56+ menu_summary : str
57+ top_reviews : str
58+ vibe : str
59+
60+
61+ class ItemWithDistance (ItemPublic ):
62+ distance : float
63+
64+ def __init__ (self , ** data ):
65+ super ().__init__ (** data )
66+ self .distance = round (self .distance , 2 )
67+
4368
4469class ThoughtStep (BaseModel ):
4570 title : str
@@ -48,7 +73,7 @@ class ThoughtStep(BaseModel):
4873
4974
5075class RAGContext (BaseModel ):
51- data_points : dict [int , dict [ str , Any ] ]
76+ data_points : dict [int , ItemPublic ]
5277 thoughts : list [ThoughtStep ]
5378 followup_questions : Optional [list [str ]] = None
5479
@@ -69,34 +94,39 @@ class RetrievalResponseDelta(BaseModel):
6994 sessionState : Optional [Any ] = None
7095
7196
72- class ItemPublic (BaseModel ):
73- id : int
74- name : str
75- location : str
76- cuisine : str
77- rating : int
78- price_level : int
79- review_count : int
80- hours : int
81- tags : str
82- description : str
83- menu_summary : str
84- top_reviews : str
85- vibe : str
86-
87-
88- class ItemWithDistance (ItemPublic ):
89- distance : float
90-
91- def __init__ (self , ** data ):
92- super ().__init__ (** data )
93- self .distance = round (self .distance , 2 )
94-
95-
9697class ChatParams (ChatRequestOverrides ):
9798 prompt_template : str
9899 response_token_limit : int = 1024
99100 enable_text_search : bool
100101 enable_vector_search : bool
101102 original_user_query : str
102- past_messages : list [ChatCompletionMessageParam ]
103+ past_messages : list [Union [ModelRequest , ModelResponse ]]
104+
105+
106+ class Filter (BaseModel ):
107+ column : str
108+ comparison_operator : str
109+ value : Any
110+
111+
112+ class PriceLevelFilter (Filter ):
113+ column : str = Field (default = "price_level" , description = "The column to filter on (always 'price_level' for this filter)" )
114+ comparison_operator : str = Field (description = "The operator for price level comparison ('>', '<', '>=', '<=', '=')" )
115+ value : float = Field (description = "Value to compare against, either 1, 2, 3, 4" )
116+
117+
118+ class RatingFilter (Filter ):
119+ column : str = Field (default = "rating" , description = "The column to filter on (always 'rating' for this filter)" )
120+ comparison_operator : str = Field (description = "The operator for rating comparison ('>', '<', '>=', '<=', '=')" )
121+ value : str = Field (description = "Value to compare against, either 0 1 2 3 4" )
122+
123+
124+ class SearchResults (BaseModel ):
125+ query : str
126+ """The original search query"""
127+
128+ items : list [ItemPublic ]
129+ """List of items that match the search query and filters"""
130+
131+ filters : list [Filter ]
132+ """List of filters applied to the search results"""
0 commit comments