1- import os
21from collections .abc import AsyncGenerator
32from typing import Optional , TypedDict , Union
43
5- from openai import AsyncAzureOpenAI , AsyncOpenAI , AsyncStream
6- from openai .types .chat import ChatCompletionChunk , ChatCompletionMessageParam
7- from openai_messages_token_helper import get_token_limit
4+ from openai import AsyncAzureOpenAI , AsyncOpenAI
5+ from openai .types .chat import ChatCompletionMessageParam
86from pydantic_ai import Agent , RunContext
97from pydantic_ai .messages import ModelMessagesTypeAdapter
108from pydantic_ai .models .openai import OpenAIModel
1311
1412from fastapi_app .api_models import (
1513 AIChatRoles ,
14+ ChatRequestOverrides ,
1615 ItemPublic ,
1716 Message ,
1817 RAGContext ,
1918 RetrievalResponse ,
2019 RetrievalResponseDelta ,
2120 ThoughtStep ,
2221)
23- from fastapi_app .postgres_models import Item
2422from fastapi_app .postgres_searcher import PostgresSearcher
2523from fastapi_app .rag_base import ChatParams , RAGChatBase
2624
27- # Experiment #1: Annotated did not work!
28- # Experiment #2: Function-level docstring, Inline docstrings next to attributes
29- # Function -level docstring leads to XML like this: <summary>Search ...
30- # Experiment #3: Move the docstrings below the attributes in triple-quoted strings - SUCCESS!!!
31-
3225
3326class PriceFilter (TypedDict ):
3427 column : str = "price"
@@ -64,19 +57,44 @@ class SearchResults(TypedDict):
6457
6558
6659class AdvancedRAGChat (RAGChatBase ):
60+ query_prompt_template = open (RAGChatBase .prompts_dir / "query.txt" ).read ()
61+ query_fewshots = open (RAGChatBase .prompts_dir / "query_fewshots.json" ).read ()
62+
6763 def __init__ (
6864 self ,
6965 * ,
66+ messages : list [ChatCompletionMessageParam ],
67+ overrides : ChatRequestOverrides ,
7068 searcher : PostgresSearcher ,
7169 openai_chat_client : Union [AsyncOpenAI , AsyncAzureOpenAI ],
7270 chat_model : str ,
7371 chat_deployment : Optional [str ], # Not needed for non-Azure OpenAI
7472 ):
7573 self .searcher = searcher
76- self .openai_chat_client = openai_chat_client
77- self .chat_model = chat_model
78- self .chat_deployment = chat_deployment
79- self .chat_token_limit = get_token_limit (chat_model , default_to_minimum = True )
74+ self .chat_params = self .get_chat_params (messages , overrides )
75+ self .model_for_thoughts = (
76+ {"model" : chat_model , "deployment" : chat_deployment } if chat_deployment else {"model" : chat_model }
77+ )
78+ pydantic_chat_model = OpenAIModel (
79+ chat_model if chat_deployment is None else chat_deployment ,
80+ provider = OpenAIProvider (openai_client = openai_chat_client ),
81+ )
82+ self .search_agent = Agent (
83+ pydantic_chat_model ,
84+ model_settings = ModelSettings (temperature = 0.0 , max_tokens = 500 , seed = self .chat_params .seed ),
85+ system_prompt = self .query_prompt_template ,
86+ tools = [self .search_database ],
87+ output_type = SearchResults ,
88+ )
89+ self .answer_agent = Agent (
90+ pydantic_chat_model ,
91+ system_prompt = self .answer_prompt_template ,
92+ model_settings = ModelSettings (
93+ temperature = self .chat_params .temperature ,
94+ max_tokens = self .chat_params .response_token_limit ,
95+ seed = self .chat_params .seed ,
96+ ),
97+ )
8098
8199 async def search_database (
82100 self ,
@@ -113,42 +131,28 @@ async def search_database(
113131 query = search_query , items = [ItemPublic .model_validate (item .to_dict ()) for item in results ], filters = filters
114132 )
115133
116- async def prepare_context (self , chat_params : ChatParams ) -> tuple [list [ItemPublic ], list [ThoughtStep ]]:
117- model = OpenAIModel (
118- os .environ ["AZURE_OPENAI_CHAT_DEPLOYMENT" ], provider = OpenAIProvider (openai_client = self .openai_chat_client )
119- )
120- agent = Agent (
121- model ,
122- model_settings = ModelSettings (temperature = 0.0 , max_tokens = 500 , seed = chat_params .seed ),
123- system_prompt = self .query_prompt_template ,
124- tools = [self .search_database ],
125- output_type = SearchResults ,
126- )
134+ async def prepare_context (self ) -> tuple [list [ItemPublic ], list [ThoughtStep ]]:
127135 few_shots = ModelMessagesTypeAdapter .validate_json (self .query_fewshots )
128- user_query = f"Find search results for user query: { chat_params .original_user_query } "
129- results = await agent .run (
136+ user_query = f"Find search results for user query: { self . chat_params .original_user_query } "
137+ results = await self . search_agent .run (
130138 user_query ,
131- message_history = few_shots + chat_params .past_messages ,
132- deps = chat_params ,
139+ message_history = few_shots + self . chat_params .past_messages ,
140+ deps = self . chat_params ,
133141 )
134142 items = results .output ["items" ]
135143 thoughts = [
136144 ThoughtStep (
137145 title = "Prompt to generate search arguments" ,
138146 description = results .all_messages (),
139- props = (
140- {"model" : self .chat_model , "deployment" : self .chat_deployment }
141- if self .chat_deployment
142- else {"model" : self .chat_model } # TODO
143- ),
147+ props = self .model_for_thoughts ,
144148 ),
145149 ThoughtStep (
146150 title = "Search using generated search arguments" ,
147151 description = results .output ["query" ],
148152 props = {
149- "top" : chat_params .top ,
150- "vector_search" : chat_params .enable_vector_search ,
151- "text_search" : chat_params .enable_text_search ,
153+ "top" : self . chat_params .top ,
154+ "vector_search" : self . chat_params .enable_vector_search ,
155+ "text_search" : self . chat_params .enable_text_search ,
152156 "filters" : results .output ["filters" ],
153157 },
154158 ),
@@ -161,25 +165,12 @@ async def prepare_context(self, chat_params: ChatParams) -> tuple[list[ItemPubli
161165
162166 async def answer (
163167 self ,
164- chat_params : ChatParams ,
165168 items : list [ItemPublic ],
166169 earlier_thoughts : list [ThoughtStep ],
167170 ) -> RetrievalResponse :
168- agent = Agent (
169- OpenAIModel (
170- os .environ ["AZURE_OPENAI_CHAT_DEPLOYMENT" ],
171- provider = OpenAIProvider (openai_client = self .openai_chat_client ),
172- ),
173- system_prompt = self .answer_prompt_template ,
174- model_settings = ModelSettings (
175- temperature = chat_params .temperature , max_tokens = chat_params .response_token_limit , seed = chat_params .seed
176- ),
177- )
178-
179- sources_content = [f"[{ (item .id )} ]:{ item .to_str_for_rag ()} \n \n " for item in items ]
180- response = await agent .run (
181- user_prompt = chat_params .original_user_query + "Sources:\n " + "\n " .join (sources_content ),
182- message_history = chat_params .past_messages ,
171+ response = await self .answer_agent .run (
172+ user_prompt = self .prepare_rag_request (self .chat_params .original_user_query , items ),
173+ message_history = self .chat_params .past_messages ,
183174 )
184175
185176 return RetrievalResponse (
@@ -191,57 +182,35 @@ async def answer(
191182 ThoughtStep (
192183 title = "Prompt to generate answer" ,
193184 description = response .all_messages (),
194- props = (
195- {"model" : self .chat_model , "deployment" : self .chat_deployment }
196- if self .chat_deployment
197- else {"model" : self .chat_model }
198- ),
185+ props = self .model_for_thoughts ,
199186 ),
200187 ],
201188 ),
202189 )
203190
204191 async def answer_stream (
205192 self ,
206- chat_params : ChatParams ,
207- contextual_messages : list [ChatCompletionMessageParam ],
208- results : list [Item ],
193+ items : list [ItemPublic ],
209194 earlier_thoughts : list [ThoughtStep ],
210195 ) -> AsyncGenerator [RetrievalResponseDelta , None ]:
211- chat_completion_async_stream : AsyncStream [
212- ChatCompletionChunk
213- ] = await self .openai_chat_client .chat .completions .create (
214- # Azure OpenAI takes the deployment name as the model name
215- model = self .chat_deployment if self .chat_deployment else self .chat_model ,
216- messages = contextual_messages ,
217- temperature = chat_params .temperature ,
218- max_tokens = chat_params .response_token_limit ,
219- n = 1 ,
220- stream = True ,
221- )
222-
223- yield RetrievalResponseDelta (
224- context = RAGContext (
225- data_points = {item .id : item .to_dict () for item in results },
226- thoughts = earlier_thoughts
227- + [
228- ThoughtStep (
229- title = "Prompt to generate answer" ,
230- description = contextual_messages ,
231- props = (
232- {"model" : self .chat_model , "deployment" : self .chat_deployment }
233- if self .chat_deployment
234- else {"model" : self .chat_model }
196+ async with self .answer_agent .run_stream (
197+ self .prepare_rag_request (self .chat_params .original_user_query , items ),
198+ message_history = self .chat_params .past_messages ,
199+ ) as agent_stream_runner :
200+ yield RetrievalResponseDelta (
201+ context = RAGContext (
202+ data_points = {item .id : item for item in items },
203+ thoughts = earlier_thoughts
204+ + [
205+ ThoughtStep (
206+ title = "Prompt to generate answer" ,
207+ description = agent_stream_runner .all_messages (),
208+ props = self .model_for_thoughts ,
235209 ),
236- ),
237- ],
238- ),
239- )
210+ ],
211+ ),
212+ )
240213
241- async for response_chunk in chat_completion_async_stream :
242- # first response has empty choices and last response has empty content
243- if response_chunk .choices and response_chunk .choices [0 ].delta .content :
244- yield RetrievalResponseDelta (
245- delta = Message (content = str (response_chunk .choices [0 ].delta .content ), role = AIChatRoles .ASSISTANT )
246- )
247- return
214+ async for message in agent_stream_runner .stream_text (delta = True , debounce_by = None ):
215+ yield RetrievalResponseDelta (delta = Message (content = str (message ), role = AIChatRoles .ASSISTANT ))
216+ return
0 commit comments