11import argparse
22import asyncio
3+ import datetime
34import logging
45import os
56import pathlib
67import sys
8+ from typing import Optional
79
810import requests
911from azure .ai .evaluation import AzureAIProject
@@ -52,7 +54,7 @@ async def callback(
5254 return {"messages" : messages + [message ]}
5355
5456
55- async def run_simulator (target_url : str , max_simulations : int ):
57+ async def run_simulator (target_url : str , max_simulations : int , scan_name : Optional [ str ] = None ):
5658 credential = get_azure_credential ()
5759 azure_ai_project : AzureAIProject = {
5860 "subscription_id" : os .getenv ("AZURE_SUBSCRIPTION_ID" ),
@@ -64,26 +66,25 @@ async def run_simulator(target_url: str, max_simulations: int):
6466 credential = credential ,
6567 risk_categories = [
6668 RiskCategory .Violence ,
67- # RiskCategory.HateUnfairness,
68- # RiskCategory.Sexual,
69- # RiskCategory.SelfHarm,
69+ RiskCategory .HateUnfairness ,
70+ RiskCategory .Sexual ,
71+ RiskCategory .SelfHarm ,
7072 ],
7173 num_objectives = 1 ,
7274 )
75+ if scan_name is None :
76+ timestamp = datetime .datetime .now ().strftime ("%Y-%m-%d_%H-%M-%S" )
77+ scan_name = f"Safety evaluation { timestamp } "
7378 await model_red_team .scan (
7479 target = lambda messages , stream = False , session_state = None , context = None : callback (messages , target_url ),
75- scan_name = "Advanced-Callback-Scan" ,
80+ scan_name = scan_name ,
7681 attack_strategies = [
77- AttackStrategy .EASY , # Group of easy complexity attacks
78- # AttackStrategy.MODERATE, # Group of moderate complexity attacks
79- # AttackStrategy.CharacterSpace, # Add character spaces
80- # AttackStrategy.ROT13, # Use ROT13 encoding
81- # AttackStrategy.UnicodeConfusable, # Use confusable Unicode characters
82- # AttackStrategy.CharSwap, # Swap characters in prompts
83- # AttackStrategy.Morse, # Encode prompts in Morse code
84- # AttackStrategy.Leetspeak, # Use Leetspeak
85- # AttackStrategy.Url, # Use URLs in prompts
86- # AttackStrategy.Binary, # Encode prompts in binary
82+ AttackStrategy .DIFFICULT ,
83+ AttackStrategy .Baseline ,
84+ AttackStrategy .UnicodeConfusable , # Use confusable Unicode characters
85+ AttackStrategy .Morse , # Encode prompts in Morse code
86+ AttackStrategy .Leetspeak , # Use Leetspeak
87+ AttackStrategy .Url , # Use URLs in prompts
8788 ],
8889 output_path = "Advanced-Callback-Scan.json" ,
8990 )
@@ -97,28 +98,29 @@ async def run_simulator(target_url: str, max_simulations: int):
9798 parser .add_argument (
9899 "--max_simulations" , type = int , default = 200 , help = "Maximum number of simulations (question/response pairs)."
99100 )
101+ # argument for the name
102+ parser .add_argument ("--scan_name" , type = str , default = None , help = "Name of the safety evaluation (optional)." )
100103 args = parser .parse_args ()
101104
102105 # Configure logging to show tracebacks for warnings and above
103106 logging .basicConfig (
104- level = logging .DEBUG ,
107+ level = logging .WARNING ,
105108 format = "%(message)s" ,
106109 datefmt = "[%X]" ,
107110 handlers = [RichHandler (rich_tracebacks = False , show_path = True )],
108111 )
109112
110113 # Set urllib3 and azure libraries to WARNING level to see connection issues
111114 logging .getLogger ("urllib3" ).setLevel (logging .WARNING )
112- logging .getLogger ("azure" ).setLevel (logging .DEBUG )
113- logging .getLogger ("RedTeamLogger" ).setLevel (logging .DEBUG )
115+ logging .getLogger ("azure" ).setLevel (logging .WARNING )
114116
115117 # Set our application logger to INFO level
116118 logger .setLevel (logging .INFO )
117119
118120 load_azd_env ()
119121
120122 try :
121- asyncio .run (run_simulator (args .target_url , args .max_simulations ))
123+ asyncio .run (run_simulator (args .target_url , args .max_simulations , args . scan_name ))
122124 except Exception :
123125 logging .exception ("Unhandled exception in safety evaluation" )
124126 sys .exit (1 )
0 commit comments