@@ -12,39 +12,39 @@ def build_search_function() -> list[ChatCompletionToolParam]:
1212 "type" : "function" ,
1313 "function" : {
1414 "name" : "search_database" ,
15- "description" : "Search PostgreSQL database for relevant products based on user query" ,
15+ "description" : "Search PostgreSQL database for relevant restaurants based on user query" ,
1616 "parameters" : {
1717 "type" : "object" ,
1818 "properties" : {
1919 "search_query" : {
2020 "type" : "string" ,
2121 "description" : "Query string to use for full text search, e.g. 'red shoes'" ,
2222 },
23- "price_filter " : {
23+ "price_level_filter " : {
2424 "type" : "object" ,
25- "description" : "Filter search results based on price of the product" ,
25+ "description" : "Filter search results to a certain price level (from 1 $ to 4 $$$$, with 4 being most costly)" , # noqa: E501
2626 "properties" : {
2727 "comparison_operator" : {
2828 "type" : "string" ,
29- "description" : "Operator to compare the column value, either '>', '<', '>=', '<=', '='" , # noqa
29+ "description" : "Operator to compare the column value, either '>', '<', '>=', '<=', '='" , # noqa: E501
3030 },
3131 "value" : {
3232 "type" : "number" ,
33- "description" : "Value to compare against, e.g. 30 " ,
33+ "description" : "Value to compare against, either 1, 2, 3, 4 " ,
3434 },
3535 },
3636 },
37- "brand_filter " : {
37+ "rating_filter " : {
3838 "type" : "object" ,
39- "description" : "Filter search results based on brand of the product" ,
39+ "description" : "Filter search results based on ratings of restaurant (from 1 to 5 stars, with 5 the best)" , # noqa: E501
4040 "properties" : {
4141 "comparison_operator" : {
4242 "type" : "string" ,
43- "description" : "Operator to compare the column value, either '=' or '! ='" ,
43+ "description" : "Operator to compare the column value, either '>', '<', '>=', '<=', ' ='" , # noqa: E501
4444 },
4545 "value" : {
4646 "type" : "string" ,
47- "description" : "Value to compare against, e.g. AirStrider " ,
47+ "description" : "Value to compare against, either 0 1 2 3 4 5 " ,
4848 },
4949 },
5050 },
@@ -69,22 +69,26 @@ def extract_search_arguments(original_user_query: str, chat_completion: ChatComp
6969 arg = json .loads (function .arguments )
7070 # Even though its required, search_query is not always specified
7171 search_query = arg .get ("search_query" , original_user_query )
72- if "price_filter" in arg and arg ["price_filter" ] and isinstance (arg ["price_filter" ], dict ):
73- price_filter = arg ["price_filter" ]
72+ if (
73+ "price_level_filter" in arg
74+ and arg ["price_level_filter" ]
75+ and isinstance (arg ["price_level_filter" ], dict )
76+ ):
77+ price_level_filter = arg ["price_level_filter" ]
7478 filters .append (
7579 {
76- "column" : "price " ,
77- "comparison_operator" : price_filter ["comparison_operator" ],
78- "value" : price_filter ["value" ],
80+ "column" : "price_level " ,
81+ "comparison_operator" : price_level_filter ["comparison_operator" ],
82+ "value" : price_level_filter ["value" ],
7983 }
8084 )
81- if "brand_filter " in arg and arg ["brand_filter " ] and isinstance (arg ["brand_filter " ], dict ):
82- brand_filter = arg ["brand_filter " ]
85+ if "rating_filter " in arg and arg ["rating_filter " ] and isinstance (arg ["rating_filter " ], dict ):
86+ rating_filter = arg ["rating_filter " ]
8387 filters .append (
8488 {
85- "column" : "brand " ,
86- "comparison_operator" : brand_filter ["comparison_operator" ],
87- "value" : brand_filter ["value" ],
89+ "column" : "rating " ,
90+ "comparison_operator" : rating_filter ["comparison_operator" ],
91+ "value" : rating_filter ["value" ],
8892 }
8993 )
9094 elif query_text := response_message .content :
0 commit comments