11from collections .abc import AsyncGenerator
2- from typing import Optional , TypedDict , Union
2+ from typing import Optional , Union
33
44from openai import AsyncAzureOpenAI , AsyncOpenAI
55from openai .types .chat import ChatCompletionMessageParam
1111
1212from fastapi_app .api_models import (
1313 AIChatRoles ,
14+ BrandFilter ,
1415 ChatRequestOverrides ,
16+ Filter ,
1517 ItemPublic ,
1618 Message ,
19+ PriceFilter ,
1720 RAGContext ,
1821 RetrievalResponse ,
1922 RetrievalResponseDelta ,
23+ SearchResults ,
2024 ThoughtStep ,
2125)
2226from fastapi_app .postgres_searcher import PostgresSearcher
2327from fastapi_app .rag_base import ChatParams , RAGChatBase
2428
2529
26- class PriceFilter (TypedDict ):
27- column : str = "price"
28- """The column to filter on (always 'price' for this filter)"""
29-
30- comparison_operator : str
31- """The operator for price comparison ('>', '<', '>=', '<=', '=')"""
32-
33- value : float
34- """ The price value to compare against (e.g., 30.00) """
35-
36-
37- class BrandFilter (TypedDict ):
38- column : str = "brand"
39- """The column to filter on (always 'brand' for this filter)"""
40-
41- comparison_operator : str
42- """The operator for brand comparison ('=' or '!=')"""
43-
44- value : str
45- """The brand name to compare against (e.g., 'AirStrider')"""
46-
47-
48- class SearchResults (TypedDict ):
49- query : str
50- """The original search query"""
51-
52- items : list [ItemPublic ]
53- """List of items that match the search query and filters"""
54-
55- filters : list [Union [PriceFilter , BrandFilter ]]
56- """List of filters applied to the search results"""
57-
58-
5930class AdvancedRAGChat (RAGChatBase ):
6031 query_prompt_template = open (RAGChatBase .prompts_dir / "query.txt" ).read ()
6132 query_fewshots = open (RAGChatBase .prompts_dir / "query_fewshots.json" ).read ()
@@ -79,9 +50,13 @@ def __init__(
7950 chat_model if chat_deployment is None else chat_deployment ,
8051 provider = OpenAIProvider (openai_client = openai_chat_client ),
8152 )
82- self .search_agent = Agent (
53+ self .search_agent = Agent [ ChatParams , SearchResults ] (
8354 pydantic_chat_model ,
84- model_settings = ModelSettings (temperature = 0.0 , max_tokens = 500 , seed = self .chat_params .seed ),
55+ model_settings = ModelSettings (
56+ temperature = 0.0 ,
57+ max_tokens = 500 ,
58+ ** ({"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {}),
59+ ),
8560 system_prompt = self .query_prompt_template ,
8661 tools = [self .search_database ],
8762 output_type = SearchResults ,
@@ -92,7 +67,7 @@ def __init__(
9267 model_settings = ModelSettings (
9368 temperature = self .chat_params .temperature ,
9469 max_tokens = self .chat_params .response_token_limit ,
95- seed = self .chat_params .seed ,
70+ ** ({ " seed" : self .chat_params .seed } if self . chat_params . seed is not None else {}) ,
9671 ),
9772 )
9873
@@ -115,7 +90,7 @@ async def search_database(
11590 List of formatted items that match the search query and filters
11691 """
11792 # Only send non-None filters
118- filters = []
93+ filters : list [ Filter ] = []
11994 if price_filter :
12095 filters .append (price_filter )
12196 if brand_filter :
@@ -134,12 +109,12 @@ async def search_database(
134109 async def prepare_context (self ) -> tuple [list [ItemPublic ], list [ThoughtStep ]]:
135110 few_shots = ModelMessagesTypeAdapter .validate_json (self .query_fewshots )
136111 user_query = f"Find search results for user query: { self .chat_params .original_user_query } "
137- results = await self .search_agent .run (
112+ results = await self .search_agent .run ( # type: ignore[call-overload]
138113 user_query ,
139114 message_history = few_shots + self .chat_params .past_messages ,
140115 deps = self .chat_params ,
141116 )
142- items = results .output [ " items" ]
117+ items = results .output . items
143118 thoughts = [
144119 ThoughtStep (
145120 title = "Prompt to generate search arguments" ,
@@ -148,12 +123,12 @@ async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
148123 ),
149124 ThoughtStep (
150125 title = "Search using generated search arguments" ,
151- description = results .output [ " query" ] ,
126+ description = results .output . query ,
152127 props = {
153128 "top" : self .chat_params .top ,
154129 "vector_search" : self .chat_params .enable_vector_search ,
155130 "text_search" : self .chat_params .enable_text_search ,
156- "filters" : results .output [ " filters" ] ,
131+ "filters" : results .output . filters ,
157132 },
158133 ),
159134 ThoughtStep (
0 commit comments