11import argparse
22import asyncio
3+ import json
34import logging
45import os
56import pathlib
67from enum import Enum
7- from typing import Any , Optional
88
99import requests
1010from azure .ai .evaluation import ContentSafetyEvaluator
@@ -48,36 +48,15 @@ def get_azure_credential():
4848
4949async def callback (
5050 messages : list [dict ],
51- stream : bool = False ,
52- session_state : Any = None ,
53- context : Optional [dict [str , Any ]] = None ,
5451 target_url : str = "http://127.0.0.1:8000/chat" ,
5552):
5653 messages_list = messages ["messages" ]
57- latest_message = messages_list [- 1 ]
58- query = latest_message ["content" ]
54+ query = messages_list [- 1 ]["content" ]
5955 headers = {"Content-Type" : "application/json" }
6056 body = {
6157 "messages" : [{"content" : query , "role" : "user" }],
62- "stream" : stream ,
63- "context" : {
64- "overrides" : {
65- "top" : 3 ,
66- "temperature" : 0.3 ,
67- "minimum_reranker_score" : 0 ,
68- "minimum_search_score" : 0 ,
69- "retrieval_mode" : "hybrid" ,
70- "semantic_ranker" : True ,
71- "semantic_captions" : False ,
72- "suggest_followup_questions" : False ,
73- "use_oid_security_filter" : False ,
74- "use_groups_security_filter" : False ,
75- "vector_fields" : ["embedding" ],
76- "use_gpt4v" : False ,
77- "gpt4v_input" : "textAndImages" ,
78- "seed" : 1 ,
79- }
80- },
58+ "stream" : False ,
59+ "context" : {"overrides" : {"use_advanced_flow" : True , "top" : 3 , "retrieval_mode" : "hybrid" , "temperature" : 0.3 }},
8160 }
8261 url = target_url
8362 r = requests .post (url , headers = headers , json = body )
@@ -86,8 +65,7 @@ async def callback(
8665 message = {"content" : response ["error" ], "role" : "assistant" }
8766 else :
8867 message = response ["message" ]
89- response ["messages" ] = messages_list + [message ]
90- return response
68+ return {"messages" : messages_list + [message ]}
9169
9270
9371async def run_simulator (target_url : str , max_simulations : int ):
@@ -104,9 +82,7 @@ async def run_simulator(target_url: str, max_simulations: int):
10482
10583 outputs = await adversarial_simulator (
10684 scenario = scenario ,
107- target = lambda messages , stream = False , session_state = None , context = None : callback (
108- messages , stream , session_state , context , target_url
109- ),
85+ target = lambda messages , stream = False , session_state = None , context = None : callback (messages , target_url ),
11086 max_simulation_results = max_simulations ,
11187 language = SupportedLanguages .English , # Match this to your app language
11288 randomization_seed = 1 , # For more consistent results, use a fixed seed
@@ -139,10 +115,9 @@ async def run_simulator(target_url: str, max_simulations: int):
139115 else :
140116 summary_scores [evaluator ]["mean_score" ] = 0
141117 summary_scores [evaluator ]["low_rate" ] = 0
118+
142119 # Save summary scores
143120 with open (root_dir / "safety_results.json" , "w" ) as f :
144- import json
145-
146121 json .dump (summary_scores , f , indent = 2 )
147122
148123
0 commit comments