1+ import json
12from collections .abc import AsyncGenerator
23from typing import Optional , Union
34
5+ from agents import (
6+ Agent ,
7+ ItemHelpers ,
8+ ModelSettings ,
9+ OpenAIChatCompletionsModel ,
10+ Runner ,
11+ ToolCallOutputItem ,
12+ function_tool ,
13+ set_tracing_disabled ,
14+ )
415from openai import AsyncAzureOpenAI , AsyncOpenAI
5- from openai .types .chat import ChatCompletionMessageParam
6- from pydantic_ai import Agent , RunContext
7- from pydantic_ai .messages import ModelMessagesTypeAdapter
8- from pydantic_ai .models .openai import OpenAIModel
9- from pydantic_ai .providers .openai import OpenAIProvider
10- from pydantic_ai .settings import ModelSettings
16+ from openai .types .responses import EasyInputMessageParam , ResponseInputItemParam , ResponseTextDeltaEvent
1117
1218from fastapi_app .api_models import (
1319 AIChatRoles ,
2430 ThoughtStep ,
2531)
2632from fastapi_app .postgres_searcher import PostgresSearcher
27- from fastapi_app .rag_base import ChatParams , RAGChatBase
33+ from fastapi_app .rag_base import RAGChatBase
34+
35+ set_tracing_disabled (disabled = True )
2836
2937
3038class AdvancedRAGChat (RAGChatBase ):
@@ -34,7 +42,7 @@ class AdvancedRAGChat(RAGChatBase):
3442 def __init__ (
3543 self ,
3644 * ,
37- messages : list [ChatCompletionMessageParam ],
45+ messages : list [ResponseInputItemParam ],
3846 overrides : ChatRequestOverrides ,
3947 searcher : PostgresSearcher ,
4048 openai_chat_client : Union [AsyncOpenAI , AsyncAzureOpenAI ],
@@ -46,34 +54,29 @@ def __init__(
4654 self .model_for_thoughts = (
4755 {"model" : chat_model , "deployment" : chat_deployment } if chat_deployment else {"model" : chat_model }
4856 )
49- pydantic_chat_model = OpenAIModel (
50- chat_model if chat_deployment is None else chat_deployment ,
51- provider = OpenAIProvider (openai_client = openai_chat_client ),
57+ openai_agents_model = OpenAIChatCompletionsModel (
58+ model = chat_model if chat_deployment is None else chat_deployment , openai_client = openai_chat_client
5259 )
53- self .search_agent = Agent [ChatParams , SearchResults ](
54- pydantic_chat_model ,
55- model_settings = ModelSettings (
56- temperature = 0.0 ,
57- max_tokens = 500 ,
58- ** ({"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {}),
59- ),
60- system_prompt = self .query_prompt_template ,
61- tools = [self .search_database ],
62- output_type = SearchResults ,
60+ self .search_agent = Agent (
61+ name = "Searcher" ,
62+ instructions = self .query_prompt_template ,
63+ tools = [function_tool (self .search_database )],
64+ tool_use_behavior = "stop_on_first_tool" ,
65+ model = openai_agents_model ,
6366 )
6467 self .answer_agent = Agent (
65- pydantic_chat_model ,
66- system_prompt = self .answer_prompt_template ,
68+ name = "Answerer" ,
69+ instructions = self .answer_prompt_template ,
70+ model = openai_agents_model ,
6771 model_settings = ModelSettings (
6872 temperature = self .chat_params .temperature ,
6973 max_tokens = self .chat_params .response_token_limit ,
70- ** ( {"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {}) ,
74+ extra_body = {"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {},
7175 ),
7276 )
7377
7478 async def search_database (
7579 self ,
76- ctx : RunContext [ChatParams ],
7780 search_query : str ,
7881 price_filter : Optional [PriceFilter ] = None ,
7982 brand_filter : Optional [BrandFilter ] = None ,
@@ -97,66 +100,73 @@ async def search_database(
97100 filters .append (brand_filter )
98101 results = await self .searcher .search_and_embed (
99102 search_query ,
100- top = ctx . deps .top ,
101- enable_vector_search = ctx . deps .enable_vector_search ,
102- enable_text_search = ctx . deps .enable_text_search ,
103+ top = self . chat_params .top ,
104+ enable_vector_search = self . chat_params .enable_vector_search ,
105+ enable_text_search = self . chat_params .enable_text_search ,
103106 filters = filters ,
104107 )
105108 return SearchResults (
106109 query = search_query , items = [ItemPublic .model_validate (item .to_dict ()) for item in results ], filters = filters
107110 )
108111
109112 async def prepare_context (self ) -> tuple [list [ItemPublic ], list [ThoughtStep ]]:
110- few_shots = ModelMessagesTypeAdapter . validate_json (self .query_fewshots )
113+ few_shots : list [ ResponseInputItemParam ] = json . loads (self .query_fewshots )
111114 user_query = f"Find search results for user query: { self .chat_params .original_user_query } "
112- results = await self .search_agent .run (
113- user_query ,
114- message_history = few_shots + self .chat_params .past_messages ,
115- deps = self .chat_params ,
116- )
117- items = results .output .items
115+ new_user_message = EasyInputMessageParam (role = "user" , content = user_query )
116+ all_messages = few_shots + self .chat_params .past_messages + [new_user_message ]
117+
118+ run_results = await Runner .run (self .search_agent , input = all_messages )
119+ most_recent_response = run_results .new_items [- 1 ]
120+ if isinstance (most_recent_response , ToolCallOutputItem ):
121+ search_results = most_recent_response .output
122+ else :
123+ raise ValueError ("Error retrieving search results, model did not call tool properly" )
124+
118125 thoughts = [
119126 ThoughtStep (
120127 title = "Prompt to generate search arguments" ,
121- description = results .all_messages (),
128+ description = [{"content" : self .query_prompt_template }]
129+ + ItemHelpers .input_to_new_input_list (run_results .input ),
122130 props = self .model_for_thoughts ,
123131 ),
124132 ThoughtStep (
125133 title = "Search using generated search arguments" ,
126- description = results . output .query ,
134+ description = search_results .query ,
127135 props = {
128136 "top" : self .chat_params .top ,
129137 "vector_search" : self .chat_params .enable_vector_search ,
130138 "text_search" : self .chat_params .enable_text_search ,
131- "filters" : results . output .filters ,
139+ "filters" : search_results .filters ,
132140 },
133141 ),
134142 ThoughtStep (
135143 title = "Search results" ,
136- description = items ,
144+ description = search_results . items ,
137145 ),
138146 ]
139- return items , thoughts
147+ return search_results . items , thoughts
140148
141149 async def answer (
142150 self ,
143151 items : list [ItemPublic ],
144152 earlier_thoughts : list [ThoughtStep ],
145153 ) -> RetrievalResponse :
146- response = await self .answer_agent .run (
147- user_prompt = self .prepare_rag_request (self .chat_params .original_user_query , items ),
148- message_history = self .chat_params .past_messages ,
154+ run_results = await Runner .run (
155+ self .answer_agent ,
156+ input = self .chat_params .past_messages
157+ + [{"content" : self .prepare_rag_request (self .chat_params .original_user_query , items ), "role" : "user" }],
149158 )
150159
151160 return RetrievalResponse (
152- message = Message (content = str (response . output ), role = AIChatRoles .ASSISTANT ),
161+ message = Message (content = str (run_results . final_output ), role = AIChatRoles .ASSISTANT ),
153162 context = RAGContext (
154163 data_points = {item .id : item for item in items },
155164 thoughts = earlier_thoughts
156165 + [
157166 ThoughtStep (
158167 title = "Prompt to generate answer" ,
159- description = response .all_messages (),
168+ description = [{"content" : self .answer_prompt_template }]
169+ + ItemHelpers .input_to_new_input_list (run_results .input ),
160170 props = self .model_for_thoughts ,
161171 ),
162172 ],
@@ -168,24 +178,28 @@ async def answer_stream(
168178 items : list [ItemPublic ],
169179 earlier_thoughts : list [ThoughtStep ],
170180 ) -> AsyncGenerator [RetrievalResponseDelta , None ]:
171- async with self .answer_agent .run_stream (
172- self .prepare_rag_request (self .chat_params .original_user_query , items ),
173- message_history = self .chat_params .past_messages ,
174- ) as agent_stream_runner :
175- yield RetrievalResponseDelta (
176- context = RAGContext (
177- data_points = {item .id : item for item in items },
178- thoughts = earlier_thoughts
179- + [
180- ThoughtStep (
181- title = "Prompt to generate answer" ,
182- description = agent_stream_runner .all_messages (),
183- props = self .model_for_thoughts ,
184- ),
185- ],
186- ),
187- )
188-
189- async for message in agent_stream_runner .stream_text (delta = True , debounce_by = None ):
190- yield RetrievalResponseDelta (delta = Message (content = str (message ), role = AIChatRoles .ASSISTANT ))
191- return
181+ run_results = Runner .run_streamed (
182+ self .answer_agent ,
183+ input = self .chat_params .past_messages
184+ + [{"content" : self .prepare_rag_request (self .chat_params .original_user_query , items ), "role" : "user" }], # noqa
185+ )
186+
187+ yield RetrievalResponseDelta (
188+ context = RAGContext (
189+ data_points = {item .id : item for item in items },
190+ thoughts = earlier_thoughts
191+ + [
192+ ThoughtStep (
193+ title = "Prompt to generate answer" ,
194+ description = [{"content" : self .answer_prompt_template }]
195+ + ItemHelpers .input_to_new_input_list (run_results .input ),
196+ props = self .model_for_thoughts ,
197+ ),
198+ ],
199+ ),
200+ )
201+
202+ async for event in run_results .stream_events ():
203+ if event .type == "raw_response_event" and isinstance (event .data , ResponseTextDeltaEvent ):
204+ yield RetrievalResponseDelta (delta = Message (content = str (event .data .delta ), role = AIChatRoles .ASSISTANT ))
205+ return
0 commit comments