From ad55308299e2550d1c58be9f7c53665b39cc0193 Mon Sep 17 00:00:00 2001 From: wubinbin Date: Wed, 10 Jul 2024 11:16:59 +0800 Subject: [PATCH 001/130] set port 3000 --- application/docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/application/docker-compose.yml b/application/docker-compose.yml index b09ef55..a19b13a 100644 --- a/application/docker-compose.yml +++ b/application/docker-compose.yml @@ -101,7 +101,7 @@ services: dockerfile: Dockerfile restart: always ports: - - "80:80" + - "3000:80" expose: - "80" networks: From 0d19d0389341cb85a4e3dc0b3ebd6b9ff9998d0e Mon Sep 17 00:00:00 2001 From: keithyt06 Date: Wed, 10 Jul 2024 08:04:12 +0000 Subject: [PATCH 002/130] Add new feature: support for StarRocks --- application/nlq/data_access/database.py | 9 ++++++++- .../2_\360\237\252\231_Data_Connection_Management.py" | 1 + application/requirements.txt | 3 ++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/application/nlq/data_access/database.py b/application/nlq/data_access/database.py index 653343b..ad250a5 100644 --- a/application/nlq/data_access/database.py +++ b/application/nlq/data_access/database.py @@ -12,7 +12,8 @@ class RelationDatabase(): db_mapping = { 'mysql': 'mysql+pymysql', 'postgresql': 'postgresql+psycopg2', - 'redshift': 'postgresql+psycopg2' + 'redshift': 'postgresql+psycopg2', + 'starrocks': 'starrocks' # Add more mappings here for other databases } @@ -72,6 +73,12 @@ def get_all_schema_names_by_connection(cls, connection: ConnectConfigEntity): engine = db.create_engine(db_url) database_connect = sqlalchemy.inspect(engine) schemas = database_connect.get_schema_names() + elif connection.db_type == 'starrocks': + db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host, + connection.db_port, connection.db_name) + engine = db.create_engine(db_url) + database_connect = sqlalchemy.inspect(engine) + schemas = database_connect.get_schema_names() return schemas @classmethod diff --git "a/application/pages/2_\360\237\252\231_Data_Connection_Management.py" "b/application/pages/2_\360\237\252\231_Data_Connection_Management.py" index 42fe47a..54a6868 100644 --- "a/application/pages/2_\360\237\252\231_Data_Connection_Management.py" +++ "b/application/pages/2_\360\237\252\231_Data_Connection_Management.py" @@ -12,6 +12,7 @@ 'mysql': 'MySQL', 'postgresql': 'PostgreSQL', 'redshift': 'Redshift', + 'starrocks': 'StarRocks', } diff --git a/application/requirements.txt b/application/requirements.txt index 7c6efac..df4aabc 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -13,4 +13,5 @@ langchain-core~=0.1.30 sqlparse~=0.4.2 debugpy pandas==2.0.3 -openpyxl \ No newline at end of file +openpyxl +starrocks \ No newline at end of file From cd9d824063341f2aeb6909e38093e5a5ff438641 Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 11 Jul 2024 09:10:02 +0800 Subject: [PATCH 003/130] add code for StarRocks --- application/api/main.py | 6 +++-- application/api/schemas.py | 2 ++ application/api/service.py | 27 ++++++++++--------- application/nlq/business/log_store.py | 4 +-- application/nlq/data_access/database.py | 9 ++++++- .../nlq/data_access/dynamo_query_log.py | 12 ++++++--- ...237\252\231_Data_Connection_Management.py" | 1 + application/requirements-api.txt | 3 ++- application/requirements.txt | 3 ++- application/utils/prompt.py | 5 ++++ application/utils/prompts/generate_prompt.py | 4 ++- 11 files changed, 52 insertions(+), 24 deletions(-) diff --git a/application/api/main.py b/application/api/main.py index a184f34..0b450a9 100644 --- a/application/api/main.py +++ b/application/api/main.py @@ -41,12 +41,14 @@ def ask(question: Question): @router.post("/user_feedback") def user_feedback(input_data: FeedBackInput): feedback_type = input_data.feedback_type + user_id = input_data.user_id + session_id = input_data.session_id if feedback_type == "upvote": - upvote_res = service.user_feedback_upvote(input_data.data_profiles, input_data.query, + upvote_res = service.user_feedback_upvote(input_data.data_profiles, user_id, session_id, input_data.query, input_data.query_intent, input_data.query_answer) return upvote_res else: - downvote_res = service.user_feedback_downvote(input_data.data_profiles, input_data.query, + downvote_res = service.user_feedback_downvote(input_data.data_profiles, user_id, session_id, input_data.query, input_data.query_intent, input_data.query_answer) return downvote_res diff --git a/application/api/schemas.py b/application/api/schemas.py index ff209fe..c2a4a80 100644 --- a/application/api/schemas.py +++ b/application/api/schemas.py @@ -39,6 +39,8 @@ class FeedBackInput(BaseModel): query: str query_intent: str query_answer: str + session_id: str = "-1" + user_id: str = "admin" class Option(BaseModel): diff --git a/application/api/service.py b/application/api/service.py index 4394424..b449eee 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -111,6 +111,8 @@ def get_result_from_llm(question: Question, current_nlq_chain: NLQChain, with_re def ask(question: Question) -> Answer: logger.debug(question) verify_parameters(question) + user_id = question.user_id + session_id =question.session_id intent_ner_recognition_flag = question.intent_ner_recognition_flag agent_cot_flag = question.agent_cot_flag @@ -193,7 +195,7 @@ def ask(question: Question) -> Answer: answer = Answer(query=search_box, query_intent="reject_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[]) - LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql="", query=search_box, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql="", query=search_box, intent="reject_search", log_info="", time_str=current_time) return answer elif search_intent_flag: @@ -210,7 +212,7 @@ def ask(question: Question) -> Answer: sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[]) - LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql="", query=search_box, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql="", query=search_box, intent="knowledge_search", log_info=knowledge_search_result.knowledge_response, time_str=current_time) @@ -272,7 +274,7 @@ def ask(question: Question) -> Answer: sql_search_result.data_show_type = model_select_type log_info = str(search_intent_result["error_info"]) + ";" + sql_search_result.data_analyse - LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql=sql_search_result.sql, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql=sql_search_result.sql, query=search_box, intent="normal_search", log_info=log_info, @@ -318,7 +320,7 @@ def ask(question: Question) -> Answer: else: log_info = agent_search_result[i]["query"] + "The SQL error Info: " log_id = generate_log_id() - LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql=each_task_res["sql"], + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql=each_task_res["sql"], query=search_box + "; The sub task is " + agent_search_result[i]["query"], intent="agent_search", log_info=log_info, @@ -340,6 +342,7 @@ def ask(question: Question) -> Answer: async def ask_websocket(websocket: WebSocket, question : Question): logger.info(question) session_id = question.session_id + user_id = question.user_id intent_ner_recognition_flag = question.intent_ner_recognition_flag agent_cot_flag = question.agent_cot_flag @@ -424,7 +427,7 @@ async def ask_websocket(websocket: WebSocket, question : Question): answer = Answer(query=search_box, query_intent="reject_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[]) - LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql="", query=search_box, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql="", query=search_box, intent="reject_search", log_info="", time_str=current_time) return answer elif search_intent_flag: @@ -441,7 +444,7 @@ async def ask_websocket(websocket: WebSocket, question : Question): sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[]) - LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql="", query=search_box, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql="", query=search_box, intent="knowledge_search", log_info=knowledge_search_result.knowledge_response, time_str=current_time) @@ -511,7 +514,7 @@ async def ask_websocket(websocket: WebSocket, question : Question): sql_search_result.data_show_type = model_select_type log_info = str(search_intent_result["error_info"]) + ";" + sql_search_result.data_analyse - LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql=sql_search_result.sql, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql=sql_search_result.sql, query=search_box, intent="normal_search", log_info=log_info, @@ -557,7 +560,7 @@ async def ask_websocket(websocket: WebSocket, question : Question): else: log_info = agent_search_result[i]["query"] + "The SQL error Info: " log_id = generate_log_id() - LogManagement.add_log_to_database(log_id=log_id, profile_name=selected_profile, sql=each_task_res["sql"], + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql=each_task_res["sql"], query=search_box + "; The sub task is " + agent_search_result[i]["query"], intent="agent_search", log_info=log_info, @@ -576,7 +579,7 @@ async def ask_websocket(websocket: WebSocket, question : Question): return answer -def user_feedback_upvote(data_profiles: str, query: str, query_intent: str, query_answer): +def user_feedback_upvote(data_profiles: str, user_id : str, session_id : str, query: str, query_intent: str, query_answer): try: if query_intent == "normal_search": VectorStore.add_sample(data_profiles, query, query_answer) @@ -588,12 +591,12 @@ def user_feedback_upvote(data_profiles: str, query: str, query_intent: str, quer return False -def user_feedback_downvote(data_profiles: str, query: str, query_intent: str, query_answer): +def user_feedback_downvote(data_profiles: str, user_id : str, session_id : str, query: str, query_intent: str, query_answer): try: if query_intent == "normal_search": log_id = generate_log_id() current_time = get_current_time() - LogManagement.add_log_to_database(log_id=log_id, profile_name=data_profiles, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=data_profiles, sql=query_answer, query=query, intent="normal_search_user_downvote", log_info="", @@ -601,7 +604,7 @@ def user_feedback_downvote(data_profiles: str, query: str, query_intent: str, qu elif query_intent == "agent_search": log_id = generate_log_id() current_time = get_current_time() - LogManagement.add_log_to_database(log_id=log_id, profile_name=data_profiles, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=data_profiles, sql=query_answer, query=query, intent="agent_search_user_downvote", log_info="", diff --git a/application/nlq/business/log_store.py b/application/nlq/business/log_store.py index e33316a..891364a 100644 --- a/application/nlq/business/log_store.py +++ b/application/nlq/business/log_store.py @@ -9,5 +9,5 @@ class LogManagement: query_log_dao = DynamoQueryLogDao() @classmethod - def add_log_to_database(cls, log_id, profile_name, sql, query, intent, log_info, time_str): - cls.query_log_dao.add_log(log_id, profile_name, sql, query, intent, log_info, time_str) + def add_log_to_database(cls, log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str): + cls.query_log_dao.add_log(log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str) diff --git a/application/nlq/data_access/database.py b/application/nlq/data_access/database.py index 653343b..ad250a5 100644 --- a/application/nlq/data_access/database.py +++ b/application/nlq/data_access/database.py @@ -12,7 +12,8 @@ class RelationDatabase(): db_mapping = { 'mysql': 'mysql+pymysql', 'postgresql': 'postgresql+psycopg2', - 'redshift': 'postgresql+psycopg2' + 'redshift': 'postgresql+psycopg2', + 'starrocks': 'starrocks' # Add more mappings here for other databases } @@ -72,6 +73,12 @@ def get_all_schema_names_by_connection(cls, connection: ConnectConfigEntity): engine = db.create_engine(db_url) database_connect = sqlalchemy.inspect(engine) schemas = database_connect.get_schema_names() + elif connection.db_type == 'starrocks': + db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host, + connection.db_port, connection.db_name) + engine = db.create_engine(db_url) + database_connect = sqlalchemy.inspect(engine) + schemas = database_connect.get_schema_names() return schemas @classmethod diff --git a/application/nlq/data_access/dynamo_query_log.py b/application/nlq/data_access/dynamo_query_log.py index 02060d6..ed4a48a 100644 --- a/application/nlq/data_access/dynamo_query_log.py +++ b/application/nlq/data_access/dynamo_query_log.py @@ -12,9 +12,11 @@ class DynamoQueryLogEntity: - def __init__(self, log_id, profile_name, sql, query, intent, log_info, time_str): + def __init__(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str): self.log_id = log_id self.profile_name = profile_name + self.user_id = user_id + self.session_id = session_id self.sql = sql self.query = query self.intent = intent @@ -26,6 +28,8 @@ def to_dict(self): return { 'log_id': self.log_id, 'profile_name': self.profile_name, + 'user_id': self.user_id, + 'session_id': self.session_id, 'sql': self.sql, 'query': self.query, 'intent': self.intent, @@ -104,11 +108,11 @@ def add(self, entity): try: self.table.put_item(Item=entity.to_dict()) except Exception as e: - logger.error("add log entity is error {}",e) + logger.error("add log entity is error {}", e) def update(self, entity): self.table.put_item(Item=entity.to_dict()) - def add_log(self, log_id, profile_name, sql, query, intent, log_info, time_str): - entity = DynamoQueryLogEntity(log_id, profile_name, sql, query, intent, log_info, time_str) + def add_log(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str): + entity = DynamoQueryLogEntity(log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str) self.add(entity) diff --git "a/application/pages/2_\360\237\252\231_Data_Connection_Management.py" "b/application/pages/2_\360\237\252\231_Data_Connection_Management.py" index 42fe47a..dfa2971 100644 --- "a/application/pages/2_\360\237\252\231_Data_Connection_Management.py" +++ "b/application/pages/2_\360\237\252\231_Data_Connection_Management.py" @@ -12,6 +12,7 @@ 'mysql': 'MySQL', 'postgresql': 'PostgreSQL', 'redshift': 'Redshift', + 'starrocks': 'StarRocks' } diff --git a/application/requirements-api.txt b/application/requirements-api.txt index 300a02d..0b25975 100644 --- a/application/requirements-api.txt +++ b/application/requirements-api.txt @@ -14,4 +14,5 @@ langchain~=0.1.11 langchain-core~=0.1.30 sqlparse~=0.4.2 pandas==2.0.3 -openpyxl \ No newline at end of file +openpyxl +starrocks==1.0.6 \ No newline at end of file diff --git a/application/requirements.txt b/application/requirements.txt index 7c6efac..81cef76 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -13,4 +13,5 @@ langchain-core~=0.1.30 sqlparse~=0.4.2 debugpy pandas==2.0.3 -openpyxl \ No newline at end of file +openpyxl +starrocks==1.0.6 \ No newline at end of file diff --git a/application/utils/prompt.py b/application/utils/prompt.py index b6bf0d8..61ca564 100644 --- a/application/utils/prompt.py +++ b/application/utils/prompt.py @@ -17,6 +17,11 @@ Pay attention to use CURDATE() function to get the current date, if the question involves "today". In the process of generating SQL statements, please do not use aliases. Aside from giving the SQL answer, concisely explain yourself after giving the answer in the same language as the question.""".format(top_k=TOP_K) +STARROCKS_DIALECT_PROMPT_CLAUDE3=""" +You are a data analysis expert and proficient in StarRocks. Given an input question, first create a syntactically correct StarRocks SQL query to run, then look at the results of the query and return the answer to the input +question.When generating SQL, do not add double quotes or single quotes around table names. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per StarRocks SQL. +Never query for all columns from a table.""".format(top_k=TOP_K) + AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3 = """You are a Amazon Redshift expert. Given an input question, first create a syntactically correct Redshift query to run, then look at the results of the query and return the answer to the input question.When generating SQL, do not add double quotes or single quotes around table names. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. diff --git a/application/utils/prompts/generate_prompt.py b/application/utils/prompts/generate_prompt.py index f99c43c..9c46fa9 100644 --- a/application/utils/prompts/generate_prompt.py +++ b/application/utils/prompts/generate_prompt.py @@ -1,5 +1,5 @@ from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \ - DEFAULT_DIALECT_PROMPT, AGENT_COT_EXAMPLE, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3 + DEFAULT_DIALECT_PROMPT, AGENT_COT_EXAMPLE, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3, STARROCKS_DIALECT_PROMPT_CLAUDE3 from utils.prompts import guidance_prompt from utils.prompts import table_prompt import logging @@ -1907,6 +1907,8 @@ def generate_llm_prompt(ddl, hints, prompt_map, search_box, sql_examples=None, n dialect_prompt = MYSQL_DIALECT_PROMPT_CLAUDE3 elif dialect == 'redshift': dialect_prompt = AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3 + elif dialect == 'starrocks': + dialect_prompt = STARROCKS_DIALECT_PROMPT_CLAUDE3 else: dialect_prompt = DEFAULT_DIALECT_PROMPT From 6fa26bd03ed4ab132a3adf809f2bd574e4af0aa7 Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 11 Jul 2024 11:06:17 +0800 Subject: [PATCH 004/130] fix docker build --- application/Dockerfile | 15 ++++++++++++++- application/Dockerfile-api | 14 +++++++++++++- application/api/service.py | 4 +++- 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/application/Dockerfile b/application/Dockerfile index 989d5fe..bbd5809 100644 --- a/application/Dockerfile +++ b/application/Dockerfile @@ -7,7 +7,20 @@ RUN adduser --disabled-password --gecos '' appuser WORKDIR /app COPY requirements.txt /app/ -RUN pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple + +ARG AWS_REGION=us-east-1 +ENV AWS_REGION=${AWS_REGION} + +# Print the AWS_REGION for verification +RUN echo "Current AWS Region: $AWS_REGION" + +# Install dependencies using the appropriate PyPI source based on AWS region +RUN if [ "$AWS_REGION" = "cn-north-1" ] || [ "$AWS_REGION" = "cn-northwest-1" ]; then \ + pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple; \ + else \ + pip3 install -r requirements.txt; \ + fi + COPY . /app/ diff --git a/application/Dockerfile-api b/application/Dockerfile-api index c66d8eb..876eb68 100644 --- a/application/Dockerfile-api +++ b/application/Dockerfile-api @@ -3,7 +3,19 @@ FROM public.ecr.aws/docker/library/python:3.10-slim WORKDIR /app COPY . /app/ -RUN pip3 install -r requirements-api.txt -i https://pypi.tuna.tsinghua.edu.cn/simple + +ARG AWS_REGION=us-east-1 +ENV AWS_REGION=${AWS_REGION} + +# Print the AWS_REGION for verification +RUN echo "Current AWS Region: $AWS_REGION" + +# Install dependencies using the appropriate PyPI source based on AWS region +RUN if [ "$AWS_REGION" = "cn-north-1" ] || [ "$AWS_REGION" = "cn-northwest-1" ]; then \ + pip3 install -r requirements-api.txt -i https://pypi.tuna.tsinghua.edu.cn/simple; \ + else \ + pip3 install -r requirements-api.txt; \ + fi EXPOSE 8000 diff --git a/application/api/service.py b/application/api/service.py index b449eee..b5edda7 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -29,10 +29,10 @@ logger = logging.getLogger(__name__) load_dotenv() -all_profiles = ProfileManagement.get_all_profiles_with_info() def get_option() -> Option: + all_profiles = ProfileManagement.get_all_profiles_with_info() option = Option( data_profiles=all_profiles.keys(), bedrock_model_ids=BEDROCK_MODEL_IDS, @@ -62,6 +62,7 @@ def get_result_from_llm(question: Question, current_nlq_chain: NLQChain, with_re logger.info('try to get generated sql from LLM') entity_slot_retrieve = [] + all_profiles = ProfileManagement.get_all_profiles_with_info() database_profile = all_profiles[question.profile_name] if question.intent_ner_recognition: intent_response = get_query_intent(question.bedrock_model_id, question.keywords, database_profile['prompt_map']) @@ -627,6 +628,7 @@ def explain_with_response_stream(current_nlq_chain: NLQChain) -> dict: def get_executed_result(current_nlq_chain: NLQChain) -> str: + all_profiles = ProfileManagement.get_all_profiles_with_info() sql_query_result = current_nlq_chain.get_executed_result_df(all_profiles[current_nlq_chain.profile]) final_sql_query_result = sql_query_result.to_markdown() return final_sql_query_result From e9136f04c0bed643d13678e05e1c19e666b3775d Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 11 Jul 2024 11:24:38 +0800 Subject: [PATCH 005/130] fix log issue --- application/api/service.py | 109 +++++++++++------- .../nlq/data_access/dynamo_query_log.py | 2 +- 2 files changed, 68 insertions(+), 43 deletions(-) diff --git a/application/api/service.py b/application/api/service.py index b5edda7..6a8678e 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -72,7 +72,8 @@ def get_result_from_llm(question: Question, current_nlq_chain: NLQChain, with_re entity_slot = intent_response.get("slot", []) if entity_slot: for each_entity in entity_slot: - entity_retrieve = get_retrieve_opensearch(opensearch_info, each_entity, "ner", question.profile_name, 1, 0.7) + entity_retrieve = get_retrieve_opensearch(opensearch_info, each_entity, "ner", question.profile_name, 1, + 0.7) if entity_retrieve: entity_slot_retrieve.extend(entity_retrieve) @@ -113,7 +114,7 @@ def ask(question: Question) -> Answer: logger.debug(question) verify_parameters(question) user_id = question.user_id - session_id =question.session_id + session_id = question.session_id intent_ner_recognition_flag = question.intent_ner_recognition_flag agent_cot_flag = question.agent_cot_flag @@ -196,7 +197,8 @@ def ask(question: Question) -> Answer: answer = Answer(query=search_box, query_intent="reject_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[]) - LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql="", query=search_box, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, + profile_name=selected_profile, sql="", query=search_box, intent="reject_search", log_info="", time_str=current_time) return answer elif search_intent_flag: @@ -213,7 +215,8 @@ def ask(question: Question) -> Answer: sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[]) - LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql="", query=search_box, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, + profile_name=selected_profile, sql="", query=search_box, intent="knowledge_search", log_info=knowledge_search_result.knowledge_response, time_str=current_time) @@ -261,9 +264,12 @@ def ask(question: Question) -> Answer: sql_search_result.data_analyse = search_intent_analyse_result - model_select_type, show_select_data, select_chart_type, show_chart_data = data_visualization(model_type, search_box, - search_intent_result["data"], - database_profile['prompt_map']) + model_select_type, show_select_data, select_chart_type, show_chart_data = data_visualization(model_type, + search_box, + search_intent_result[ + "data"], + database_profile[ + 'prompt_map']) if select_chart_type != "-1": sql_chart_data = ChartEntity(chart_type="", chart_data=[]) @@ -275,7 +281,8 @@ def ask(question: Question) -> Answer: sql_search_result.data_show_type = model_select_type log_info = str(search_intent_result["error_info"]) + ";" + sql_search_result.data_analyse - LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql=sql_search_result.sql, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, + profile_name=selected_profile, sql=sql_search_result.sql, query=search_box, intent="normal_search", log_info=log_info, @@ -295,11 +302,13 @@ def ask(question: Question) -> Answer: each_task_sql_res = [list(each_task_res["data"].columns)] + each_task_res["data"].values.tolist() model_select_type, show_select_data, select_chart_type, show_chart_data = data_visualization(model_type, - agent_search_result[i][ - "query"], - each_task_res["data"], - database_profile[ - 'prompt_map']) + agent_search_result[ + i][ + "query"], + each_task_res[ + "data"], + database_profile[ + 'prompt_map']) each_task_sql_response = get_generated_sql_explain(agent_search_result[i]["response"]) sub_task_sql_result = SQLSearchResult(sql_data=show_select_data, sql=each_task_res["sql"], @@ -321,7 +330,8 @@ def ask(question: Question) -> Answer: else: log_info = agent_search_result[i]["query"] + "The SQL error Info: " log_id = generate_log_id() - LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql=each_task_res["sql"], + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, + profile_name=selected_profile, sql=each_task_res["sql"], query=search_box + "; The sub task is " + agent_search_result[i]["query"], intent="agent_search", log_info=log_info, @@ -340,7 +350,7 @@ def ask(question: Question) -> Answer: return answer -async def ask_websocket(websocket: WebSocket, question : Question): +async def ask_websocket(websocket: WebSocket, question: Question): logger.info(question) session_id = question.session_id user_id = question.user_id @@ -397,8 +407,7 @@ async def ask_websocket(websocket: WebSocket, question : Question): prompt_map = database_profile['prompt_map'] entity_slot = [] - # 通过标志位控制后续的逻辑 - # 主要的意图有4个, 拒绝, 查询, 思维链, 知识问答 + if intent_ner_recognition_flag: await response_websocket(websocket, session_id, "Query Intent Analyse", ContentEnum.STATE, "start") intent_response = get_query_intent(model_type, search_box, prompt_map) @@ -428,14 +437,15 @@ async def ask_websocket(websocket: WebSocket, question : Question): answer = Answer(query=search_box, query_intent="reject_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[]) - LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql="", query=search_box, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, + profile_name=selected_profile, sql="", query=search_box, intent="reject_search", log_info="", time_str=current_time) return answer elif search_intent_flag: normal_search_result = await normal_text_search_websocket(websocket, session_id, search_box, model_type, - database_profile, - entity_slot, opensearch_info, - selected_profile, use_rag_flag) + database_profile, + entity_slot, opensearch_info, + selected_profile, use_rag_flag) elif knowledge_search_flag: response = knowledge_search(search_box=search_box, model_id=model_type, prompt_map=prompt_map) @@ -445,7 +455,8 @@ async def ask_websocket(websocket: WebSocket, question : Question): sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[]) - LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql="", query=search_box, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, + profile_name=selected_profile, sql="", query=search_box, intent="knowledge_search", log_info=knowledge_search_result.knowledge_response, time_str=current_time) @@ -469,7 +480,7 @@ async def ask_websocket(websocket: WebSocket, question : Question): split_strings = generated_sq.split("[generate]") generate_suggested_question_list = [s.strip() for s in split_strings if s.strip()] - # 连接数据库,执行SQL, 记录历史记录并展示 + if search_intent_flag: if normal_search_result.sql != "": current_nlq_chain.set_generated_sql(normal_search_result.sql) @@ -492,18 +503,23 @@ async def ask_websocket(websocket: WebSocket, question : Question): else: if search_intent_result["data"] is not None and len(search_intent_result["data"]) > 0: if answer_with_insights: - await response_websocket(websocket, session_id, "Generating Data Insights", ContentEnum.STATE, "start") + await response_websocket(websocket, session_id, "Generating Data Insights", ContentEnum.STATE, + "start") search_intent_analyse_result = data_analyse_tool(model_type, prompt_map, search_box, search_intent_result["data"].to_json( orient='records', force_ascii=False), "query") - await response_websocket(websocket, session_id, "Generating Data Insights", ContentEnum.STATE, "end") + await response_websocket(websocket, session_id, "Generating Data Insights", ContentEnum.STATE, + "end") sql_search_result.data_analyse = search_intent_analyse_result - model_select_type, show_select_data, select_chart_type, show_chart_data = data_visualization(model_type, search_box, - search_intent_result["data"], - database_profile['prompt_map']) + model_select_type, show_select_data, select_chart_type, show_chart_data = data_visualization(model_type, + search_box, + search_intent_result[ + "data"], + database_profile[ + 'prompt_map']) if select_chart_type != "-1": sql_chart_data = ChartEntity(chart_type="", chart_data=[]) @@ -515,7 +531,8 @@ async def ask_websocket(websocket: WebSocket, question : Question): sql_search_result.data_show_type = model_select_type log_info = str(search_intent_result["error_info"]) + ";" + sql_search_result.data_analyse - LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql=sql_search_result.sql, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, + profile_name=selected_profile, sql=sql_search_result.sql, query=search_box, intent="normal_search", log_info=log_info, @@ -535,11 +552,13 @@ async def ask_websocket(websocket: WebSocket, question : Question): each_task_sql_res = [list(each_task_res["data"].columns)] + each_task_res["data"].values.tolist() model_select_type, show_select_data, select_chart_type, show_chart_data = data_visualization(model_type, - agent_search_result[i][ - "query"], - each_task_res["data"], - database_profile[ - 'prompt_map']) + agent_search_result[ + i][ + "query"], + each_task_res[ + "data"], + database_profile[ + 'prompt_map']) each_task_sql_response = get_generated_sql_explain(agent_search_result[i]["response"]) sub_task_sql_result = SQLSearchResult(sql_data=show_select_data, sql=each_task_res["sql"], @@ -561,7 +580,8 @@ async def ask_websocket(websocket: WebSocket, question : Question): else: log_info = agent_search_result[i]["query"] + "The SQL error Info: " log_id = generate_log_id() - LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=selected_profile, sql=each_task_res["sql"], + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, + profile_name=selected_profile, sql=each_task_res["sql"], query=search_box + "; The sub task is " + agent_search_result[i]["query"], intent="agent_search", log_info=log_info, @@ -580,7 +600,8 @@ async def ask_websocket(websocket: WebSocket, question : Question): return answer -def user_feedback_upvote(data_profiles: str, user_id : str, session_id : str, query: str, query_intent: str, query_answer): +def user_feedback_upvote(data_profiles: str, user_id: str, session_id: str, query: str, query_intent: str, + query_answer): try: if query_intent == "normal_search": VectorStore.add_sample(data_profiles, query, query_answer) @@ -592,12 +613,14 @@ def user_feedback_upvote(data_profiles: str, user_id : str, session_id : str, qu return False -def user_feedback_downvote(data_profiles: str, user_id : str, session_id : str, query: str, query_intent: str, query_answer): +def user_feedback_downvote(data_profiles: str, user_id: str, session_id: str, query: str, query_intent: str, + query_answer): try: if query_intent == "normal_search": log_id = generate_log_id() current_time = get_current_time() - LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=data_profiles, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, + profile_name=data_profiles, sql=query_answer, query=query, intent="normal_search_user_downvote", log_info="", @@ -605,7 +628,8 @@ def user_feedback_downvote(data_profiles: str, user_id : str, session_id : str, elif query_intent == "agent_search": log_id = generate_log_id() current_time = get_current_time() - LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id= session_id, profile_name=data_profiles, + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, + profile_name=data_profiles, sql=query_answer, query=query, intent="agent_search_user_downvote", log_info="", @@ -614,6 +638,7 @@ def user_feedback_downvote(data_profiles: str, user_id : str, session_id : str, except Exception as e: return False + def ask_with_response_stream(question: Question, current_nlq_chain: NLQChain) -> dict: logger.info('try to get generated sql from LLM') response = get_result_from_llm(question, current_nlq_chain, True) @@ -659,7 +684,6 @@ async def normal_text_search_websocket(websocket: WebSocket, session_id: str, se entity_slot_retrieve.extend(entity_retrieve) await response_websocket(websocket, session_id, "Entity Info Retrieval", ContentEnum.STATE, "end") - if use_rag: await response_websocket(websocket, session_id, "QA Info Retrieval", ContentEnum.STATE, "start") retrieve_result = get_retrieve_opensearch(opensearch_info, search_box, "query", @@ -692,7 +716,8 @@ async def normal_text_search_websocket(websocket: WebSocket, session_id: str, se async def response_websocket(websocket: WebSocket, session_id: str, content, - content_type: ContentEnum = ContentEnum.COMMON, status: str = "-1", user_id: str = "admin"): + content_type: ContentEnum = ContentEnum.COMMON, status: str = "-1", + user_id: str = "admin"): if content_type == ContentEnum.STATE: content_json = { "text": content, @@ -708,4 +733,4 @@ async def response_websocket(websocket: WebSocket, session_id: str, content, } logger.info(content_obj) final_content = json.dumps(content_obj) - await websocket.send_text(final_content) \ No newline at end of file + await websocket.send_text(final_content) diff --git a/application/nlq/data_access/dynamo_query_log.py b/application/nlq/data_access/dynamo_query_log.py index ed4a48a..9f6ef68 100644 --- a/application/nlq/data_access/dynamo_query_log.py +++ b/application/nlq/data_access/dynamo_query_log.py @@ -114,5 +114,5 @@ def update(self, entity): self.table.put_item(Item=entity.to_dict()) def add_log(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str): - entity = DynamoQueryLogEntity(log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str) + entity = DynamoQueryLogEntity(log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str) self.add(entity) From 2551c9d3b911880128bc67e517e5aa11276122af Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 11 Jul 2024 11:40:53 +0800 Subject: [PATCH 006/130] remove some code and replace space in SQL --- application/api/main.py | 22 ---------------------- application/api/service.py | 9 ++++----- 2 files changed, 4 insertions(+), 27 deletions(-) diff --git a/application/api/main.py b/application/api/main.py index 0b450a9..c06f71e 100644 --- a/application/api/main.py +++ b/application/api/main.py @@ -65,28 +65,6 @@ async def websocket_endpoint(websocket: WebSocket): session_id = question.session_id ask_result = await ask_websocket(websocket, question) logger.info(ask_result) - - - # current_nlq_chain = service.get_nlq_chain(question) - # if question.use_rag: - # examples = service.get_example(current_nlq_chain) - # await response_websocket(websocket, session_id, "Examples:\n```json\n") - # await response_websocket(websocket, session_id, str(examples)) - # await response_websocket(websocket, session_id, "\n```\n") - # response = service.ask_with_response_stream(question, current_nlq_chain) - # if os.getenv('SAGEMAKER_ENDPOINT_SQL', ''): - # await response_sagemaker_sql(websocket, session_id, response, current_nlq_chain) - # await response_websocket(websocket, session_id, "\n") - # explain_response = service.explain_with_response_stream(current_nlq_chain) - # await response_sagemaker_explain(websocket, session_id, explain_response) - # else: - # await response_bedrock(websocket, session_id, response, current_nlq_chain) - # - # if question.query_result: - # final_sql_query_result = service.get_executed_result(current_nlq_chain) - # await response_websocket(websocket, session_id, "\n\nQuery result: \n") - # await response_websocket(websocket, session_id, final_sql_query_result) - # await response_websocket(websocket, session_id, "\n") await response_websocket(websocket, session_id, ask_result.dict(), ContentEnum.END) except Exception: msg = traceback.format_exc() diff --git a/application/api/service.py b/application/api/service.py index 6a8678e..d3eb836 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -240,14 +240,13 @@ def ask(question: Question) -> Answer: split_strings = generated_sq.split("[generate]") generate_suggested_question_list = [s.strip() for s in split_strings if s.strip()] - # 连接数据库,执行SQL, 记录历史记录并展示 if search_intent_flag: if normal_search_result.sql != "": current_nlq_chain.set_generated_sql(normal_search_result.sql) - sql_search_result.sql = normal_search_result.sql + sql_search_result.sql = normal_search_result.sql.strip() current_nlq_chain.set_generated_sql_response(normal_search_result.response) if explain_gen_process_flag: - sql_search_result.sql_gen_process = current_nlq_chain.get_generated_sql_explain() + sql_search_result.sql_gen_process = current_nlq_chain.get_generated_sql_explain().strip() else: sql_search_result.sql = "-1" @@ -484,10 +483,10 @@ async def ask_websocket(websocket: WebSocket, question: Question): if search_intent_flag: if normal_search_result.sql != "": current_nlq_chain.set_generated_sql(normal_search_result.sql) - sql_search_result.sql = normal_search_result.sql + sql_search_result.sql = normal_search_result.sql.strip() current_nlq_chain.set_generated_sql_response(normal_search_result.response) if explain_gen_process_flag: - sql_search_result.sql_gen_process = current_nlq_chain.get_generated_sql_explain() + sql_search_result.sql_gen_process = current_nlq_chain.get_generated_sql_explain().strip() else: sql_search_result.sql = "-1" From 72d0260cc318fddc959e1e5f88b2f343337315ee Mon Sep 17 00:00:00 2001 From: Zhoutong Wang Date: Thu, 11 Jul 2024 06:46:11 +0000 Subject: [PATCH 007/130] dockerfile update --- application/Dockerfile | 4 ++-- application/Dockerfile-api | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/application/Dockerfile b/application/Dockerfile index 989d5fe..b492ca1 100644 --- a/application/Dockerfile +++ b/application/Dockerfile @@ -7,7 +7,7 @@ RUN adduser --disabled-password --gecos '' appuser WORKDIR /app COPY requirements.txt /app/ -RUN pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple +RUN pip3 install -r requirements.txt COPY . /app/ @@ -25,4 +25,4 @@ USER appuser HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health -ENTRYPOINT ["streamlit", "run", "Index.py", "--server.port=8501", "--server.address=0.0.0.0"] \ No newline at end of file +ENTRYPOINT ["streamlit", "run", "Index.py", "--server.port=8501", "--server.address=0.0.0.0"] diff --git a/application/Dockerfile-api b/application/Dockerfile-api index c66d8eb..0dcdb8f 100644 --- a/application/Dockerfile-api +++ b/application/Dockerfile-api @@ -3,8 +3,8 @@ FROM public.ecr.aws/docker/library/python:3.10-slim WORKDIR /app COPY . /app/ -RUN pip3 install -r requirements-api.txt -i https://pypi.tuna.tsinghua.edu.cn/simple +RUN pip3 install -r requirements-api.txt EXPOSE 8000 -ENTRYPOINT ["uvicorn", "main:app", "--host", "0.0.0.0"] \ No newline at end of file +ENTRYPOINT ["uvicorn", "main:app", "--host", "0.0.0.0"] From 5189965934911e2a8a9d9884c8bb708eac6bcd7a Mon Sep 17 00:00:00 2001 From: Zhoutong Wang Date: Thu, 11 Jul 2024 07:18:04 +0000 Subject: [PATCH 008/130] dynamo bug fixed --- application/nlq/data_access/dynamo_suggested_question.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/application/nlq/data_access/dynamo_suggested_question.py b/application/nlq/data_access/dynamo_suggested_question.py index 7adedd2..f3abf32 100644 --- a/application/nlq/data_access/dynamo_suggested_question.py +++ b/application/nlq/data_access/dynamo_suggested_question.py @@ -26,7 +26,7 @@ def to_dict(self): class SuggestedQuestionDao: def __init__(self, table_name_prefix=''): - self.dynamodb = boto3.resource('dynamodb') + self.dynamodb = boto3.resource('dynamodb', region_name=os.getenv("DYNAMODB_AWS_REGION")) self.table_name = table_name_prefix + PROFILE_QUESTION_TABLE_NAME if not self.exists(): self.create_table() From 012472501111ebba01cd0c1db12bf2a66ab5cab8 Mon Sep 17 00:00:00 2001 From: Zhoutong Wang Date: Thu, 11 Jul 2024 07:21:48 +0000 Subject: [PATCH 009/130] conflict --- application/Dockerfile-api | 4 ---- 1 file changed, 4 deletions(-) diff --git a/application/Dockerfile-api b/application/Dockerfile-api index ba657d7..aa52257 100644 --- a/application/Dockerfile-api +++ b/application/Dockerfile-api @@ -3,9 +3,6 @@ FROM public.ecr.aws/docker/library/python:3.10-slim WORKDIR /app COPY . /app/ -<<<<<<< HEAD -RUN pip3 install -r requirements-api.txt -======= ARG AWS_REGION=us-east-1 ENV AWS_REGION=${AWS_REGION} @@ -19,7 +16,6 @@ RUN if [ "$AWS_REGION" = "cn-north-1" ] || [ "$AWS_REGION" = "cn-northwest-1" ]; else \ pip3 install -r requirements-api.txt; \ fi ->>>>>>> origin EXPOSE 8000 From d583b07785945088eae7a3424e1b29d883672a91 Mon Sep 17 00:00:00 2001 From: Zhoutong Wang Date: Thu, 11 Jul 2024 16:19:57 +0800 Subject: [PATCH 010/130] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1c59f22..977798c 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ The following table provides a sample cost breakdown for deploying this Guidance | ----------- | ------------ | ------------ | | Amazon ECS | v0.75 CPU 5GB | $804.1 | | Amazon DynamoDB | 25 provisioned write & read capacity units per month | $ 14.04 | -| Amazon Bedrock | 2000 requests per month, with each request consuming 10000 input tokens and 1000 output tokens | $ 416.00 | +| Amazon Bedrock | 2000 requests per month, with each request consuming 10000 input tokens and 1000 output tokens | $ 90.00 | | Amazon OpenSearch Service | 1 domain with m5.large.search | $ 103.66 | ## Prerequisites From 6dca7fd02b8faefcc4f632935972ca3574a7de89 Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 11 Jul 2024 16:54:55 +0800 Subject: [PATCH 011/130] change expiry_days --- application/config_files/stauth_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/application/config_files/stauth_config.yaml b/application/config_files/stauth_config.yaml index 5a2fc4d..53029e7 100644 --- a/application/config_files/stauth_config.yaml +++ b/application/config_files/stauth_config.yaml @@ -7,7 +7,7 @@ credentials: name: AWS password: $2b$12$NDQv5NLaWiVlNuzQYHwAo.tv.f.TuX1nbdoUZi44/Y3xv4I4QAfjy # Set the password following instructions in README cookie: - expiry_days: 30 + expiry_days: 2 key: some_signature_key # Must be string name: some_cookie_name pre-authorized: From 7903be1b1d73d866c4b509888c0f07db069048c1 Mon Sep 17 00:00:00 2001 From: Zhoutong Wang Date: Thu, 11 Jul 2024 09:10:40 +0000 Subject: [PATCH 012/130] import error --- application/nlq/data_access/dynamo_suggested_question.py | 1 + 1 file changed, 1 insertion(+) diff --git a/application/nlq/data_access/dynamo_suggested_question.py b/application/nlq/data_access/dynamo_suggested_question.py index f3abf32..f71654e 100644 --- a/application/nlq/data_access/dynamo_suggested_question.py +++ b/application/nlq/data_access/dynamo_suggested_question.py @@ -2,6 +2,7 @@ from utils.prompt import SUGGESTED_QUESTION_PROMPT_CLAUDE3 import boto3 import logging +import os from botocore.exceptions import ClientError from utils.constant import PROFILE_QUESTION_TABLE_NAME, ACTIVE_PROMPT_NAME, DEFAULT_PROMPT_NAME From 794af4c6338aec284824291ece1e2e9948a26d10 Mon Sep 17 00:00:00 2001 From: Zhoutong Wang Date: Thu, 11 Jul 2024 14:01:29 +0000 Subject: [PATCH 013/130] auto correcting --- ...0\237\214\215_Generative_BI_Playground.py" | 35 ++++++++++++++++++- application/utils/llm.py | 4 +-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" index c3cbd55..ef72b79 100644 --- "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" +++ "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" @@ -301,6 +301,7 @@ def main(): explain_gen_process_flag = st.checkbox("Explain Generation Process", True) data_with_analyse = st.checkbox("Answer With Insights", False) gen_suggested_question_flag = st.checkbox("Generate Suggested Questions", False) + auto_correction_flag = st.checkbox("Auto Correcting SQL", False) context_window = st.slider("Multiple Rounds of Context Window", 0, 10, 0) clean_history = st.button("clean history", on_click=clean_st_history, args=[selected_profile]) @@ -503,7 +504,39 @@ def main(): if search_intent_result["status_code"] == 500: with st.expander("The SQL Error Info"): st.markdown(search_intent_result["error_info"]) - else: + + if auto_correction_flag: + with st.status("Regenerating SQL") as status_text: + response = text_to_sql(database_profile['tables_info'], + database_profile['hints'], + database_profile['prompt_map'], + search_box, + model_id=model_type, + sql_examples=normal_search_result.retrieve_result, + ner_example=normal_search_result.entity_slot_retrieve, + dialect=database_profile['db_type'], + model_provider=None, + additional_info='''\n NOTE: when I try to write a SQL {sql_statement}, I got an error {error}. Please consider and avoid this problem. '''.format(sql_statement=current_nlq_chain.get_generated_sql(), error=search_intent_result["error_info"])) + + regen_sql = get_generated_sql(response) + + st.code(regen_sql, language="sql") + + status_text.update( + label=f"Generating SQL Done", + state="complete", expanded=True) + + with st.spinner('Executing query...'): + search_intent_result = get_sql_result_tool( + st.session_state['profiles'][current_nlq_chain.profile], + regen_sql) + + if search_intent_result["status_code"] == 500: + with st.expander("The SQL Error Info"): + st.markdown(search_intent_result["error_info"]) + + if search_intent_result["status_code"] != 500: + # else: if search_intent_result["data"] is not None and len( search_intent_result["data"]) > 0 and data_with_analyse: with st.spinner('Generating data summarize...'): diff --git a/application/utils/llm.py b/application/utils/llm.py index 9213af4..b8dc3a3 100644 --- a/application/utils/llm.py +++ b/application/utils/llm.py @@ -275,11 +275,11 @@ def invoke_llm_model(model_id, system_prompt, user_prompt, max_tokens=2048, with def text_to_sql(ddl, hints, prompt_map, search_box, sql_examples=None, ner_example=None, model_id=None, dialect='mysql', - model_provider=None, with_response_stream=False): + model_provider=None, with_response_stream=False, additional_info=''): user_prompt, system_prompt = generate_llm_prompt(ddl, hints, prompt_map, search_box, sql_examples, ner_example, model_id, dialect=dialect) max_tokens = 4096 - response = invoke_llm_model(model_id, system_prompt, user_prompt, max_tokens, with_response_stream) + response = invoke_llm_model(model_id, system_prompt, user_prompt + additional_info, max_tokens, with_response_stream) return response From c6f44ec0bf57a794783f83930b2728377679bdc1 Mon Sep 17 00:00:00 2001 From: supinyu Date: Fri, 12 Jul 2024 08:05:35 +0800 Subject: [PATCH 014/130] add force_set_cookie --- application/Dockerfile | 2 +- application/Dockerfile-api | 2 +- application/Index.py | 3 ++- .../5_\360\237\252\231_Prompt_Management.py" | 6 ++---- application/utils/navigation.py | 18 ++++++++++++++++-- 5 files changed, 22 insertions(+), 9 deletions(-) diff --git a/application/Dockerfile b/application/Dockerfile index 8cc14dc..e7229a3 100644 --- a/application/Dockerfile +++ b/application/Dockerfile @@ -8,7 +8,7 @@ WORKDIR /app COPY requirements.txt /app/ -ARG AWS_REGION=us-east-1 +#ARG AWS_REGION=us-east-1 ENV AWS_REGION=${AWS_REGION} # Print the AWS_REGION for verification diff --git a/application/Dockerfile-api b/application/Dockerfile-api index aa52257..cb4759d 100644 --- a/application/Dockerfile-api +++ b/application/Dockerfile-api @@ -4,7 +4,7 @@ WORKDIR /app COPY . /app/ -ARG AWS_REGION=us-east-1 +#ARG AWS_REGION=us-east-1 ENV AWS_REGION=${AWS_REGION} # Print the AWS_REGION for verification diff --git a/application/Index.py b/application/Index.py index 8ecda2e..dca3964 100644 --- a/application/Index.py +++ b/application/Index.py @@ -1,5 +1,5 @@ import streamlit as st -from utils.navigation import get_authenticator +from utils.navigation import get_authenticator, force_set_cookie st.set_page_config( page_title="Intelligent BI", @@ -10,6 +10,7 @@ name, authentication_status, username = authenticator.login('main') if authentication_status: + force_set_cookie(authenticator) st.switch_page("pages/mainpage.py") elif authentication_status == False: st.error('Username/password is incorrect') diff --git "a/application/pages/5_\360\237\252\231_Prompt_Management.py" "b/application/pages/5_\360\237\252\231_Prompt_Management.py" index 0b5e74c..2b354ee 100644 --- "a/application/pages/5_\360\237\252\231_Prompt_Management.py" +++ "b/application/pages/5_\360\237\252\231_Prompt_Management.py" @@ -36,10 +36,6 @@ def main(): prompt_type_selected_table = st.selectbox("Prompt Type", prompt_map.keys(), index=None, format_func=lambda x: prompt_map[x].get('title'), placeholder="Please select a prompt type") - - profile_detail = ProfileManagement.get_profile_by_name(current_profile) - prompt_map = profile_detail.prompt_map - if prompt_type_selected_table is not None: single_type_prompt_map = prompt_map.get(prompt_type_selected_table) system_prompt = single_type_prompt_map.get('system_prompt') @@ -48,6 +44,8 @@ def main(): placeholder="Please select a model") if model_selected_table is not None: + profile_detail = ProfileManagement.get_profile_by_name(current_profile) + prompt_map = profile_detail.prompt_map system_prompt_input = st.text_area('System Prompt', system_prompt[model_selected_table], height=300) user_prompt_input = st.text_area('User Prompt', user_prompt[model_selected_table], height=500) diff --git a/application/utils/navigation.py b/application/utils/navigation.py index e99dc86..5ae4cdb 100644 --- a/application/utils/navigation.py +++ b/application/utils/navigation.py @@ -20,6 +20,18 @@ def get_authenticator(): ) +def force_set_cookie(authenticator): + """ + Force the cookie + :param authenticator: + :return: + """ + try: + authenticator.cookie_handler.set_cookie() + except: + pass + + def get_current_page_name(): ctx = get_script_run_ctx() if ctx is None: @@ -35,12 +47,14 @@ def make_sidebar(): if st.session_state.get('authentication_status'): st.page_link("pages/mainpage.py", label="Index") st.page_link("pages/1_🌍_Generative_BI_Playground.py", label="Generative BI Playground", icon="🌍") - st.markdown(":gray[Data Customization Management]", help='Add your own datasources and customize description for LLM to better understand them') + st.markdown(":gray[Data Customization Management]", + help='Add your own datasources and customize description for LLM to better understand them') st.page_link("pages/2_🪙_Data_Connection_Management.py", label="Data Connection Management", icon="🪙") st.page_link("pages/3_🪙_Data_Profile_Management.py", label="Data Profile Management", icon="🪙") st.page_link("pages/4_🪙_Schema_Description_Management.py", label="Schema Description Management", icon="🪙") st.page_link("pages/5_🪙_Prompt_Management.py", label="Prompt Management", icon="🪙") - st.markdown(":gray[Performance Enhancement]", help='Optimize your LLM for better performance by adding RAG or agent') + st.markdown(":gray[Performance Enhancement]", + help='Optimize your LLM for better performance by adding RAG or agent') st.page_link("pages/6_📚_Index_Management.py", label="Index Management", icon="📚") st.page_link("pages/7_📚_Entity_Management.py", label="Entity Management", icon="📚") st.page_link("pages/8_📚_Agent_Cot_Management.py", label="Agent Cot Management", icon="📚") From d9e3dd2d1f36658ede61bc89267e8ea91badde5a Mon Sep 17 00:00:00 2001 From: supinyu Date: Fri, 12 Jul 2024 08:39:55 +0800 Subject: [PATCH 015/130] fix prompt --- "application/pages/5_\360\237\252\231_Prompt_Management.py" | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git "a/application/pages/5_\360\237\252\231_Prompt_Management.py" "b/application/pages/5_\360\237\252\231_Prompt_Management.py" index 2b354ee..21910bb 100644 --- "a/application/pages/5_\360\237\252\231_Prompt_Management.py" +++ "b/application/pages/5_\360\237\252\231_Prompt_Management.py" @@ -39,13 +39,15 @@ def main(): if prompt_type_selected_table is not None: single_type_prompt_map = prompt_map.get(prompt_type_selected_table) system_prompt = single_type_prompt_map.get('system_prompt') - user_prompt = single_type_prompt_map.get('user_prompt') model_selected_table = st.selectbox("LLM Model", system_prompt.keys(), index=None, placeholder="Please select a model") if model_selected_table is not None: profile_detail = ProfileManagement.get_profile_by_name(current_profile) prompt_map = profile_detail.prompt_map + single_type_prompt_map = prompt_map.get(prompt_type_selected_table) + system_prompt = single_type_prompt_map.get('system_prompt') + user_prompt = single_type_prompt_map.get('user_prompt') system_prompt_input = st.text_area('System Prompt', system_prompt[model_selected_table], height=300) user_prompt_input = st.text_area('User Prompt', user_prompt[model_selected_table], height=500) From 83a1fa6f19b898d9fff8b7080c83c971ca4f8ffb Mon Sep 17 00:00:00 2001 From: supinyu Date: Fri, 12 Jul 2024 10:46:48 +0800 Subject: [PATCH 016/130] add cdk config --- application/.env.template | 2 ++ application/nlq/business/vector_store.py | 5 +++++ source/resources/bin/main.ts | 15 +++++++++++++-- source/resources/dk-config.json | 15 +++++++++++++++ 4 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 source/resources/dk-config.json diff --git a/application/.env.template b/application/.env.template index a474d53..b000648 100644 --- a/application/.env.template +++ b/application/.env.template @@ -31,3 +31,5 @@ BEDROCK_SECRETS_AK_SK= OPENSEARCH_SECRETS_URL_HOST=opensearch-host-url OPENSEARCH_SECRETS_USERNAME_PASSWORD=opensearch-master-user + +SAGEMAKER_ENDPOINT_EMBEDDING= \ No newline at end of file diff --git a/application/nlq/business/vector_store.py b/application/nlq/business/vector_store.py index 0c4195f..4b3830a 100644 --- a/application/nlq/business/vector_store.py +++ b/application/nlq/business/vector_store.py @@ -117,6 +117,11 @@ def create_vector_embedding_with_bedrock(cls, text): return embedding + @classmethod + def create_vector_embedding_with_sagemaker(cls): + # to do + pass + @classmethod def delete_sample(cls, profile_name, doc_id): logger.info(f'delete sample question id: {doc_id} from profile {profile_name}') diff --git a/source/resources/bin/main.ts b/source/resources/bin/main.ts index fd5c887..f8534b6 100644 --- a/source/resources/bin/main.ts +++ b/source/resources/bin/main.ts @@ -1,13 +1,24 @@ import * as cdk from 'aws-cdk-lib'; import { MainStack } from '../lib/main-stack'; +import * as fs from 'fs'; +import * as path from 'path'; const devEnv = { account: process.env.CDK_DEFAULT_ACCOUNT, region: process.env.CDK_DEFAULT_REGION, }; +const configPath = path.join(__dirname, '..', 'cdk-config.json'); +const config = JSON.parse(fs.readFileSync(configPath, 'utf8')); + const app = new cdk.App(); -const deployRds = process.argv.includes('--deploy-rds'); // Check if --deploy-rds flag is present -new MainStack(app, 'GenBiMainStack', { env: devEnv, deployRds }); // Pass deployRDS flag to MainStack constructor +const rds = config.rds + +const cdkConfig = { + env: devEnv, + deployRds: rds.deployRds +}; + +new MainStack(app, 'GenBiMainStack', cdkConfig); // Pass deployRDS flag to MainStack constructor app.synth(); diff --git a/source/resources/dk-config.json b/source/resources/dk-config.json new file mode 100644 index 0000000..9ab742e --- /dev/null +++ b/source/resources/dk-config.json @@ -0,0 +1,15 @@ +{ + "rds": { + "deploy": false + }, + "embedding": { + "bedrock_embedding_name": "", + "embedding_dimension": 1536 + }, + "segamaker": { + "endpoint_name" : "" + }, + "opensearch": { + "sql_index" : "genbi_sql_index" + } +} \ No newline at end of file From 82fa51c550d217cf895e376f006c3aa0e3f390ce Mon Sep 17 00:00:00 2001 From: supinyu Date: Fri, 12 Jul 2024 11:14:46 +0800 Subject: [PATCH 017/130] add cdk config --- source/resources/bin/main.ts | 2 +- source/resources/cdk-config.json | 18 ++++++++++++++++++ source/resources/dk-config.json | 15 --------------- 3 files changed, 19 insertions(+), 16 deletions(-) create mode 100644 source/resources/cdk-config.json delete mode 100644 source/resources/dk-config.json diff --git a/source/resources/bin/main.ts b/source/resources/bin/main.ts index f8534b6..36f44f9 100644 --- a/source/resources/bin/main.ts +++ b/source/resources/bin/main.ts @@ -17,7 +17,7 @@ const rds = config.rds const cdkConfig = { env: devEnv, - deployRds: rds.deployRds + deployRds: rds.deploy }; new MainStack(app, 'GenBiMainStack', cdkConfig); // Pass deployRDS flag to MainStack constructor diff --git a/source/resources/cdk-config.json b/source/resources/cdk-config.json new file mode 100644 index 0000000..c48fa55 --- /dev/null +++ b/source/resources/cdk-config.json @@ -0,0 +1,18 @@ +{ + "rds": { + "deploy": false + }, + "embedding": { + "bedrock_embedding_name": "amazon.titan-embed-text-v1", + "embedding_dimension": 1536, + "segamaker_embedding_name" : "" + }, + "segamaker": { + "endpoint_name" : "" + }, + "opensearch": { + "sql_index" : "genbi_sql_index", + "ner_index" : "genbi_ner_index", + "cot_index" : "genbi_cot_index" + } +} \ No newline at end of file diff --git a/source/resources/dk-config.json b/source/resources/dk-config.json deleted file mode 100644 index 9ab742e..0000000 --- a/source/resources/dk-config.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "rds": { - "deploy": false - }, - "embedding": { - "bedrock_embedding_name": "", - "embedding_dimension": 1536 - }, - "segamaker": { - "endpoint_name" : "" - }, - "opensearch": { - "sql_index" : "genbi_sql_index" - } -} \ No newline at end of file From 096803b911b24ee4213d39bffe834e5694821abd Mon Sep 17 00:00:00 2001 From: supinyu Date: Fri, 12 Jul 2024 11:19:19 +0800 Subject: [PATCH 018/130] add cdk config --- source/resources/lib/main-stack.ts | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index 68922a8..6310483 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -31,13 +31,29 @@ export class MainStack extends cdk.Stack { // default: "not-set" // }); - // ======== Step 2. Define the AOSStack ========= + // ======== Step 2. Define the AOSStack ========= + const selectedSubnets = _VpcStack.vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS }); + const _AosStack = new AOSStack(this, 'aos-Stack', { env: props.env, vpc: _VpcStack.vpc, - subnets: _VpcStack.vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS }).subnets, + subnets: selectedSubnets.subnets, + }); + + // 打印选择的私有子网信息 + console.log('Selected PRIVATE_WITH_EGRESS subnets:'); + selectedSubnets.subnets.forEach((subnet, index) => { + console.log(`Subnet ${index + 1}:`); + console.log(` ID: ${subnet.subnetId}`); + console.log(` Availability Zone: ${subnet.availabilityZone}`); + console.log(` CIDR: ${subnet.ipv4CidrBlock}`); }); + // 打印选择的子网数量 + console.log(`Total number of selected subnets: ${selectedSubnets.subnets.length}`); + + + const aosEndpoint = _AosStack.endpoint; // ======== Step 3. Define the RDSStack ========= From 1d569c0537532fc655eed9cf2f8f783c3d32fd08 Mon Sep 17 00:00:00 2001 From: supinyu Date: Mon, 15 Jul 2024 10:08:24 +0800 Subject: [PATCH 019/130] change index --- source/resources/cdk-config.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/source/resources/cdk-config.json b/source/resources/cdk-config.json index c48fa55..cf36c20 100644 --- a/source/resources/cdk-config.json +++ b/source/resources/cdk-config.json @@ -11,8 +11,8 @@ "endpoint_name" : "" }, "opensearch": { - "sql_index" : "genbi_sql_index", - "ner_index" : "genbi_ner_index", - "cot_index" : "genbi_cot_index" + "sql_index" : "ubs", + "ner_index" : "uba_ner", + "cot_index" : "uba_agent" } } \ No newline at end of file From 13a0c086fae717b8bdd98ab27ccb9c8cf15f17dd Mon Sep 17 00:00:00 2001 From: Zhoutong Wang Date: Mon, 15 Jul 2024 10:10:52 +0800 Subject: [PATCH 020/130] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 977798c..92ecc12 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ The following table provides a sample cost breakdown for deploying this Guidance | AWS service | Dimensions | Cost [USD] per Month | | ----------- | ------------ | ------------ | -| Amazon ECS | v0.75 CPU 5GB | $804.1 | +| Amazon ECS | v0.75 CPU 5GB | $11.51 | | Amazon DynamoDB | 25 provisioned write & read capacity units per month | $ 14.04 | | Amazon Bedrock | 2000 requests per month, with each request consuming 10000 input tokens and 1000 output tokens | $ 90.00 | | Amazon OpenSearch Service | 1 domain with m5.large.search | $ 103.66 | From 113c7178c962db2f64ca3c485be3dfc200b1fcbd Mon Sep 17 00:00:00 2001 From: supinyu Date: Mon, 15 Jul 2024 11:10:39 +0800 Subject: [PATCH 021/130] change index --- source/resources/cdk-config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/resources/cdk-config.json b/source/resources/cdk-config.json index cf36c20..9053f52 100644 --- a/source/resources/cdk-config.json +++ b/source/resources/cdk-config.json @@ -11,7 +11,7 @@ "endpoint_name" : "" }, "opensearch": { - "sql_index" : "ubs", + "sql_index" : "uba", "ner_index" : "uba_ner", "cot_index" : "uba_agent" } From 937d136cfe1e9a626f690c1c194b7a0e1bd70ca3 Mon Sep 17 00:00:00 2001 From: supinyu Date: Mon, 15 Jul 2024 16:18:25 +0800 Subject: [PATCH 022/130] add rds cdk fix --- source/resources/lib/main-stack.ts | 3 +++ source/resources/lib/rds/rds-stack.ts | 7 +++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index 6310483..23f3332 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -58,8 +58,11 @@ export class MainStack extends cdk.Stack { // ======== Step 3. Define the RDSStack ========= if (_deployRds) { + const rdsSubnets = _VpcStack.vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_ISOLATED }); + const _RdsStack = new RDSStack(this, 'rds-Stack', { env: props.env, + subnets: rdsSubnets.subnets }); new cdk.CfnOutput(this, 'RDSEndpoint', { value: _RdsStack.endpoint, diff --git a/source/resources/lib/rds/rds-stack.ts b/source/resources/lib/rds/rds-stack.ts index 1240e02..eb7126a 100644 --- a/source/resources/lib/rds/rds-stack.ts +++ b/source/resources/lib/rds/rds-stack.ts @@ -9,9 +9,8 @@ import * as secretsmanager from 'aws-cdk-lib/aws-secretsmanager'; export class RDSStack extends cdk.Stack { _vpc; public readonly endpoint: string; - constructor(scope: Construct, id: string, props?: cdk.StackProps) { + constructor(scope: Construct, id: string, props?: cdk.StackProps & { subnets: cdk.aws_ec2.ISubnet[] }) { super(scope, id, props); - this._vpc = ec2.Vpc.fromLookup(this, "VPC", { isDefault: true, }); @@ -34,9 +33,9 @@ export class RDSStack extends cdk.Stack { instanceType: ec2.InstanceType.of(InstanceClass.T3, InstanceSize.MICRO), vpc: this._vpc, vpcSubnets: { - subnetType: ec2.SubnetType.PRIVATE_ISOLATED + subnetType: props?.subnets }, - publiclyAccessible: true, + publiclyAccessible: false, databaseName: 'GenBIDB', credentials: rds.Credentials.fromSecret(templatedSecret), }); From d2bbee63ec2a34a62384938e82f1990352b4bf35 Mon Sep 17 00:00:00 2001 From: supinyu Date: Mon, 15 Jul 2024 16:44:01 +0800 Subject: [PATCH 023/130] add rds cdk fix --- source/resources/lib/main-stack.ts | 5 +++-- source/resources/lib/rds/rds-stack.ts | 18 +++++++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index 23f3332..a94b06b 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -58,11 +58,12 @@ export class MainStack extends cdk.Stack { // ======== Step 3. Define the RDSStack ========= if (_deployRds) { - const rdsSubnets = _VpcStack.vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_ISOLATED }); + const rdsSubnets = _VpcStack.vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS }); const _RdsStack = new RDSStack(this, 'rds-Stack', { env: props.env, - subnets: rdsSubnets.subnets + subnets: rdsSubnets, + vpcId : _VpcStack.vpc.vpcId }); new cdk.CfnOutput(this, 'RDSEndpoint', { value: _RdsStack.endpoint, diff --git a/source/resources/lib/rds/rds-stack.ts b/source/resources/lib/rds/rds-stack.ts index eb7126a..728517b 100644 --- a/source/resources/lib/rds/rds-stack.ts +++ b/source/resources/lib/rds/rds-stack.ts @@ -5,15 +5,16 @@ import { Construct } from 'constructs'; import { InstanceClass, InstanceSize, InstanceType, Port, SubnetType, Vpc } from 'aws-cdk-lib/aws-ec2' import * as secretsmanager from 'aws-cdk-lib/aws-secretsmanager'; +interface RDSStackProps extends cdk.StackProps { + subnets?: ec2.SubnetSelection; + vpcId?: string; +} // add rds stack export class RDSStack extends cdk.Stack { - _vpc; public readonly endpoint: string; - constructor(scope: Construct, id: string, props?: cdk.StackProps & { subnets: cdk.aws_ec2.ISubnet[] }) { + constructor(scope: Construct, id: string, props?: RDSStackProps) { super(scope, id, props); - this._vpc = ec2.Vpc.fromLookup(this, "VPC", { - isDefault: true, - }); + const vpc = props?.vpcId ? ec2.Vpc.fromLookup(this, "VPC", { vpcId: props.vpcId }) : ec2.Vpc.fromLookup(this, "VPC", { isDefault: true }); const templatedSecret = new secretsmanager.Secret(this, 'TemplatedSecret', { description: 'Templated secret used for RDS password', @@ -31,10 +32,8 @@ export class RDSStack extends cdk.Stack { const database = new rds.DatabaseInstance(this, 'Database', { engine: rds.DatabaseInstanceEngine.mysql({ version: rds.MysqlEngineVersion.VER_8_0 }), instanceType: ec2.InstanceType.of(InstanceClass.T3, InstanceSize.MICRO), - vpc: this._vpc, - vpcSubnets: { - subnetType: props?.subnets - }, + vpc: vpc, + vpcSubnets: props?.subnets || { subnetType: SubnetType.PRIVATE_WITH_EGRESS }, publiclyAccessible: false, databaseName: 'GenBIDB', credentials: rds.Credentials.fromSecret(templatedSecret), @@ -43,6 +42,7 @@ export class RDSStack extends cdk.Stack { // Output the database endpoint new cdk.CfnOutput(this, 'RDSEndpoint', { value: database.instanceEndpoint.hostname, + description: 'The endpoint of the RDS instance', }); } } \ No newline at end of file From 3be8ef4f56fd4d2b9f6944446d64e9fe6e4c3c41 Mon Sep 17 00:00:00 2001 From: supinyu Date: Mon, 15 Jul 2024 16:59:17 +0800 Subject: [PATCH 024/130] print ecs log info --- source/resources/lib/main-stack.ts | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index a94b06b..a64e6e7 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -32,25 +32,25 @@ export class MainStack extends cdk.Stack { // }); // ======== Step 2. Define the AOSStack ========= - const selectedSubnets = _VpcStack.vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS }); + const aosSubnets = _VpcStack.vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS }); const _AosStack = new AOSStack(this, 'aos-Stack', { env: props.env, vpc: _VpcStack.vpc, - subnets: selectedSubnets.subnets, + subnets: aosSubnets.subnets, }); - // 打印选择的私有子网信息 - console.log('Selected PRIVATE_WITH_EGRESS subnets:'); - selectedSubnets.subnets.forEach((subnet, index) => { + // print AOS subnet Info + console.log('AOS subnets Info:'); + aosSubnets.subnets.forEach((subnet, index) => { console.log(`Subnet ${index + 1}:`); console.log(` ID: ${subnet.subnetId}`); console.log(` Availability Zone: ${subnet.availabilityZone}`); console.log(` CIDR: ${subnet.ipv4CidrBlock}`); }); - // 打印选择的子网数量 - console.log(`Total number of selected subnets: ${selectedSubnets.subnets.length}`); + // print AOS subnet length + console.log(`Total number of AOS subnets: ${aosSubnets.subnets.length}`); @@ -78,10 +78,24 @@ export class MainStack extends cdk.Stack { // ======== Step 5. Define the ECS ========= // pass the aosEndpoint and aosPassword to the ecs stack + const ecsSubnets = _VpcStack.vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS }); + + // print AOS subnet Info + console.log('ECS subnets Info:'); + ecsSubnets.subnets.forEach((subnet, index) => { + console.log(`Subnet ${index + 1}:`); + console.log(` ID: ${subnet.subnetId}`); + console.log(` Availability Zone: ${subnet.availabilityZone}`); + console.log(` CIDR: ${subnet.ipv4CidrBlock}`); + }); + + // print AOS subnet length + console.log(`Total number of ECS subnets: ${ecsSubnets.subnets.length}`); + const _EcsStack = new ECSStack(this, 'ecs-Stack', { env: props.env, vpc: _VpcStack.vpc, - subnets: _VpcStack.vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS }).subnets, + subnets: ecsSubnets.subnets, cognitoUserPoolId: _CognitoStack.userPoolId, cognitoUserPoolClientId: _CognitoStack.userPoolClientId, OSMasterUserSecretName: _AosStack.OSMasterUserSecretName, From c090514cae147c2090c2f71e87dabd2b53574240 Mon Sep 17 00:00:00 2001 From: supinyu Date: Mon, 15 Jul 2024 18:31:49 +0800 Subject: [PATCH 025/130] print ecs log info --- source/resources/lib/main-stack.ts | 2 +- source/resources/lib/rds/rds-stack.ts | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index a64e6e7..e3619f5 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -63,7 +63,7 @@ export class MainStack extends cdk.Stack { const _RdsStack = new RDSStack(this, 'rds-Stack', { env: props.env, subnets: rdsSubnets, - vpcId : _VpcStack.vpc.vpcId + vpc : _VpcStack.vpc }); new cdk.CfnOutput(this, 'RDSEndpoint', { value: _RdsStack.endpoint, diff --git a/source/resources/lib/rds/rds-stack.ts b/source/resources/lib/rds/rds-stack.ts index 728517b..19ff6d7 100644 --- a/source/resources/lib/rds/rds-stack.ts +++ b/source/resources/lib/rds/rds-stack.ts @@ -7,15 +7,14 @@ import * as secretsmanager from 'aws-cdk-lib/aws-secretsmanager'; interface RDSStackProps extends cdk.StackProps { subnets?: ec2.SubnetSelection; - vpcId?: string; + vpc:ec2.IVpc; } // add rds stack export class RDSStack extends cdk.Stack { public readonly endpoint: string; - constructor(scope: Construct, id: string, props?: RDSStackProps) { + constructor(scope: Construct, id: string, props: RDSStackProps) { super(scope, id, props); - const vpc = props?.vpcId ? ec2.Vpc.fromLookup(this, "VPC", { vpcId: props.vpcId }) : ec2.Vpc.fromLookup(this, "VPC", { isDefault: true }); - + const templatedSecret = new secretsmanager.Secret(this, 'TemplatedSecret', { description: 'Templated secret used for RDS password', generateSecretString: { @@ -32,8 +31,8 @@ export class RDSStack extends cdk.Stack { const database = new rds.DatabaseInstance(this, 'Database', { engine: rds.DatabaseInstanceEngine.mysql({ version: rds.MysqlEngineVersion.VER_8_0 }), instanceType: ec2.InstanceType.of(InstanceClass.T3, InstanceSize.MICRO), - vpc: vpc, - vpcSubnets: props?.subnets || { subnetType: SubnetType.PRIVATE_WITH_EGRESS }, + vpc: props.vpc, + vpcSubnets: props.subnets || { subnetType: SubnetType.PRIVATE_WITH_EGRESS }, publiclyAccessible: false, databaseName: 'GenBIDB', credentials: rds.Credentials.fromSecret(templatedSecret), From e4a2d07f672ad1cc6a86e36fbb9bf076fbb5e7fc Mon Sep 17 00:00:00 2001 From: supinyu Date: Tue, 16 Jul 2024 09:33:08 +0800 Subject: [PATCH 026/130] change README_CN.md --- README_CN.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README_CN.md b/README_CN.md index 31afd75..e7c9e1a 100644 --- a/README_CN.md +++ b/README_CN.md @@ -62,9 +62,9 @@ | AWS service | Dimensions | Cost [USD] per Month | | ----------- | ------------ | ------------ | -| Amazon ECS | v0.75 CPU 5GB | $804.1 | +| Amazon ECS | v0.75 CPU 5GB | $11.51 | | Amazon DynamoDB | 25 provisioned write & read capacity units per month | $ 14.04 | -| Amazon Bedrock | 2000 requests per month, with each request consuming 10000 input tokens and 1000 output tokens | $ 416.00 | +| Amazon Bedrock | 2000 requests per month, with each request consuming 10000 input tokens and 1000 output tokens | $ 90.00 | | Amazon OpenSearch Service | 1 domain with m5.large.search | $ 103.66 | From e449ce8477cb2a356ae09771157a5fc129c1ff10 Mon Sep 17 00:00:00 2001 From: supinyu Date: Tue, 16 Jul 2024 18:04:58 +0800 Subject: [PATCH 027/130] change current profile --- .../3_\360\237\252\231_Data_Profile_Management.py" | 11 ++++++++--- ...\360\237\252\231_Schema_Description_Management.py" | 1 + .../pages/5_\360\237\252\231_Prompt_Management.py" | 1 + .../pages/6_\360\237\223\232_Index_Management.py" | 1 + .../pages/7_\360\237\223\232_Entity_Management.py" | 1 + .../pages/8_\360\237\223\232_Agent_Cot_Management.py" | 1 + 6 files changed, 13 insertions(+), 3 deletions(-) diff --git "a/application/pages/3_\360\237\252\231_Data_Profile_Management.py" "b/application/pages/3_\360\237\252\231_Data_Profile_Management.py" index b24dd1d..9ba64c6 100644 --- "a/application/pages/3_\360\237\252\231_Data_Profile_Management.py" +++ "b/application/pages/3_\360\237\252\231_Data_Profile_Management.py" @@ -27,9 +27,14 @@ def main(): with st.sidebar: st.title("Data Profile Management") - st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(), - index=None, - placeholder="Please select data profile...", key='current_profile_name') + all_profiles_list = ProfileManagement.get_all_profiles() + if st.session_state.current_profile != "" and st.session_state.current_profile in all_profiles_list: + profile_index = all_profiles_list.index(st.session_state.current_profile) + current_profile = st.selectbox("My Data Profiles", all_profiles_list, index=profile_index) + else: + current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(), + index=None, + placeholder="Please select data profile...", key='current_profile_name') if st.session_state.current_profile_name: st.session_state.profile_page_mode = 'update' diff --git "a/application/pages/4_\360\237\252\231_Schema_Description_Management.py" "b/application/pages/4_\360\237\252\231_Schema_Description_Management.py" index 569587c..ff3a4b1 100644 --- "a/application/pages/4_\360\237\252\231_Schema_Description_Management.py" +++ "b/application/pages/4_\360\237\252\231_Schema_Description_Management.py" @@ -27,6 +27,7 @@ def main(): placeholder="Please select data profile...", key='current_profile_name') if current_profile is not None: + st.session_state['current_profile'] = current_profile profile_detail = ProfileManagement.get_profile_by_name(current_profile) selected_table = st.selectbox("Tables", profile_detail.tables, index=None, placeholder="Please select a table") diff --git "a/application/pages/5_\360\237\252\231_Prompt_Management.py" "b/application/pages/5_\360\237\252\231_Prompt_Management.py" index 21910bb..abad222 100644 --- "a/application/pages/5_\360\237\252\231_Prompt_Management.py" +++ "b/application/pages/5_\360\237\252\231_Prompt_Management.py" @@ -29,6 +29,7 @@ def main(): placeholder="Please select data profile...", key='current_profile_name') if current_profile is not None: + st.session_state['current_profile'] = current_profile profile_detail = ProfileManagement.get_profile_by_name(current_profile) prompt_map = profile_detail.prompt_map diff --git "a/application/pages/6_\360\237\223\232_Index_Management.py" "b/application/pages/6_\360\237\223\232_Index_Management.py" index 220f76c..795c54e 100644 --- "a/application/pages/6_\360\237\223\232_Index_Management.py" +++ "b/application/pages/6_\360\237\223\232_Index_Management.py" @@ -63,6 +63,7 @@ def main(): tab_view, tab_add, tab_search, batch_insert = st.tabs(['View Samples', 'Add New Sample', 'Sample Search', 'Batch Insert Samples']) if current_profile is not None: + st.session_state['current_profile'] = current_profile with tab_view: if current_profile is not None: st.write("The display page can show a maximum of 5000 pieces of data") diff --git "a/application/pages/7_\360\237\223\232_Entity_Management.py" "b/application/pages/7_\360\237\223\232_Entity_Management.py" index f6fc0e4..b04d1cb 100644 --- "a/application/pages/7_\360\237\223\232_Entity_Management.py" +++ "b/application/pages/7_\360\237\223\232_Entity_Management.py" @@ -65,6 +65,7 @@ def main(): tab_view, tab_add, tab_search, batch_insert = st.tabs( ['View Samples', 'Add New Sample', 'Sample Search', 'Batch Insert Samples']) if current_profile is not None: + st.session_state['current_profile'] = current_profile with tab_view: if current_profile is not None: st.write("The display page can show a maximum of 5000 pieces of data") diff --git "a/application/pages/8_\360\237\223\232_Agent_Cot_Management.py" "b/application/pages/8_\360\237\223\232_Agent_Cot_Management.py" index 6c1631b..fa4e7ae 100644 --- "a/application/pages/8_\360\237\223\232_Agent_Cot_Management.py" +++ "b/application/pages/8_\360\237\223\232_Agent_Cot_Management.py" @@ -40,6 +40,7 @@ def main(): tab_view, tab_add, tab_search = st.tabs(['View Samples', 'Add New Sample', 'Sample Search']) if current_profile is not None: + st.session_state['current_profile'] = current_profile with tab_view: if current_profile is not None: st.write("The display page can show a maximum of 5000 pieces of data") From ba9eb22b1d99f76940642f04f88916d5a74f68f5 Mon Sep 17 00:00:00 2001 From: supinyu Date: Tue, 16 Jul 2024 21:39:55 +0800 Subject: [PATCH 028/130] fix updata profile.py --- application/nlq/business/profile.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/application/nlq/business/profile.py b/application/nlq/business/profile.py index 6eefceb..13b82c0 100644 --- a/application/nlq/business/profile.py +++ b/application/nlq/business/profile.py @@ -41,7 +41,9 @@ def get_profile_by_name(cls, profile_name): @classmethod def update_profile(cls, profile_name, conn_name, schemas, tables, comment, tables_info): - entity = ProfileConfigEntity(profile_name, conn_name, schemas, tables, comment, tables_info) + all_profiles = ProfileManagement.get_all_profiles_with_info() + prompt_map = all_profiles[profile_name]["prompt_map"] + entity = ProfileConfigEntity(profile_name, conn_name, schemas, tables, comment, tables_info, prompt_map) cls.profile_config_dao.update(entity) logger.info(f"Profile {profile_name} updated") From 4cfa4ff21f18ebcb24a1bbccfbb446189d8d75c6 Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 09:00:01 +0800 Subject: [PATCH 029/130] fix __dirname issue --- source/resources/bin/main.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/source/resources/bin/main.ts b/source/resources/bin/main.ts index 36f44f9..fa6e892 100644 --- a/source/resources/bin/main.ts +++ b/source/resources/bin/main.ts @@ -2,12 +2,16 @@ import * as cdk from 'aws-cdk-lib'; import { MainStack } from '../lib/main-stack'; import * as fs from 'fs'; import * as path from 'path'; +import { fileURLToPath } from 'url'; const devEnv = { account: process.env.CDK_DEFAULT_ACCOUNT, region: process.env.CDK_DEFAULT_REGION, }; +const __filename = fileURLToPath(import.meta.url); +const __dirname = path.dirname(__filename); + const configPath = path.join(__dirname, '..', 'cdk-config.json'); const config = JSON.parse(fs.readFileSync(configPath, 'utf8')); From da9f257a19985fe214cdffb8fde5828bca0faf24 Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 09:04:10 +0800 Subject: [PATCH 030/130] fix update profile --- .../pages/3_\360\237\252\231_Data_Profile_Management.py" | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git "a/application/pages/3_\360\237\252\231_Data_Profile_Management.py" "b/application/pages/3_\360\237\252\231_Data_Profile_Management.py" index 9ba64c6..9c9351d 100644 --- "a/application/pages/3_\360\237\252\231_Data_Profile_Management.py" +++ "b/application/pages/3_\360\237\252\231_Data_Profile_Management.py" @@ -25,6 +25,10 @@ def main(): if 'current_profile' not in st.session_state: st.session_state['current_profile'] = '' + if 'current_profile_name' not in st.session_state: + st.session_state['current_profile_name'] = '' + + with st.sidebar: st.title("Data Profile Management") all_profiles_list = ProfileManagement.get_all_profiles() @@ -35,7 +39,10 @@ def main(): current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(), index=None, placeholder="Please select data profile...", key='current_profile_name') - if st.session_state.current_profile_name: + + if current_profile is not None: + st.session_state.current_profile = current_profile + st.session_state.current_profile_name = current_profile st.session_state.profile_page_mode = 'update' st.button('Create new profile...', on_click=new_profile_clicked) From 0ac29ec6b90a1cdaf85b18de107c6a103765487c Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 09:13:52 +0800 Subject: [PATCH 031/130] fix update profile --- ...60\237\252\231_Data_Profile_Management.py" | 22 +++++-------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git "a/application/pages/3_\360\237\252\231_Data_Profile_Management.py" "b/application/pages/3_\360\237\252\231_Data_Profile_Management.py" index 9c9351d..d554bcb 100644 --- "a/application/pages/3_\360\237\252\231_Data_Profile_Management.py" +++ "b/application/pages/3_\360\237\252\231_Data_Profile_Management.py" @@ -1,7 +1,7 @@ import streamlit as st import sqlalchemy as db from dotenv import load_dotenv -import logging +import logging from nlq.business.connection import ConnectionManagement from nlq.business.profile import ProfileManagement from utils.navigation import make_sidebar @@ -25,24 +25,12 @@ def main(): if 'current_profile' not in st.session_state: st.session_state['current_profile'] = '' - if 'current_profile_name' not in st.session_state: - st.session_state['current_profile_name'] = '' - - with st.sidebar: st.title("Data Profile Management") - all_profiles_list = ProfileManagement.get_all_profiles() - if st.session_state.current_profile != "" and st.session_state.current_profile in all_profiles_list: - profile_index = all_profiles_list.index(st.session_state.current_profile) - current_profile = st.selectbox("My Data Profiles", all_profiles_list, index=profile_index) - else: - current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(), - index=None, - placeholder="Please select data profile...", key='current_profile_name') - - if current_profile is not None: - st.session_state.current_profile = current_profile - st.session_state.current_profile_name = current_profile + st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(), + index=None, + placeholder="Please select data profile...", key='current_profile_name') + if st.session_state.current_profile_name: st.session_state.profile_page_mode = 'update' st.button('Create new profile...', on_click=new_profile_clicked) From c40901a919e07f605087d21214fa5fdf2186c7c1 Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 09:23:21 +0800 Subject: [PATCH 032/130] fix batch upload slow --- .../pages/7_\360\237\223\232_Entity_Management.py" | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git "a/application/pages/7_\360\237\223\232_Entity_Management.py" "b/application/pages/7_\360\237\223\232_Entity_Management.py" index b04d1cb..35f826a 100644 --- "a/application/pages/7_\360\237\223\232_Entity_Management.py" +++ "b/application/pages/7_\360\237\223\232_Entity_Management.py" @@ -33,7 +33,7 @@ def read_file(uploaded_file): return None columns = list(uploaded_data.columns) if "entity" in columns and "comment" in columns: - return uploaded_data + return uploaded_data[["entity", "comment"]] else: st.error(f"The columns need contains entity and comment") return None @@ -121,9 +121,9 @@ def main(): status_text.text(f"Processing file {i + 1} of {len(uploaded_files)}: {uploaded_file.name}") each_upload_data = read_file(uploaded_file) if each_upload_data is not None: - for index, item in each_upload_data.iterrows(): - entity = str(item["entity"]) - comment = str(item["comment"]) + for item in each_upload_data.itertuples(): + entity = str(item.entity) + comment = str(item.comment) VectorStore.add_entity_sample(current_profile, entity, comment) progress_bar.progress((i + 1) / len(uploaded_files)) From 8445fde952443f3af9eaeec95f8dc913b1b03eb4 Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 09:43:35 +0800 Subject: [PATCH 033/130] fix updata entity --- .../pages/7_\360\237\223\232_Entity_Management.py" | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git "a/application/pages/7_\360\237\223\232_Entity_Management.py" "b/application/pages/7_\360\237\223\232_Entity_Management.py" index 35f826a..0c5521d 100644 --- "a/application/pages/7_\360\237\223\232_Entity_Management.py" +++ "b/application/pages/7_\360\237\223\232_Entity_Management.py" @@ -115,20 +115,21 @@ def main(): uploaded_files = st.file_uploader("Choose CSV or Excel files", accept_multiple_files=True, type=['csv', 'xls', 'xlsx']) if uploaded_files: - progress_bar = st.progress(0) - status_text = st.empty() for i, uploaded_file in enumerate(uploaded_files): + status_text = st.empty() status_text.text(f"Processing file {i + 1} of {len(uploaded_files)}: {uploaded_file.name}") each_upload_data = read_file(uploaded_file) if each_upload_data is not None: + progress_bar = st.progress(0) + progress_text = "batch insert {} entity in progress. Please wait.".format(uploaded_file.name) for item in each_upload_data.itertuples(): entity = str(item.entity) comment = str(item.comment) VectorStore.add_entity_sample(current_profile, entity, comment) - progress_bar.progress((i + 1) / len(uploaded_files)) + progress_bar.progress((i + 1) / len(each_upload_data), text=progress_text) st.success("{uploaded_file} uploaded successfully!".format(uploaded_file=uploaded_file.name)) - progress_bar.empty() + progress_bar.empty() else: st.info('Please select data profile in the left sidebar.') From 77f44184fe2afaf29e60e632efc2b183b0f0b46b Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 09:52:25 +0800 Subject: [PATCH 034/130] fix updata entity --- "application/pages/7_\360\237\223\232_Entity_Management.py" | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git "a/application/pages/7_\360\237\223\232_Entity_Management.py" "b/application/pages/7_\360\237\223\232_Entity_Management.py" index 0c5521d..d54cea9 100644 --- "a/application/pages/7_\360\237\223\232_Entity_Management.py" +++ "b/application/pages/7_\360\237\223\232_Entity_Management.py" @@ -127,9 +127,9 @@ def main(): comment = str(item.comment) VectorStore.add_entity_sample(current_profile, entity, comment) progress_bar.progress((i + 1) / len(each_upload_data), text=progress_text) - + progress_bar.empty() st.success("{uploaded_file} uploaded successfully!".format(uploaded_file=uploaded_file.name)) - progress_bar.empty() + else: st.info('Please select data profile in the left sidebar.') From 6f4d85cb60247d6488aae91b564629668ca54546 Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 10:00:07 +0800 Subject: [PATCH 035/130] fix updata entity --- "application/pages/7_\360\237\223\232_Entity_Management.py" | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git "a/application/pages/7_\360\237\223\232_Entity_Management.py" "b/application/pages/7_\360\237\223\232_Entity_Management.py" index d54cea9..7811556 100644 --- "a/application/pages/7_\360\237\223\232_Entity_Management.py" +++ "b/application/pages/7_\360\237\223\232_Entity_Management.py" @@ -120,13 +120,15 @@ def main(): status_text.text(f"Processing file {i + 1} of {len(uploaded_files)}: {uploaded_file.name}") each_upload_data = read_file(uploaded_file) if each_upload_data is not None: + total_rows = len(each_upload_data) progress_bar = st.progress(0) progress_text = "batch insert {} entity in progress. Please wait.".format(uploaded_file.name) - for item in each_upload_data.itertuples(): + for j, item in enumerate(each_upload_data.itertuples(), 1): entity = str(item.entity) comment = str(item.comment) VectorStore.add_entity_sample(current_profile, entity, comment) - progress_bar.progress((i + 1) / len(each_upload_data), text=progress_text) + progress = (j * 1.0) / total_rows + progress_bar.progress(progress, text=progress_text) progress_bar.empty() st.success("{uploaded_file} uploaded successfully!".format(uploaded_file=uploaded_file.name)) From b08073d4bf7d8b5b839fbae48a2db0d000a5193a Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 10:23:07 +0800 Subject: [PATCH 036/130] fix batch entity insert --- .../6_\360\237\223\232_Index_Management.py" | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git "a/application/pages/6_\360\237\223\232_Index_Management.py" "b/application/pages/6_\360\237\223\232_Index_Management.py" index 795c54e..6f8a7a3 100644 --- "a/application/pages/6_\360\237\223\232_Index_Management.py" +++ "b/application/pages/6_\360\237\223\232_Index_Management.py" @@ -105,22 +105,25 @@ def main(): st.write("This page support CSV or Excel files batch insert sql samples.") st.write("**The Column Name need contain 'question' and 'sql'**") uploaded_files = st.file_uploader("Choose CSV or Excel files", accept_multiple_files=True, - type=['csv', 'xls', 'xlsx']) + type=['csv', 'xls', 'xlsx']) if uploaded_files: - progress_bar = st.progress(0) - status_text = st.empty() for i, uploaded_file in enumerate(uploaded_files): + status_text = st.empty() status_text.text(f"Processing file {i + 1} of {len(uploaded_files)}: {uploaded_file.name}") each_upload_data = read_file(uploaded_file) if each_upload_data is not None: - for index, item in each_upload_data.iterrows(): - question = str(item["question"]) - sql = str(item["sql"]) - VectorStore.add_sample(current_profile, question, sql) - progress_bar.progress((i + 1) / len(uploaded_files)) - + total_rows = len(each_upload_data) + progress_bar = st.progress(0) + progress_text = "batch insert {} entity in progress. Please wait.".format( + uploaded_file.name) + for j, item in enumerate(each_upload_data.itertuples(), 1): + question = str(item.question) + sql = str(item.sql) + VectorStore.add_entity_sample(current_profile, question, sql) + progress = (j * 1.0) / total_rows + progress_bar.progress(progress, text=progress_text) + progress_bar.empty() st.success("{uploaded_file} uploaded successfully!".format(uploaded_file=uploaded_file.name)) - progress_bar.empty() else: st.info('Please select data profile in the left sidebar.') From a2cc20b10933864391dd1a339c49b254d4120c61 Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 11:23:40 +0800 Subject: [PATCH 037/130] fix types node --- source/resources/tsconfig.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/resources/tsconfig.json b/source/resources/tsconfig.json index bc46f18..9ca35ff 100644 --- a/source/resources/tsconfig.json +++ b/source/resources/tsconfig.json @@ -23,7 +23,8 @@ "esModuleInterop": true, "typeRoots": [ "./node_modules/@types" - ] + ], + "types": ["node"] }, "exclude": [ "node_modules", From ea60e5b57c24db07c9feaa1619bba92f6d431772 Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 11:33:51 +0800 Subject: [PATCH 038/130] fix __dirname node --- source/resources/bin/main.ts | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/source/resources/bin/main.ts b/source/resources/bin/main.ts index fa6e892..20671f0 100644 --- a/source/resources/bin/main.ts +++ b/source/resources/bin/main.ts @@ -2,15 +2,13 @@ import * as cdk from 'aws-cdk-lib'; import { MainStack } from '../lib/main-stack'; import * as fs from 'fs'; import * as path from 'path'; -import { fileURLToPath } from 'url'; const devEnv = { account: process.env.CDK_DEFAULT_ACCOUNT, region: process.env.CDK_DEFAULT_REGION, }; -const __filename = fileURLToPath(import.meta.url); -const __dirname = path.dirname(__filename); +declare const __dirname: string; const configPath = path.join(__dirname, '..', 'cdk-config.json'); const config = JSON.parse(fs.readFileSync(configPath, 'utf8')); From 4e2454195291d0a646593cfde1e95fc1097bd040 Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 12:21:56 +0800 Subject: [PATCH 039/130] fix nan --- application/api/service.py | 1 + application/utils/llm.py | 1 + 2 files changed, 2 insertions(+) diff --git a/application/api/service.py b/application/api/service.py index d3eb836..6868edb 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -257,6 +257,7 @@ def ask(question: Question) -> Answer: else: if search_intent_result["data"] is not None and len(search_intent_result["data"]) > 0: if answer_with_insights: + search_intent_result["data"] = search_intent_result["data"].fillna("") search_intent_analyse_result = data_analyse_tool(model_type, prompt_map, search_box, search_intent_result["data"].to_json( orient='records', force_ascii=False), "query") diff --git a/application/utils/llm.py b/application/utils/llm.py index 9213af4..1b9127e 100644 --- a/application/utils/llm.py +++ b/application/utils/llm.py @@ -434,6 +434,7 @@ def select_data_visualization_type(model_id, search_box, search_data, prompt_map def data_visualization(model_id, search_box, search_data, prompt_map): + search_data = search_data.fillna("") columns = list(search_data.columns) data_list = search_data.values.tolist() all_columns_data = [columns] + data_list From ce8cf47bdbe40c2c24836b7ec98d56ba475eb0da Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 13:48:52 +0800 Subject: [PATCH 040/130] change vpc name --- source/resources/lib/vpc/vpc-stack.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/resources/lib/vpc/vpc-stack.ts b/source/resources/lib/vpc/vpc-stack.ts index 4699cb1..6e99115 100644 --- a/source/resources/lib/vpc/vpc-stack.ts +++ b/source/resources/lib/vpc/vpc-stack.ts @@ -14,7 +14,7 @@ public readonly publicSubnets: ec2.ISubnet[]; constructor(scope: Construct, id: string, props: cdk.StackProps) { super(scope, id, props); // Create a VPC - const vpc = new ec2.Vpc(this, 'MyVpc', { + const vpc = new ec2.Vpc(this, 'GenBIVpc', { maxAzs: 3, // Default is all AZs in the region subnetConfiguration: [ { From 6a1c1fc632bb922ad61ea9bf66aa2e9e613ec2e1 Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 14:27:47 +0800 Subject: [PATCH 041/130] fix current profile and update profile --- application/api/service.py | 1 + application/nlq/business/profile.py | 4 +- ...\252\231_Schema_Description_Management.py" | 1 + .../5_\360\237\252\231_Prompt_Management.py" | 1 + .../6_\360\237\223\232_Index_Management.py" | 47 +++++++++++-------- .../7_\360\237\223\232_Entity_Management.py" | 24 ++++++---- ..._\360\237\223\232_Agent_Cot_Management.py" | 1 + application/utils/llm.py | 1 + 8 files changed, 50 insertions(+), 30 deletions(-) diff --git a/application/api/service.py b/application/api/service.py index d3eb836..6868edb 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -257,6 +257,7 @@ def ask(question: Question) -> Answer: else: if search_intent_result["data"] is not None and len(search_intent_result["data"]) > 0: if answer_with_insights: + search_intent_result["data"] = search_intent_result["data"].fillna("") search_intent_analyse_result = data_analyse_tool(model_type, prompt_map, search_box, search_intent_result["data"].to_json( orient='records', force_ascii=False), "query") diff --git a/application/nlq/business/profile.py b/application/nlq/business/profile.py index 6eefceb..13b82c0 100644 --- a/application/nlq/business/profile.py +++ b/application/nlq/business/profile.py @@ -41,7 +41,9 @@ def get_profile_by_name(cls, profile_name): @classmethod def update_profile(cls, profile_name, conn_name, schemas, tables, comment, tables_info): - entity = ProfileConfigEntity(profile_name, conn_name, schemas, tables, comment, tables_info) + all_profiles = ProfileManagement.get_all_profiles_with_info() + prompt_map = all_profiles[profile_name]["prompt_map"] + entity = ProfileConfigEntity(profile_name, conn_name, schemas, tables, comment, tables_info, prompt_map) cls.profile_config_dao.update(entity) logger.info(f"Profile {profile_name} updated") diff --git "a/application/pages/4_\360\237\252\231_Schema_Description_Management.py" "b/application/pages/4_\360\237\252\231_Schema_Description_Management.py" index 569587c..ff3a4b1 100644 --- "a/application/pages/4_\360\237\252\231_Schema_Description_Management.py" +++ "b/application/pages/4_\360\237\252\231_Schema_Description_Management.py" @@ -27,6 +27,7 @@ def main(): placeholder="Please select data profile...", key='current_profile_name') if current_profile is not None: + st.session_state['current_profile'] = current_profile profile_detail = ProfileManagement.get_profile_by_name(current_profile) selected_table = st.selectbox("Tables", profile_detail.tables, index=None, placeholder="Please select a table") diff --git "a/application/pages/5_\360\237\252\231_Prompt_Management.py" "b/application/pages/5_\360\237\252\231_Prompt_Management.py" index 21910bb..abad222 100644 --- "a/application/pages/5_\360\237\252\231_Prompt_Management.py" +++ "b/application/pages/5_\360\237\252\231_Prompt_Management.py" @@ -29,6 +29,7 @@ def main(): placeholder="Please select data profile...", key='current_profile_name') if current_profile is not None: + st.session_state['current_profile'] = current_profile profile_detail = ProfileManagement.get_profile_by_name(current_profile) prompt_map = profile_detail.prompt_map diff --git "a/application/pages/6_\360\237\223\232_Index_Management.py" "b/application/pages/6_\360\237\223\232_Index_Management.py" index 220f76c..c527410 100644 --- "a/application/pages/6_\360\237\223\232_Index_Management.py" +++ "b/application/pages/6_\360\237\223\232_Index_Management.py" @@ -11,10 +11,12 @@ logger = logging.getLogger(__name__) + def delete_sample(profile_name, id): VectorStore.delete_sample(profile_name, id) st.success(f'Sample {id} deleted.') + def read_file(uploaded_file): """ read upload csv file @@ -36,6 +38,7 @@ def read_file(uploaded_file): st.error(f"The columns need contains question and sql") return None + def main(): load_dotenv() logger.info('start index management') @@ -45,7 +48,6 @@ def main(): if 'profile_page_mode' not in st.session_state: st.session_state['index_mgt_mode'] = 'default' - if 'current_profile' not in st.session_state: st.session_state['current_profile'] = '' @@ -54,22 +56,25 @@ def main(): all_profiles_list = ProfileManagement.get_all_profiles() if st.session_state.current_profile != "" and st.session_state.current_profile in all_profiles_list: profile_index = all_profiles_list.index(st.session_state.current_profile) - current_profile = st.selectbox("My Data Profiles", all_profiles_list, index = profile_index) + current_profile = st.selectbox("My Data Profiles", all_profiles_list, index=profile_index) else: current_profile = st.selectbox("My Data Profiles", ProfileManagement.get_all_profiles(), - index=None, - placeholder="Please select data profile...", key='current_profile_name') + index=None, + placeholder="Please select data profile...", key='current_profile_name') - tab_view, tab_add, tab_search, batch_insert = st.tabs(['View Samples', 'Add New Sample', 'Sample Search', 'Batch Insert Samples']) + tab_view, tab_add, tab_search, batch_insert = st.tabs( + ['View Samples', 'Add New Sample', 'Sample Search', 'Batch Insert Samples']) if current_profile is not None: + st.session_state['current_profile'] = current_profile with tab_view: if current_profile is not None: st.write("The display page can show a maximum of 5000 pieces of data") for sample in VectorStore.get_all_samples(current_profile): with st.expander(sample['text']): st.code(sample['sql']) - st.button('Delete ' + sample['id'], on_click=delete_sample, args=[current_profile, sample['id']]) + st.button('Delete ' + sample['id'], on_click=delete_sample, + args=[current_profile, sample['id']]) with tab_add: if current_profile is not None: @@ -90,7 +95,8 @@ def main(): retrieve_number = st.slider("Question Retrieve Number", 0, 100, 10) if st.button('Search', type='primary'): if len(entity_search) > 0: - search_sample_result = VectorStore.search_sample(current_profile, retrieve_number, opensearch_info['sql_index'], + search_sample_result = VectorStore.search_sample(current_profile, retrieve_number, + opensearch_info['sql_index'], entity_search) for sample in search_sample_result: sample_res = {'Score': sample['_score'], @@ -104,25 +110,28 @@ def main(): st.write("This page support CSV or Excel files batch insert sql samples.") st.write("**The Column Name need contain 'question' and 'sql'**") uploaded_files = st.file_uploader("Choose CSV or Excel files", accept_multiple_files=True, - type=['csv', 'xls', 'xlsx']) + type=['csv', 'xls', 'xlsx']) if uploaded_files: - progress_bar = st.progress(0) - status_text = st.empty() for i, uploaded_file in enumerate(uploaded_files): + status_text = st.empty() status_text.text(f"Processing file {i + 1} of {len(uploaded_files)}: {uploaded_file.name}") each_upload_data = read_file(uploaded_file) if each_upload_data is not None: - for index, item in each_upload_data.iterrows(): - question = str(item["question"]) - sql = str(item["sql"]) - VectorStore.add_sample(current_profile, question, sql) - progress_bar.progress((i + 1) / len(uploaded_files)) - + total_rows = len(each_upload_data) + progress_bar = st.progress(0) + progress_text = "batch insert {} entity in progress. Please wait.".format( + uploaded_file.name) + for j, item in enumerate(each_upload_data.itertuples(), 1): + question = str(item.question) + sql = str(item.sql) + VectorStore.add_entity_sample(current_profile, question, sql) + progress = (j * 1.0) / total_rows + progress_bar.progress(progress, text=progress_text) + progress_bar.empty() st.success("{uploaded_file} uploaded successfully!".format(uploaded_file=uploaded_file.name)) - progress_bar.empty() else: st.info('Please select data profile in the left sidebar.') - + if __name__ == '__main__': - main() + main() \ No newline at end of file diff --git "a/application/pages/7_\360\237\223\232_Entity_Management.py" "b/application/pages/7_\360\237\223\232_Entity_Management.py" index f6fc0e4..db1eef5 100644 --- "a/application/pages/7_\360\237\223\232_Entity_Management.py" +++ "b/application/pages/7_\360\237\223\232_Entity_Management.py" @@ -33,7 +33,7 @@ def read_file(uploaded_file): return None columns = list(uploaded_data.columns) if "entity" in columns and "comment" in columns: - return uploaded_data + return uploaded_data[["entity", "comment"]] else: st.error(f"The columns need contains entity and comment") return None @@ -65,6 +65,7 @@ def main(): tab_view, tab_add, tab_search, batch_insert = st.tabs( ['View Samples', 'Add New Sample', 'Sample Search', 'Batch Insert Samples']) if current_profile is not None: + st.session_state['current_profile'] = current_profile with tab_view: if current_profile is not None: st.write("The display page can show a maximum of 5000 pieces of data") @@ -114,23 +115,26 @@ def main(): uploaded_files = st.file_uploader("Choose CSV or Excel files", accept_multiple_files=True, type=['csv', 'xls', 'xlsx']) if uploaded_files: - progress_bar = st.progress(0) - status_text = st.empty() for i, uploaded_file in enumerate(uploaded_files): + status_text = st.empty() status_text.text(f"Processing file {i + 1} of {len(uploaded_files)}: {uploaded_file.name}") each_upload_data = read_file(uploaded_file) if each_upload_data is not None: - for index, item in each_upload_data.iterrows(): - entity = str(item["entity"]) - comment = str(item["comment"]) + total_rows = len(each_upload_data) + progress_bar = st.progress(0) + progress_text = "batch insert {} entity in progress. Please wait.".format(uploaded_file.name) + for j, item in enumerate(each_upload_data.itertuples(), 1): + entity = str(item.entity) + comment = str(item.comment) VectorStore.add_entity_sample(current_profile, entity, comment) - progress_bar.progress((i + 1) / len(uploaded_files)) - + progress = (j * 1.0) / total_rows + progress_bar.progress(progress, text=progress_text) + progress_bar.empty() st.success("{uploaded_file} uploaded successfully!".format(uploaded_file=uploaded_file.name)) - progress_bar.empty() + else: st.info('Please select data profile in the left sidebar.') if __name__ == '__main__': - main() + main() \ No newline at end of file diff --git "a/application/pages/8_\360\237\223\232_Agent_Cot_Management.py" "b/application/pages/8_\360\237\223\232_Agent_Cot_Management.py" index 6c1631b..fa4e7ae 100644 --- "a/application/pages/8_\360\237\223\232_Agent_Cot_Management.py" +++ "b/application/pages/8_\360\237\223\232_Agent_Cot_Management.py" @@ -40,6 +40,7 @@ def main(): tab_view, tab_add, tab_search = st.tabs(['View Samples', 'Add New Sample', 'Sample Search']) if current_profile is not None: + st.session_state['current_profile'] = current_profile with tab_view: if current_profile is not None: st.write("The display page can show a maximum of 5000 pieces of data") diff --git a/application/utils/llm.py b/application/utils/llm.py index b8dc3a3..0518432 100644 --- a/application/utils/llm.py +++ b/application/utils/llm.py @@ -434,6 +434,7 @@ def select_data_visualization_type(model_id, search_box, search_data, prompt_map def data_visualization(model_id, search_box, search_data, prompt_map): + search_data = search_data.fillna("") columns = list(search_data.columns) data_list = search_data.values.tolist() all_columns_data = [columns] + data_list From b3c8d585ed054dcd5661ea4a02aa28ff65b21aaa Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 15:02:12 +0800 Subject: [PATCH 042/130] fix data_visualization nan --- application/api/service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/application/api/service.py b/application/api/service.py index 6868edb..101ba6e 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -502,6 +502,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): sql_search_result.data_analyse = "The query results are temporarily unavailable, please switch to debugging webpage to try the same query and check the log file for more information." else: if search_intent_result["data"] is not None and len(search_intent_result["data"]) > 0: + search_intent_result["data"] = search_intent_result["data"].fillna("") if answer_with_insights: await response_websocket(websocket, session_id, "Generating Data Insights", ContentEnum.STATE, "start") From ed05c63fed21e627750836667f11b0e06caeeaa9 Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 17 Jul 2024 17:41:31 +0800 Subject: [PATCH 043/130] fix log issue --- application/nlq/business/log_store.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/application/nlq/business/log_store.py b/application/nlq/business/log_store.py index 891364a..9ab480a 100644 --- a/application/nlq/business/log_store.py +++ b/application/nlq/business/log_store.py @@ -10,4 +10,5 @@ class LogManagement: @classmethod def add_log_to_database(cls, log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str): - cls.query_log_dao.add_log(log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str) + cls.query_log_dao.add_log(log_id=log_id, profile_name=profile_name, user_id=user_id, session_id=session_id, + sql=sql, query=query, intent=intent, log_info=log_info, time_str=time_str) From 2f9f0e57179edb617bcbb6d6570d5c9f5773d8d4 Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 18 Jul 2024 07:02:58 +0800 Subject: [PATCH 044/130] fix userid issue --- application/api/main.py | 11 ++++++----- application/api/service.py | 28 ++++++++++++++-------------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/application/api/main.py b/application/api/main.py index c06f71e..d785898 100644 --- a/application/api/main.py +++ b/application/api/main.py @@ -59,17 +59,18 @@ async def websocket_endpoint(websocket: WebSocket): try: while True: data = await websocket.receive_text() + question_json = json.loads(data) + question = Question(**question_json) + session_id = question.session_id + user_id = question.user_id try: - question_json = json.loads(data) - question = Question(**question_json) - session_id = question.session_id ask_result = await ask_websocket(websocket, question) logger.info(ask_result) - await response_websocket(websocket, session_id, ask_result.dict(), ContentEnum.END) + await response_websocket(websocket, session_id, ask_result.dict(), ContentEnum.END, user_id) except Exception: msg = traceback.format_exc() logger.exception(msg) - await response_websocket(websocket, session_id, msg, ContentEnum.EXCEPTION) + await response_websocket(websocket, session_id, msg, ContentEnum.EXCEPTION, user_id) except WebSocketDisconnect: logger.info(f"{websocket.client.host} disconnected.") diff --git a/application/api/service.py b/application/api/service.py index 101ba6e..0a50b25 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -409,9 +409,9 @@ async def ask_websocket(websocket: WebSocket, question: Question): entity_slot = [] if intent_ner_recognition_flag: - await response_websocket(websocket, session_id, "Query Intent Analyse", ContentEnum.STATE, "start") + await response_websocket(websocket, session_id, "Query Intent Analyse", ContentEnum.STATE, "start", user_id) intent_response = get_query_intent(model_type, search_box, prompt_map) - await response_websocket(websocket, session_id, "Query Intent Analyse", ContentEnum.STATE, "end") + await response_websocket(websocket, session_id, "Query Intent Analyse", ContentEnum.STATE, "end", user_id) intent = intent_response.get("intent", "normal_search") entity_slot = intent_response.get("slot", []) if intent == "reject_search": @@ -445,7 +445,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): normal_search_result = await normal_text_search_websocket(websocket, session_id, search_box, model_type, database_profile, entity_slot, opensearch_info, - selected_profile, use_rag_flag) + selected_profile, use_rag_flag, user_id) elif knowledge_search_flag: response = knowledge_search(search_box=search_box, model_id=model_type, prompt_map=prompt_map) @@ -491,12 +491,12 @@ async def ask_websocket(websocket: WebSocket, question: Question): else: sql_search_result.sql = "-1" - await response_websocket(websocket, session_id, "Database SQL Execution", ContentEnum.STATE, "start") + await response_websocket(websocket, session_id, "Database SQL Execution", ContentEnum.STATE, "start", user_id) search_intent_result = get_sql_result_tool(database_profile, current_nlq_chain.get_generated_sql()) - await response_websocket(websocket, session_id, "Database SQL Execution", ContentEnum.STATE, "end") + await response_websocket(websocket, session_id, "Database SQL Execution", ContentEnum.STATE, "end", user_id) if search_intent_result["status_code"] == 500: sql_search_result.data_analyse = "The query results are temporarily unavailable, please switch to debugging webpage to try the same query and check the log file for more information." @@ -505,13 +505,13 @@ async def ask_websocket(websocket: WebSocket, question: Question): search_intent_result["data"] = search_intent_result["data"].fillna("") if answer_with_insights: await response_websocket(websocket, session_id, "Generating Data Insights", ContentEnum.STATE, - "start") + "start", user_id) search_intent_analyse_result = data_analyse_tool(model_type, prompt_map, search_box, search_intent_result["data"].to_json( orient='records', force_ascii=False), "query") await response_websocket(websocket, session_id, "Generating Data Insights", ContentEnum.STATE, - "end") + "end", user_id) sql_search_result.data_analyse = search_intent_analyse_result @@ -661,7 +661,7 @@ def get_executed_result(current_nlq_chain: NLQChain) -> str: async def normal_text_search_websocket(websocket: WebSocket, session_id: str, search_box, model_type, database_profile, - entity_slot, opensearch_info, selected_profile, use_rag, + entity_slot, opensearch_info, selected_profile, use_rag,user_id, model_provider=None): entity_slot_retrieve = [] retrieve_result = [] @@ -677,21 +677,21 @@ async def normal_text_search_websocket(websocket: WebSocket, session_id: str, se database_profile['db_type'] = ConnectionManagement.get_db_type_by_name(conn_name) if len(entity_slot) > 0 and use_rag: - await response_websocket(websocket, session_id, "Entity Info Retrieval", ContentEnum.STATE, "start") + await response_websocket(websocket, session_id, "Entity Info Retrieval", ContentEnum.STATE, "start", user_id) for each_entity in entity_slot: entity_retrieve = get_retrieve_opensearch(opensearch_info, each_entity, "ner", selected_profile, 1, 0.7) if len(entity_retrieve) > 0: entity_slot_retrieve.extend(entity_retrieve) - await response_websocket(websocket, session_id, "Entity Info Retrieval", ContentEnum.STATE, "end") + await response_websocket(websocket, session_id, "Entity Info Retrieval", ContentEnum.STATE, "end", user_id) if use_rag: - await response_websocket(websocket, session_id, "QA Info Retrieval", ContentEnum.STATE, "start") + await response_websocket(websocket, session_id, "QA Info Retrieval", ContentEnum.STATE, "start", user_id) retrieve_result = get_retrieve_opensearch(opensearch_info, search_box, "query", selected_profile, 3, 0.5) - await response_websocket(websocket, session_id, "QA Info Retrieval", ContentEnum.STATE, "end") + await response_websocket(websocket, session_id, "QA Info Retrieval", ContentEnum.STATE, "end", user_id) - await response_websocket(websocket, session_id, "Generating SQL", ContentEnum.STATE, "start") + await response_websocket(websocket, session_id, "Generating SQL", ContentEnum.STATE, "start", user_id) response = text_to_sql(database_profile['tables_info'], database_profile['hints'], @@ -703,7 +703,7 @@ async def normal_text_search_websocket(websocket: WebSocket, session_id: str, se dialect=database_profile['db_type'], model_provider=model_provider) logger.info(f'{response=}') - await response_websocket(websocket, session_id, "Generating SQL", ContentEnum.STATE, "end") + await response_websocket(websocket, session_id, "Generating SQL", ContentEnum.STATE, "end", user_id) sql = get_generated_sql(response) search_result = SearchTextSqlResult(search_query=search_box, entity_slot_retrieve=entity_slot_retrieve, retrieve_result=retrieve_result, response=response, sql="") From 03afb03ee70067b66ae8af7e33dec93b6ee0d64e Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 18 Jul 2024 07:06:31 +0800 Subject: [PATCH 045/130] fix userid issue --- application/api/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/application/api/main.py b/application/api/main.py index d785898..96648c1 100644 --- a/application/api/main.py +++ b/application/api/main.py @@ -66,11 +66,11 @@ async def websocket_endpoint(websocket: WebSocket): try: ask_result = await ask_websocket(websocket, question) logger.info(ask_result) - await response_websocket(websocket, session_id, ask_result.dict(), ContentEnum.END, user_id) + await response_websocket(websocket=websocket, session_id=session_id, content=ask_result.dict(), content_type=ContentEnum.END, user_id=user_id) except Exception: msg = traceback.format_exc() logger.exception(msg) - await response_websocket(websocket, session_id, msg, ContentEnum.EXCEPTION, user_id) + await response_websocket(websocket=websocket, session_id=session_id, content=msg, content_type=ContentEnum.EXCEPTION, user_id=user_id) except WebSocketDisconnect: logger.info(f"{websocket.client.host} disconnected.") From 37358a1dbaefba4fdd70b81dcfb5ce0319b4d1b3 Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 18 Jul 2024 09:08:14 +0800 Subject: [PATCH 046/130] fix china deploy --- source/resources/lib/main-stack.ts | 243 +++++++++++++++-------------- 1 file changed, 128 insertions(+), 115 deletions(-) diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index e3619f5..057435f 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -1,127 +1,140 @@ -import { StackProps, CfnParameter, CfnOutput } from 'aws-cdk-lib'; +import {StackProps, CfnParameter, CfnOutput} from 'aws-cdk-lib'; import * as cdk from 'aws-cdk-lib'; -import { Construct } from 'constructs'; +import {Construct} from 'constructs'; import * as ec2 from 'aws-cdk-lib/aws-ec2'; -import { AOSStack } from './aos/aos-stack'; +import {AOSStack} from './aos/aos-stack'; // import { LLMStack } from './model/llm-stack'; -import { ECSStack } from './ecs/ecs-stack'; -import { CognitoStack } from './cognito/cognito-stack'; -import { RDSStack } from './rds/rds-stack'; -import { VPCStack } from './vpc/vpc-stack'; +import {ECSStack} from './ecs/ecs-stack'; +import {CognitoStack} from './cognito/cognito-stack'; +import {RDSStack} from './rds/rds-stack'; +import {VPCStack} from './vpc/vpc-stack'; interface MainStackProps extends StackProps { - deployRds?: boolean; + deployRds?: boolean; } export class MainStack extends cdk.Stack { - constructor(scope: Construct, id: string, props: MainStackProps={ deployRds: false }) { - super(scope, id, props); - - const _deployRds = props.deployRds || false; - - // ======== Step 0. Define the VPC ========= - const _VpcStack = new VPCStack(this, 'vpc-Stack', { - env: props.env, - }); - - // ======== Step 1. Define the LLMStack ========= - // const s3ModelAssetsBucket = new CfnParameter(this, "S3ModelAssetsBucket", { - // type: "String", - // description: "S3 Bucket for model & code assets", - // default: "not-set" - // }); - - // ======== Step 2. Define the AOSStack ========= - const aosSubnets = _VpcStack.vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS }); - - const _AosStack = new AOSStack(this, 'aos-Stack', { - env: props.env, - vpc: _VpcStack.vpc, - subnets: aosSubnets.subnets, - }); - - // print AOS subnet Info - console.log('AOS subnets Info:'); - aosSubnets.subnets.forEach((subnet, index) => { - console.log(`Subnet ${index + 1}:`); - console.log(` ID: ${subnet.subnetId}`); - console.log(` Availability Zone: ${subnet.availabilityZone}`); - console.log(` CIDR: ${subnet.ipv4CidrBlock}`); - }); - - // print AOS subnet length - console.log(`Total number of AOS subnets: ${aosSubnets.subnets.length}`); - - - - const aosEndpoint = _AosStack.endpoint; - - // ======== Step 3. Define the RDSStack ========= - if (_deployRds) { - const rdsSubnets = _VpcStack.vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS }); - - const _RdsStack = new RDSStack(this, 'rds-Stack', { - env: props.env, - subnets: rdsSubnets, - vpc : _VpcStack.vpc - }); - new cdk.CfnOutput(this, 'RDSEndpoint', { - value: _RdsStack.endpoint, - description: 'The endpoint of the RDS instance', - }); - } + constructor(scope: Construct, id: string, props: MainStackProps = {deployRds: false}) { + super(scope, id, props); + + const _deployRds = props.deployRds || false; + + // ======== Step 0. Define the VPC ========= + const _VpcStack = new VPCStack(this, 'vpc-Stack', { + env: props.env, + }); + + // ======== Step 1. Define the LLMStack ========= + // const s3ModelAssetsBucket = new CfnParameter(this, "S3ModelAssetsBucket", { + // type: "String", + // description: "S3 Bucket for model & code assets", + // default: "not-set" + // }); - // ======== Step 4. Define Cognito ========= - const _CognitoStack = new CognitoStack(this, 'cognito-Stack', { - env: props.env - }); - - // ======== Step 5. Define the ECS ========= - // pass the aosEndpoint and aosPassword to the ecs stack - const ecsSubnets = _VpcStack.vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS }); + // ======== Step 2. Define the AOSStack ========= + const aosSubnets = _VpcStack.vpc.selectSubnets({subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS}); + + const _AosStack = new AOSStack(this, 'aos-Stack', { + env: props.env, + vpc: _VpcStack.vpc, + subnets: aosSubnets.subnets, + }); + + // print AOS subnet Info + console.log('AOS subnets Info:'); + aosSubnets.subnets.forEach((subnet, index) => { + console.log(`Subnet ${index + 1}:`); + console.log(` ID: ${subnet.subnetId}`); + console.log(` Availability Zone: ${subnet.availabilityZone}`); + console.log(` CIDR: ${subnet.ipv4CidrBlock}`); + }); + + // print AOS subnet length + console.log(`Total number of AOS subnets: ${aosSubnets.subnets.length}`); + + + const aosEndpoint = _AosStack.endpoint; + + // ======== Step 3. Define the RDSStack ========= + if (_deployRds) { + const rdsSubnets = _VpcStack.vpc.selectSubnets({subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS}); + + const _RdsStack = new RDSStack(this, 'rds-Stack', { + env: props.env, + subnets: rdsSubnets, + vpc: _VpcStack.vpc + }); + new cdk.CfnOutput(this, 'RDSEndpoint', { + value: _RdsStack.endpoint, + description: 'The endpoint of the RDS instance', + }); + } + + // ======== Step 4. Define Cognito ========= + const isChinaRegion = env?.region === "cn-north-1" || env?.region === "cn-northwest-1"; + + let cognitoStack: CognitoStack | undefined; + if (!isChinaRegion) { + const _CognitoStack = new CognitoStack(this, 'cognito-Stack', { + env: props.env + }); + } + + + // ======== Step 5. Define the ECS ========= + // pass the aosEndpoint and aosPassword to the ecs stack + const ecsSubnets = _VpcStack.vpc.selectSubnets({subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS}); // print AOS subnet Info - console.log('ECS subnets Info:'); - ecsSubnets.subnets.forEach((subnet, index) => { - console.log(`Subnet ${index + 1}:`); - console.log(` ID: ${subnet.subnetId}`); - console.log(` Availability Zone: ${subnet.availabilityZone}`); - console.log(` CIDR: ${subnet.ipv4CidrBlock}`); - }); - - // print AOS subnet length - console.log(`Total number of ECS subnets: ${ecsSubnets.subnets.length}`); - - const _EcsStack = new ECSStack(this, 'ecs-Stack', { - env: props.env, - vpc: _VpcStack.vpc, - subnets: ecsSubnets.subnets, - cognitoUserPoolId: _CognitoStack.userPoolId, - cognitoUserPoolClientId: _CognitoStack.userPoolClientId, - OSMasterUserSecretName: _AosStack.OSMasterUserSecretName, - OSHostSecretName: _AosStack.OSHostSecretName, - }); - _AosStack.addDependency(_VpcStack); - _EcsStack.addDependency(_AosStack); - _EcsStack.addDependency(_CognitoStack); - _EcsStack.addDependency(_VpcStack); - - new cdk.CfnOutput(this, 'AOSDomainEndpoint', { - value: aosEndpoint, - description: 'The endpoint of the OpenSearch domain' - }); - - new cdk.CfnOutput(this, 'StreamlitEndpoint', { - value: _EcsStack.streamlitEndpoint, - description: 'The endpoint of the Streamlit service' - }); - new cdk.CfnOutput(this, 'FrontendEndpoint', { - value: _EcsStack.frontendEndpoint, - description: 'The endpoint of the Frontend service' - }); - new cdk.CfnOutput(this, 'APIEndpoint', { - value: _EcsStack.apiEndpoint, - description: 'The endpoint of the API service' - }); - } + console.log('ECS subnets Info:'); + ecsSubnets.subnets.forEach((subnet, index) => { + console.log(`Subnet ${index + 1}:`); + console.log(` ID: ${subnet.subnetId}`); + console.log(` Availability Zone: ${subnet.availabilityZone}`); + console.log(` CIDR: ${subnet.ipv4CidrBlock}`); + }); + + // print AOS subnet length + console.log(`Total number of ECS subnets: ${ecsSubnets.subnets.length}`); + + + const _EcsStack = new ECSStack(this, 'ecs-Stack', { + env: props.env, + vpc: _VpcStack.vpc, + subnets: ecsSubnets.subnets, + cognitoUserPoolId: _CognitoStack?.userPoolId ?? "", + cognitoUserPoolClientId: _CognitoStack?.userPoolClientId ?? "", + OSMasterUserSecretName: + _AosStack.OSMasterUserSecretName, + OSHostSecretName: + _AosStack.OSHostSecretName, + }) + ; + + + _AosStack.addDependency(_VpcStack); + _EcsStack.addDependency(_AosStack); + if (cognitoStack) { + ecsStack.addDependency(cognitoStack); + } + _EcsStack.addDependency(_VpcStack); + + new cdk.CfnOutput(this, 'AOSDomainEndpoint', { + value: aosEndpoint, + description: 'The endpoint of the OpenSearch domain' + }); + + new cdk.CfnOutput(this, 'StreamlitEndpoint', { + value: _EcsStack.streamlitEndpoint, + description: 'The endpoint of the Streamlit service' + }); + new cdk.CfnOutput(this, 'FrontendEndpoint', { + value: _EcsStack.frontendEndpoint, + description: 'The endpoint of the Frontend service' + }); + new cdk.CfnOutput(this, 'APIEndpoint', { + value: _EcsStack.apiEndpoint, + description: 'The endpoint of the API service' + }); + } } \ No newline at end of file From 9ac40c6c1508b985ee4aa2f31ee135c4ebff9640 Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 18 Jul 2024 09:22:54 +0800 Subject: [PATCH 047/130] fix china deploy --- source/resources/lib/main-stack.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index 057435f..fb46b76 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -71,9 +71,9 @@ export class MainStack extends cdk.Stack { } // ======== Step 4. Define Cognito ========= - const isChinaRegion = env?.region === "cn-north-1" || env?.region === "cn-northwest-1"; + const isChinaRegion = props.env?.region === "cn-north-1" || props.env?.region === "cn-northwest-1"; - let cognitoStack: CognitoStack | undefined; + let _CognitoStack: CognitoStack | undefined; if (!isChinaRegion) { const _CognitoStack = new CognitoStack(this, 'cognito-Stack', { env: props.env @@ -114,8 +114,8 @@ export class MainStack extends cdk.Stack { _AosStack.addDependency(_VpcStack); _EcsStack.addDependency(_AosStack); - if (cognitoStack) { - ecsStack.addDependency(cognitoStack); + if (_CognitoStack) { + _EcsStack.addDependency(_CognitoStack); } _EcsStack.addDependency(_VpcStack); From 4e0f387acfb5d3b336ae62174eddc81939cdf290 Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 18 Jul 2024 10:01:21 +0800 Subject: [PATCH 048/130] add docker aws region --- application/Dockerfile | 2 +- application/Dockerfile-api | 2 +- source/resources/lib/ecs/ecs-stack.ts | 36 ++++++++++++++++----------- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/application/Dockerfile b/application/Dockerfile index e7229a3..4479f1c 100644 --- a/application/Dockerfile +++ b/application/Dockerfile @@ -8,7 +8,7 @@ WORKDIR /app COPY requirements.txt /app/ -#ARG AWS_REGION=us-east-1 +ARG AWS_REGION ENV AWS_REGION=${AWS_REGION} # Print the AWS_REGION for verification diff --git a/application/Dockerfile-api b/application/Dockerfile-api index cb4759d..96aac5e 100644 --- a/application/Dockerfile-api +++ b/application/Dockerfile-api @@ -4,7 +4,7 @@ WORKDIR /app COPY . /app/ -#ARG AWS_REGION=us-east-1 +ARG AWS_REGION ENV AWS_REGION=${AWS_REGION} # Print the AWS_REGION for verification diff --git a/source/resources/lib/ecs/ecs-stack.ts b/source/resources/lib/ecs/ecs-stack.ts index 5d54938..3681f27 100644 --- a/source/resources/lib/ecs/ecs-stack.ts +++ b/source/resources/lib/ecs/ecs-stack.ts @@ -12,10 +12,10 @@ export class ECSStack extends cdk.Stack { public readonly streamlitEndpoint: string; public readonly frontendEndpoint: string; public readonly apiEndpoint: string; -constructor(scope: Construct, id: string, props: cdk.StackProps +constructor(scope: Construct, id: string, props: cdk.StackProps & { vpc: ec2.Vpc} - & { subnets: cdk.aws_ec2.ISubnet[] } & { cognitoUserPoolId: string} - & { cognitoUserPoolClientId: string} & {OSMasterUserSecretName: string} + & { subnets: cdk.aws_ec2.ISubnet[] } & { cognitoUserPoolId: string} + & { cognitoUserPoolClientId: string} & {OSMasterUserSecretName: string} & {OSHostSecretName: string}) { super(scope, id, props); @@ -32,16 +32,24 @@ constructor(scope: Construct, id: string, props: cdk.StackProps { name: 'genbi-frontend', dockerfile: 'Dockerfile', port: 80, dockerfileDirectory: path.join(__dirname, '../../../../report-front-end')}, ]; + const awsRegion = props.env?.region as string; + const GenBiStreamlitDockerImageAsset = {'dockerImageAsset': new DockerImageAsset(this, 'GenBiStreamlitDockerImage', { - directory: services[0].dockerfileDirectory, - file: services[0].dockerfile, + directory: services[0].dockerfileDirectory, + file: services[0].dockerfile, + buildArgs: { + AWS_REGION: awsRegion, // Pass the AWS region as a build argument + }, }), 'port': services[0].port}; - + const GenBiAPIDockerImageAsset = {'dockerImageAsset': new DockerImageAsset(this, 'GenBiAPIDockerImage', { - directory: services[1].dockerfileDirectory, - file: services[1].dockerfile, + directory: services[1].dockerfileDirectory, + file: services[1].dockerfile, + buildArgs : { + AWS_REGION: awsRegion, // Pass the AWS region as a build argument + } }), 'port': services[1].port}; - + // Create an ECS cluster const cluster = new ecs.Cluster(this, 'GenBiCluster', { vpc: props.vpc, @@ -108,7 +116,7 @@ constructor(scope: Construct, id: string, props: cdk.StackProps ] }); taskRole.addToPolicy(bedrockAccessPolicy); - } + } // Add SageMaker endpoint access policy const sageMakerEndpointAccessPolicy = new iam.PolicyStatement({ @@ -137,7 +145,7 @@ constructor(scope: Construct, id: string, props: cdk.StackProps ] }); taskRole.addToPolicy(cognitoAccessPolicy); - } + } // Create ECS services through Fargate // ======= 1. Streamlit Service ======= @@ -221,8 +229,8 @@ constructor(scope: Construct, id: string, props: cdk.StackProps // ======= 3. Frontend Service ======= const GenBiFrontendDockerImageAsset = {'dockerImageAsset': new DockerImageAsset(this, 'GenBiFrontendDockerImage', { - directory: services[2].dockerfileDirectory, - file: services[2].dockerfile, + directory: services[2].dockerfileDirectory, + file: services[2].dockerfile, }), 'port': services[2].port}; const taskDefinitionFrontend = new ecs.FargateTaskDefinition(this, 'GenBiTaskDefinitionFrontend', { @@ -231,7 +239,7 @@ constructor(scope: Construct, id: string, props: cdk.StackProps executionRole: taskExecutionRole, taskRole: taskRole }); - + const containerFrontend = taskDefinitionFrontend.addContainer('GenBiContainerFrontend', { image: ecs.ContainerImage.fromDockerImageAsset(GenBiFrontendDockerImageAsset.dockerImageAsset), memoryLimitMiB: 512, From b160ff330d652993514c86b3dd9d30bee50a3800 Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 18 Jul 2024 10:36:27 +0800 Subject: [PATCH 049/130] add docker region --- report-front-end/Dockerfile | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/report-front-end/Dockerfile b/report-front-end/Dockerfile index 5a1b763..169b149 100644 --- a/report-front-end/Dockerfile +++ b/report-front-end/Dockerfile @@ -1,7 +1,17 @@ FROM public.ecr.aws/docker/library/node:18.17.0 AS builder WORKDIR /frontend COPY package*.json ./ -RUN npm install + +ARG AWS_REGION + +RUN if [ "$AWS_REGION" = "cn-north-1" ] || [ "$AWS_REGION" = "cn-northwest-1" ]; then \ + npm config set registry https://registry.npmmirror.com && \ + npm install; \ + else \ + npm install; \ + fi + + COPY . . RUN npm run build COPY .env /frontend/.env From 906559b3295e57e52e4215c462d0f1b7e0c410f7 Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 18 Jul 2024 10:41:26 +0800 Subject: [PATCH 050/130] fix china cdk deploy npm source --- report-front-end/Dockerfile | 2 ++ source/resources/lib/ecs/ecs-stack.ts | 3 +++ 2 files changed, 5 insertions(+) diff --git a/report-front-end/Dockerfile b/report-front-end/Dockerfile index 169b149..0ddb0fc 100644 --- a/report-front-end/Dockerfile +++ b/report-front-end/Dockerfile @@ -4,6 +4,8 @@ COPY package*.json ./ ARG AWS_REGION +RUN echo "Current AWS Region: $AWS_REGION" + RUN if [ "$AWS_REGION" = "cn-north-1" ] || [ "$AWS_REGION" = "cn-northwest-1" ]; then \ npm config set registry https://registry.npmmirror.com && \ npm install; \ diff --git a/source/resources/lib/ecs/ecs-stack.ts b/source/resources/lib/ecs/ecs-stack.ts index 3681f27..cb9f5df 100644 --- a/source/resources/lib/ecs/ecs-stack.ts +++ b/source/resources/lib/ecs/ecs-stack.ts @@ -231,6 +231,9 @@ constructor(scope: Construct, id: string, props: cdk.StackProps const GenBiFrontendDockerImageAsset = {'dockerImageAsset': new DockerImageAsset(this, 'GenBiFrontendDockerImage', { directory: services[2].dockerfileDirectory, file: services[2].dockerfile, + buildArgs : { + AWS_REGION: awsRegion, // Pass the AWS region as a build argument + } }), 'port': services[2].port}; const taskDefinitionFrontend = new ecs.FargateTaskDefinition(this, 'GenBiTaskDefinitionFrontend', { From 41e5a328f091cebea9e32cfedf481f374d89d783 Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 18 Jul 2024 15:27:58 +0800 Subject: [PATCH 051/130] add replay ask --- ...0\237\214\215_Generative_BI_Playground.py" | 72 +++++++++++-------- application/utils/llm.py | 9 ++- 2 files changed, 48 insertions(+), 33 deletions(-) diff --git "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" index ef72b79..438cc61 100644 --- "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" +++ "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" @@ -53,6 +53,7 @@ def upvote_agent_clicked(question, comment): def clean_st_history(selected_profile): st.session_state.messages[selected_profile] = [] + st.session_state.query_rewrite_history[selected_profile] = {} def get_user_history(selected_profile: str): @@ -61,12 +62,10 @@ def get_user_history(selected_profile: str): :param selected_profile: :return: history for selected profile list type """ - history_list = st.session_state.messages[selected_profile] + history_list = st.session_state.query_rewrite_history[selected_profile] history_query = [] for messages in history_list: - current_role = messages["role"] - if current_role == "user": - history_query.append(messages["content"]) + history_query.append(messages["role"] + ":" + messages["content"]) return history_query @@ -254,6 +253,9 @@ def main(): if "messages" not in st.session_state: st.session_state.messages = {} + if "query_rewrite_history" not in st.session_state: + st.session_state.query_rewrite_history = {} + if "current_sql_result" not in st.session_state: st.session_state.current_sql_result = {} @@ -365,6 +367,8 @@ def main(): current_nlq_chain.set_question(search_box) st.session_state.messages[selected_profile].append( {"role": "user", "content": search_box, "type": "text"}) + st.session_state.query_rewrite_history[selected_profile].append( + {"role": "user", "content": search_box}) st.markdown(current_nlq_chain.get_question()) with st.chat_message("assistant"): filter_deep_dive_sql_result = [] @@ -393,16 +397,26 @@ def main(): # Multiple rounds of dialogue, query rewriting user_query_history = get_user_history(selected_profile) - if len(user_query_history) > 0 and context_window > 0: + query_rewrite_result = {"original_problem": search_box} + if context_window > 0: with st.status("Query Context Understanding") as status_text: - user_query_history = user_query_history[-context_window:] - logger.info("The Chat history is {history}".format(history=",".join(user_query_history))) - new_search_box = get_query_rewrite(model_type, search_box, prompt_map, user_query_history) - logger.info("The Origin query is {query} query rewrite is {new_query}".format(query=search_box, - new_query=new_search_box)) - search_box = new_search_box + context_window_select = context_window * 2 + user_query_history = user_query_history[-context_window_select:] + logger.info("The Chat history is {history}".format(history="\n".join(user_query_history))) + query_rewrite_result = get_query_rewrite(model_type, search_box, prompt_map, user_query_history) + logger.info("The query_rewrite_result is {query_rewrite_result}".format( + query_rewrite_result=search_box)) + search_box = query_rewrite_result.get("query") + st.session_state.query_rewrite_history[selected_profile].append( + {"role": "assistant", "content": search_box}) + st.session_state.messages[selected_profile].append( + {"role": "assistant", "content": search_box, "type": "text"}) st.write(search_box) status_text.update(label=f"Query Context Rewrite Completed", state="complete", expanded=False) + + if "ask_in_reply" in query_rewrite_result: + return + intent_response = { "intent": "normal_search", "slot": [] @@ -504,39 +518,41 @@ def main(): if search_intent_result["status_code"] == 500: with st.expander("The SQL Error Info"): st.markdown(search_intent_result["error_info"]) - + if auto_correction_flag: with st.status("Regenerating SQL") as status_text: response = text_to_sql(database_profile['tables_info'], - database_profile['hints'], - database_profile['prompt_map'], - search_box, - model_id=model_type, - sql_examples=normal_search_result.retrieve_result, - ner_example=normal_search_result.entity_slot_retrieve, - dialect=database_profile['db_type'], - model_provider=None, - additional_info='''\n NOTE: when I try to write a SQL {sql_statement}, I got an error {error}. Please consider and avoid this problem. '''.format(sql_statement=current_nlq_chain.get_generated_sql(), error=search_intent_result["error_info"])) + database_profile['hints'], + database_profile['prompt_map'], + search_box, + model_id=model_type, + sql_examples=normal_search_result.retrieve_result, + ner_example=normal_search_result.entity_slot_retrieve, + dialect=database_profile['db_type'], + model_provider=None, + additional_info='''\n NOTE: when I try to write a SQL {sql_statement}, I got an error {error}. Please consider and avoid this problem. '''.format( + sql_statement=current_nlq_chain.get_generated_sql(), + error=search_intent_result["error_info"])) regen_sql = get_generated_sql(response) st.code(regen_sql, language="sql") - + status_text.update( label=f"Generating SQL Done", state="complete", expanded=True) - + with st.spinner('Executing query...'): search_intent_result = get_sql_result_tool( - st.session_state['profiles'][current_nlq_chain.profile], - regen_sql) - + st.session_state['profiles'][current_nlq_chain.profile], + regen_sql) + if search_intent_result["status_code"] == 500: with st.expander("The SQL Error Info"): st.markdown(search_intent_result["error_info"]) - + if search_intent_result["status_code"] != 500: - # else: + # else: if search_intent_result["data"] is not None and len( search_intent_result["data"]) > 0 and data_with_analyse: with st.spinner('Generating data summarize...'): diff --git a/application/utils/llm.py b/application/utils/llm.py index 0518432..28eeaaf 100644 --- a/application/utils/llm.py +++ b/application/utils/llm.py @@ -381,10 +381,8 @@ def get_query_intent(model_id, search_box, prompt_map): def get_query_rewrite(model_id, search_box, prompt_map, chat_history): - query_rewrite = {"query_rewrite": search_box} - history_query = "" - for item in chat_history: - history_query = history_query + "user : " + item + "\n" + query_rewrite = {"original_problem": search_box} + history_query = "\n".join(chat_history) try: intent_endpoint = os.getenv("SAGEMAKER_ENDPOINT_INTENT") if intent_endpoint: @@ -399,8 +397,9 @@ def get_query_rewrite(model_id, search_box, prompt_map, chat_history): user_prompt, system_prompt = generate_query_rewrite_prompt(prompt_map, search_box, model_id, history_query) max_tokens = 2048 final_response = invoke_llm_model(model_id, system_prompt, user_prompt, max_tokens, False) + query_rewrite_result = json_parse.parse(final_response) logger.info(f'{final_response=}') - return final_response + return query_rewrite_result except Exception as e: logger.error("get_query_rewrite is error:{}".format(e)) return query_rewrite From 9ecdf15b5610155c90e6ab6641b7a2000c84a446 Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 18 Jul 2024 16:23:06 +0800 Subject: [PATCH 052/130] add replay ask --- ...0\237\214\215_Generative_BI_Playground.py" | 443 +++++++++--------- 1 file changed, 225 insertions(+), 218 deletions(-) diff --git "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" index 438cc61..3131baa 100644 --- "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" +++ "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" @@ -53,7 +53,7 @@ def upvote_agent_clicked(question, comment): def clean_st_history(selected_profile): st.session_state.messages[selected_profile] = [] - st.session_state.query_rewrite_history[selected_profile] = {} + st.session_state.query_rewrite_history[selected_profile] = [] def get_user_history(selected_profile: str): @@ -238,6 +238,9 @@ def main(): if 'selected_sample' not in st.session_state: st.session_state['selected_sample'] = '' + if 'ask_replay' not in st.session_state: + st.session_state.ask_replay = False + if 'current_profile' not in st.session_state: st.session_state['current_profile'] = '' @@ -288,6 +291,8 @@ def main(): st.session_state.current_profile = selected_profile if selected_profile not in st.session_state.messages: st.session_state.messages[selected_profile] = [] + if selected_profile not in st.session_state.query_rewrite_history: + st.session_state.query_rewrite_history[selected_profile] = [] st.session_state.nlq_chain = NLQChain(selected_profile) if st.session_state.current_model_id != "" and st.session_state.current_model_id in model_ids: @@ -354,6 +359,7 @@ def main(): search_box = st.session_state['selected_sample'] st.session_state['selected_sample'] = "" + st.session_state.ask_replay = False reject_intent_flag = False search_intent_flag = False agent_intent_flag = False @@ -405,7 +411,7 @@ def main(): logger.info("The Chat history is {history}".format(history="\n".join(user_query_history))) query_rewrite_result = get_query_rewrite(model_type, search_box, prompt_map, user_query_history) logger.info("The query_rewrite_result is {query_rewrite_result}".format( - query_rewrite_result=search_box)) + query_rewrite_result=query_rewrite_result)) search_box = query_rewrite_result.get("query") st.session_state.query_rewrite_history[selected_profile].append( {"role": "assistant", "content": search_box}) @@ -415,229 +421,230 @@ def main(): status_text.update(label=f"Query Context Rewrite Completed", state="complete", expanded=False) if "ask_in_reply" in query_rewrite_result: - return - - intent_response = { - "intent": "normal_search", - "slot": [] - } - - if intent_ner_recognition_flag: - with st.status("Performing intent recognition...") as status_text: - intent_response = get_query_intent(model_type, search_box, prompt_map) - intent = intent_response.get("intent", "normal_search") - entity_slot = intent_response.get("slot", []) - st.write(intent_response) - status_text.update(label=f"Intent Recognition Completed: This is a **{intent}** question", - state="complete", expanded=False) - if intent == "reject_search": - reject_intent_flag = True - search_intent_flag = False - elif intent == "agent_search": - agent_intent_flag = True - if agent_cot_flag: + st.session_state.ask_replay = True + + if not st.session_state.ask_replay: + intent_response = { + "intent": "normal_search", + "slot": [] + } + + if intent_ner_recognition_flag: + with st.status("Performing intent recognition...") as status_text: + intent_response = get_query_intent(model_type, search_box, prompt_map) + intent = intent_response.get("intent", "normal_search") + entity_slot = intent_response.get("slot", []) + st.write(intent_response) + status_text.update(label=f"Intent Recognition Completed: This is a **{intent}** question", + state="complete", expanded=False) + if intent == "reject_search": + reject_intent_flag = True + search_intent_flag = False + elif intent == "agent_search": + agent_intent_flag = True + if agent_cot_flag: + search_intent_flag = False + else: + search_intent_flag = True + agent_intent_flag = False + elif intent == "knowledge_search": + knowledge_search_flag = True search_intent_flag = False + agent_intent_flag = False else: search_intent_flag = True - agent_intent_flag = False - elif intent == "knowledge_search": - knowledge_search_flag = True - search_intent_flag = False - agent_intent_flag = False + else: + search_intent_flag = True + + if reject_intent_flag: + st.write("Your query statement is currently not supported by the system") + + elif search_intent_flag: + normal_search_result = normal_text_search_streamlit(search_box, model_type, + database_profile, + entity_slot, opensearch_info, + selected_profile, + explain_gen_process_flag, use_rag_flag) + elif knowledge_search_flag: + with st.spinner('Performing knowledge search...'): + response = knowledge_search(search_box=search_box, model_id=model_type, + prompt_map=prompt_map) + logger.info(f'got llm response for knowledge_search: {response}') + st.markdown(f'This is a knowledge search question.\n{response}') + + elif agent_intent_flag: + with st.spinner('Analysis Of Complex Problems'): + agent_cot_retrieve = get_retrieve_opensearch(opensearch_info, search_box, "agent", + selected_profile, 2, 0.5) + agent_cot_task_result = get_agent_cot_task(model_type, prompt_map, search_box, + database_profile['tables_info'], + agent_cot_retrieve) + with st.expander(f'Agent Query Retrieve : {len(agent_cot_retrieve)}'): + agent_examples = [] + for example in agent_cot_retrieve: + agent_examples.append({'Score': example['_score'], + 'Question': example['_source']['query'], + 'Answer': example['_source']['comment'].strip()}) + st.write(agent_examples) + with st.expander(f'Agent Task : {len(agent_cot_task_result)}'): + st.write(agent_cot_task_result) + + with st.spinner('Generate SQL For Multiple Sub Problems'): + agent_search_result = agent_text_search(search_box, model_type, + database_profile, + entity_slot, opensearch_info, + selected_profile, use_rag_flag, agent_cot_task_result) + else: + st.error("Intent recognition error") + + if search_intent_flag: + if normal_search_result.sql != "": + current_nlq_chain.set_generated_sql(normal_search_result.sql) + + current_nlq_chain.set_generated_sql_response(normal_search_result.response) + + if explain_gen_process_flag: + with st.status("Generating explanations...") as status_text: + st.markdown(current_nlq_chain.get_generated_sql_explain()) + status_text.update( + label=f"Generating explanations Done", + state="complete", expanded=False) + st.session_state.messages[selected_profile].append( + {"role": "assistant", "content": "SQL:" + normal_search_result.sql, "type": "sql"}) else: - search_intent_flag = True - else: - search_intent_flag = True - - if reject_intent_flag: - st.write("Your query statement is currently not supported by the system") - - elif search_intent_flag: - normal_search_result = normal_text_search_streamlit(search_box, model_type, - database_profile, - entity_slot, opensearch_info, - selected_profile, - explain_gen_process_flag, use_rag_flag) - elif knowledge_search_flag: - with st.spinner('Performing knowledge search...'): - response = knowledge_search(search_box=search_box, model_id=model_type, - prompt_map=prompt_map) - logger.info(f'got llm response for knowledge_search: {response}') - st.markdown(f'This is a knowledge search question.\n{response}') - - elif agent_intent_flag: - with st.spinner('Analysis Of Complex Problems'): - agent_cot_retrieve = get_retrieve_opensearch(opensearch_info, search_box, "agent", - selected_profile, 2, 0.5) - agent_cot_task_result = get_agent_cot_task(model_type, prompt_map, search_box, - database_profile['tables_info'], - agent_cot_retrieve) - with st.expander(f'Agent Query Retrieve : {len(agent_cot_retrieve)}'): - agent_examples = [] - for example in agent_cot_retrieve: - agent_examples.append({'Score': example['_score'], - 'Question': example['_source']['query'], - 'Answer': example['_source']['comment'].strip()}) - st.write(agent_examples) - with st.expander(f'Agent Task : {len(agent_cot_task_result)}'): - st.write(agent_cot_task_result) - - with st.spinner('Generate SQL For Multiple Sub Problems'): - agent_search_result = agent_text_search(search_box, model_type, - database_profile, - entity_slot, opensearch_info, - selected_profile, use_rag_flag, agent_cot_task_result) - else: - st.error("Intent recognition error") - - if search_intent_flag: - if normal_search_result.sql != "": - current_nlq_chain.set_generated_sql(normal_search_result.sql) - - current_nlq_chain.set_generated_sql_response(normal_search_result.response) - - if explain_gen_process_flag: - with st.status("Generating explanations...") as status_text: - st.markdown(current_nlq_chain.get_generated_sql_explain()) - status_text.update( - label=f"Generating explanations Done", - state="complete", expanded=False) + st.write("Unable to generate SQL at the moment, please provide more information") + elif agent_intent_flag: + with st.expander(f'Agent Task Result: {len(agent_search_result)}'): + st.write(agent_search_result) + + if search_intent_flag: + with st.spinner('Executing query...'): + search_intent_result = get_sql_result_tool( + st.session_state['profiles'][current_nlq_chain.profile], + current_nlq_chain.get_generated_sql()) + if search_intent_result["status_code"] == 500: + with st.expander("The SQL Error Info"): + st.markdown(search_intent_result["error_info"]) + + if auto_correction_flag: + with st.status("Regenerating SQL") as status_text: + response = text_to_sql(database_profile['tables_info'], + database_profile['hints'], + database_profile['prompt_map'], + search_box, + model_id=model_type, + sql_examples=normal_search_result.retrieve_result, + ner_example=normal_search_result.entity_slot_retrieve, + dialect=database_profile['db_type'], + model_provider=None, + additional_info='''\n NOTE: when I try to write a SQL {sql_statement}, I got an error {error}. Please consider and avoid this problem. '''.format( + sql_statement=current_nlq_chain.get_generated_sql(), + error=search_intent_result["error_info"])) + + regen_sql = get_generated_sql(response) + + st.code(regen_sql, language="sql") + + status_text.update( + label=f"Generating SQL Done", + state="complete", expanded=True) + + with st.spinner('Executing query...'): + search_intent_result = get_sql_result_tool( + st.session_state['profiles'][current_nlq_chain.profile], + regen_sql) + + if search_intent_result["status_code"] == 500: + with st.expander("The SQL Error Info"): + st.markdown(search_intent_result["error_info"]) + + if search_intent_result["status_code"] != 500: + # else: + if search_intent_result["data"] is not None and len( + search_intent_result["data"]) > 0 and data_with_analyse: + with st.spinner('Generating data summarize...'): + search_intent_analyse_result = data_analyse_tool(model_type, prompt_map, search_box, + search_intent_result["data"].to_json( + orient='records', + force_ascii=False), "query") + st.markdown(search_intent_analyse_result) + st.session_state.messages[selected_profile].append( + {"role": "assistant", "content": search_intent_analyse_result, "type": "text"}) + st.session_state.current_sql_result[selected_profile] = search_intent_result["data"] + + elif agent_intent_flag: + for i in range(len(agent_search_result)): + each_task_res = get_sql_result_tool( + st.session_state['profiles'][current_nlq_chain.profile], + agent_search_result[i]["sql"]) + if each_task_res["status_code"] == 200 and len(each_task_res["data"]) > 0: + agent_search_result[i]["data_result"] = each_task_res["data"].to_json( + orient='records') + filter_deep_dive_sql_result.append(agent_search_result[i]) + + agent_data_analyse_result = data_analyse_tool(model_type, prompt_map, search_box, + json.dumps(filter_deep_dive_sql_result, + ensure_ascii=False), "agent") + logger.info("agent_data_analyse_result") + logger.info(agent_data_analyse_result) st.session_state.messages[selected_profile].append( - {"role": "assistant", "content": "SQL:" + normal_search_result.sql, "type": "sql"}) - else: - st.write("Unable to generate SQL at the moment, please provide more information") - elif agent_intent_flag: - with st.expander(f'Agent Task Result: {len(agent_search_result)}'): - st.write(agent_search_result) - - if search_intent_flag: - with st.spinner('Executing query...'): - search_intent_result = get_sql_result_tool( - st.session_state['profiles'][current_nlq_chain.profile], - current_nlq_chain.get_generated_sql()) - if search_intent_result["status_code"] == 500: - with st.expander("The SQL Error Info"): - st.markdown(search_intent_result["error_info"]) - - if auto_correction_flag: - with st.status("Regenerating SQL") as status_text: - response = text_to_sql(database_profile['tables_info'], - database_profile['hints'], - database_profile['prompt_map'], - search_box, - model_id=model_type, - sql_examples=normal_search_result.retrieve_result, - ner_example=normal_search_result.entity_slot_retrieve, - dialect=database_profile['db_type'], - model_provider=None, - additional_info='''\n NOTE: when I try to write a SQL {sql_statement}, I got an error {error}. Please consider and avoid this problem. '''.format( - sql_statement=current_nlq_chain.get_generated_sql(), - error=search_intent_result["error_info"])) - - regen_sql = get_generated_sql(response) - - st.code(regen_sql, language="sql") - - status_text.update( - label=f"Generating SQL Done", - state="complete", expanded=True) - - with st.spinner('Executing query...'): - search_intent_result = get_sql_result_tool( - st.session_state['profiles'][current_nlq_chain.profile], - regen_sql) - - if search_intent_result["status_code"] == 500: - with st.expander("The SQL Error Info"): - st.markdown(search_intent_result["error_info"]) - - if search_intent_result["status_code"] != 500: - # else: - if search_intent_result["data"] is not None and len( - search_intent_result["data"]) > 0 and data_with_analyse: - with st.spinner('Generating data summarize...'): - search_intent_analyse_result = data_analyse_tool(model_type, prompt_map, search_box, - search_intent_result["data"].to_json( - orient='records', - force_ascii=False), "query") - st.markdown(search_intent_analyse_result) - st.session_state.messages[selected_profile].append( - {"role": "assistant", "content": search_intent_analyse_result, "type": "text"}) - st.session_state.current_sql_result[selected_profile] = search_intent_result["data"] - - elif agent_intent_flag: - for i in range(len(agent_search_result)): - each_task_res = get_sql_result_tool( - st.session_state['profiles'][current_nlq_chain.profile], - agent_search_result[i]["sql"]) - if each_task_res["status_code"] == 200 and len(each_task_res["data"]) > 0: - agent_search_result[i]["data_result"] = each_task_res["data"].to_json( - orient='records') - filter_deep_dive_sql_result.append(agent_search_result[i]) - - agent_data_analyse_result = data_analyse_tool(model_type, prompt_map, search_box, - json.dumps(filter_deep_dive_sql_result, - ensure_ascii=False), "agent") - logger.info("agent_data_analyse_result") - logger.info(agent_data_analyse_result) - st.session_state.messages[selected_profile].append( - {"role": "user", "content": search_box, "type": "text"}) - for i in range(len(filter_deep_dive_sql_result)): - st.write(filter_deep_dive_sql_result[i]["query"]) - st.dataframe(pd.read_json(filter_deep_dive_sql_result[i]["data_result"], - orient='records'), hide_index=True) - - st.session_state.messages[selected_profile].append( - {"role": "assistant", "content": filter_deep_dive_sql_result, "type": "pandas"}) - - st.markdown(agent_data_analyse_result) - current_nlq_chain.set_generated_sql_response(agent_data_analyse_result) - st.session_state.messages[selected_profile].append( - {"role": "assistant", "content": agent_data_analyse_result, "type": "text"}) - - st.markdown('You can provide feedback:') - - # add a upvote(green)/downvote button with logo - feedback = st.columns(2) - feedback[0].button('👍 Upvote (save as embedding for retrieval)', type='secondary', - use_container_width=True, - on_click=upvote_agent_clicked, - args=[current_nlq_chain.get_question(), - agent_cot_task_result]) - - if feedback[1].button('👎 Downvote', type='secondary', use_container_width=True): - # do something here - pass - - if visualize_results_flag and search_intent_flag: - current_search_sql_result = st.session_state.current_sql_result[selected_profile] - if current_search_sql_result is not None and len(current_search_sql_result) > 0: + {"role": "user", "content": search_box, "type": "text"}) + for i in range(len(filter_deep_dive_sql_result)): + st.write(filter_deep_dive_sql_result[i]["query"]) + st.dataframe(pd.read_json(filter_deep_dive_sql_result[i]["data_result"], + orient='records'), hide_index=True) + st.session_state.messages[selected_profile].append( - {"role": "assistant", "content": current_search_sql_result, "type": "pandas"}) + {"role": "assistant", "content": filter_deep_dive_sql_result, "type": "pandas"}) - do_visualize_results(current_nlq_chain, st.session_state.current_sql_result[selected_profile]) - else: - st.markdown("No relevant data found") - - if gen_suggested_question_flag and (search_intent_flag or agent_intent_flag): - st.markdown('You might want to further ask:') - with st.spinner('Generating suggested questions...'): - generated_sq = generate_suggested_question(prompt_map, search_box, model_id=model_type) - split_strings = generated_sq.split("[generate]") - gen_sq_list = [s.strip() for s in split_strings if s.strip()] - sq_result = st.columns(3) - sq_result[0].button(gen_sq_list[0], type='secondary', - use_container_width=True, - on_click=sample_question_clicked, - args=[gen_sq_list[0]]) - sq_result[1].button(gen_sq_list[1], type='secondary', - use_container_width=True, - on_click=sample_question_clicked, - args=[gen_sq_list[1]]) - sq_result[2].button(gen_sq_list[2], type='secondary', - use_container_width=True, - on_click=sample_question_clicked, - args=[gen_sq_list[2]]) + st.markdown(agent_data_analyse_result) + current_nlq_chain.set_generated_sql_response(agent_data_analyse_result) + st.session_state.messages[selected_profile].append( + {"role": "assistant", "content": agent_data_analyse_result, "type": "text"}) + + st.markdown('You can provide feedback:') + + # add a upvote(green)/downvote button with logo + feedback = st.columns(2) + feedback[0].button('👍 Upvote (save as embedding for retrieval)', type='secondary', + use_container_width=True, + on_click=upvote_agent_clicked, + args=[current_nlq_chain.get_question(), + agent_cot_task_result]) + + if feedback[1].button('👎 Downvote', type='secondary', use_container_width=True): + # do something here + pass + + if visualize_results_flag and search_intent_flag: + current_search_sql_result = st.session_state.current_sql_result[selected_profile] + if current_search_sql_result is not None and len(current_search_sql_result) > 0: + st.session_state.messages[selected_profile].append( + {"role": "assistant", "content": current_search_sql_result, "type": "pandas"}) + + do_visualize_results(current_nlq_chain, st.session_state.current_sql_result[selected_profile]) + else: + st.markdown("No relevant data found") + + if gen_suggested_question_flag and (search_intent_flag or agent_intent_flag): + st.markdown('You might want to further ask:') + with st.spinner('Generating suggested questions...'): + generated_sq = generate_suggested_question(prompt_map, search_box, model_id=model_type) + split_strings = generated_sq.split("[generate]") + gen_sq_list = [s.strip() for s in split_strings if s.strip()] + sq_result = st.columns(3) + sq_result[0].button(gen_sq_list[0], type='secondary', + use_container_width=True, + on_click=sample_question_clicked, + args=[gen_sq_list[0]]) + sq_result[1].button(gen_sq_list[1], type='secondary', + use_container_width=True, + on_click=sample_question_clicked, + args=[gen_sq_list[1]]) + sq_result[2].button(gen_sq_list[2], type='secondary', + use_container_width=True, + on_click=sample_question_clicked, + args=[gen_sq_list[2]]) else: if current_nlq_chain.is_visualization_config_changed(): From e59febf972efb56746bcbb3395f2d544abd79e20 Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 18 Jul 2024 16:43:18 +0800 Subject: [PATCH 053/130] add replay ask --- .../pages/1_\360\237\214\215_Generative_BI_Playground.py" | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" index 3131baa..545df96 100644 --- "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" +++ "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" @@ -403,7 +403,7 @@ def main(): # Multiple rounds of dialogue, query rewriting user_query_history = get_user_history(selected_profile) - query_rewrite_result = {"original_problem": search_box} + query_rewrite_result = {"intent" : "original_problem", "query" :search_box} if context_window > 0: with st.status("Query Context Understanding") as status_text: context_window_select = context_window * 2 @@ -420,7 +420,9 @@ def main(): st.write(search_box) status_text.update(label=f"Query Context Rewrite Completed", state="complete", expanded=False) - if "ask_in_reply" in query_rewrite_result: + query_rewrite_intent = query_rewrite_result.get("intent") + if "ask_in_reply" == query_rewrite_intent: + st.write(query_rewrite_result.get("query")) st.session_state.ask_replay = True if not st.session_state.ask_replay: From 081cdd6f83e4890471de4fe80304fc1c835bebf8 Mon Sep 17 00:00:00 2001 From: supinyu Date: Thu, 18 Jul 2024 17:14:36 +0800 Subject: [PATCH 054/130] update prompt --- application/utils/prompts/generate_prompt.py | 375 +++++++++++++++++-- 1 file changed, 335 insertions(+), 40 deletions(-) diff --git a/application/utils/prompts/generate_prompt.py b/application/utils/prompts/generate_prompt.py index 9c46fa9..7f17737 100644 --- a/application/utils/prompts/generate_prompt.py +++ b/application/utils/prompts/generate_prompt.py @@ -101,95 +101,390 @@ } query_rewrite_system_prompt_dict['mixtral-8x7b-instruct-0'] = """ -You are a data product manager experienced in data requirements, and you need its user's historical chat query and please try understanding the semantics and rewrite a question. +You are an experienced data product manager specializing in data requirements. Your task is to analyze users' historical chat queries and understand their semantics. + +You have three possible actions. You must select one of the following intents: + + +- original_problem: If the current question has no semantic relationship with the previous conversation, input the current question directly without rewriting it. +- ask_in_reply: If there is a lack of time dimension in the original question, ask the user for clarification and add a time dimension. +- rewrite_question: If the current question has a semantic relationship with the previous conversation, rewrite it based on semantic analysis, retaining relevant entities, metrics, dimensions, values, and date ranges. + + +Guidelines for this task: + -- If the current question has no semantic relationship with the previous conversation, input the current question directly without rewriting it. -- Based on semantic analysis, keep relevant entities, metrics, dimensions, values and date ranges. -- The output language will be consistent with the language of the question. +- The output language should be consistent with the language of the question. +- Only output a JSON structure, where the keys are "intent" and "query". + +Examples will follow, where in the chat history, "User" represents the user's question, and "Assistant" represents the chatbot's answer. + + + + +The Chat history is : +user: 上个月欧洲希尔顿酒店的销量是多少 +assistant: 查询上个月欧洲希尔顿酒店的销量 +user: 亚洲呢 +assistant: 查询上个月亚洲希尔顿酒店的销量 +user: 上上个月呢 + +answer: + +{ + "intent" : "rewrite_question", + "query": "查询上上个月亚洲希尔顿酒店的销量" +} + + + +The Chat history is : +user: 上个月欧洲希尔顿酒店的销量是多少。 +assistant: 查询上个月欧洲希尔顿酒店的销量。 + +The user question is : 对比欧洲和亚洲两个的订单量 + +answer: + +{ + "intent" : "original_problem", + "query": "对比欧洲和亚洲两个的订单量" +} + + + +The user question is : 查询万豪酒店的订单量 + +answer: + +{ + "intent" : "ask_in_reply", + "query": "请问您想查询的时间范围是多少呢" +} + + + + + """ query_rewrite_system_prompt_dict['llama3-70b-instruct-0'] = """ -You are a data product manager experienced in data requirements, and you need its user's historical chat query and please try understanding the semantics and rewrite a question. +You are an experienced data product manager specializing in data requirements. Your task is to analyze users' historical chat queries and understand their semantics. + +You have three possible actions. You must select one of the following intents: + + +- original_problem: If the current question has no semantic relationship with the previous conversation, input the current question directly without rewriting it. +- ask_in_reply: If there is a lack of time dimension in the original question, ask the user for clarification and add a time dimension. +- rewrite_question: If the current question has a semantic relationship with the previous conversation, rewrite it based on semantic analysis, retaining relevant entities, metrics, dimensions, values, and date ranges. + + +Guidelines for this task: + -- If the current question has no semantic relationship with the previous conversation, input the current question directly without rewriting it. -- Based on semantic analysis, keep relevant entities, metrics, dimensions, values and date ranges. -- The output language will be consistent with the language of the question. +- The output language should be consistent with the language of the question. +- Only output a JSON structure, where the keys are "intent" and "query". + +Examples will follow, where in the chat history, "User" represents the user's question, and "Assistant" represents the chatbot's answer. + + + + +The Chat history is : +user: 上个月欧洲希尔顿酒店的销量是多少 +assistant: 查询上个月欧洲希尔顿酒店的销量 +user: 亚洲呢 +assistant: 查询上个月亚洲希尔顿酒店的销量 +user: 上上个月呢 + +answer: + +{ + "intent" : "rewrite_question", + "query": "查询上上个月亚洲希尔顿酒店的销量" +} + + + +The Chat history is : +user: 上个月欧洲希尔顿酒店的销量是多少。 +assistant: 查询上个月欧洲希尔顿酒店的销量。 + +The user question is : 对比欧洲和亚洲两个的订单量 + +answer: + +{ + "intent" : "original_problem", + "query": "对比欧洲和亚洲两个的订单量" +} + + + +The user question is : 查询万豪酒店的订单量 + +answer: + +{ + "intent" : "ask_in_reply", + "query": "请问您想查询的时间范围是多少呢" +} + + + + + """ query_rewrite_system_prompt_dict['haiku-20240307v1-0'] = """ -You are a data product manager experienced in data requirements, and you need its user's historical chat query and please try understanding the semantics and rewrite a question. +You are an experienced data product manager specializing in data requirements. Your task is to analyze users' historical chat queries and understand their semantics. + +You have three possible actions. You must select one of the following intents: + + +- original_problem: If the current question has no semantic relationship with the previous conversation, input the current question directly without rewriting it. +- ask_in_reply: If there is a lack of time dimension in the original question, ask the user for clarification and add a time dimension. +- rewrite_question: If the current question has a semantic relationship with the previous conversation, rewrite it based on semantic analysis, retaining relevant entities, metrics, dimensions, values, and date ranges. + + +Guidelines for this task: + -- If the current question has no semantic relationship with the previous conversation, input the current question directly without rewriting it. -- Based on semantic analysis, keep relevant entities, metrics, dimensions, values and date ranges. -- The output language will be consistent with the language of the question. +- The output language should be consistent with the language of the question. +- Only output a JSON structure, where the keys are "intent" and "query". + +Examples will follow, where in the chat history, "User" represents the user's question, and "Assistant" represents the chatbot's answer. + + + + +The Chat history is : +user: 上个月欧洲希尔顿酒店的销量是多少 +assistant: 查询上个月欧洲希尔顿酒店的销量 +user: 亚洲呢 +assistant: 查询上个月亚洲希尔顿酒店的销量 +user: 上上个月呢 + +answer: + +{ + "intent" : "rewrite_question", + "query": "查询上上个月亚洲希尔顿酒店的销量" +} + + + +The Chat history is : +user: 上个月欧洲希尔顿酒店的销量是多少。 +assistant: 查询上个月欧洲希尔顿酒店的销量。 + +The user question is : 对比欧洲和亚洲两个的订单量 + +answer: + +{ + "intent" : "original_problem", + "query": "对比欧洲和亚洲两个的订单量" +} + + + +The user question is : 查询万豪酒店的订单量 + +answer: + +{ + "intent" : "ask_in_reply", + "query": "请问您想查询的时间范围是多少呢" +} + + + + + """ query_rewrite_system_prompt_dict['sonnet-20240229v1-0'] = """ -You are a data product manager experienced in data requirements, and you need its user's historical chat query and please try understanding the semantics and rewrite a question. +You are an experienced data product manager specializing in data requirements. Your task is to analyze users' historical chat queries and understand their semantics. + +You have three possible actions. You must select one of the following intents: + + +- original_problem: If the current question has no semantic relationship with the previous conversation, input the current question directly without rewriting it. +- ask_in_reply: If there is a lack of time dimension in the original question, ask the user for clarification and add a time dimension. +- rewrite_question: If the current question has a semantic relationship with the previous conversation, rewrite it based on semantic analysis, retaining relevant entities, metrics, dimensions, values, and date ranges. + + +Guidelines for this task: + -- If the current question has no semantic relationship with the previous conversation, input the current question directly without rewriting it. -- Based on semantic analysis, keep relevant entities, metrics, dimensions, values and date ranges. -- The output language will be consistent with the language of the question. +- The output language should be consistent with the language of the question. +- Only output a JSON structure, where the keys are "intent" and "query". + +Examples will follow, where in the chat history, "User" represents the user's question, and "Assistant" represents the chatbot's answer. + + + + +The Chat history is : +user: 上个月欧洲希尔顿酒店的销量是多少 +assistant: 查询上个月欧洲希尔顿酒店的销量 +user: 亚洲呢 +assistant: 查询上个月亚洲希尔顿酒店的销量 +user: 上上个月呢 + +answer: + +{ + "intent" : "rewrite_question", + "query": "查询上上个月亚洲希尔顿酒店的销量" +} + + + +The Chat history is : +user: 上个月欧洲希尔顿酒店的销量是多少。 +assistant: 查询上个月欧洲希尔顿酒店的销量。 + +The user question is : 对比欧洲和亚洲两个的订单量 + +answer: + +{ + "intent" : "original_problem", + "query": "对比欧洲和亚洲两个的订单量" +} + + + +The user question is : 查询万豪酒店的订单量 + +answer: + +{ + "intent" : "ask_in_reply", + "query": "请问您想查询的时间范围是多少呢" +} + + + + + """ query_rewrite_system_prompt_dict['sonnet-3-5-20240620v1-0'] = """ -You are a data product manager experienced in data requirements, and you need its user's historical chat query and please try understanding the semantics and rewrite a question. +You are an experienced data product manager specializing in data requirements. Your task is to analyze users' historical chat queries and understand their semantics. + +You have three possible actions. You must select one of the following intents: + + +- original_problem: If the current question has no semantic relationship with the previous conversation, input the current question directly without rewriting it. +- ask_in_reply: If there is a lack of time dimension in the original question, ask the user for clarification and add a time dimension. +- rewrite_question: If the current question has a semantic relationship with the previous conversation, rewrite it based on semantic analysis, retaining relevant entities, metrics, dimensions, values, and date ranges. + + +Guidelines for this task: + -- If the current question has no semantic relationship with the previous conversation, input the current question directly without rewriting it. -- Based on semantic analysis, keep relevant entities, metrics, dimensions, values and date ranges. -- The output language will be consistent with the language of the question. +- The output language should be consistent with the language of the question. +- Only output a JSON structure, where the keys are "intent" and "query". + +Examples will follow, where in the chat history, "User" represents the user's question, and "Assistant" represents the chatbot's answer. + + + + +The Chat history is : +user: 上个月欧洲希尔顿酒店的销量是多少 +assistant: 查询上个月欧洲希尔顿酒店的销量 +user: 亚洲呢 +assistant: 查询上个月亚洲希尔顿酒店的销量 +user: 上上个月呢 + +answer: + +{ + "intent" : "rewrite_question", + "query": "查询上上个月亚洲希尔顿酒店的销量" +} + + + +The Chat history is : +user: 上个月欧洲希尔顿酒店的销量是多少。 +assistant: 查询上个月欧洲希尔顿酒店的销量。 + +The user question is : 对比欧洲和亚洲两个的订单量 + +answer: + +{ + "intent" : "original_problem", + "query": "对比欧洲和亚洲两个的订单量" +} + + + +The user question is : 查询万豪酒店的订单量 + +answer: + +{ + "intent" : "ask_in_reply", + "query": "请问您想查询的时间范围是多少呢" +} + + + + + """ query_rewrite_user_prompt_dict['mixtral-8x7b-instruct-0'] = """ -Given the following conversation and a follow up question. Just output the rewritten question without explanation. -Chat History: +The Chat History: {chat_history} -Follow Up Input: {question} -The rewrite question is : +======================== +The question is : {question} """ query_rewrite_user_prompt_dict['llama3-70b-instruct-0'] = """ -Given the following conversation and a follow up question. Just output the rewritten question without explanation. -Chat History: +The Chat History: {chat_history} -Follow Up Input: {question} -The rewrite question is : +======================== +The question is : {question} """ query_rewrite_user_prompt_dict['haiku-20240307v1-0'] = """ -Given the following conversation and a follow up question. Just output the rewritten question without explanation. -Chat History: +The Chat History: {chat_history} -Follow Up Input: {question} -The rewrite question is : +======================== +The question is : {question} """ query_rewrite_user_prompt_dict['sonnet-20240229v1-0'] = """ -Given the following conversation and a follow up question. Just output the rewritten question without explanation. -Chat History: +The Chat History: {chat_history} -Follow Up Input: {question} -The rewrite question is : +======================== +The question is : {question} """ query_rewrite_user_prompt_dict['sonnet-3-5-20240620v1-0'] = """ -Given the following conversation and a follow up question. Just output the rewritten question without explanation. -Chat History: +The Chat History: {chat_history} -Follow Up Input: {question} -The rewrite question is : +======================== +The question is : {question} """ From b0b2081c0d2f2c7094b81a3d7540c7f0ffa5aaa0 Mon Sep 17 00:00:00 2001 From: supinyu Date: Fri, 19 Jul 2024 10:03:43 +0800 Subject: [PATCH 055/130] change default falg --- .../pages/1_\360\237\214\215_Generative_BI_Playground.py" | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" index 545df96..bdce367 100644 --- "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" +++ "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" @@ -308,8 +308,8 @@ def main(): explain_gen_process_flag = st.checkbox("Explain Generation Process", True) data_with_analyse = st.checkbox("Answer With Insights", False) gen_suggested_question_flag = st.checkbox("Generate Suggested Questions", False) - auto_correction_flag = st.checkbox("Auto Correcting SQL", False) - context_window = st.slider("Multiple Rounds of Context Window", 0, 10, 0) + auto_correction_flag = st.checkbox("Auto Correcting SQL", True) + context_window = st.slider("Multiple Rounds of Context Window", 0, 10, 5) clean_history = st.button("clean history", on_click=clean_st_history, args=[selected_profile]) From 34370bb2e5418faa180987ceca8edea448658ea2 Mon Sep 17 00:00:00 2001 From: supinyu Date: Fri, 19 Jul 2024 14:41:36 +0800 Subject: [PATCH 056/130] add segamaker env --- application/nlq/business/vector_store.py | 33 +++++++++++++++++++----- application/utils/env_var.py | 7 +++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/application/nlq/business/vector_store.py b/application/nlq/business/vector_store.py index 4b3830a..3a118ba 100644 --- a/application/nlq/business/vector_store.py +++ b/application/nlq/business/vector_store.py @@ -3,8 +3,10 @@ import boto3 import json from nlq.data_access.opensearch import OpenSearchDao -from utils.env_var import BEDROCK_REGION, AOS_HOST, AOS_PORT, AOS_USER, AOS_PASSWORD, opensearch_info +from utils.env_var import BEDROCK_REGION, AOS_HOST, AOS_PORT, AOS_USER, AOS_PASSWORD, opensearch_info, \ + SAGEMAKER_ENDPOINT_EMBEDDING from utils.env_var import bedrock_ak_sk_info +from utils.llm import invoke_model_sagemaker_endpoint logger = logging.getLogger(__name__) @@ -73,7 +75,10 @@ def get_all_agent_cot_samples(cls, profile_name): @classmethod def add_sample(cls, profile_name, question, answer): logger.info(f'add sample question: {question} to profile {profile_name}') - embedding = cls.create_vector_embedding_with_bedrock(question) + if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "": + embedding = cls.create_vector_embedding_with_sagemaker(question) + else: + embedding = cls.create_vector_embedding_with_bedrock(question) has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['sql_index'], embedding) if has_same_sample: logger.info(f'delete sample sample entity: {question} to profile {profile_name}') @@ -83,7 +88,10 @@ def add_sample(cls, profile_name, question, answer): @classmethod def add_entity_sample(cls, profile_name, entity, comment): logger.info(f'add sample entity: {entity} to profile {profile_name}') - embedding = cls.create_vector_embedding_with_bedrock(entity) + if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "": + embedding = cls.create_vector_embedding_with_sagemaker(entity) + else: + embedding = cls.create_vector_embedding_with_bedrock(entity) has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding) if has_same_sample: logger.info(f'delete sample sample entity: {entity} to profile {profile_name}') @@ -93,7 +101,10 @@ def add_entity_sample(cls, profile_name, entity, comment): @classmethod def add_agent_cot_sample(cls, profile_name, entity, comment): logger.info(f'add agent sample query: {entity} to profile {profile_name}') - embedding = cls.create_vector_embedding_with_bedrock(entity) + if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "": + embedding = cls.create_vector_embedding_with_sagemaker(entity) + else: + embedding = cls.create_vector_embedding_with_bedrock(entity) has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['agent_index'], embedding) if has_same_sample: logger.info(f'delete agent sample sample query: {entity} to profile {profile_name}') @@ -118,9 +129,17 @@ def create_vector_embedding_with_bedrock(cls, text): return embedding @classmethod - def create_vector_embedding_with_sagemaker(cls): - # to do - pass + def create_vector_embedding_with_sagemaker(cls, text): + try: + model_kwargs = {} + model_kwargs["batch_size"] = 12 + model_kwargs["max_length"] = 512 + model_kwargs["return_type"] = "dense" + body = json.dumps({"inputs": [text], **model_kwargs}) + embeddings = invoke_model_sagemaker_endpoint(SAGEMAKER_ENDPOINT_EMBEDDING, body) + return embeddings + except Exception as e: + logger.error(f'create_vector_embedding_with_sagemaker is error {e}') @classmethod def delete_sample(cls, profile_name, doc_id): diff --git a/application/utils/env_var.py b/application/utils/env_var.py index f51c2b2..52ce794 100644 --- a/application/utils/env_var.py +++ b/application/utils/env_var.py @@ -46,6 +46,13 @@ SAGEMAKER_ENDPOINT_EMBEDDING = os.getenv('SAGEMAKER_ENDPOINT_EMBEDDING', '') +SAGEMAKER_ENDPOINT_SQL = os.getenv('SAGEMAKER_ENDPOINT_SQL', '') + +SAGEMAKER_EMBEDDING_REGION = os.getenv('SAGEMAKER_EMBEDDING_REGION', '') + +SAGEMAKER_SQL_REGION = os.getenv('SAGEMAKER_SQL_REGION', '') + + def get_opensearch_parameter(): try: session = boto3.session.Session() From 32578ceafabb5d515d70dc615e879e5270e77e0a Mon Sep 17 00:00:00 2001 From: Zhoutong Wang Date: Fri, 19 Jul 2024 09:46:44 +0000 Subject: [PATCH 057/130] cdk change for no cognito --- report-front-end/.env | 2 +- report-front-end/.env.template | 2 +- report-front-end/docker-entry.sh | 1 + source/resources/lib/ecs/ecs-stack.ts | 2 ++ source/resources/lib/main-stack.ts | 1 + 5 files changed, 6 insertions(+), 2 deletions(-) diff --git a/report-front-end/.env b/report-front-end/.env index 4eff791..05d174d 100644 --- a/report-front-end/.env +++ b/report-front-end/.env @@ -9,7 +9,7 @@ VITE_RIGHT_LOGO= # Login configuration, e.g. Cognito | None -VITE_LOGIN_TYPE=Cognito +VITE_LOGIN_TYPE=PLACEHOLDER_VITE_LOGIN_TYPE # KEEP the placeholder values if using CDK to deploy the backend! diff --git a/report-front-end/.env.template b/report-front-end/.env.template index 487a32a..8e578b3 100644 --- a/report-front-end/.env.template +++ b/report-front-end/.env.template @@ -9,7 +9,7 @@ VITE_RIGHT_LOGO= # Login configuration, e.g. Cognito | None -VITE_LOGIN_TYPE=Cognito +VITE_LOGIN_TYPE=PLACEHOLDER_VITE_LOGIN_TYPE # KEEP the placeholder values if using CDK to deploy the backend! diff --git a/report-front-end/docker-entry.sh b/report-front-end/docker-entry.sh index 45ca4fa..d578d4d 100644 --- a/report-front-end/docker-entry.sh +++ b/report-front-end/docker-entry.sh @@ -3,6 +3,7 @@ # Read variable names from the .env file env_file="/.env" vars="VITE_COGNITO_REGION +VITE_LOGIN_TYPE VITE_COGNITO_USER_POOL_WEB_CLIENT_ID VITE_COGNITO_USER_POOL_ID VITE_BACKEND_URL diff --git a/source/resources/lib/ecs/ecs-stack.ts b/source/resources/lib/ecs/ecs-stack.ts index cb9f5df..fc92348 100644 --- a/source/resources/lib/ecs/ecs-stack.ts +++ b/source/resources/lib/ecs/ecs-stack.ts @@ -15,6 +15,7 @@ export class ECSStack extends cdk.Stack { constructor(scope: Construct, id: string, props: cdk.StackProps & { vpc: ec2.Vpc} & { subnets: cdk.aws_ec2.ISubnet[] } & { cognitoUserPoolId: string} + & { authenticationType: string} & { cognitoUserPoolClientId: string} & {OSMasterUserSecretName: string} & {OSHostSecretName: string}) { super(scope, id, props); @@ -255,6 +256,7 @@ constructor(scope: Construct, id: string, props: cdk.StackProps containerFrontend.addEnvironment('VITE_TITLE', 'Guidance for Generative BI') containerFrontend.addEnvironment('VITE_LOGO', '/logo.png'); containerFrontend.addEnvironment('VITE_RIGHT_LOGO', ''); + containerFrontend.addEnvironment('VITE_LOGIN_TYPE', props.authenticationType); containerFrontend.addEnvironment('VITE_COGNITO_REGION', cdk.Aws.REGION); containerFrontend.addEnvironment('VITE_COGNITO_USER_POOL_ID', props.cognitoUserPoolId); containerFrontend.addEnvironment('VITE_COGNITO_USER_POOL_WEB_CLIENT_ID', props.cognitoUserPoolClientId); diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index fb46b76..2468c31 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -102,6 +102,7 @@ export class MainStack extends cdk.Stack { env: props.env, vpc: _VpcStack.vpc, subnets: ecsSubnets.subnets, + authenticationType: _CognitoStack ? "Cognito" : "None", cognitoUserPoolId: _CognitoStack?.userPoolId ?? "", cognitoUserPoolClientId: _CognitoStack?.userPoolClientId ?? "", OSMasterUserSecretName: From 25ddecb936efd8f0151cdc9d544634bfa9ded515 Mon Sep 17 00:00:00 2001 From: supinyu Date: Fri, 19 Jul 2024 17:48:24 +0800 Subject: [PATCH 058/130] change max read sql --- application/utils/llm.py | 123 +++++++++++++++++++++++++++++++++------ 1 file changed, 104 insertions(+), 19 deletions(-) diff --git a/application/utils/llm.py b/application/utils/llm.py index 1b9127e..f46b3d3 100644 --- a/application/utils/llm.py +++ b/application/utils/llm.py @@ -5,6 +5,8 @@ from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \ DEFAULT_DIALECT_PROMPT, SEARCH_INTENT_PROMPT_CLAUDE3, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3 import os +import sagemaker +from sagemaker import Model, image_uris, serializers, deserializers import logging from langchain_core.output_parsers import JsonOutputParser from utils.prompts.generate_prompt import generate_llm_prompt, generate_sagemaker_intent_prompt, \ @@ -13,7 +15,9 @@ generate_agent_analyse_prompt, generate_data_summary_prompt, generate_suggest_question_prompt, \ generate_query_rewrite_prompt -from utils.env_var import bedrock_ak_sk_info, BEDROCK_REGION, BEDROCK_EMBEDDING_MODEL +from utils.env_var import bedrock_ak_sk_info, BEDROCK_REGION, BEDROCK_EMBEDDING_MODEL, SAGEMAKER_EMBEDDING_REGION, \ + SAGEMAKER_SQL_REGION, SAGEMAKER_ENDPOINT_EMBEDDING + logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -31,6 +35,7 @@ bedrock = None json_parse = JsonOutputParser() +embedding_sagemaker_client = None sagemaker_client = None @@ -41,7 +46,7 @@ def get_bedrock_client(): bedrock = boto3.client(service_name='bedrock-runtime', config=config) else: bedrock = boto3.client( - service_name='bedrock-runtime', config=config, + service_name='bedrock-runtime', config=config, aws_access_key_id=bedrock_ak_sk_info['access_key_id'], aws_secret_access_key=bedrock_ak_sk_info['secret_access_key']) return bedrock @@ -105,6 +110,43 @@ def invoke_llama_70b(model_id, system_prompt, user_prompt, max_tokens, with_resp logger.error(e) +def invoke_mixtral_8x7b_sagemaker(model_id, system_prompt, messages, max_tokens, with_response_stream=False): + """ + Invokes the Mixtral 8c7B model to run an inference using the input + provided in the request body. + + :param prompt: The prompt that you want Mixtral to complete. + :return: List of inference responses from the model. + """ + + try: + instruction = f"[INST] {system_prompt} \n The question you need to answer is: {messages[0]['content']} [/INST]" + + body = { + "inputs": instruction, + "parameters": { + "max_new_tokens": max_tokens, + "do_sample": True, + "temperature": 0.1, + "top_p": 0.95, + "top_k": 50, + "repetition_penalty": 1.0 + } + } + + response = invoke_model_sagemaker_endpoint( + endpoint_name=model_id, + body=json.dumps(body), + model_type="LLM", + with_response_stream=with_response_stream + ) + return response['generated_text'] + + except Exception as e: + logger.error("Couldn't invoke Mixtral 8x7B on SageMaker") + logger.error(e) + raise + def invoke_mixtral_8x7b(model_id, system_prompt, messages, max_tokens, with_response_stream=False): """ Invokes the Mixtral 8c7B model to run an inference using the input @@ -140,29 +182,60 @@ def invoke_mixtral_8x7b(model_id, system_prompt, messages, max_tokens, with_resp raise +def get_embedding_sagemaker_client(): + global embedding_sagemaker_client + if not embedding_sagemaker_client: + if SAGEMAKER_EMBEDDING_REGION is not None and SAGEMAKER_EMBEDDING_REGION != "": + embedding_sagemaker_client = boto3.client(service_name='sagemaker-runtime', + region_name=SAGEMAKER_EMBEDDING_REGION) + else: + embedding_sagemaker_client = boto3.client(service_name='sagemaker-runtime') + return embedding_sagemaker_client + + def get_sagemaker_client(): global sagemaker_client if not sagemaker_client: - sagemaker_client = boto3.client(service_name='sagemaker-runtime') + if SAGEMAKER_SQL_REGION is not None and SAGEMAKER_SQL_REGION != "": + sagemaker_client = boto3.client(service_name='sagemaker-runtime', + region_name=SAGEMAKER_SQL_REGION) + else: + sagemaker_client = boto3.client(service_name='sagemaker-runtime') return sagemaker_client - -def invoke_model_sagemaker_endpoint(endpoint_name, body, with_response_stream=False): +def invoke_model_sagemaker_endpoint(endpoint_name, body, model_type="LLM", with_response_stream=False): if with_response_stream: - response = get_sagemaker_client().invoke_endpoint_with_response_stream( - EndpointName=endpoint_name, - Body=body, - ContentType="application/json", - ) + if model_type == "LLM": + response = get_sagemaker_client().invoke_endpoint_with_response_stream( + EndpointName=endpoint_name, + Body=body, + ContentType="application/json", + ) + return response + else: + response = get_embedding_sagemaker_client().invoke_endpoint_with_response_stream( + EndpointName=endpoint_name, + Body=body, + ContentType="application/json", + ) return response else: - response = get_sagemaker_client().invoke_endpoint( - EndpointName=endpoint_name, - Body=body, - ContentType="application/json", - ) - response_body = json.loads(response.get('Body').read()) - return response_body + if model_type == "LLM": + response = get_sagemaker_client().invoke_endpoint( + EndpointName=endpoint_name, + Body=body, + ContentType="application/json", + ) + response_body = json.loads(response.get('Body').read()) + return response_body + else: + response = get_embedding_sagemaker_client().invoke_endpoint( + EndpointName=endpoint_name, + Body=body, + ContentType="application/json", + ) + response_body = json.loads(response.get('Body').read()) + return response_body def claude_select_table(): @@ -258,7 +331,10 @@ def invoke_llm_model(model_id, system_prompt, user_prompt, max_tokens=2048, with if model_id.startswith('anthropic.claude-3'): response = invoke_model_claude3(model_id, system_prompt, messages, max_tokens, with_response_stream) elif model_id.startswith('mistral.mixtral-8x7b'): - response = invoke_mixtral_8x7b(model_id, system_prompt, messages, max_tokens, with_response_stream) + if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "": + response = invoke_mixtral_8x7b_sagemaker(model_id, system_prompt, messages, max_tokens, with_response_stream) + else: + response = invoke_mixtral_8x7b(model_id, system_prompt, messages, max_tokens, with_response_stream) elif model_id.startswith('meta.llama3-70b'): response = invoke_llama_70b(model_id, system_prompt, user_prompt, max_tokens, with_response_stream) if with_response_stream: @@ -266,6 +342,15 @@ def invoke_llm_model(model_id, system_prompt, user_prompt, max_tokens=2048, with else: if model_id.startswith('meta.llama3-70b'): return response["generation"] + elif model_id.startswith('mistral.mixtral'): + if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "": + response = json.loads(response) + response = response['generated_text'] + response = response.replace("\\", "") + return response + else: + final_response = response.get("content")[0].get("text") + return final_response else: final_response = response.get("content")[0].get("text") return final_response @@ -493,7 +578,7 @@ def create_vector_embedding_with_sagemaker(endpoint_name, text, index_name): model_kwargs["max_length"] = 512 model_kwargs["return_type"] = "dense" body = json.dumps({"inputs": [text], **model_kwargs}) - response = invoke_model_sagemaker_endpoint(endpoint_name, body) + response = invoke_model_sagemaker_endpoint(endpoint_name, body, model_type="embedding") embeddings = response["sentence_embeddings"] return {"_index": index_name, "text": text, "vector_field": embeddings["dense_vecs"][0]} From 65b0799f22001f1c5c8626913d42a911804d89c1 Mon Sep 17 00:00:00 2001 From: Feng Xu Date: Fri, 19 Jul 2024 22:29:58 +0800 Subject: [PATCH 059/130] 1. support clickhouse db as datasource 2. return meaningful error when profile name not found for qa ask API --- application/api/enum.py | 1 + application/api/service.py | 2 + application/nlq/data_access/database.py | 52 ++++++------------- ...237\252\231_Data_Connection_Management.py" | 1 + application/utils/prompt.py | 5 ++ application/utils/prompts/generate_prompt.py | 4 +- 6 files changed, 27 insertions(+), 38 deletions(-) diff --git a/application/api/enum.py b/application/api/enum.py index 48e0f80..a72a81c 100644 --- a/application/api/enum.py +++ b/application/api/enum.py @@ -8,6 +8,7 @@ class ErrorEnum(Enum): NOT_SUPPORTED = {1001: "Your query statement is currently not supported by the system"} INVAILD_BEDROCK_MODEL_ID = {1002: f"Invalid bedrock model id.Vaild ids:{BEDROCK_MODEL_IDS}"} INVAILD_SESSION_ID = {1003: f"Invalid session id."} + PROFILE_NOT_FOUND = {1004: "Profile name not found."} UNKNOWN_ERROR = {9999: "Unknown error."} def get_code(self): diff --git a/application/api/service.py b/application/api/service.py index 0a50b25..101192b 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -142,6 +142,8 @@ def ask(question: Question) -> Answer: log_info = "" all_profiles = ProfileManagement.get_all_profiles_with_info() + if selected_profile not in all_profiles: + raise BizException(ErrorEnum.PROFILE_NOT_FOUND) database_profile = all_profiles[selected_profile] current_nlq_chain = NLQChain(selected_profile) diff --git a/application/nlq/data_access/database.py b/application/nlq/data_access/database.py index ad250a5..b3fd616 100644 --- a/application/nlq/data_access/database.py +++ b/application/nlq/data_access/database.py @@ -13,7 +13,8 @@ class RelationDatabase(): 'mysql': 'mysql+pymysql', 'postgresql': 'postgresql+psycopg2', 'redshift': 'postgresql+psycopg2', - 'starrocks': 'starrocks' + 'starrocks': 'starrocks', + 'clickhouse': 'clickhouse', # Add more mappings here for other databases } @@ -42,43 +43,20 @@ def test_connection(cls, db_type, user, password, host, port, db_name) -> bool: @classmethod def get_all_schema_names_by_connection(cls, connection: ConnectConfigEntity): - schemas = [] - if connection.db_type == 'postgresql': - db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host, - connection.db_port, connection.db_name) - engine = db.create_engine(db_url) - # with engine.connect() as conn: - # query = text(""" - # SELECT nspname AS schema_name - # FROM pg_catalog.pg_namespace - # WHERE nspname !~ '^pg_' AND nspname <> 'information_schema' AND nspname <> 'public' - # AND has_schema_privilege(nspname, 'USAGE'); - # """) - # - # # Executing the query - # result = conn.execute(query) - # schemas = [row['schema_name'] for row in result.mappings()] - # print(schemas) - inspector = sqlalchemy.inspect(engine) - schemas = inspector.get_schema_names() - elif connection.db_type == 'redshift': - db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host, - connection.db_port, connection.db_name) - engine = db.create_engine(db_url) - inspector = inspect(engine) + db_type = connection.db_type + db_url = cls.get_db_url(db_type, connection.db_user, connection.db_pwd, connection.db_host, connection.db_port, + connection.db_name) + engine = db.create_engine(db_url) + inspector = inspect(engine) + + if db_type == 'postgresql': + schemas = [schema for schema in inspector.get_schema_names() if + schema not in ('pg_catalog', 'information_schema', 'public')] + elif db_type in ('redshift', 'mysql', 'starrocks', 'clickhouse'): schemas = inspector.get_schema_names() - elif connection.db_type == 'mysql': - db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host, - connection.db_port, connection.db_name) - engine = db.create_engine(db_url) - database_connect = sqlalchemy.inspect(engine) - schemas = database_connect.get_schema_names() - elif connection.db_type == 'starrocks': - db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host, - connection.db_port, connection.db_name) - engine = db.create_engine(db_url) - database_connect = sqlalchemy.inspect(engine) - schemas = database_connect.get_schema_names() + else: + raise ValueError("Unsupported database type") + return schemas @classmethod diff --git "a/application/pages/2_\360\237\252\231_Data_Connection_Management.py" "b/application/pages/2_\360\237\252\231_Data_Connection_Management.py" index 54a6868..067fe17 100644 --- "a/application/pages/2_\360\237\252\231_Data_Connection_Management.py" +++ "b/application/pages/2_\360\237\252\231_Data_Connection_Management.py" @@ -13,6 +13,7 @@ 'postgresql': 'PostgreSQL', 'redshift': 'Redshift', 'starrocks': 'StarRocks', + 'clickhouse': 'Clickhouse', } diff --git a/application/utils/prompt.py b/application/utils/prompt.py index 61ca564..aaf16ad 100644 --- a/application/utils/prompt.py +++ b/application/utils/prompt.py @@ -22,6 +22,11 @@ question.When generating SQL, do not add double quotes or single quotes around table names. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per StarRocks SQL. Never query for all columns from a table.""".format(top_k=TOP_K) +CLICKHOUSE_DIALECT_PROMPT_CLAUDE3=""" +You are a data analysis expert and proficient in Clickhouse. Given an input question, first create a syntactically correct Clickhouse query to run, then look at the results of the query and return the answer to the input question. +Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per ClickHouse. You can order the results to return the most informative data in the database. +Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. +Pay attention to use today() function to get the current date, if the question involves "today". Pay attention to adapted to the table field type. Please follow the clickhouse syntax or function case specifications.If the field alias contains Chinese characters, please use double quotes to Wrap it.""".format(top_k=TOP_K) AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3 = """You are a Amazon Redshift expert. Given an input question, first create a syntactically correct Redshift query to run, then look at the results of the query and return the answer to the input question.When generating SQL, do not add double quotes or single quotes around table names. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. diff --git a/application/utils/prompts/generate_prompt.py b/application/utils/prompts/generate_prompt.py index 9c46fa9..f81b7da 100644 --- a/application/utils/prompts/generate_prompt.py +++ b/application/utils/prompts/generate_prompt.py @@ -1,5 +1,5 @@ from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \ - DEFAULT_DIALECT_PROMPT, AGENT_COT_EXAMPLE, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3, STARROCKS_DIALECT_PROMPT_CLAUDE3 + DEFAULT_DIALECT_PROMPT, AGENT_COT_EXAMPLE, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3, STARROCKS_DIALECT_PROMPT_CLAUDE3, CLICKHOUSE_DIALECT_PROMPT_CLAUDE3 from utils.prompts import guidance_prompt from utils.prompts import table_prompt import logging @@ -1909,6 +1909,8 @@ def generate_llm_prompt(ddl, hints, prompt_map, search_box, sql_examples=None, n dialect_prompt = AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3 elif dialect == 'starrocks': dialect_prompt = STARROCKS_DIALECT_PROMPT_CLAUDE3 + elif dialect == 'clickhouse': + dialect_prompt = CLICKHOUSE_DIALECT_PROMPT_CLAUDE3 else: dialect_prompt = DEFAULT_DIALECT_PROMPT From f752da4ea0e0f7ce4910ce6bf2ff4f402e80ae6a Mon Sep 17 00:00:00 2001 From: Feng Xu Date: Sat, 20 Jul 2024 00:14:14 +0800 Subject: [PATCH 060/130] add clickhouse pkg in requirements --- application/requirements-api.txt | 3 ++- application/requirements.txt | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/application/requirements-api.txt b/application/requirements-api.txt index 0b25975..8ba8268 100644 --- a/application/requirements-api.txt +++ b/application/requirements-api.txt @@ -15,4 +15,5 @@ langchain-core~=0.1.30 sqlparse~=0.4.2 pandas==2.0.3 openpyxl -starrocks==1.0.6 \ No newline at end of file +starrocks==1.0.6 +clickhouse-sqlalchemy==0.2.6 \ No newline at end of file diff --git a/application/requirements.txt b/application/requirements.txt index 81cef76..e14ce18 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -14,4 +14,5 @@ sqlparse~=0.4.2 debugpy pandas==2.0.3 openpyxl -starrocks==1.0.6 \ No newline at end of file +starrocks==1.0.6 +clickhouse-sqlalchemy==0.2.6 \ No newline at end of file From 4c6d72a57e5b9be3affda6394def176b9b73d766 Mon Sep 17 00:00:00 2001 From: supinyu Date: Sat, 20 Jul 2024 15:55:32 +0800 Subject: [PATCH 061/130] add ecsSecurityGroup to rds and ecs --- source/resources/lib/ecs/ecs-stack.ts | 608 ++++++++++++++------------ source/resources/lib/main-stack.ts | 37 +- source/resources/lib/rds/rds-stack.ts | 1 + 3 files changed, 342 insertions(+), 304 deletions(-) diff --git a/source/resources/lib/ecs/ecs-stack.ts b/source/resources/lib/ecs/ecs-stack.ts index fc92348..789bedd 100644 --- a/source/resources/lib/ecs/ecs-stack.ts +++ b/source/resources/lib/ecs/ecs-stack.ts @@ -1,301 +1,333 @@ import * as cdk from 'aws-cdk-lib'; -import { Construct } from 'constructs'; +import {Construct} from 'constructs'; import * as ec2 from 'aws-cdk-lib/aws-ec2'; import * as ecs from 'aws-cdk-lib/aws-ecs'; import * as ecr from 'aws-cdk-lib/aws-ecr'; import * as iam from 'aws-cdk-lib/aws-iam'; -import { DockerImageAsset } from 'aws-cdk-lib/aws-ecr-assets'; +import {DockerImageAsset} from 'aws-cdk-lib/aws-ecr-assets'; import * as ecs_patterns from 'aws-cdk-lib/aws-ecs-patterns'; import * as path from 'path'; export class ECSStack extends cdk.Stack { - public readonly streamlitEndpoint: string; - public readonly frontendEndpoint: string; - public readonly apiEndpoint: string; -constructor(scope: Construct, id: string, props: cdk.StackProps - & { vpc: ec2.Vpc} - & { subnets: cdk.aws_ec2.ISubnet[] } & { cognitoUserPoolId: string} - & { authenticationType: string} - & { cognitoUserPoolClientId: string} & {OSMasterUserSecretName: string} - & {OSHostSecretName: string}) { - super(scope, id, props); - - // const isolatedSubnets = this._vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_ISOLATED }).subnets; - // const privateSubnets = this._vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS }).subnets; - - // const nonPublicSubnets = [...isolatedSubnets, ...privateSubnets]; - // const subnets = this._vpc.selectSubnets().subnets; - - // Create ECR repositories and Docker image assets - const services = [ - { name: 'genbi-streamlit', dockerfile: 'Dockerfile', port: 8501, dockerfileDirectory: path.join(__dirname, '../../../../application')}, - { name: 'genbi-api', dockerfile: 'Dockerfile-api', port: 8000, dockerfileDirectory: path.join(__dirname, '../../../../application')}, - { name: 'genbi-frontend', dockerfile: 'Dockerfile', port: 80, dockerfileDirectory: path.join(__dirname, '../../../../report-front-end')}, - ]; - - const awsRegion = props.env?.region as string; - - const GenBiStreamlitDockerImageAsset = {'dockerImageAsset': new DockerImageAsset(this, 'GenBiStreamlitDockerImage', { - directory: services[0].dockerfileDirectory, - file: services[0].dockerfile, - buildArgs: { - AWS_REGION: awsRegion, // Pass the AWS region as a build argument - }, - }), 'port': services[0].port}; - - const GenBiAPIDockerImageAsset = {'dockerImageAsset': new DockerImageAsset(this, 'GenBiAPIDockerImage', { - directory: services[1].dockerfileDirectory, - file: services[1].dockerfile, - buildArgs : { - AWS_REGION: awsRegion, // Pass the AWS region as a build argument + public readonly streamlitEndpoint: string; + public readonly frontendEndpoint: string; + public readonly apiEndpoint: string; + public readonly ecsSecurityGroup: ec2.SecurityGroup; + + constructor(scope: Construct, id: string, props: cdk.StackProps + & { vpc: ec2.Vpc } + & { subnets: cdk.aws_ec2.ISubnet[] } & { cognitoUserPoolId: string } + & { authenticationType: string } + & { cognitoUserPoolClientId: string } & { OSMasterUserSecretName: string } + & { OSHostSecretName: string }) { + super(scope, id, props); + + // const isolatedSubnets = this._vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_ISOLATED }).subnets; + // const privateSubnets = this._vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS }).subnets; + + // const nonPublicSubnets = [...isolatedSubnets, ...privateSubnets]; + // const subnets = this._vpc.selectSubnets().subnets; + + // Create ECR repositories and Docker image assets + const services = [ + { + name: 'genbi-streamlit', + dockerfile: 'Dockerfile', + port: 8501, + dockerfileDirectory: path.join(__dirname, '../../../../application') + }, + { + name: 'genbi-api', + dockerfile: 'Dockerfile-api', + port: 8000, + dockerfileDirectory: path.join(__dirname, '../../../../application') + }, + { + name: 'genbi-frontend', + dockerfile: 'Dockerfile', + port: 80, + dockerfileDirectory: path.join(__dirname, '../../../../report-front-end') + }, + ]; + + const awsRegion = props.env?.region as string; + + const GenBiStreamlitDockerImageAsset = { + 'dockerImageAsset': new DockerImageAsset(this, 'GenBiStreamlitDockerImage', { + directory: services[0].dockerfileDirectory, + file: services[0].dockerfile, + buildArgs: { + AWS_REGION: awsRegion, // Pass the AWS region as a build argument + }, + }), 'port': services[0].port + }; + + const GenBiAPIDockerImageAsset = { + 'dockerImageAsset': new DockerImageAsset(this, 'GenBiAPIDockerImage', { + directory: services[1].dockerfileDirectory, + file: services[1].dockerfile, + buildArgs: { + AWS_REGION: awsRegion, // Pass the AWS region as a build argument + } + }), 'port': services[1].port + }; + + // Create an ECS cluster + const cluster = new ecs.Cluster(this, 'GenBiCluster', { + vpc: props.vpc, + }); + + const taskExecutionRole = new iam.Role(this, 'TaskExecutionRole', { + assumedBy: new iam.ServicePrincipal('ecs-tasks.amazonaws.com'), + }); + + const taskRole = new iam.Role(this, 'TaskRole', { + assumedBy: new iam.ServicePrincipal('ecs-tasks.amazonaws.com'), + }); + + // Add OpenSearch access policy + const openSearchAccessPolicy = new iam.PolicyStatement({ + actions: [ + "es:ESHttpGet", + "es:ESHttpHead", + "es:ESHttpPut", + "es:ESHttpPost", + "es:ESHttpDelete" + ], + resources: [ + `arn:${this.partition}:es:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:domain/*` + ] + }); + taskRole.addToPolicy(openSearchAccessPolicy); + + // Add DynamoDB access policy + const dynamoDBAccessPolicy = new iam.PolicyStatement({ + actions: [ + "dynamodb:*Table", + "dynamodb:*Item", + "dynamodb:Scan", + "dynamodb:Query" + ], + resources: [ + `arn:${this.partition}:dynamodb:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:table/*`, + ] + }); + taskRole.addToPolicy(dynamoDBAccessPolicy); + + // Add secrets manager access policy + const opensearchHostUrlSecretAccessPolicy = new iam.PolicyStatement({ + actions: [ + "secretsmanager:GetSecretValue" + ], + resources: [ + `arn:${this.partition}:secretsmanager:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:secret:opensearch-host-url*`, + `arn:${this.partition}:secretsmanager:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:secret:opensearch-master-user*` + ] + }); + taskRole.addToPolicy(opensearchHostUrlSecretAccessPolicy); + + // Add Bedrock access policy + if (props.env?.region !== "cn-north-1" && props.env?.region !== "cn-northwest-1") { + const bedrockAccessPolicy = new iam.PolicyStatement({ + actions: [ + "bedrock:InvokeModel", + "bedrock:InvokeModelWithResponseStream" + ], + resources: [ + `arn:${this.partition}:bedrock:${cdk.Aws.REGION}::foundation-model/*` + ] + }); + taskRole.addToPolicy(bedrockAccessPolicy); } - }), 'port': services[1].port}; - - // Create an ECS cluster - const cluster = new ecs.Cluster(this, 'GenBiCluster', { - vpc: props.vpc, - }); - - const taskExecutionRole = new iam.Role(this, 'TaskExecutionRole', { - assumedBy: new iam.ServicePrincipal('ecs-tasks.amazonaws.com'), - }); - - const taskRole = new iam.Role(this, 'TaskRole', { - assumedBy: new iam.ServicePrincipal('ecs-tasks.amazonaws.com'), - }); - - // Add OpenSearch access policy - const openSearchAccessPolicy = new iam.PolicyStatement({ - actions: [ - "es:ESHttpGet", - "es:ESHttpHead", - "es:ESHttpPut", - "es:ESHttpPost", - "es:ESHttpDelete" - ], - resources: [ - `arn:${this.partition}:es:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:domain/*` - ] - }); - taskRole.addToPolicy(openSearchAccessPolicy); - - // Add DynamoDB access policy - const dynamoDBAccessPolicy = new iam.PolicyStatement({ - actions: [ - "dynamodb:*Table", - "dynamodb:*Item", - "dynamodb:Scan", - "dynamodb:Query" - ], - resources: [ - `arn:${this.partition}:dynamodb:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:table/*`, - ] - }); - taskRole.addToPolicy(dynamoDBAccessPolicy); - - // Add secrets manager access policy - const opensearchHostUrlSecretAccessPolicy = new iam.PolicyStatement({ - actions: [ - "secretsmanager:GetSecretValue" - ], - resources: [ - `arn:${this.partition}:secretsmanager:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:secret:opensearch-host-url*`, - `arn:${this.partition}:secretsmanager:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:secret:opensearch-master-user*` - ] - }); - taskRole.addToPolicy(opensearchHostUrlSecretAccessPolicy); - - // Add Bedrock access policy - if (props.env?.region !== "cn-north-1" && props.env?.region !== "cn-northwest-1") { - const bedrockAccessPolicy = new iam.PolicyStatement({ - actions: [ - "bedrock:InvokeModel", - "bedrock:InvokeModelWithResponseStream" - ], - resources: [ - `arn:${this.partition}:bedrock:${cdk.Aws.REGION}::foundation-model/*` - ] - }); - taskRole.addToPolicy(bedrockAccessPolicy); - } - // Add SageMaker endpoint access policy - const sageMakerEndpointAccessPolicy = new iam.PolicyStatement({ - actions: [ - "sagemaker:InvokeEndpoint", - "sagemaker:DescribeEndpoint", - "sagemaker:ListEndpoints" - ], - resources: [ - `arn:${this.partition}:sagemaker:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:endpoint/*` - ] - }); - taskRole.addToPolicy(sageMakerEndpointAccessPolicy); - - - // Add Cognito all access policy - if (props.env?.region !== "cn-north-1" && props.env?.region !== "cn-northwest-1") { - const cognitoAccessPolicy = new iam.PolicyStatement({ - actions: [ - "cognito-identity:*", - "cognito-idp:*" - ], - resources: [ - `arn:${this.partition}:cognito-idp:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:userpool/*`, - `arn:${this.partition}:cognito-identity:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:identitypool/*` - ] - }); - taskRole.addToPolicy(cognitoAccessPolicy); - } + // Add SageMaker endpoint access policy + const sageMakerEndpointAccessPolicy = new iam.PolicyStatement({ + actions: [ + "sagemaker:InvokeEndpoint", + "sagemaker:DescribeEndpoint", + "sagemaker:ListEndpoints" + ], + resources: [ + `arn:${this.partition}:sagemaker:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:endpoint/*` + ] + }); + taskRole.addToPolicy(sageMakerEndpointAccessPolicy); + + + // Add Cognito all access policy + if (props.env?.region !== "cn-north-1" && props.env?.region !== "cn-northwest-1") { + const cognitoAccessPolicy = new iam.PolicyStatement({ + actions: [ + "cognito-identity:*", + "cognito-idp:*" + ], + resources: [ + `arn:${this.partition}:cognito-idp:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:userpool/*`, + `arn:${this.partition}:cognito-identity:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:identitypool/*` + ] + }); + taskRole.addToPolicy(cognitoAccessPolicy); + } - // Create ECS services through Fargate - // ======= 1. Streamlit Service ======= - const taskDefinitionStreamlit = new ecs.FargateTaskDefinition(this, 'GenBiTaskDefinitionStreamlit', { - memoryLimitMiB: 512, - cpu: 256, - executionRole: taskExecutionRole, - taskRole: taskRole - }); - - const containerStreamlit = taskDefinitionStreamlit.addContainer('GenBiContainerStreamlit', { - image: ecs.ContainerImage.fromDockerImageAsset(GenBiStreamlitDockerImageAsset.dockerImageAsset), - memoryLimitMiB: 512, - cpu: 256, - logging: new ecs.AwsLogDriver({ - streamPrefix: 'GenBiStreamlit', - }), - }); - - containerStreamlit.addEnvironment('OPENSEARCH_TYPE', 'service'); - containerStreamlit.addEnvironment('AOS_INDEX', 'uba'); - containerStreamlit.addEnvironment('AOS_INDEX_NER', 'uba_ner'); - containerStreamlit.addEnvironment('AOS_INDEX_AGENT', 'uba_agent'); - containerStreamlit.addEnvironment('BEDROCK_REGION', cdk.Aws.REGION); - containerStreamlit.addEnvironment('RDS_REGION_NAME', cdk.Aws.REGION); - containerStreamlit.addEnvironment('AWS_DEFAULT_REGION', cdk.Aws.REGION); - containerStreamlit.addEnvironment('DYNAMODB_AWS_REGION', cdk.Aws.REGION); - containerStreamlit.addEnvironment('OPENSEARCH_SECRETS_URL_HOST', props.OSHostSecretName) - containerStreamlit.addEnvironment('OPENSEARCH_SECRETS_USERNAME_PASSWORD', props.OSMasterUserSecretName) - containerStreamlit.addPortMappings({ - containerPort: GenBiStreamlitDockerImageAsset.port, - }); - - const fargateServiceStreamlit = new ecs_patterns.ApplicationLoadBalancedFargateService(this, 'GenBiFargateServiceStreamlit', { - cluster: cluster, - taskDefinition: taskDefinitionStreamlit, - publicLoadBalancer: true, - taskSubnets: { subnets: props.subnets }, - assignPublicIp: true - }); - - // ======= 2. API Service ======= - const taskDefinitionAPI = new ecs.FargateTaskDefinition(this, 'GenBiTaskDefinitionAPI', { - memoryLimitMiB: 512, - cpu: 256, - executionRole: taskExecutionRole, - taskRole: taskRole - }); - - const containerAPI = taskDefinitionAPI.addContainer('GenBiContainerAPI', { - image: ecs.ContainerImage.fromDockerImageAsset(GenBiAPIDockerImageAsset.dockerImageAsset), - memoryLimitMiB: 512, - cpu: 256, - logging: new ecs.AwsLogDriver({ - streamPrefix: 'GenBiAPI', - }), - }); - - containerAPI.addEnvironment('OPENSEARCH_TYPE', 'service'); - containerAPI.addEnvironment('AOS_INDEX', 'uba'); - containerAPI.addEnvironment('AOS_INDEX_NER', 'uba_ner'); - containerAPI.addEnvironment('AOS_INDEX_AGENT', 'uba_agent'); - containerAPI.addEnvironment('BEDROCK_REGION', cdk.Aws.REGION); - containerAPI.addEnvironment('RDS_REGION_NAME', cdk.Aws.REGION); - containerAPI.addEnvironment('AWS_DEFAULT_REGION', cdk.Aws.REGION); - containerAPI.addEnvironment('DYNAMODB_AWS_REGION', cdk.Aws.REGION); - containerAPI.addEnvironment('OPENSEARCH_SECRETS_URL_HOST', props.OSHostSecretName) - containerAPI.addEnvironment('OPENSEARCH_SECRETS_USERNAME_PASSWORD', props.OSMasterUserSecretName) - - containerAPI.addPortMappings({ - containerPort: GenBiAPIDockerImageAsset.port, - }); - - const fargateServiceAPI = new ecs_patterns.ApplicationLoadBalancedFargateService(this, 'GenBiFargateServiceAPI', { - cluster: cluster, - taskDefinition: taskDefinitionAPI, - publicLoadBalancer: true, - taskSubnets: { subnets: props.subnets }, - assignPublicIp: true - }); - - // ======= 3. Frontend Service ======= - const GenBiFrontendDockerImageAsset = {'dockerImageAsset': new DockerImageAsset(this, 'GenBiFrontendDockerImage', { - directory: services[2].dockerfileDirectory, - file: services[2].dockerfile, - buildArgs : { - AWS_REGION: awsRegion, // Pass the AWS region as a build argument - } - }), 'port': services[2].port}; - - const taskDefinitionFrontend = new ecs.FargateTaskDefinition(this, 'GenBiTaskDefinitionFrontend', { - memoryLimitMiB: 512, - cpu: 256, - executionRole: taskExecutionRole, - taskRole: taskRole - }); - - const containerFrontend = taskDefinitionFrontend.addContainer('GenBiContainerFrontend', { - image: ecs.ContainerImage.fromDockerImageAsset(GenBiFrontendDockerImageAsset.dockerImageAsset), - memoryLimitMiB: 512, - cpu: 256, - logging: new ecs.AwsLogDriver({ - streamPrefix: 'GenBiFrontend', - }), - }); - - containerFrontend.addEnvironment('VITE_TITLE', 'Guidance for Generative BI') - containerFrontend.addEnvironment('VITE_LOGO', '/logo.png'); - containerFrontend.addEnvironment('VITE_RIGHT_LOGO', ''); - containerFrontend.addEnvironment('VITE_LOGIN_TYPE', props.authenticationType); - containerFrontend.addEnvironment('VITE_COGNITO_REGION', cdk.Aws.REGION); - containerFrontend.addEnvironment('VITE_COGNITO_USER_POOL_ID', props.cognitoUserPoolId); - containerFrontend.addEnvironment('VITE_COGNITO_USER_POOL_WEB_CLIENT_ID', props.cognitoUserPoolClientId); - containerFrontend.addEnvironment('VITE_COGNITO_IDENTITY_POOL_ID', ''); - containerFrontend.addEnvironment('VITE_SQL_DISPLAY', 'yes'); - containerFrontend.addEnvironment('VITE_BACKEND_URL', `http://${fargateServiceAPI.loadBalancer.loadBalancerDnsName}/`); - containerFrontend.addEnvironment('VITE_WEBSOCKET_URL', `ws://${fargateServiceAPI.loadBalancer.loadBalancerDnsName}/qa/ws`); - containerFrontend.addEnvironment('VITE_LOGIN_TYPE', 'Cognito'); - - containerFrontend.addPortMappings({ - containerPort: GenBiFrontendDockerImageAsset.port, - }); - - const fargateServiceFrontend = new ecs_patterns.ApplicationLoadBalancedFargateService(this, 'GenBiFargateServiceFrontend', { - cluster: cluster, - taskDefinition: taskDefinitionFrontend, - publicLoadBalancer: true, - // taskSubnets: { subnetType: ec2.SubnetType.PUBLIC }, - taskSubnets: { subnets: props.subnets }, - assignPublicIp: true - }); - - this.streamlitEndpoint = fargateServiceStreamlit.loadBalancer.loadBalancerDnsName; - this.apiEndpoint = fargateServiceAPI.loadBalancer.loadBalancerDnsName; - this.frontendEndpoint = fargateServiceFrontend.loadBalancer.loadBalancerDnsName; - - new cdk.CfnOutput(this, 'StreamlitEndpoint', { - value: fargateServiceStreamlit.loadBalancer.loadBalancerDnsName, - description: 'The endpoint of the Streamlit service' - }); - - new cdk.CfnOutput(this, 'APIEndpoint', { - value: fargateServiceAPI.loadBalancer.loadBalancerDnsName, - description: 'The endpoint of the API service' - }); - - new cdk.CfnOutput(this, 'FrontendEndpoint', { - value: fargateServiceFrontend.loadBalancer.loadBalancerDnsName, - description: 'The endpoint of the Frontend service' - }); - } + // Create ECS services through Fargate + // ======= 1. Streamlit Service ======= + const taskDefinitionStreamlit = new ecs.FargateTaskDefinition(this, 'GenBiTaskDefinitionStreamlit', { + memoryLimitMiB: 512, + cpu: 256, + executionRole: taskExecutionRole, + taskRole: taskRole + }); + + const containerStreamlit = taskDefinitionStreamlit.addContainer('GenBiContainerStreamlit', { + image: ecs.ContainerImage.fromDockerImageAsset(GenBiStreamlitDockerImageAsset.dockerImageAsset), + memoryLimitMiB: 512, + cpu: 256, + logging: new ecs.AwsLogDriver({ + streamPrefix: 'GenBiStreamlit', + }), + }); + + containerStreamlit.addEnvironment('OPENSEARCH_TYPE', 'service'); + containerStreamlit.addEnvironment('AOS_INDEX', 'uba'); + containerStreamlit.addEnvironment('AOS_INDEX_NER', 'uba_ner'); + containerStreamlit.addEnvironment('AOS_INDEX_AGENT', 'uba_agent'); + containerStreamlit.addEnvironment('BEDROCK_REGION', cdk.Aws.REGION); + containerStreamlit.addEnvironment('RDS_REGION_NAME', cdk.Aws.REGION); + containerStreamlit.addEnvironment('AWS_DEFAULT_REGION', cdk.Aws.REGION); + containerStreamlit.addEnvironment('DYNAMODB_AWS_REGION', cdk.Aws.REGION); + containerStreamlit.addEnvironment('OPENSEARCH_SECRETS_URL_HOST', props.OSHostSecretName) + containerStreamlit.addEnvironment('OPENSEARCH_SECRETS_USERNAME_PASSWORD', props.OSMasterUserSecretName) + containerStreamlit.addPortMappings({ + containerPort: GenBiStreamlitDockerImageAsset.port, + }); + + this.ecsSecurityGroup = new ec2.SecurityGroup(this, 'GenBIECSSecurityGroup', { + vpc: props.vpc, + allowAllOutbound: true, + description: 'Security group for ECS tasks', + }); + + const fargateServiceStreamlit = new ecs_patterns.ApplicationLoadBalancedFargateService(this, 'GenBiFargateServiceStreamlit', { + cluster: cluster, + taskDefinition: taskDefinitionStreamlit, + publicLoadBalancer: true, + taskSubnets: {subnets: props.subnets}, + assignPublicIp: true, + securityGroups: [this.ecsSecurityGroup], + }); + + // ======= 2. API Service ======= + const taskDefinitionAPI = new ecs.FargateTaskDefinition(this, 'GenBiTaskDefinitionAPI', { + memoryLimitMiB: 512, + cpu: 256, + executionRole: taskExecutionRole, + taskRole: taskRole + }); + + const containerAPI = taskDefinitionAPI.addContainer('GenBiContainerAPI', { + image: ecs.ContainerImage.fromDockerImageAsset(GenBiAPIDockerImageAsset.dockerImageAsset), + memoryLimitMiB: 512, + cpu: 256, + logging: new ecs.AwsLogDriver({ + streamPrefix: 'GenBiAPI', + }), + }); + + containerAPI.addEnvironment('OPENSEARCH_TYPE', 'service'); + containerAPI.addEnvironment('AOS_INDEX', 'uba'); + containerAPI.addEnvironment('AOS_INDEX_NER', 'uba_ner'); + containerAPI.addEnvironment('AOS_INDEX_AGENT', 'uba_agent'); + containerAPI.addEnvironment('BEDROCK_REGION', cdk.Aws.REGION); + containerAPI.addEnvironment('RDS_REGION_NAME', cdk.Aws.REGION); + containerAPI.addEnvironment('AWS_DEFAULT_REGION', cdk.Aws.REGION); + containerAPI.addEnvironment('DYNAMODB_AWS_REGION', cdk.Aws.REGION); + containerAPI.addEnvironment('OPENSEARCH_SECRETS_URL_HOST', props.OSHostSecretName) + containerAPI.addEnvironment('OPENSEARCH_SECRETS_USERNAME_PASSWORD', props.OSMasterUserSecretName) + + containerAPI.addPortMappings({ + containerPort: GenBiAPIDockerImageAsset.port, + }); + + const fargateServiceAPI = new ecs_patterns.ApplicationLoadBalancedFargateService(this, 'GenBiFargateServiceAPI', { + cluster: cluster, + taskDefinition: taskDefinitionAPI, + publicLoadBalancer: true, + taskSubnets: {subnets: props.subnets}, + assignPublicIp: true, + securityGroups: [this.ecsSecurityGroup], + }); + + // ======= 3. Frontend Service ======= + const GenBiFrontendDockerImageAsset = { + 'dockerImageAsset': new DockerImageAsset(this, 'GenBiFrontendDockerImage', { + directory: services[2].dockerfileDirectory, + file: services[2].dockerfile, + buildArgs: { + AWS_REGION: awsRegion, // Pass the AWS region as a build argument + } + }), 'port': services[2].port + }; + + const taskDefinitionFrontend = new ecs.FargateTaskDefinition(this, 'GenBiTaskDefinitionFrontend', { + memoryLimitMiB: 512, + cpu: 256, + executionRole: taskExecutionRole, + taskRole: taskRole + }); + + const containerFrontend = taskDefinitionFrontend.addContainer('GenBiContainerFrontend', { + image: ecs.ContainerImage.fromDockerImageAsset(GenBiFrontendDockerImageAsset.dockerImageAsset), + memoryLimitMiB: 512, + cpu: 256, + logging: new ecs.AwsLogDriver({ + streamPrefix: 'GenBiFrontend', + }), + }); + + containerFrontend.addEnvironment('VITE_TITLE', 'Guidance for Generative BI') + containerFrontend.addEnvironment('VITE_LOGO', '/logo.png'); + containerFrontend.addEnvironment('VITE_RIGHT_LOGO', ''); + containerFrontend.addEnvironment('VITE_LOGIN_TYPE', props.authenticationType); + containerFrontend.addEnvironment('VITE_COGNITO_REGION', cdk.Aws.REGION); + containerFrontend.addEnvironment('VITE_COGNITO_USER_POOL_ID', props.cognitoUserPoolId); + containerFrontend.addEnvironment('VITE_COGNITO_USER_POOL_WEB_CLIENT_ID', props.cognitoUserPoolClientId); + containerFrontend.addEnvironment('VITE_COGNITO_IDENTITY_POOL_ID', ''); + containerFrontend.addEnvironment('VITE_SQL_DISPLAY', 'yes'); + containerFrontend.addEnvironment('VITE_BACKEND_URL', `http://${fargateServiceAPI.loadBalancer.loadBalancerDnsName}/`); + containerFrontend.addEnvironment('VITE_WEBSOCKET_URL', `ws://${fargateServiceAPI.loadBalancer.loadBalancerDnsName}/qa/ws`); + containerFrontend.addEnvironment('VITE_LOGIN_TYPE', 'Cognito'); + + containerFrontend.addPortMappings({ + containerPort: GenBiFrontendDockerImageAsset.port, + }); + + const fargateServiceFrontend = new ecs_patterns.ApplicationLoadBalancedFargateService(this, 'GenBiFargateServiceFrontend', { + cluster: cluster, + taskDefinition: taskDefinitionFrontend, + publicLoadBalancer: true, + // taskSubnets: { subnetType: ec2.SubnetType.PUBLIC }, + taskSubnets: {subnets: props.subnets}, + assignPublicIp: true, + securityGroups: [this.ecsSecurityGroup], + }); + + this.streamlitEndpoint = fargateServiceStreamlit.loadBalancer.loadBalancerDnsName; + this.apiEndpoint = fargateServiceAPI.loadBalancer.loadBalancerDnsName; + this.frontendEndpoint = fargateServiceFrontend.loadBalancer.loadBalancerDnsName; + + new cdk.CfnOutput(this, 'StreamlitEndpoint', { + value: fargateServiceStreamlit.loadBalancer.loadBalancerDnsName, + description: 'The endpoint of the Streamlit service' + }); + + new cdk.CfnOutput(this, 'APIEndpoint', { + value: fargateServiceAPI.loadBalancer.loadBalancerDnsName, + description: 'The endpoint of the API service' + }); + + new cdk.CfnOutput(this, 'FrontendEndpoint', { + value: fargateServiceFrontend.loadBalancer.loadBalancerDnsName, + description: 'The endpoint of the Frontend service' + }); + } } \ No newline at end of file diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index 2468c31..472895c 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -55,21 +55,6 @@ export class MainStack extends cdk.Stack { const aosEndpoint = _AosStack.endpoint; - // ======== Step 3. Define the RDSStack ========= - if (_deployRds) { - const rdsSubnets = _VpcStack.vpc.selectSubnets({subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS}); - - const _RdsStack = new RDSStack(this, 'rds-Stack', { - env: props.env, - subnets: rdsSubnets, - vpc: _VpcStack.vpc - }); - new cdk.CfnOutput(this, 'RDSEndpoint', { - value: _RdsStack.endpoint, - description: 'The endpoint of the RDS instance', - }); - } - // ======== Step 4. Define Cognito ========= const isChinaRegion = props.env?.region === "cn-north-1" || props.env?.region === "cn-northwest-1"; @@ -112,7 +97,6 @@ export class MainStack extends cdk.Stack { }) ; - _AosStack.addDependency(_VpcStack); _EcsStack.addDependency(_AosStack); if (_CognitoStack) { @@ -120,6 +104,27 @@ export class MainStack extends cdk.Stack { } _EcsStack.addDependency(_VpcStack); + // ======== Step 3. Define the RDSStack ========= + if (_deployRds) { + const rdsSubnets = _VpcStack.vpc.selectSubnets({subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS}); + + const _RdsStack = new RDSStack(this, 'rds-Stack', { + env: props.env, + subnets: rdsSubnets, + vpc: _VpcStack.vpc + }); + + _RdsStack.rdsSecurityGroup.addIngressRule( + _EcsStack.ecsSecurityGroup, + ec2.Port.tcp(3306), + 'Allow inbound traffic from ECS on port 3306' + ); + new cdk.CfnOutput(this, 'RDSEndpoint', { + value: _RdsStack.endpoint, + description: 'The endpoint of the RDS instance', + }); + } + new cdk.CfnOutput(this, 'AOSDomainEndpoint', { value: aosEndpoint, description: 'The endpoint of the OpenSearch domain' diff --git a/source/resources/lib/rds/rds-stack.ts b/source/resources/lib/rds/rds-stack.ts index 19ff6d7..a912f86 100644 --- a/source/resources/lib/rds/rds-stack.ts +++ b/source/resources/lib/rds/rds-stack.ts @@ -12,6 +12,7 @@ interface RDSStackProps extends cdk.StackProps { // add rds stack export class RDSStack extends cdk.Stack { public readonly endpoint: string; + public readonly rdsSecurityGroup: ec2.SecurityGroup; constructor(scope: Construct, id: string, props: RDSStackProps) { super(scope, id, props); From 60b8ed7e7c28568133e18871e9ead9b46b9897c5 Mon Sep 17 00:00:00 2001 From: supinyu Date: Sun, 21 Jul 2024 20:19:52 +0800 Subject: [PATCH 062/130] change ecs input --- source/resources/lib/ecs/ecs-stack.ts | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/source/resources/lib/ecs/ecs-stack.ts b/source/resources/lib/ecs/ecs-stack.ts index 789bedd..c95ae03 100644 --- a/source/resources/lib/ecs/ecs-stack.ts +++ b/source/resources/lib/ecs/ecs-stack.ts @@ -8,18 +8,23 @@ import {DockerImageAsset} from 'aws-cdk-lib/aws-ecr-assets'; import * as ecs_patterns from 'aws-cdk-lib/aws-ecs-patterns'; import * as path from 'path'; +interface ECSStackProps extends cdk.StackProps { + vpc: ec2.Vpc; + subnets: ec2.ISubnet[]; + cognitoUserPoolId: string; + authenticationType: string; + cognitoUserPoolClientId: string; + OSMasterUserSecretName: string; + OSHostSecretName: string; +} + export class ECSStack extends cdk.Stack { public readonly streamlitEndpoint: string; public readonly frontendEndpoint: string; public readonly apiEndpoint: string; public readonly ecsSecurityGroup: ec2.SecurityGroup; - constructor(scope: Construct, id: string, props: cdk.StackProps - & { vpc: ec2.Vpc } - & { subnets: cdk.aws_ec2.ISubnet[] } & { cognitoUserPoolId: string } - & { authenticationType: string } - & { cognitoUserPoolClientId: string } & { OSMasterUserSecretName: string } - & { OSHostSecretName: string }) { + constructor(scope: Construct, id: string, props: ECSStackProps) { super(scope, id, props); // const isolatedSubnets = this._vpc.selectSubnets({ subnetType: ec2.SubnetType.PRIVATE_ISOLATED }).subnets; From 8166c7c99c56ad9230591ff9d225cfd5f41de4dc Mon Sep 17 00:00:00 2001 From: supinyu Date: Sun, 21 Jul 2024 22:04:45 +0800 Subject: [PATCH 063/130] fix some rdsSecurityGroup --- source/resources/lib/main-stack.ts | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index 472895c..f8cda00 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -114,11 +114,13 @@ export class MainStack extends cdk.Stack { vpc: _VpcStack.vpc }); - _RdsStack.rdsSecurityGroup.addIngressRule( - _EcsStack.ecsSecurityGroup, - ec2.Port.tcp(3306), - 'Allow inbound traffic from ECS on port 3306' - ); + if (_RdsStack.rdsSecurityGroup && _EcsStack.ecsSecurityGroup) { + _RdsStack.rdsSecurityGroup.addIngressRule( + _EcsStack.ecsSecurityGroup, + ec2.Port.tcp(3306), + 'Allow inbound traffic from ECS on port 3306' + ); + } new cdk.CfnOutput(this, 'RDSEndpoint', { value: _RdsStack.endpoint, description: 'The endpoint of the RDS instance', From 6df4c57a15242439cccd5888c2c00de9aff90efb Mon Sep 17 00:00:00 2001 From: supinyu Date: Mon, 22 Jul 2024 09:53:40 +0800 Subject: [PATCH 064/130] add rds addDependency --- source/resources/lib/main-stack.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index f8cda00..f3602a8 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -114,6 +114,8 @@ export class MainStack extends cdk.Stack { vpc: _VpcStack.vpc }); + _RdsStack.addDependency(_EcsStack); + if (_RdsStack.rdsSecurityGroup && _EcsStack.ecsSecurityGroup) { _RdsStack.rdsSecurityGroup.addIngressRule( _EcsStack.ecsSecurityGroup, From 4da52383577391f4b1ce75d8d790d7458f37557e Mon Sep 17 00:00:00 2001 From: supinyu Date: Mon, 22 Jul 2024 10:18:19 +0800 Subject: [PATCH 065/130] add segamaker region --- source/resources/lib/ecs/ecs-stack.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/source/resources/lib/ecs/ecs-stack.ts b/source/resources/lib/ecs/ecs-stack.ts index c95ae03..eaae0f7 100644 --- a/source/resources/lib/ecs/ecs-stack.ts +++ b/source/resources/lib/ecs/ecs-stack.ts @@ -196,6 +196,8 @@ export class ECSStack extends cdk.Stack { containerStreamlit.addEnvironment('AOS_INDEX', 'uba'); containerStreamlit.addEnvironment('AOS_INDEX_NER', 'uba_ner'); containerStreamlit.addEnvironment('AOS_INDEX_AGENT', 'uba_agent'); + // containerStreamlit.addEnvironment('SAGEMAKER_EMBEDDING_REGION', cdk.Aws.REGION); + // containerStreamlit.addEnvironment('SAGEMAKER_SQL_REGION', cdk.Aws.REGION); containerStreamlit.addEnvironment('BEDROCK_REGION', cdk.Aws.REGION); containerStreamlit.addEnvironment('RDS_REGION_NAME', cdk.Aws.REGION); containerStreamlit.addEnvironment('AWS_DEFAULT_REGION', cdk.Aws.REGION); @@ -242,6 +244,8 @@ export class ECSStack extends cdk.Stack { containerAPI.addEnvironment('AOS_INDEX', 'uba'); containerAPI.addEnvironment('AOS_INDEX_NER', 'uba_ner'); containerAPI.addEnvironment('AOS_INDEX_AGENT', 'uba_agent'); + // containerAPI.addEnvironment('SAGEMAKER_EMBEDDING_REGION', cdk.Aws.REGION); + // containerAPI.addEnvironment('SAGEMAKER_SQL_REGION', cdk.Aws.REGION); containerAPI.addEnvironment('BEDROCK_REGION', cdk.Aws.REGION); containerAPI.addEnvironment('RDS_REGION_NAME', cdk.Aws.REGION); containerAPI.addEnvironment('AWS_DEFAULT_REGION', cdk.Aws.REGION); From b5655f3d2e69c6cb2b9eb685aea88f33d84f306c Mon Sep 17 00:00:00 2001 From: supinyu Date: Tue, 23 Jul 2024 16:13:56 +0800 Subject: [PATCH 066/130] add segamaker region --- application/requirements-api.txt | 3 ++- application/requirements.txt | 3 ++- source/resources/bin/main.ts | 14 +++++++++++++- source/resources/cdk-config.json | 6 ++++-- 4 files changed, 21 insertions(+), 5 deletions(-) diff --git a/application/requirements-api.txt b/application/requirements-api.txt index 0b25975..8e70cf7 100644 --- a/application/requirements-api.txt +++ b/application/requirements-api.txt @@ -15,4 +15,5 @@ langchain-core~=0.1.30 sqlparse~=0.4.2 pandas==2.0.3 openpyxl -starrocks==1.0.6 \ No newline at end of file +starrocks==1.0.6 +sagemaker \ No newline at end of file diff --git a/application/requirements.txt b/application/requirements.txt index 81cef76..38f5798 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -14,4 +14,5 @@ sqlparse~=0.4.2 debugpy pandas==2.0.3 openpyxl -starrocks==1.0.6 \ No newline at end of file +starrocks==1.0.6 +sagemaker \ No newline at end of file diff --git a/source/resources/bin/main.ts b/source/resources/bin/main.ts index 20671f0..d36b6cf 100644 --- a/source/resources/bin/main.ts +++ b/source/resources/bin/main.ts @@ -17,9 +17,21 @@ const app = new cdk.App(); const rds = config.rds +const embedding = config.embedding + +const opensearch = config.opensearch + +const vpc = config.vpc + const cdkConfig = { env: devEnv, - deployRds: rds.deploy + deployRds: rds.deploy, + bedrock_embedding_name: embedding.bedrock_embedding_name, + embedding_dimension: embedding.embedding_dimension, + opensearch_sql_index : opensearch.sql_index, + opensearch_ner_index : opensearch.ner_index, + opensearch_cot_index : opensearch.cot_index, + vpc_id : vpc.id }; new MainStack(app, 'GenBiMainStack', cdkConfig); // Pass deployRDS flag to MainStack constructor diff --git a/source/resources/cdk-config.json b/source/resources/cdk-config.json index 9053f52..ae4b5bb 100644 --- a/source/resources/cdk-config.json +++ b/source/resources/cdk-config.json @@ -1,11 +1,13 @@ { + "vpc" : { + "vpc_id" : "" + }, "rds": { "deploy": false }, "embedding": { "bedrock_embedding_name": "amazon.titan-embed-text-v1", - "embedding_dimension": 1536, - "segamaker_embedding_name" : "" + "embedding_dimension": 1536 }, "segamaker": { "endpoint_name" : "" From a4c0d467aae0675dd929146983a37f4a8f9915d5 Mon Sep 17 00:00:00 2001 From: supinyu Date: Tue, 23 Jul 2024 18:43:34 +0800 Subject: [PATCH 067/130] fix some bug --- .../pages/1_\360\237\214\215_Generative_BI_Playground.py" | 6 ++++++ 1 file changed, 6 insertions(+) diff --git "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" index c3cbd55..dfdc5af 100644 --- "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" +++ "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" @@ -287,6 +287,12 @@ def main(): if selected_profile not in st.session_state.messages: st.session_state.messages[selected_profile] = [] st.session_state.nlq_chain = NLQChain(selected_profile) + else: + if selected_profile not in st.session_state.messages: + st.session_state.messages[selected_profile] = [] + if selected_profile not in st.session_state.query_rewrite_history: + st.session_state.query_rewrite_history[selected_profile] = [] + st.session_state.nlq_chain = NLQChain(selected_profile) if st.session_state.current_model_id != "" and st.session_state.current_model_id in model_ids: model_index = model_ids.index(st.session_state.current_model_id) From 507d516a7e905c21e7e5397372a6330cbb830cb5 Mon Sep 17 00:00:00 2001 From: supinyu Date: Tue, 23 Jul 2024 18:47:43 +0800 Subject: [PATCH 068/130] fix segamaker error --- application/utils/llm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/application/utils/llm.py b/application/utils/llm.py index f46b3d3..3048f21 100644 --- a/application/utils/llm.py +++ b/application/utils/llm.py @@ -16,7 +16,7 @@ generate_query_rewrite_prompt from utils.env_var import bedrock_ak_sk_info, BEDROCK_REGION, BEDROCK_EMBEDDING_MODEL, SAGEMAKER_EMBEDDING_REGION, \ - SAGEMAKER_SQL_REGION, SAGEMAKER_ENDPOINT_EMBEDDING + SAGEMAKER_SQL_REGION, SAGEMAKER_ENDPOINT_EMBEDDING, SAGEMAKER_ENDPOINT_SQL logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -331,8 +331,8 @@ def invoke_llm_model(model_id, system_prompt, user_prompt, max_tokens=2048, with if model_id.startswith('anthropic.claude-3'): response = invoke_model_claude3(model_id, system_prompt, messages, max_tokens, with_response_stream) elif model_id.startswith('mistral.mixtral-8x7b'): - if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "": - response = invoke_mixtral_8x7b_sagemaker(model_id, system_prompt, messages, max_tokens, with_response_stream) + if SAGEMAKER_ENDPOINT_SQL is not None and SAGEMAKER_ENDPOINT_SQL != "": + response = invoke_mixtral_8x7b_sagemaker(SAGEMAKER_ENDPOINT_SQL, system_prompt, messages, max_tokens, with_response_stream) else: response = invoke_mixtral_8x7b(model_id, system_prompt, messages, max_tokens, with_response_stream) elif model_id.startswith('meta.llama3-70b'): @@ -343,7 +343,7 @@ def invoke_llm_model(model_id, system_prompt, user_prompt, max_tokens=2048, with if model_id.startswith('meta.llama3-70b'): return response["generation"] elif model_id.startswith('mistral.mixtral'): - if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "": + if SAGEMAKER_ENDPOINT_SQL is not None and SAGEMAKER_ENDPOINT_SQL != "": response = json.loads(response) response = response['generated_text'] response = response.replace("\\", "") From 63260c153c534ee390688e338ab1eb2e4ae51e11 Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 24 Jul 2024 09:43:12 +0800 Subject: [PATCH 069/130] remove some code --- application/utils/llm.py | 2 - source/resources/lib/vpc/vpc-stack.ts | 103 ++++++++++++++------------ 2 files changed, 57 insertions(+), 48 deletions(-) diff --git a/application/utils/llm.py b/application/utils/llm.py index 50b94ea..c86664a 100644 --- a/application/utils/llm.py +++ b/application/utils/llm.py @@ -5,8 +5,6 @@ from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \ DEFAULT_DIALECT_PROMPT, SEARCH_INTENT_PROMPT_CLAUDE3, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3 import os -import sagemaker -from sagemaker import Model, image_uris, serializers, deserializers import logging from langchain_core.output_parsers import JsonOutputParser from utils.prompts.generate_prompt import generate_llm_prompt, generate_sagemaker_intent_prompt, \ diff --git a/source/resources/lib/vpc/vpc-stack.ts b/source/resources/lib/vpc/vpc-stack.ts index 6e99115..f04c79c 100644 --- a/source/resources/lib/vpc/vpc-stack.ts +++ b/source/resources/lib/vpc/vpc-stack.ts @@ -1,53 +1,64 @@ import * as cdk from 'aws-cdk-lib'; -import { Construct } from 'constructs'; +import {Construct} from 'constructs'; import * as ec2 from 'aws-cdk-lib/aws-ec2'; -import * as ecs from 'aws-cdk-lib/aws-ecs'; -import * as ecr from 'aws-cdk-lib/aws-ecr'; -import * as iam from 'aws-cdk-lib/aws-iam'; -import { DockerImageAsset } from 'aws-cdk-lib/aws-ecr-assets'; -import * as ecs_patterns from 'aws-cdk-lib/aws-ecs-patterns'; -import * as path from 'path'; + export class VPCStack extends cdk.Stack { -public readonly vpc: ec2.Vpc; -public readonly publicSubnets: ec2.ISubnet[]; -constructor(scope: Construct, id: string, props: cdk.StackProps) { - super(scope, id, props); - // Create a VPC - const vpc = new ec2.Vpc(this, 'GenBIVpc', { - maxAzs: 3, // Default is all AZs in the region - subnetConfiguration: [ - { - cidrMask: 24, - name: 'public-subnet', - subnetType: ec2.SubnetType.PUBLIC, - }, - { - cidrMask: 24, - name: 'private-subnet', - subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS, - }, - ], - }); - - this.vpc = vpc; - - // Output the VPC ID - new cdk.CfnOutput(this, 'VpcId', { - value: vpc.vpcId, - }); - - // Output the Subnet IDs - vpc.publicSubnets.forEach((subnet, index) => { - new cdk.CfnOutput(this, `PublicSubnet${index}Id`, { - value: subnet.subnetId, + public readonly vpc: ec2.Vpc; + public readonly publicSubnets: ec2.ISubnet[]; + + constructor(scope: Construct, id: string, props: cdk.StackProps) { + super(scope, id, props); + // Create a VPC + const vpc = new ec2.Vpc(this, 'GenBIVpc', { + maxAzs: 3, // Default is all AZs in the region + natGateways: 1, + subnetConfiguration: [ + { + cidrMask: 24, + name: 'public-subnet', + subnetType: ec2.SubnetType.PUBLIC, + }, + { + cidrMask: 24, + name: 'private-subnet', + subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS, + }, + ], + }); + + this.vpc = vpc; + + // Output the VPC ID + new cdk.CfnOutput(this, 'VpcId', { + value: vpc.vpcId, + }); + + // Output the Subnet IDs + vpc.publicSubnets.forEach((subnet, index) => { + new cdk.CfnOutput(this, `PublicSubnet${index}Id`, { + value: subnet.subnetId, + }); }); - }); - - vpc.privateSubnets.forEach((subnet, index) => { - new cdk.CfnOutput(this, `PrivateSubnet${index}Id`, { - value: subnet.subnetId, + + vpc.privateSubnets.forEach((subnet, index) => { + new cdk.CfnOutput(this, `PrivateSubnet${index}Id`, { + value: subnet.subnetId, + }); + }); + + // Output NatGatewayId ID + new cdk.CfnOutput(this, 'NatGatewayId', { + value: this.vpc.natGateways[0].gatewayId, + }); + + // Output RouteTable ID + new cdk.CfnOutput(this, 'PublicRouteTableId', { + value: this.vpc.publicSubnets[0].routeTable.routeTableId, + }); + + new cdk.CfnOutput(this, 'PrivateRouteTableId', { + value: this.vpc.privateSubnets[0].routeTable.routeTableId, }); - }); - } + } } \ No newline at end of file From c36a3ac90714e616b32e4cb6f8e547ef9ed0ff36 Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 24 Jul 2024 20:29:05 +0800 Subject: [PATCH 070/130] fix some bug --- ...1_\360\237\214\215_Generative_BI_Playground.py" | 3 ++- source/resources/cdk-config.json | 6 ++++-- source/resources/lib/rds/rds-stack.ts | 2 +- source/resources/lib/vpc/vpc-stack.ts | 14 -------------- 4 files changed, 7 insertions(+), 18 deletions(-) diff --git "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" index 338ad64..ca59147 100644 --- "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" +++ "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" @@ -65,7 +65,8 @@ def get_user_history(selected_profile: str): history_list = st.session_state.query_rewrite_history[selected_profile] history_query = [] for messages in history_list: - history_query.append(messages["role"] + ":" + messages["content"]) + if messages["content"] is not None: + history_query.append(messages["role"] + ":" + messages["content"]) return history_query diff --git a/source/resources/cdk-config.json b/source/resources/cdk-config.json index ae4b5bb..58f1b94 100644 --- a/source/resources/cdk-config.json +++ b/source/resources/cdk-config.json @@ -9,8 +9,10 @@ "bedrock_embedding_name": "amazon.titan-embed-text-v1", "embedding_dimension": 1536 }, - "segamaker": { - "endpoint_name" : "" + "sagemaker": { + "segamaker_endpoint_embedding" : "", + "segamaker_endpoint_sql" : "", + "embedding_dimension": "" }, "opensearch": { "sql_index" : "uba", diff --git a/source/resources/lib/rds/rds-stack.ts b/source/resources/lib/rds/rds-stack.ts index a912f86..30308bb 100644 --- a/source/resources/lib/rds/rds-stack.ts +++ b/source/resources/lib/rds/rds-stack.ts @@ -16,7 +16,7 @@ export class RDSStack extends cdk.Stack { constructor(scope: Construct, id: string, props: RDSStackProps) { super(scope, id, props); - const templatedSecret = new secretsmanager.Secret(this, 'TemplatedSecret', { + const templatedSecret = new secretsmanager.Secret(this, 'GenBIRDSTemplatedSecret', { description: 'Templated secret used for RDS password', generateSecretString: { excludePunctuation: true, diff --git a/source/resources/lib/vpc/vpc-stack.ts b/source/resources/lib/vpc/vpc-stack.ts index f04c79c..76d08d7 100644 --- a/source/resources/lib/vpc/vpc-stack.ts +++ b/source/resources/lib/vpc/vpc-stack.ts @@ -46,19 +46,5 @@ export class VPCStack extends cdk.Stack { value: subnet.subnetId, }); }); - - // Output NatGatewayId ID - new cdk.CfnOutput(this, 'NatGatewayId', { - value: this.vpc.natGateways[0].gatewayId, - }); - - // Output RouteTable ID - new cdk.CfnOutput(this, 'PublicRouteTableId', { - value: this.vpc.publicSubnets[0].routeTable.routeTableId, - }); - - new cdk.CfnOutput(this, 'PrivateRouteTableId', { - value: this.vpc.privateSubnets[0].routeTable.routeTableId, - }); } } \ No newline at end of file From c0699a5dcf041153581d77b3bc405cf2946de96e Mon Sep 17 00:00:00 2001 From: supinyu Date: Wed, 24 Jul 2024 20:31:46 +0800 Subject: [PATCH 071/130] fix some bug --- application/utils/llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/application/utils/llm.py b/application/utils/llm.py index c86664a..9adecd5 100644 --- a/application/utils/llm.py +++ b/application/utils/llm.py @@ -138,12 +138,12 @@ def invoke_mixtral_8x7b_sagemaker(model_id, system_prompt, messages, max_tokens, model_type="LLM", with_response_stream=with_response_stream ) - return response['generated_text'] + response = str(response, 'utf-8') + return response except Exception as e: logger.error("Couldn't invoke Mixtral 8x7B on SageMaker") logger.error(e) - raise def invoke_mixtral_8x7b(model_id, system_prompt, messages, max_tokens, with_response_stream=False): """ From c79d31e6b0508c9471cbf38ea3630d2a40505048 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 11:53:04 +0000 Subject: [PATCH 072/130] Bump torch from 2.1.2 to 2.2.0 in /source/model/internlm/code Bumps [torch](https://github.com/pytorch/pytorch) from 2.1.2 to 2.2.0. - [Release notes](https://github.com/pytorch/pytorch/releases) - [Changelog](https://github.com/pytorch/pytorch/blob/main/RELEASE.md) - [Commits](https://github.com/pytorch/pytorch/compare/v2.1.2...v2.2.0) --- updated-dependencies: - dependency-name: torch dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- source/model/internlm/code/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/model/internlm/code/requirements.txt b/source/model/internlm/code/requirements.txt index 40c5321..ebdb182 100644 --- a/source/model/internlm/code/requirements.txt +++ b/source/model/internlm/code/requirements.txt @@ -1,5 +1,5 @@ exllamav2==0.0.14 -torch==2.1.2 +torch==2.2.0 sentencepiece==0.1.99 accelerate==0.25.0 bitsandbytes==0.41.1 From aa7c8362ae9664a4dd83ab772473fc2e8c07422a Mon Sep 17 00:00:00 2001 From: supinyu Date: Fri, 26 Jul 2024 13:42:17 +0800 Subject: [PATCH 073/130] change code for segamaker --- application/nlq/business/vector_store.py | 13 +++++++------ application/utils/llm.py | 15 ++++++++------- source/resources/lib/ecs/ecs-stack.ts | 4 ++++ 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/application/nlq/business/vector_store.py b/application/nlq/business/vector_store.py index 3a118ba..c57697f 100644 --- a/application/nlq/business/vector_store.py +++ b/application/nlq/business/vector_store.py @@ -131,12 +131,13 @@ def create_vector_embedding_with_bedrock(cls, text): @classmethod def create_vector_embedding_with_sagemaker(cls, text): try: - model_kwargs = {} - model_kwargs["batch_size"] = 12 - model_kwargs["max_length"] = 512 - model_kwargs["return_type"] = "dense" - body = json.dumps({"inputs": [text], **model_kwargs}) - embeddings = invoke_model_sagemaker_endpoint(SAGEMAKER_ENDPOINT_EMBEDDING, body) + body = json.dumps( + { + "inputs": text, + "is_query": True + } + ) + embeddings = invoke_model_sagemaker_endpoint(SAGEMAKER_ENDPOINT_EMBEDDING, body, model_type="embedding") return embeddings except Exception as e: logger.error(f'create_vector_embedding_with_sagemaker is error {e}') diff --git a/application/utils/llm.py b/application/utils/llm.py index 9adecd5..4a54cf5 100644 --- a/application/utils/llm.py +++ b/application/utils/llm.py @@ -570,14 +570,15 @@ def create_vector_embedding_with_bedrock(text, index_name): def create_vector_embedding_with_sagemaker(endpoint_name, text, index_name): - model_kwargs = {} - model_kwargs["batch_size"] = 12 - model_kwargs["max_length"] = 512 - model_kwargs["return_type"] = "dense" - body = json.dumps({"inputs": [text], **model_kwargs}) + body = json.dumps( + { + "inputs": text, + "is_query": True + } + ) response = invoke_model_sagemaker_endpoint(endpoint_name, body, model_type="embedding") - embeddings = response["sentence_embeddings"] - return {"_index": index_name, "text": text, "vector_field": embeddings["dense_vecs"][0]} + embeddings = response[0] + return {"_index": index_name, "text": text, "vector_field": embeddings} def generate_suggested_question(prompt_map, search_box, model_id=None): diff --git a/source/resources/lib/ecs/ecs-stack.ts b/source/resources/lib/ecs/ecs-stack.ts index eaae0f7..70ce5e2 100644 --- a/source/resources/lib/ecs/ecs-stack.ts +++ b/source/resources/lib/ecs/ecs-stack.ts @@ -198,6 +198,8 @@ export class ECSStack extends cdk.Stack { containerStreamlit.addEnvironment('AOS_INDEX_AGENT', 'uba_agent'); // containerStreamlit.addEnvironment('SAGEMAKER_EMBEDDING_REGION', cdk.Aws.REGION); // containerStreamlit.addEnvironment('SAGEMAKER_SQL_REGION', cdk.Aws.REGION); + // containerStreamlit.addEnvironment('SAGEMAKER_ENDPOINT_EMBEDDING', ''); + // containerStreamlit.addEnvironment('SAGEMAKER_ENDPOINT_SQL', ''); containerStreamlit.addEnvironment('BEDROCK_REGION', cdk.Aws.REGION); containerStreamlit.addEnvironment('RDS_REGION_NAME', cdk.Aws.REGION); containerStreamlit.addEnvironment('AWS_DEFAULT_REGION', cdk.Aws.REGION); @@ -246,6 +248,8 @@ export class ECSStack extends cdk.Stack { containerAPI.addEnvironment('AOS_INDEX_AGENT', 'uba_agent'); // containerAPI.addEnvironment('SAGEMAKER_EMBEDDING_REGION', cdk.Aws.REGION); // containerAPI.addEnvironment('SAGEMAKER_SQL_REGION', cdk.Aws.REGION); + // containerAPI.addEnvironment('SAGEMAKER_ENDPOINT_EMBEDDING', ''); + // containerAPI.addEnvironment('SAGEMAKER_ENDPOINT_SQL', ''); containerAPI.addEnvironment('BEDROCK_REGION', cdk.Aws.REGION); containerAPI.addEnvironment('RDS_REGION_NAME', cdk.Aws.REGION); containerAPI.addEnvironment('AWS_DEFAULT_REGION', cdk.Aws.REGION); From f47ee3bd4d1cb7f0722c7f17b26d482ac8241276 Mon Sep 17 00:00:00 2001 From: supinyu Date: Fri, 26 Jul 2024 15:08:18 +0800 Subject: [PATCH 074/130] change code for segamaker --- application/nlq/business/vector_store.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/application/nlq/business/vector_store.py b/application/nlq/business/vector_store.py index c57697f..b0f9577 100644 --- a/application/nlq/business/vector_store.py +++ b/application/nlq/business/vector_store.py @@ -137,7 +137,8 @@ def create_vector_embedding_with_sagemaker(cls, text): "is_query": True } ) - embeddings = invoke_model_sagemaker_endpoint(SAGEMAKER_ENDPOINT_EMBEDDING, body, model_type="embedding") + response = invoke_model_sagemaker_endpoint(SAGEMAKER_ENDPOINT_EMBEDDING, body, model_type="embedding") + embeddings = response[0] return embeddings except Exception as e: logger.error(f'create_vector_embedding_with_sagemaker is error {e}') From 44d51d213670dec7e4f9a6b72f7b26966fc34378 Mon Sep 17 00:00:00 2001 From: Zhoutong Wang Date: Fri, 26 Jul 2024 16:18:11 +0800 Subject: [PATCH 075/130] Update main-stack.ts to fix cognito bug --- source/resources/lib/main-stack.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index 2468c31..0e131e4 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -75,7 +75,7 @@ export class MainStack extends cdk.Stack { let _CognitoStack: CognitoStack | undefined; if (!isChinaRegion) { - const _CognitoStack = new CognitoStack(this, 'cognito-Stack', { + _CognitoStack = new CognitoStack(this, 'cognito-Stack', { env: props.env }); } @@ -138,4 +138,4 @@ export class MainStack extends cdk.Stack { description: 'The endpoint of the API service' }); } -} \ No newline at end of file +} From d23b4c888c226d4ee9ab48345aededda5339f933 Mon Sep 17 00:00:00 2001 From: wubinbin Date: Wed, 31 Jul 2024 14:55:03 +0800 Subject: [PATCH 076/130] import classnames library --- report-front-end/package.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/report-front-end/package.json b/report-front-end/package.json index 524601b..13a9e2b 100644 --- a/report-front-end/package.json +++ b/report-front-end/package.json @@ -53,6 +53,7 @@ "sass": "^1.65.1", "typescript": "^5.0.2", "vite": "^4.5.2", - "zen-observable": "^0.10.0" + "zen-observable": "^0.10.0", + "classnames": "^2.5.1" } } From 39ee6d746ee696f2050b77ddfd4ba86d4831e495 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Wed, 31 Jul 2024 15:16:59 +0800 Subject: [PATCH 077/130] fix explain_gen_process_flag bug --- .../pages/1_\360\237\214\215_Generative_BI_Playground.py" | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" index ca59147..824494a 100644 --- "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" +++ "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" @@ -473,7 +473,7 @@ def main(): database_profile, entity_slot, opensearch_info, selected_profile, - explain_gen_process_flag, use_rag_flag) + use_rag_flag, use_rag_flag) elif knowledge_search_flag: with st.spinner('Performing knowledge search...'): response = knowledge_search(search_box=search_box, model_id=model_type, From c6f59cd8789849f58c67c1a36317d6c662e2597a Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Wed, 31 Jul 2024 15:20:16 +0800 Subject: [PATCH 078/130] change _CognitoStack --- source/resources/lib/main-stack.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index ec22a4a..82e7312 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -1,3 +1,4 @@ + import {StackProps, CfnParameter, CfnOutput} from 'aws-cdk-lib'; import * as cdk from 'aws-cdk-lib'; import {Construct} from 'constructs'; @@ -60,7 +61,7 @@ export class MainStack extends cdk.Stack { let _CognitoStack: CognitoStack | undefined; if (!isChinaRegion) { - _CognitoStack = new CognitoStack(this, 'cognito-Stack', { + const _CognitoStack = new CognitoStack(this, 'cognito-Stack', { env: props.env }); } From 33581058dc0277db76fbf9c0b7c44263caf4ed65 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Wed, 31 Jul 2024 15:23:49 +0800 Subject: [PATCH 079/130] change _CognitoStack --- source/resources/lib/main-stack.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index 82e7312..b5b291a 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -61,7 +61,7 @@ export class MainStack extends cdk.Stack { let _CognitoStack: CognitoStack | undefined; if (!isChinaRegion) { - const _CognitoStack = new CognitoStack(this, 'cognito-Stack', { + _CognitoStack = new CognitoStack(this, 'cognito-Stack', { env: props.env }); } From 77fccd53954b9757525f1bd29917c9b2219fee09 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Wed, 31 Jul 2024 15:36:39 +0800 Subject: [PATCH 080/130] remove some line --- source/resources/lib/main-stack.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/source/resources/lib/main-stack.ts b/source/resources/lib/main-stack.ts index b5b291a..ec22a4a 100644 --- a/source/resources/lib/main-stack.ts +++ b/source/resources/lib/main-stack.ts @@ -1,4 +1,3 @@ - import {StackProps, CfnParameter, CfnOutput} from 'aws-cdk-lib'; import * as cdk from 'aws-cdk-lib'; import {Construct} from 'constructs'; From 4bad91882c15bc64020aae401caef6c0fc0cac2d Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Wed, 31 Jul 2024 21:10:43 +0800 Subject: [PATCH 081/130] change Login configuration --- report-front-end/.env | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/report-front-end/.env b/report-front-end/.env index 05d174d..4eff791 100644 --- a/report-front-end/.env +++ b/report-front-end/.env @@ -9,7 +9,7 @@ VITE_RIGHT_LOGO= # Login configuration, e.g. Cognito | None -VITE_LOGIN_TYPE=PLACEHOLDER_VITE_LOGIN_TYPE +VITE_LOGIN_TYPE=Cognito # KEEP the placeholder values if using CDK to deploy the backend! From 3f3ec78efcc7ba4197adc40590041b76d1021860 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 1 Aug 2024 15:42:42 +0800 Subject: [PATCH 082/130] change Login configuration --- application/api/main.py | 9 +- application/api/schemas.py | 22 ++++- application/api/service.py | 90 ++++++++++++++++--- application/nlq/business/log_store.py | 11 ++- .../nlq/data_access/dynamo_query_log.py | 41 ++++++++- application/utils/tool.py | 10 +++ 6 files changed, 163 insertions(+), 20 deletions(-) diff --git a/application/api/main.py b/application/api/main.py index 96648c1..3351772 100644 --- a/application/api/main.py +++ b/application/api/main.py @@ -4,7 +4,7 @@ import logging from nlq.business.profile import ProfileManagement from .enum import ContentEnum -from .schemas import Question, Answer, Option, CustomQuestion, FeedBackInput +from .schemas import Question, Answer, Option, CustomQuestion, FeedBackInput, HistoryRequest from . import service from nlq.business.nlq_chain import NLQChain from dotenv import load_dotenv @@ -38,6 +38,13 @@ def ask(question: Question): return service.ask(question) +@router.post("/get_history_by_user_profile") +def get_history_by_user_profile(history_request : HistoryRequest): + user_id = history_request.user_id + profile_name = history_request.profile_name + return service.get_history_by_user_profile(user_id, profile_name) + + @router.post("/user_feedback") def user_feedback(input_data: FeedBackInput): feedback_type = input_data.feedback_type diff --git a/application/api/schemas.py b/application/api/schemas.py index c2a4a80..28d8597 100644 --- a/application/api/schemas.py +++ b/application/api/schemas.py @@ -17,7 +17,7 @@ class Question(BaseModel): top_p: float = 0.9 max_tokens: int = 2048 temperature: float = 0.01 - context_window: int = 3 + context_window: int = 5 session_id: str = "-1" user_id: str = "admin" @@ -28,6 +28,11 @@ class Example(BaseModel): answer: str +class HistoryRequest(BaseModel): + user_id: str + profile_name: str + + class QueryEntity(BaseModel): query: str sql: str @@ -80,10 +85,25 @@ class AgentSearchResult(BaseModel): agent_summary: str +class AskReplayResult(BaseModel): + query_rewrite: str + + class Answer(BaseModel): query: str query_intent: str knowledge_search_result: KnowledgeSearchResult sql_search_result: SQLSearchResult agent_search_result: AgentSearchResult + ask_rewrite_result: AskReplayResult suggested_question: list[str] + + +class HistoryMessage(BaseModel): + type: str + content: Answer + + +class ChatHistory(BaseModel): + session_id: str + messages: list[HistoryMessage] diff --git a/application/api/service.py b/application/api/service.py index 101192b..6b7e538 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -14,13 +14,14 @@ from utils.domain import SearchTextSqlResult from utils.llm import text_to_sql, get_query_intent, create_vector_embedding_with_sagemaker, \ sagemaker_to_sql, sagemaker_to_explain, knowledge_search, get_agent_cot_task, data_analyse_tool, \ - generate_suggested_question, data_visualization + generate_suggested_question, data_visualization, get_query_rewrite from utils.opensearch import get_retrieve_opensearch from utils.env_var import opensearch_info from utils.text_search import normal_text_search, agent_text_search -from utils.tool import generate_log_id, get_current_time, get_generated_sql_explain, get_generated_sql +from utils.tool import generate_log_id, get_current_time, get_generated_sql_explain, get_generated_sql, \ + change_class_to_str from .schemas import Question, Answer, Example, Option, SQLSearchResult, AgentSearchResult, KnowledgeSearchResult, \ - TaskSQLSearchResult, ChartEntity + TaskSQLSearchResult, ChartEntity, AskReplayResult from .exception_handler import BizException from utils.constant import BEDROCK_MODEL_IDS, ACTIVE_PROMPT_NAME from .enum import ErrorEnum, ContentEnum @@ -57,6 +58,11 @@ def get_example(current_nlq_chain: NLQChain) -> list[Example]: return examples +def get_history_by_user_profile(user_id: str, profile_name: str): + history_list = LogManagement.get_history(user_id, profile_name) + return history_list + + def get_result_from_llm(question: Question, current_nlq_chain: NLQChain, with_response_stream=False) -> Union[ str, dict]: logger.info('try to get generated sql from LLM') @@ -367,11 +373,13 @@ async def ask_websocket(websocket: WebSocket, question: Question): explain_gen_process_flag = question.explain_gen_process_flag gen_suggested_question_flag = question.gen_suggested_question_flag answer_with_insights = question.answer_with_insights + context_window = question.context_window reject_intent_flag = False search_intent_flag = False agent_intent_flag = False knowledge_search_flag = False + ask_replay_flag = False agent_search_result = [] normal_search_result = None @@ -389,6 +397,8 @@ async def ask_websocket(websocket: WebSocket, question: Question): sql_chart_data = ChartEntity(chart_type="", chart_data=[]) + ask_result = AskReplayResult(query_rewrite="") + sql_search_result = SQLSearchResult(sql_data=[], sql="", data_show_type="table", sql_gen_process="", data_analyse="", sql_data_chart=[]) @@ -410,6 +420,38 @@ async def ask_websocket(websocket: WebSocket, question: Question): entity_slot = [] + user_query_history = [] + query_rewrite_result = {"intent": "original_problem", "query": search_box} + if context_window > 0: + context_window_select = context_window * 2 + user_query_history = user_query_history[-context_window_select:] + logger.info("The Chat history is {history}".format(history="\n".join(user_query_history))) + query_rewrite_result = get_query_rewrite(model_type, search_box, prompt_map, user_query_history) + logger.info( + "The query_rewrite_result is {query_rewrite_result}".format(query_rewrite_result=query_rewrite_result)) + search_box = query_rewrite_result.get("query") + + query_rewrite_intent = query_rewrite_result.get("intent") + if "ask_in_reply" == query_rewrite_intent: + ask_replay_flag = True + + if ask_replay_flag: + ask_result.query_rewrite = query_rewrite_result.get("query") + + answer = Answer(query=search_box, query_intent="ask_in_reply", knowledge_search_result=knowledge_search_result, + sql_search_result=sql_search_result, agent_search_result=agent_search_response, + suggested_question=[], ask_rewrite_result=ask_result) + + ask_answer_info = change_class_to_str(answer) + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, + profile_name=selected_profile, sql="", query=search_box, + intent="ask_in_reply", + log_info=ask_answer_info, + log_type="chat_history", + time_str=current_time) + return answer + + if intent_ner_recognition_flag: await response_websocket(websocket, session_id, "Query Intent Analyse", ContentEnum.STATE, "start", user_id) intent_response = get_query_intent(model_type, search_box, prompt_map) @@ -438,10 +480,12 @@ async def ask_websocket(websocket: WebSocket, question: Question): if reject_intent_flag: answer = Answer(query=search_box, query_intent="reject_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, - suggested_question=[]) + suggested_question=[], ask_rewrite_result=ask_result) + reject_answer_info = change_class_to_str(answer) LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, profile_name=selected_profile, sql="", query=search_box, - intent="reject_search", log_info="", time_str=current_time) + intent="reject_search", log_info=reject_answer_info, log_type="chat_history", + time_str=current_time) return answer elif search_intent_flag: normal_search_result = await normal_text_search_websocket(websocket, session_id, search_box, model_type, @@ -455,12 +499,13 @@ async def ask_websocket(websocket: WebSocket, question: Question): answer = Answer(query=search_box, query_intent="knowledge_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, - suggested_question=[]) - + suggested_question=[], ask_rewrite_result=ask_result) + knowledge_answer_info = change_class_to_str(answer) LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, profile_name=selected_profile, sql="", query=search_box, intent="knowledge_search", - log_info=knowledge_search_result.knowledge_response, + log_info=knowledge_answer_info, + log_type="chat_history", time_str=current_time) return answer @@ -482,7 +527,6 @@ async def ask_websocket(websocket: WebSocket, question: Question): split_strings = generated_sq.split("[generate]") generate_suggested_question_list = [s.strip() for s in split_strings if s.strip()] - if search_intent_flag: if normal_search_result.sql != "": current_nlq_chain.set_generated_sql(normal_search_result.sql) @@ -539,10 +583,20 @@ async def ask_websocket(websocket: WebSocket, question: Question): query=search_box, intent="normal_search", log_info=log_info, + log_type="normal_log", time_str=current_time) answer = Answer(query=search_box, query_intent="normal_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, - suggested_question=generate_suggested_question_list) + suggested_question=generate_suggested_question_list, ask_rewrite_result=ask_result) + + intent_answer_info = change_class_to_str(answer) + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, + profile_name=selected_profile, sql=sql_search_result.sql, + query=search_box, + intent="normal_search", + log_info=intent_answer_info, + log_type="chat_history", + time_str=current_time) return answer else: sub_search_task = [] @@ -599,7 +653,16 @@ async def ask_websocket(websocket: WebSocket, question: Question): answer = Answer(query=search_box, query_intent="agent_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, - suggested_question=generate_suggested_question_list) + suggested_question=generate_suggested_question_list, ask_rewrite_result=ask_result) + + agent_answer_info = change_class_to_str(answer) + LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, + profile_name=selected_profile, sql="", + query=search_box, + intent="agent_search", + log_info=agent_answer_info, + log_type="chat_history", + time_str=current_time) return answer @@ -663,7 +726,7 @@ def get_executed_result(current_nlq_chain: NLQChain) -> str: async def normal_text_search_websocket(websocket: WebSocket, session_id: str, search_box, model_type, database_profile, - entity_slot, opensearch_info, selected_profile, use_rag,user_id, + entity_slot, opensearch_info, selected_profile, use_rag, user_id, model_provider=None): entity_slot_retrieve = [] retrieve_result = [] @@ -679,7 +742,8 @@ async def normal_text_search_websocket(websocket: WebSocket, session_id: str, se database_profile['db_type'] = ConnectionManagement.get_db_type_by_name(conn_name) if len(entity_slot) > 0 and use_rag: - await response_websocket(websocket, session_id, "Entity Info Retrieval", ContentEnum.STATE, "start", user_id) + await response_websocket(websocket, session_id, "Entity Info Retrieval", ContentEnum.STATE, "start", + user_id) for each_entity in entity_slot: entity_retrieve = get_retrieve_opensearch(opensearch_info, each_entity, "ner", selected_profile, 1, 0.7) diff --git a/application/nlq/business/log_store.py b/application/nlq/business/log_store.py index 9ab480a..2b7951c 100644 --- a/application/nlq/business/log_store.py +++ b/application/nlq/business/log_store.py @@ -9,6 +9,13 @@ class LogManagement: query_log_dao = DynamoQueryLogDao() @classmethod - def add_log_to_database(cls, log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str): + def add_log_to_database(cls, log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str, + log_type="normal_log", ): cls.query_log_dao.add_log(log_id=log_id, profile_name=profile_name, user_id=user_id, session_id=session_id, - sql=sql, query=query, intent=intent, log_info=log_info, time_str=time_str) + sql=sql, query=query, intent=intent, log_info=log_info, log_type=log_type, + time_str=time_str) + + @classmethod + def get_history(cls, user_id, profile_name): + history_list = cls.query_log_dao.get_history_by_user_profile(user_id, profile_name) + return history_list diff --git a/application/nlq/data_access/dynamo_query_log.py b/application/nlq/data_access/dynamo_query_log.py index 9f6ef68..0b12fff 100644 --- a/application/nlq/data_access/dynamo_query_log.py +++ b/application/nlq/data_access/dynamo_query_log.py @@ -3,6 +3,7 @@ import boto3 from botocore.exceptions import ClientError +from boto3.dynamodb.conditions import Key logger = logging.getLogger(__name__) @@ -12,7 +13,7 @@ class DynamoQueryLogEntity: - def __init__(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str): + def __init__(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, log_type, time_str): self.log_id = log_id self.profile_name = profile_name self.user_id = user_id @@ -21,6 +22,7 @@ def __init__(self, log_id, profile_name, user_id, session_id, sql, query, intent self.query = query self.intent = intent self.log_info = log_info + self.log_type = log_type self.time_str = time_str def to_dict(self): @@ -34,6 +36,7 @@ def to_dict(self): 'query': self.query, 'intent': self.intent, 'log_info': self.log_info, + 'log_type': self.log_type, 'time_str': self.time_str } @@ -113,6 +116,38 @@ def add(self, entity): def update(self, entity): self.table.put_item(Item=entity.to_dict()) - def add_log(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str): - entity = DynamoQueryLogEntity(log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str) + def add_log(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, log_type, time_str): + entity = DynamoQueryLogEntity(log_id, profile_name, user_id, session_id, sql, query, intent, log_info, log_type, time_str) self.add(entity) + + def get_history_by_user_profile(self, user_id, profile_name): + try: + # First, we need to scan the table to find all items for the user and profile + response = self.table.scan( + FilterExpression=Key('user_id').eq(user_id) & Key('profile_name').eq(profile_name) + ) + + items = response['Items'] + + # DynamoDB might not return all items in a single response if the data set is large + while 'LastEvaluatedKey' in response: + response = self.table.scan( + FilterExpression=Key('user_id').eq(user_id) & Key('profile_name').eq(profile_name), + ExclusiveStartKey=response['LastEvaluatedKey'] + ) + items.extend(response['Items']) + + # Sort the items by time_str to get them in chronological order + sorted_items = sorted(items, key=lambda x: x['time_str']) + + return sorted_items + + except ClientError as err: + logger.error( + "Couldn't get history for user %s and profile %s. Here's why: %s: %s", + user_id, + profile_name, + err.response["Error"]["Code"], + err.response["Error"]["Message"], + ) + return [] \ No newline at end of file diff --git a/application/utils/tool.py b/application/utils/tool.py index 3743dfa..9003484 100644 --- a/application/utils/tool.py +++ b/application/utils/tool.py @@ -1,3 +1,4 @@ +import json import logging import time import random @@ -36,3 +37,12 @@ def get_generated_sql_explain(generated_sql_response): return generated_sql_response[index + len(""):] else: return generated_sql_response + + +def change_class_to_str(result): + try: + log_info = json.dumps(result.dict()) + return log_info + except Exception as e: + logger.error(f"Error in changing class to string: {e}") + return "" From 027e53cd610213e8d4638fc6183daf491e40b811 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 1 Aug 2024 16:56:25 +0800 Subject: [PATCH 083/130] change Login configuration --- application/api/schemas.py | 12 +++++++---- application/api/service.py | 21 +++++++++++++++++-- .../nlq/data_access/dynamo_query_log.py | 2 +- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/application/api/schemas.py b/application/api/schemas.py index 28d8597..0f8d973 100644 --- a/application/api/schemas.py +++ b/application/api/schemas.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Union from pydantic import BaseModel @@ -99,11 +99,15 @@ class Answer(BaseModel): suggested_question: list[str] -class HistoryMessage(BaseModel): +class Message(BaseModel): type: str - content: Answer + content: Union[str, Answer] -class ChatHistory(BaseModel): +class HistoryMessage(BaseModel): session_id: str + messages: list[Message] + + +class ChatHistory(BaseModel): messages: list[HistoryMessage] diff --git a/application/api/service.py b/application/api/service.py index 6b7e538..41087ce 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -21,7 +21,7 @@ from utils.tool import generate_log_id, get_current_time, get_generated_sql_explain, get_generated_sql, \ change_class_to_str from .schemas import Question, Answer, Example, Option, SQLSearchResult, AgentSearchResult, KnowledgeSearchResult, \ - TaskSQLSearchResult, ChartEntity, AskReplayResult + TaskSQLSearchResult, ChartEntity, AskReplayResult, ChatHistory, Message, HistoryMessage from .exception_handler import BizException from utils.constant import BEDROCK_MODEL_IDS, ACTIVE_PROMPT_NAME from .enum import ErrorEnum, ContentEnum @@ -60,7 +60,23 @@ def get_example(current_nlq_chain: NLQChain) -> list[Example]: def get_history_by_user_profile(user_id: str, profile_name: str): history_list = LogManagement.get_history(user_id, profile_name) - return history_list + chat_history = [] + chat_history_session = [] + for item in history_list: + session_id = item['session_id'] + if session_id not in chat_history_session: + chat_history_session[session_id] = [] + log_info = item['log_info'] + query = item['query'] + human_message = Message(type="human", human_content=query) + bot_message = Message(type="AI", bot_content=json.loads(log_info)) + chat_history_session[session_id].append(human_message) + chat_history_session[session_id].append(bot_message) + + for key, value in chat_history_session.items(): + each_session_history = HistoryMessage(session_id=key, messages=value) + chat_history.append(each_session_history) + return chat_history def get_result_from_llm(question: Question, current_nlq_chain: NLQChain, with_response_stream=False) -> Union[ @@ -425,6 +441,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): if context_window > 0: context_window_select = context_window * 2 user_query_history = user_query_history[-context_window_select:] + user_query_history = ["user: " + search_box] logger.info("The Chat history is {history}".format(history="\n".join(user_query_history))) query_rewrite_result = get_query_rewrite(model_type, search_box, prompt_map, user_query_history) logger.info( diff --git a/application/nlq/data_access/dynamo_query_log.py b/application/nlq/data_access/dynamo_query_log.py index 0b12fff..69ef9c9 100644 --- a/application/nlq/data_access/dynamo_query_log.py +++ b/application/nlq/data_access/dynamo_query_log.py @@ -124,7 +124,7 @@ def get_history_by_user_profile(self, user_id, profile_name): try: # First, we need to scan the table to find all items for the user and profile response = self.table.scan( - FilterExpression=Key('user_id').eq(user_id) & Key('profile_name').eq(profile_name) + FilterExpression=Key('user_id').eq(user_id) & Key('profile_name').eq(profile_name) & Key('log_type').eq("chat_history") ) items = response['Items'] From ce87a302d750ebf509e853ada659edac70b05997 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 1 Aug 2024 17:23:44 +0800 Subject: [PATCH 084/130] change code for multi chat --- application/api/service.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/application/api/service.py b/application/api/service.py index 41087ce..02cc18b 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -61,15 +61,15 @@ def get_example(current_nlq_chain: NLQChain) -> list[Example]: def get_history_by_user_profile(user_id: str, profile_name: str): history_list = LogManagement.get_history(user_id, profile_name) chat_history = [] - chat_history_session = [] + chat_history_session = {} for item in history_list: session_id = item['session_id'] if session_id not in chat_history_session: chat_history_session[session_id] = [] log_info = item['log_info'] query = item['query'] - human_message = Message(type="human", human_content=query) - bot_message = Message(type="AI", bot_content=json.loads(log_info)) + human_message = Message(type="human", content=query) + bot_message = Message(type="AI", content=json.loads(log_info)) chat_history_session[session_id].append(human_message) chat_history_session[session_id].append(bot_message) From d7ffca6bae19055067eaae166ce7ff781a5a9160 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 1 Aug 2024 19:08:23 +0800 Subject: [PATCH 085/130] change share data for chatbot --- application/api/schemas.py | 1 + application/api/service.py | 47 +++++++++++-------- application/main.py | 45 +++++++++++++++--- application/nlq/business/log_store.py | 5 ++ .../nlq/data_access/dynamo_query_log.py | 32 ++++++++++++- application/utils/tool.py | 15 ++++++ 6 files changed, 117 insertions(+), 28 deletions(-) diff --git a/application/api/schemas.py b/application/api/schemas.py index 0f8d973..e040978 100644 --- a/application/api/schemas.py +++ b/application/api/schemas.py @@ -91,6 +91,7 @@ class AskReplayResult(BaseModel): class Answer(BaseModel): query: str + query_rewrite: str = "" query_intent: str knowledge_search_result: KnowledgeSearchResult sql_search_result: SQLSearchResult diff --git a/application/api/service.py b/application/api/service.py index 02cc18b..8b89892 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -3,6 +3,8 @@ from typing import Union from dotenv import load_dotenv import logging + +from main import shared_data from nlq.business.connection import ConnectionManagement from nlq.business.nlq_chain import NLQChain from nlq.business.profile import ProfileManagement @@ -19,7 +21,7 @@ from utils.env_var import opensearch_info from utils.text_search import normal_text_search, agent_text_search from utils.tool import generate_log_id, get_current_time, get_generated_sql_explain, get_generated_sql, \ - change_class_to_str + change_class_to_str, get_window_history from .schemas import Question, Answer, Example, Option, SQLSearchResult, AgentSearchResult, KnowledgeSearchResult, \ TaskSQLSearchResult, ChartEntity, AskReplayResult, ChatHistory, Message, HistoryMessage from .exception_handler import BizException @@ -192,8 +194,7 @@ def ask(question: Question) -> Answer: prompt_map = database_profile['prompt_map'] entity_slot = [] - # 通过标志位控制后续的逻辑 - # 主要的意图有4个, 拒绝, 查询, 思维链, 知识问答 + if intent_ner_recognition_flag: intent_response = get_query_intent(model_type, search_box, prompt_map) intent = intent_response.get("intent", "normal_search") @@ -405,6 +406,9 @@ async def ask_websocket(websocket: WebSocket, question: Question): log_id = generate_log_id() current_time = get_current_time() log_info = "" + query_rewrite = "" + + shared_data["log_id"] = log_id all_profiles = ProfileManagement.get_all_profiles_with_info() database_profile = all_profiles[selected_profile] @@ -437,11 +441,14 @@ async def ask_websocket(websocket: WebSocket, question: Question): entity_slot = [] user_query_history = [] + if session_id in shared_data: + user_query_history = shared_data[session_id] query_rewrite_result = {"intent": "original_problem", "query": search_box} if context_window > 0: context_window_select = context_window * 2 - user_query_history = user_query_history[-context_window_select:] - user_query_history = ["user: " + search_box] + if len(user_query_history) > 0: + user_query_history = user_query_history[-context_window_select:] + user_query_history = get_window_history(user_query_history) logger.info("The Chat history is {history}".format(history="\n".join(user_query_history))) query_rewrite_result = get_query_rewrite(model_type, search_box, prompt_map, user_query_history) logger.info( @@ -449,13 +456,14 @@ async def ask_websocket(websocket: WebSocket, question: Question): search_box = query_rewrite_result.get("query") query_rewrite_intent = query_rewrite_result.get("intent") + query_rewrite = query_rewrite_result.get("query") if "ask_in_reply" == query_rewrite_intent: ask_replay_flag = True if ask_replay_flag: ask_result.query_rewrite = query_rewrite_result.get("query") - answer = Answer(query=search_box, query_intent="ask_in_reply", knowledge_search_result=knowledge_search_result, + answer = Answer(query=search_box, query_rewrite=ask_result.query_rewrite, query_intent="ask_in_reply", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[], ask_rewrite_result=ask_result) @@ -468,10 +476,9 @@ async def ask_websocket(websocket: WebSocket, question: Question): time_str=current_time) return answer - if intent_ner_recognition_flag: await response_websocket(websocket, session_id, "Query Intent Analyse", ContentEnum.STATE, "start", user_id) - intent_response = get_query_intent(model_type, search_box, prompt_map) + intent_response = get_query_intent(model_type, query_rewrite, prompt_map) await response_websocket(websocket, session_id, "Query Intent Analyse", ContentEnum.STATE, "end", user_id) intent = intent_response.get("intent", "normal_search") entity_slot = intent_response.get("slot", []) @@ -495,7 +502,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): search_intent_flag = True if reject_intent_flag: - answer = Answer(query=search_box, query_intent="reject_search", knowledge_search_result=knowledge_search_result, + answer = Answer(query=search_box, query_rewrite=query_rewrite, query_intent="reject_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[], ask_rewrite_result=ask_result) reject_answer_info = change_class_to_str(answer) @@ -505,7 +512,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): time_str=current_time) return answer elif search_intent_flag: - normal_search_result = await normal_text_search_websocket(websocket, session_id, search_box, model_type, + normal_search_result = await normal_text_search_websocket(websocket, session_id, query_rewrite, model_type, database_profile, entity_slot, opensearch_info, selected_profile, use_rag_flag, user_id) @@ -513,7 +520,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): response = knowledge_search(search_box=search_box, model_id=model_type, prompt_map=prompt_map) knowledge_search_result.knowledge_response = response - answer = Answer(query=search_box, query_intent="knowledge_search", + answer = Answer(query=search_box, query_rewrite=query_rewrite, query_intent="knowledge_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[], ask_rewrite_result=ask_result) @@ -527,20 +534,20 @@ async def ask_websocket(websocket: WebSocket, question: Question): return answer else: - agent_cot_retrieve = get_retrieve_opensearch(opensearch_info, search_box, "agent", + agent_cot_retrieve = get_retrieve_opensearch(opensearch_info, query_rewrite, "agent", selected_profile, 2, 0.5) - agent_cot_task_result = get_agent_cot_task(model_type, prompt_map, search_box, + agent_cot_task_result = get_agent_cot_task(model_type, prompt_map, query_rewrite, database_profile['tables_info'], agent_cot_retrieve) - agent_search_result = agent_text_search(search_box, model_type, + agent_search_result = agent_text_search(query_rewrite, model_type, database_profile, entity_slot, opensearch_info, selected_profile, use_rag_flag, agent_cot_task_result) if gen_suggested_question_flag and (search_intent_flag or agent_intent_flag): active_prompt = sqm.get_prompt_by_name(ACTIVE_PROMPT_NAME).prompt - generated_sq = generate_suggested_question(prompt_map, search_box, model_id=model_type) + generated_sq = generate_suggested_question(prompt_map, query_rewrite, model_id=model_type) split_strings = generated_sq.split("[generate]") generate_suggested_question_list = [s.strip() for s in split_strings if s.strip()] @@ -569,7 +576,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): if answer_with_insights: await response_websocket(websocket, session_id, "Generating Data Insights", ContentEnum.STATE, "start", user_id) - search_intent_analyse_result = data_analyse_tool(model_type, prompt_map, search_box, + search_intent_analyse_result = data_analyse_tool(model_type, prompt_map, query_rewrite, search_intent_result["data"].to_json( orient='records', force_ascii=False), "query") @@ -579,7 +586,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): sql_search_result.data_analyse = search_intent_analyse_result model_select_type, show_select_data, select_chart_type, show_chart_data = data_visualization(model_type, - search_box, + query_rewrite, search_intent_result[ "data"], database_profile[ @@ -602,7 +609,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): log_info=log_info, log_type="normal_log", time_str=current_time) - answer = Answer(query=search_box, query_intent="normal_search", knowledge_search_result=knowledge_search_result, + answer = Answer(query=search_box, query_rewrite=query_rewrite, query_intent="normal_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=generate_suggested_question_list, ask_rewrite_result=ask_result) @@ -660,7 +667,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): intent="agent_search", log_info=log_info, time_str=current_time) - agent_data_analyse_result = data_analyse_tool(model_type, prompt_map, search_box, + agent_data_analyse_result = data_analyse_tool(model_type, prompt_map, query_rewrite, json.dumps(filter_deep_dive_sql_result, ensure_ascii=False), "agent") logger.info("agent_data_analyse_result") @@ -668,7 +675,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): agent_search_response.agent_summary = agent_data_analyse_result agent_search_response.agent_sql_search_result = agent_sql_search_result - answer = Answer(query=search_box, query_intent="agent_search", knowledge_search_result=knowledge_search_result, + answer = Answer(query=search_box, query_rewrite=query_rewrite, query_intent="agent_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=generate_suggested_question_list, ask_rewrite_result=ask_result) diff --git a/application/main.py b/application/main.py index 55d4f9a..03b7913 100644 --- a/application/main.py +++ b/application/main.py @@ -1,3 +1,6 @@ +import json +from multiprocessing import Manager + from fastapi import FastAPI, status from fastapi.staticfiles import StaticFiles from fastapi.responses import RedirectResponse @@ -5,17 +8,21 @@ from api.main import router from fastapi.middleware.cors import CORSMiddleware from api import service -from api.schemas import Option +from api.schemas import Option, Message +from nlq.business.log_store import LogManagement +MAX_CHAT_WINDOW_SIZE = 10 * 2 app = FastAPI(title='GenBI') -# 配置CORS中间件 +manager = Manager() +shared_data = manager.dict() # shared data between processes + app.add_middleware( CORSMiddleware, - allow_origins=['*'], # 允许所有源访问,可以根据需求进行修改 - allow_credentials=True, # 允许发送凭据(如Cookie) - allow_methods=['*'], # 允许所有HTTP方法 - allow_headers=['*'], # 允许所有请求头 + allow_origins=['*'], + allow_credentials=True, + allow_methods=['*'], + allow_headers=['*'], ) # Global exception capture @@ -23,16 +30,40 @@ app.mount("/static", StaticFiles(directory="static"), name="static") app.include_router(router) + # changed from "/" to "/test" to avoid health check fails in ECS @app.get("/test", status_code=status.HTTP_302_FOUND) def index(): return RedirectResponse("static/WebSocket.html") + # health check @app.get("/") def health(): return {"status": "ok"} + @app.get("/option", response_model=Option) def option(): - return service.get_option() \ No newline at end of file + return service.get_option() + + +@app.on_event("startup") +def set_history_in_share(): + global shared_data + history_list = LogManagement.get_all_history() + chat_history_session = {} + for item in history_list: + session_id = item['session_id'] + if session_id not in chat_history_session: + chat_history_session[session_id] = [] + log_info = item['log_info'] + query = item['query'] + human_message = Message(type="human", content=query) + bot_message = Message(type="AI", content=json.loads(log_info)) + chat_history_session[session_id].append(human_message) + chat_history_session[session_id].append(bot_message) + + for key, value in chat_history_session.items(): + value = value[-MAX_CHAT_WINDOW_SIZE:] + shared_data[key] = value diff --git a/application/nlq/business/log_store.py b/application/nlq/business/log_store.py index 2b7951c..7d8d8ef 100644 --- a/application/nlq/business/log_store.py +++ b/application/nlq/business/log_store.py @@ -19,3 +19,8 @@ def add_log_to_database(cls, log_id, user_id, session_id, profile_name, sql, que def get_history(cls, user_id, profile_name): history_list = cls.query_log_dao.get_history_by_user_profile(user_id, profile_name) return history_list + + @classmethod + def get_all_history(cls): + history_list = cls.query_log_dao.get_all_history() + return history_list diff --git a/application/nlq/data_access/dynamo_query_log.py b/application/nlq/data_access/dynamo_query_log.py index 69ef9c9..b543af2 100644 --- a/application/nlq/data_access/dynamo_query_log.py +++ b/application/nlq/data_access/dynamo_query_log.py @@ -132,7 +132,7 @@ def get_history_by_user_profile(self, user_id, profile_name): # DynamoDB might not return all items in a single response if the data set is large while 'LastEvaluatedKey' in response: response = self.table.scan( - FilterExpression=Key('user_id').eq(user_id) & Key('profile_name').eq(profile_name), + FilterExpression=Key('user_id').eq(user_id) & Key('profile_name').eq(profile_name) & Key('log_type').eq("chat_history"), ExclusiveStartKey=response['LastEvaluatedKey'] ) items.extend(response['Items']) @@ -150,4 +150,34 @@ def get_history_by_user_profile(self, user_id, profile_name): err.response["Error"]["Code"], err.response["Error"]["Message"], ) + return [] + + def get_all_history(self): + try: + # First, we need to scan the table to find all items for the user and profile + response = self.table.scan( + FilterExpression=Key('log_type').eq("chat_history") + ) + + items = response['Items'] + + # DynamoDB might not return all items in a single response if the data set is large + while 'LastEvaluatedKey' in response: + response = self.table.scan( + FilterExpression=Key('log_type').eq("chat_history"), + ExclusiveStartKey=response['LastEvaluatedKey'] + ) + items.extend(response['Items']) + + # Sort the items by time_str to get them in chronological order + sorted_items = sorted(items, key=lambda x: x['time_str']) + + return sorted_items + + except ClientError as err: + logger.error( + "Couldn't get history Here's why: %s: %s", + err.response["Error"]["Code"], + err.response["Error"]["Message"], + ) return [] \ No newline at end of file diff --git a/application/utils/tool.py b/application/utils/tool.py index 9003484..310d5ee 100644 --- a/application/utils/tool.py +++ b/application/utils/tool.py @@ -46,3 +46,18 @@ def change_class_to_str(result): except Exception as e: logger.error(f"Error in changing class to string: {e}") return "" + + +def get_window_history(user_query_history): + try: + history_list = [] + for item in user_query_history: + if item.type == "human": + history_list.append("user:" + str(item.content)) + else: + history_list.append("assistant:" + str(item.content["query_rewrite"])) + logger.info(f"history_list: {history_list}") + return history_list + except Exception as e: + logger.error(f"Error in getting window history: {e}") + return [] From e2581e9e3156aa7670ed5edf9546ef8958d60adf Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 1 Aug 2024 19:31:25 +0800 Subject: [PATCH 086/130] change share data for chatbot --- application/api/service.py | 6 ++---- application/main.py | 7 +++---- application/utils/tool.py | 14 ++++++++++++++ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/application/api/service.py b/application/api/service.py index 8b89892..139a174 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -4,7 +4,6 @@ from dotenv import load_dotenv import logging -from main import shared_data from nlq.business.connection import ConnectionManagement from nlq.business.nlq_chain import NLQChain from nlq.business.profile import ProfileManagement @@ -21,7 +20,7 @@ from utils.env_var import opensearch_info from utils.text_search import normal_text_search, agent_text_search from utils.tool import generate_log_id, get_current_time, get_generated_sql_explain, get_generated_sql, \ - change_class_to_str, get_window_history + change_class_to_str, get_window_history, get_share_data from .schemas import Question, Answer, Example, Option, SQLSearchResult, AgentSearchResult, KnowledgeSearchResult, \ TaskSQLSearchResult, ChartEntity, AskReplayResult, ChatHistory, Message, HistoryMessage from .exception_handler import BizException @@ -408,8 +407,6 @@ async def ask_websocket(websocket: WebSocket, question: Question): log_info = "" query_rewrite = "" - shared_data["log_id"] = log_id - all_profiles = ProfileManagement.get_all_profiles_with_info() database_profile = all_profiles[selected_profile] @@ -441,6 +438,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): entity_slot = [] user_query_history = [] + shared_data = get_share_data() if session_id in shared_data: user_query_history = shared_data[session_id] query_rewrite_result = {"intent": "original_problem", "query": search_box} diff --git a/application/main.py b/application/main.py index 03b7913..bd3bf88 100644 --- a/application/main.py +++ b/application/main.py @@ -1,5 +1,4 @@ import json -from multiprocessing import Manager from fastapi import FastAPI, status from fastapi.staticfiles import StaticFiles @@ -10,12 +9,12 @@ from api import service from api.schemas import Option, Message from nlq.business.log_store import LogManagement +from utils.tool import set_share_data MAX_CHAT_WINDOW_SIZE = 10 * 2 app = FastAPI(title='GenBI') -manager = Manager() -shared_data = manager.dict() # shared data between processes + app.add_middleware( CORSMiddleware, @@ -66,4 +65,4 @@ def set_history_in_share(): for key, value in chat_history_session.items(): value = value[-MAX_CHAT_WINDOW_SIZE:] - shared_data[key] = value + set_share_data(key, value) \ No newline at end of file diff --git a/application/utils/tool.py b/application/utils/tool.py index 310d5ee..385d960 100644 --- a/application/utils/tool.py +++ b/application/utils/tool.py @@ -3,10 +3,14 @@ import time import random from datetime import datetime +from multiprocessing import Manager logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +manager = Manager() +shared_data = manager.dict() # shared data between processes + def get_generated_sql(generated_sql_response): sql = "" @@ -61,3 +65,13 @@ def get_window_history(user_query_history): except Exception as e: logger.error(f"Error in getting window history: {e}") return [] + + +def get_share_data(): + global shared_data + return shared_data + + +def set_share_data(session_id, value): + global shared_data + shared_data[session_id] = value From fa7ebf91be4a830060fc9022571798e31a8ae4a6 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 1 Aug 2024 19:41:54 +0800 Subject: [PATCH 087/130] change share data for chatbot --- application/api/service.py | 11 ++++++----- application/utils/llm.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/application/api/service.py b/application/api/service.py index 139a174..8154d74 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -447,14 +447,15 @@ async def ask_websocket(websocket: WebSocket, question: Question): if len(user_query_history) > 0: user_query_history = user_query_history[-context_window_select:] user_query_history = get_window_history(user_query_history) - logger.info("The Chat history is {history}".format(history="\n".join(user_query_history))) - query_rewrite_result = get_query_rewrite(model_type, search_box, prompt_map, user_query_history) - logger.info( + logger.info("The Chat history is {history}".format(history="\n".join(user_query_history))) + query_rewrite_result = get_query_rewrite(model_type, search_box, prompt_map, user_query_history) + logger.info( "The query_rewrite_result is {query_rewrite_result}".format(query_rewrite_result=query_rewrite_result)) - search_box = query_rewrite_result.get("query") + query_rewrite = query_rewrite_result.get("query") + else: + query_rewrite = search_box query_rewrite_intent = query_rewrite_result.get("intent") - query_rewrite = query_rewrite_result.get("query") if "ask_in_reply" == query_rewrite_intent: ask_replay_flag = True diff --git a/application/utils/llm.py b/application/utils/llm.py index 4a54cf5..9eb542c 100644 --- a/application/utils/llm.py +++ b/application/utils/llm.py @@ -464,7 +464,7 @@ def get_query_intent(model_id, search_box, prompt_map): def get_query_rewrite(model_id, search_box, prompt_map, chat_history): - query_rewrite = {"original_problem": search_box} + query_rewrite = {"intent": "original_problem", "query": search_box} history_query = "\n".join(chat_history) try: intent_endpoint = os.getenv("SAGEMAKER_ENDPOINT_INTENT") From 75e690b820e800fc18b5f078bb6bac6b8640b352 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 1 Aug 2024 20:01:34 +0800 Subject: [PATCH 088/130] change share data for chatbot --- application/main.py | 12 +++++++----- application/utils/tool.py | 9 +++------ 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/application/main.py b/application/main.py index bd3bf88..e54aa53 100644 --- a/application/main.py +++ b/application/main.py @@ -1,4 +1,5 @@ import json +import logging from fastapi import FastAPI, status from fastapi.staticfiles import StaticFiles @@ -9,13 +10,11 @@ from api import service from api.schemas import Option, Message from nlq.business.log_store import LogManagement -from utils.tool import set_share_data +from utils.tool import set_share_data, get_share_data MAX_CHAT_WINDOW_SIZE = 10 * 2 app = FastAPI(title='GenBI') - - app.add_middleware( CORSMiddleware, allow_origins=['*'], @@ -49,7 +48,8 @@ def option(): @app.on_event("startup") def set_history_in_share(): - global shared_data + logging.info("Setting history in share data") + share_data = get_share_data() history_list = LogManagement.get_all_history() chat_history_session = {} for item in history_list: @@ -65,4 +65,6 @@ def set_history_in_share(): for key, value in chat_history_session.items(): value = value[-MAX_CHAT_WINDOW_SIZE:] - set_share_data(key, value) \ No newline at end of file + set_share_data(share_data, key, value) + logging.info("Setting history in share data done") + logging.info(share_data) \ No newline at end of file diff --git a/application/utils/tool.py b/application/utils/tool.py index 385d960..7433189 100644 --- a/application/utils/tool.py +++ b/application/utils/tool.py @@ -8,9 +8,6 @@ logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -manager = Manager() -shared_data = manager.dict() # shared data between processes - def get_generated_sql(generated_sql_response): sql = "" @@ -68,10 +65,10 @@ def get_window_history(user_query_history): def get_share_data(): - global shared_data + with Manager() as manager: + shared_data = manager.dict() return shared_data -def set_share_data(session_id, value): - global shared_data +def set_share_data(shared_data, session_id, value): shared_data[session_id] = value From e9b12ee006dfbc23bcf630dde47e3d72712afcab Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 1 Aug 2024 20:06:23 +0800 Subject: [PATCH 089/130] change share data for chatbot --- application/utils/tool.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/application/utils/tool.py b/application/utils/tool.py index 7433189..add5be5 100644 --- a/application/utils/tool.py +++ b/application/utils/tool.py @@ -8,6 +8,9 @@ logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +manager = Manager() +shared_data = manager.dict() + def get_generated_sql(generated_sql_response): sql = "" @@ -64,11 +67,9 @@ def get_window_history(user_query_history): return [] -def get_share_data(): - with Manager() as manager: - shared_data = manager.dict() - return shared_data +def set_share_data(session_id, value): + shared_data[session_id] = value -def set_share_data(shared_data, session_id, value): - shared_data[session_id] = value +def get_share_data(session_id): + return shared_data.get(session_id) From 8aa37084ddbb035062bc03aae71303fae89460a8 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 1 Aug 2024 20:07:47 +0800 Subject: [PATCH 090/130] change share data for chatbot --- application/main.py | 6 ++---- application/utils/tool.py | 1 + 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/application/main.py b/application/main.py index e54aa53..2e02351 100644 --- a/application/main.py +++ b/application/main.py @@ -49,7 +49,6 @@ def option(): @app.on_event("startup") def set_history_in_share(): logging.info("Setting history in share data") - share_data = get_share_data() history_list = LogManagement.get_all_history() chat_history_session = {} for item in history_list: @@ -65,6 +64,5 @@ def set_history_in_share(): for key, value in chat_history_session.items(): value = value[-MAX_CHAT_WINDOW_SIZE:] - set_share_data(share_data, key, value) - logging.info("Setting history in share data done") - logging.info(share_data) \ No newline at end of file + set_share_data(key, value) + logging.info("Setting history in share data done") \ No newline at end of file diff --git a/application/utils/tool.py b/application/utils/tool.py index add5be5..97d90ae 100644 --- a/application/utils/tool.py +++ b/application/utils/tool.py @@ -69,6 +69,7 @@ def get_window_history(user_query_history): def set_share_data(session_id, value): shared_data[session_id] = value + logger.info("Set share data total session is : %s", str(len(shared_data))) def get_share_data(session_id): From ed6cf9d1e3048fb00ee5e13002149b1486d8b511 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 1 Aug 2024 20:15:26 +0800 Subject: [PATCH 091/130] change share data for chatbot --- application/api/service.py | 4 +--- application/utils/tool.py | 5 ++++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/application/api/service.py b/application/api/service.py index 8154d74..cfd6df1 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -438,9 +438,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): entity_slot = [] user_query_history = [] - shared_data = get_share_data() - if session_id in shared_data: - user_query_history = shared_data[session_id] + user_query_history = get_share_data(session_id) query_rewrite_result = {"intent": "original_problem", "query": search_box} if context_window > 0: context_window_select = context_window * 2 diff --git a/application/utils/tool.py b/application/utils/tool.py index 97d90ae..84171bf 100644 --- a/application/utils/tool.py +++ b/application/utils/tool.py @@ -73,4 +73,7 @@ def set_share_data(session_id, value): def get_share_data(session_id): - return shared_data.get(session_id) + if session_id in shared_data: + return shared_data.get(session_id) + else: + return [] From e5a2c926cfb6e5f378fb1215f079ce228e6a3959 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 1 Aug 2024 20:42:20 +0800 Subject: [PATCH 092/130] change share data for chatbot --- application/api/service.py | 18 ++++++++++-------- application/utils/tool.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/application/api/service.py b/application/api/service.py index cfd6df1..6f746d4 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -20,7 +20,7 @@ from utils.env_var import opensearch_info from utils.text_search import normal_text_search, agent_text_search from utils.tool import generate_log_id, get_current_time, get_generated_sql_explain, get_generated_sql, \ - change_class_to_str, get_window_history, get_share_data + change_class_to_str, get_window_history, get_share_data, update_share_data from .schemas import Question, Answer, Example, Option, SQLSearchResult, AgentSearchResult, KnowledgeSearchResult, \ TaskSQLSearchResult, ChartEntity, AskReplayResult, ChatHistory, Message, HistoryMessage from .exception_handler import BizException @@ -438,21 +438,20 @@ async def ask_websocket(websocket: WebSocket, question: Question): entity_slot = [] user_query_history = [] - user_query_history = get_share_data(session_id) + original_user_query_history = get_share_data(session_id) query_rewrite_result = {"intent": "original_problem", "query": search_box} if context_window > 0: context_window_select = context_window * 2 if len(user_query_history) > 0: - user_query_history = user_query_history[-context_window_select:] + user_query_history = original_user_query_history[-context_window_select:] user_query_history = get_window_history(user_query_history) logger.info("The Chat history is {history}".format(history="\n".join(user_query_history))) query_rewrite_result = get_query_rewrite(model_type, search_box, prompt_map, user_query_history) logger.info( "The query_rewrite_result is {query_rewrite_result}".format(query_rewrite_result=query_rewrite_result)) - - query_rewrite = query_rewrite_result.get("query") else: - query_rewrite = search_box + query_rewrite_result = get_query_rewrite(model_type, search_box, prompt_map, []) + query_rewrite = query_rewrite_result.get("query") query_rewrite_intent = query_rewrite_result.get("intent") if "ask_in_reply" == query_rewrite_intent: ask_replay_flag = True @@ -464,6 +463,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[], ask_rewrite_result=ask_result) + update_share_data(session_id, search_box, answer) ask_answer_info = change_class_to_str(answer) LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, profile_name=selected_profile, sql="", query=search_box, @@ -502,6 +502,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): answer = Answer(query=search_box, query_rewrite=query_rewrite, query_intent="reject_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[], ask_rewrite_result=ask_result) + update_share_data(session_id, search_box, answer) reject_answer_info = change_class_to_str(answer) LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, profile_name=selected_profile, sql="", query=search_box, @@ -521,6 +522,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=[], ask_rewrite_result=ask_result) + update_share_data(session_id, search_box, answer) knowledge_answer_info = change_class_to_str(answer) LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, profile_name=selected_profile, sql="", query=search_box, @@ -609,7 +611,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): answer = Answer(query=search_box, query_rewrite=query_rewrite, query_intent="normal_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=generate_suggested_question_list, ask_rewrite_result=ask_result) - + update_share_data(session_id, search_box, answer) intent_answer_info = change_class_to_str(answer) LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, profile_name=selected_profile, sql=sql_search_result.sql, @@ -675,7 +677,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): answer = Answer(query=search_box, query_rewrite=query_rewrite, query_intent="agent_search", knowledge_search_result=knowledge_search_result, sql_search_result=sql_search_result, agent_search_result=agent_search_response, suggested_question=generate_suggested_question_list, ask_rewrite_result=ask_result) - + update_share_data(session_id, search_box, answer) agent_answer_info = change_class_to_str(answer) LogManagement.add_log_to_database(log_id=log_id, user_id=user_id, session_id=session_id, profile_name=selected_profile, sql="", diff --git a/application/utils/tool.py b/application/utils/tool.py index 84171bf..188145c 100644 --- a/application/utils/tool.py +++ b/application/utils/tool.py @@ -5,6 +5,8 @@ from datetime import datetime from multiprocessing import Manager +from api.schemas import Message + logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -77,3 +79,12 @@ def get_share_data(session_id): return shared_data.get(session_id) else: return [] + + +def update_share_data(session_id, search_box, answer): + if session_id not in shared_data: + shared_data[session_id] = [] + human_message = Message(type="human", content=search_box) + bot_message = Message(type="AI", content=answer) + shared_data[session_id].append(human_message) + shared_data[session_id].append(bot_message) \ No newline at end of file From 1b277b9b27f5d24dbfccc728feb0ac535087237a Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 1 Aug 2024 21:53:14 +0800 Subject: [PATCH 093/130] change share data for chatbot --- application/api/service.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/application/api/service.py b/application/api/service.py index 6f746d4..b6c1f18 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -445,11 +445,13 @@ async def ask_websocket(websocket: WebSocket, question: Question): if len(user_query_history) > 0: user_query_history = original_user_query_history[-context_window_select:] user_query_history = get_window_history(user_query_history) + user_query_history.append("user:" + search_box) logger.info("The Chat history is {history}".format(history="\n".join(user_query_history))) query_rewrite_result = get_query_rewrite(model_type, search_box, prompt_map, user_query_history) logger.info( "The query_rewrite_result is {query_rewrite_result}".format(query_rewrite_result=query_rewrite_result)) else: + user_query_history.append("user:" + search_box) query_rewrite_result = get_query_rewrite(model_type, search_box, prompt_map, []) query_rewrite = query_rewrite_result.get("query") query_rewrite_intent = query_rewrite_result.get("intent") From b70aeadc36e157ae906ab9fe5d5ede61654fa60f Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 1 Aug 2024 22:03:15 +0800 Subject: [PATCH 094/130] change share data for chatbot --- application/api/service.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/application/api/service.py b/application/api/service.py index b6c1f18..1a5c4c0 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -439,6 +439,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): user_query_history = [] original_user_query_history = get_share_data(session_id) + logger.info("The original_user_query_history is {original_user_query_history}".format(original_user_query_history=original_user_query_history)) query_rewrite_result = {"intent": "original_problem", "query": search_box} if context_window > 0: context_window_select = context_window * 2 @@ -452,7 +453,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): "The query_rewrite_result is {query_rewrite_result}".format(query_rewrite_result=query_rewrite_result)) else: user_query_history.append("user:" + search_box) - query_rewrite_result = get_query_rewrite(model_type, search_box, prompt_map, []) + query_rewrite_result = get_query_rewrite(model_type, search_box, prompt_map, user_query_history) query_rewrite = query_rewrite_result.get("query") query_rewrite_intent = query_rewrite_result.get("intent") if "ask_in_reply" == query_rewrite_intent: From 10acd3f48172e04765e648f4cbd9ec59719b7356 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 1 Aug 2024 22:28:57 +0800 Subject: [PATCH 095/130] change share data for chatbot --- application/utils/tool.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/application/utils/tool.py b/application/utils/tool.py index 188145c..f4d5893 100644 --- a/application/utils/tool.py +++ b/application/utils/tool.py @@ -87,4 +87,5 @@ def update_share_data(session_id, search_box, answer): human_message = Message(type="human", content=search_box) bot_message = Message(type="AI", content=answer) shared_data[session_id].append(human_message) - shared_data[session_id].append(bot_message) \ No newline at end of file + shared_data[session_id].append(bot_message) + logger.info("Update share data is : %s", shared_data) \ No newline at end of file From 9f3056565b739fa4d7c3d100d97acb2df57e5286 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 07:27:41 +0800 Subject: [PATCH 096/130] add log --- application/utils/tool.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/application/utils/tool.py b/application/utils/tool.py index f4d5893..4d1d844 100644 --- a/application/utils/tool.py +++ b/application/utils/tool.py @@ -84,8 +84,10 @@ def get_share_data(session_id): def update_share_data(session_id, search_box, answer): if session_id not in shared_data: shared_data[session_id] = [] + logger.info("session_id not in shared_data") + logger.info("Update session_id is : %s", session_id) human_message = Message(type="human", content=search_box) bot_message = Message(type="AI", content=answer) shared_data[session_id].append(human_message) shared_data[session_id].append(bot_message) - logger.info("Update share data is : %s", shared_data) \ No newline at end of file + logger.info("Update session is %s, share data is : %s", session_id, shared_data[session_id]) \ No newline at end of file From e657b66de117cf477ad36ab81bd1c9e3c690b7f6 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 07:36:17 +0800 Subject: [PATCH 097/130] add log --- application/utils/tool.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/application/utils/tool.py b/application/utils/tool.py index 4d1d844..456d39d 100644 --- a/application/utils/tool.py +++ b/application/utils/tool.py @@ -82,12 +82,22 @@ def get_share_data(session_id): def update_share_data(session_id, search_box, answer): + chat_list = [] if session_id not in shared_data: shared_data[session_id] = [] logger.info("session_id not in shared_data") logger.info("Update session_id is : %s", session_id) - human_message = Message(type="human", content=search_box) - bot_message = Message(type="AI", content=answer) - shared_data[session_id].append(human_message) - shared_data[session_id].append(bot_message) - logger.info("Update session is %s, share data is : %s", session_id, shared_data[session_id]) \ No newline at end of file + human_message = Message(type="human", content=search_box) + bot_message = Message(type="AI", content=answer) + chat_list.append(human_message) + chat_list.append(bot_message) + set_share_data(session_id, chat_list) + logger.info("not have session is %s, share data is : %s", session_id, shared_data[session_id]) + else: + chat_list = shared_data[session_id] + human_message = Message(type="human", content=search_box) + bot_message = Message(type="AI", content=answer) + chat_list.append(human_message) + chat_list.append(bot_message) + set_share_data(session_id, chat_list) + logger.info("have session is %s, share data is : %s", session_id, shared_data[session_id]) \ No newline at end of file From d9856ca8520a51d0413594f2f9cb3aadb314eb41 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 07:49:16 +0800 Subject: [PATCH 098/130] fix some issue --- application/api/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/application/api/service.py b/application/api/service.py index 1a5c4c0..c3aa1a1 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -443,7 +443,7 @@ async def ask_websocket(websocket: WebSocket, question: Question): query_rewrite_result = {"intent": "original_problem", "query": search_box} if context_window > 0: context_window_select = context_window * 2 - if len(user_query_history) > 0: + if len(original_user_query_history) > 0: user_query_history = original_user_query_history[-context_window_select:] user_query_history = get_window_history(user_query_history) user_query_history.append("user:" + search_box) From e554bd45bf52eae9a2808ed2a1b24bf8d9e0d033 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 08:20:53 +0800 Subject: [PATCH 099/130] fix some issue --- application/utils/tool.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/application/utils/tool.py b/application/utils/tool.py index 456d39d..03a17ac 100644 --- a/application/utils/tool.py +++ b/application/utils/tool.py @@ -61,7 +61,7 @@ def get_window_history(user_query_history): if item.type == "human": history_list.append("user:" + str(item.content)) else: - history_list.append("assistant:" + str(item.content["query_rewrite"])) + history_list.append("assistant:" + str(item.content.query_rewrite)) logger.info(f"history_list: {history_list}") return history_list except Exception as e: @@ -92,7 +92,7 @@ def update_share_data(session_id, search_box, answer): chat_list.append(human_message) chat_list.append(bot_message) set_share_data(session_id, chat_list) - logger.info("not have session is %s, share data is : %s", session_id, shared_data[session_id]) + logger.info("not have session is %s, share data length is : %s", session_id, len(shared_data[session_id])) else: chat_list = shared_data[session_id] human_message = Message(type="human", content=search_box) @@ -100,4 +100,4 @@ def update_share_data(session_id, search_box, answer): chat_list.append(human_message) chat_list.append(bot_message) set_share_data(session_id, chat_list) - logger.info("have session is %s, share data is : %s", session_id, shared_data[session_id]) \ No newline at end of file + logger.info("have session is %s, share data is : %s", session_id, len(shared_data[session_id])) \ No newline at end of file From fb80d9a80ee1517aba57afcad67c6a01773b4a35 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 09:01:59 +0800 Subject: [PATCH 100/130] update downvote main page --- application/api/service.py | 6 ++++-- .../1_\360\237\214\215_Generative_BI_Playground.py" | 11 +++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/application/api/service.py b/application/api/service.py index c3aa1a1..c945700 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -716,7 +716,8 @@ def user_feedback_downvote(data_profiles: str, user_id: str, session_id: str, qu sql=query_answer, query=query, intent="normal_search_user_downvote", log_info="", - time_str=current_time) + time_str=current_time, + log_type="feedback_downvote") elif query_intent == "agent_search": log_id = generate_log_id() current_time = get_current_time() @@ -725,7 +726,8 @@ def user_feedback_downvote(data_profiles: str, user_id: str, session_id: str, qu sql=query_answer, query=query, intent="agent_search_user_downvote", log_info="", - time_str=current_time) + time_str=current_time, + log_type="feedback_downvote") return True except Exception as e: return False diff --git "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" index 824494a..1848a8b 100644 --- "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" +++ "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" @@ -7,6 +7,7 @@ import logging import random +from api.service import user_feedback_downvote from nlq.business.connection import ConnectionManagement from nlq.business.nlq_chain import NLQChain from nlq.business.profile import ProfileManagement @@ -51,6 +52,16 @@ def upvote_agent_clicked(question, comment): logger.info(f'up voted "{question}" with sql "{comment}"') +def downvote_clicked(question, comment): + current_profile = st.session_state.current_profile + user_id = "admin" + session_id = "-1" + query = question + query_intent = "normal_search" + query_answer = str(comment) + user_feedback_downvote(current_profile, user_id, session_id, query, query_intent, query_answer) + + def clean_st_history(selected_profile): st.session_state.messages[selected_profile] = [] st.session_state.query_rewrite_history[selected_profile] = [] From ac494aeaa9cad811cc0b60f06e06f827546ac187 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 09:11:58 +0800 Subject: [PATCH 101/130] update downvote main page --- application/utils/opensearch.py | 39 ++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/application/utils/opensearch.py b/application/utils/opensearch.py index 229bff8..1d8f7c4 100644 --- a/application/utils/opensearch.py +++ b/application/utils/opensearch.py @@ -3,7 +3,7 @@ from opensearchpy.helpers import bulk import logging from utils.llm import create_vector_embedding_with_bedrock, create_vector_embedding_with_sagemaker -from utils.env_var import opensearch_info, SAGEMAKER_ENDPOINT_EMBEDDING +from utils.env_var import opensearch_info, SAGEMAKER_ENDPOINT_EMBEDDING, AOS_INDEX_NER logger = logging.getLogger(__name__) @@ -119,6 +119,29 @@ def create_index_mapping(opensearch_client, index_name, dimension): return bool(response['acknowledged']) +def check_field_exists(opensearch_client, index_name, field_name): + """ + Check if a field exists in the specified index + :param opensearch_client: OpenSearch client + :param index_name: Name of the index + :param field_name: Name of the field to check + :return: True if the field exists, False otherwise + """ + try: + # Get the mapping for the index + mapping = opensearch_client.indices.get_mapping(index=index_name) + + # Traverse the mapping to check if the field exists + if index_name in mapping: + properties = mapping[index_name]['mappings']['properties'] + if field_name in properties: + return True + except Exception as e: + logger.error(f"Error checking field {field_name}: {e}") + + return False + + def delete_opensearch_index(opensearch_client, index_name): """ Delete index @@ -145,7 +168,8 @@ def get_retrieve_opensearch(opensearch_info, query, search_type, selected_profil index_name = opensearch_info['agent_index'] if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "": - records_with_embedding = create_vector_embedding_with_sagemaker(SAGEMAKER_ENDPOINT_EMBEDDING, query, index_name=index_name) + records_with_embedding = create_vector_embedding_with_sagemaker(SAGEMAKER_ENDPOINT_EMBEDDING, query, + index_name=index_name) else: records_with_embedding = create_vector_embedding_with_bedrock(query, index_name=index_name) retrieve_result = retrieve_results_from_opensearch( @@ -169,7 +193,8 @@ def get_retrieve_opensearch(opensearch_info, query, search_type, selected_profil def retrieve_results_from_opensearch(index_name, region_name, domain, opensearch_user, opensearch_password, query_embedding, top_k=3, host='', port=443, profile_name=None): - opensearch_client = get_opensearch_cluster_client(domain, host, port, opensearch_user, opensearch_password, region_name) + opensearch_client = get_opensearch_cluster_client(domain, host, port, opensearch_user, opensearch_password, + region_name) search_query = { "size": top_k, # Adjust the size as needed to retrieve more or fewer results "query": { @@ -205,8 +230,8 @@ def retrieve_results_from_opensearch(index_name, region_name, domain, opensearch def upload_results_to_opensearch(region_name, domain, opensearch_user, opensearch_password, index_name, query, sql, host='', port=443): - - opensearch_client = get_opensearch_cluster_client(domain, host, port, opensearch_user, opensearch_password, region_name) + opensearch_client = get_opensearch_cluster_client(domain, host, port, opensearch_user, opensearch_password, + region_name) # Vector embedding using Amazon Bedrock Titan text embedding logger.info(f"Creating embeddings for records") @@ -252,6 +277,10 @@ def opensearch_index_init(): if success: logger.info( "Creating OpenSearch index mapping, index is {index_name}".format(index_name=index_name)) + if index_name == AOS_INDEX_NER: + check_flag = check_field_exists(opensearch_client, index_name, "ner_table_info") + logger.info(f"check index flag: {check_flag}") + success = create_index_mapping(opensearch_client, index_name, dimension) logger.info(f"OpenSearch Index mapping created") else: From e017e6c3d8dfd4fd9d9757eeec8567fe8f8489e8 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 09:44:07 +0800 Subject: [PATCH 102/130] update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 92ecc12..835d7b0 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ The deployment guide here is CDK only. For manual deployment or detailed guide, ## Introduction -A NLQ(Natural Language Query) demo using Amazon Bedrock, Amazon OpenSearch with RAG technique. +A Generative BI demo using Amazon Bedrock, Amazon OpenSearch with RAG technique. ![Screenshot](./assets/aws_architecture.png) *Reference Architecture on AWS* From 58f4a1f553d3570ce3d0463fe75201e18bb76709 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 10:56:40 +0800 Subject: [PATCH 103/130] update README.md --- application/utils/opensearch.py | 56 ++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/application/utils/opensearch.py b/application/utils/opensearch.py index 1d8f7c4..4b3aed6 100644 --- a/application/utils/opensearch.py +++ b/application/utils/opensearch.py @@ -131,6 +131,7 @@ def check_field_exists(opensearch_client, index_name, field_name): # Get the mapping for the index mapping = opensearch_client.indices.get_mapping(index=index_name) + logger.info(mapping) # Traverse the mapping to check if the field exists if index_name in mapping: properties = mapping[index_name]['mappings']['properties'] @@ -141,6 +142,53 @@ def check_field_exists(opensearch_client, index_name, field_name): return False +def update_index_mapping(opensearch_client, index_name, dimension): + """ + Create index mapping + :param opensearch_client: + :param index_name: + :param dimension: + :return: + """ + response = opensearch_client.indices.put_mapping( + index=index_name, + body={ + "properties": { + "vector_field": { + "type": "knn_vector", + "dimension": dimension + }, + "text": { + "type": "keyword" + }, + "profile": { + "type": "keyword" + }, + "entity_type": { + "type": "keyword" + }, + "entity_same_count": { + "type": "integer" + }, + "entity_table_info": { + "type": "nested", + "properties": { + "table_name": { + "type": "keyword" + }, + "column_name": { + "type": "keyword" + }, + "value": { + "type": "text" + } + } + } + } + } + ) + return bool(response['acknowledged']) + def delete_opensearch_index(opensearch_client, index_name): """ @@ -277,14 +325,14 @@ def opensearch_index_init(): if success: logger.info( "Creating OpenSearch index mapping, index is {index_name}".format(index_name=index_name)) - if index_name == AOS_INDEX_NER: - check_flag = check_field_exists(opensearch_client, index_name, "ner_table_info") - logger.info(f"check index flag: {check_flag}") - success = create_index_mapping(opensearch_client, index_name, dimension) logger.info(f"OpenSearch Index mapping created") else: index_create_success = False + else: + if index_name == AOS_INDEX_NER: + check_flag = check_field_exists(opensearch_client, index_name, "ner_table_info") + logger.info(f"check index flag: {check_flag}") return index_create_success except Exception as e: logger.error("create index error") From ed07326fd7453d666d57604f873f4828b3e02641 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 11:04:45 +0800 Subject: [PATCH 104/130] update README.md --- application/utils/opensearch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/application/utils/opensearch.py b/application/utils/opensearch.py index 4b3aed6..1e72342 100644 --- a/application/utils/opensearch.py +++ b/application/utils/opensearch.py @@ -167,7 +167,7 @@ def update_index_mapping(opensearch_client, index_name, dimension): "entity_type": { "type": "keyword" }, - "entity_same_count": { + "entity_count": { "type": "integer" }, "entity_table_info": { From 9630b0fa3d0004a727c48ed88eb73e63d0081eff Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 12:14:55 +0800 Subject: [PATCH 105/130] update ner info --- application/nlq/business/vector_store.py | 9 ++++++-- application/nlq/data_access/opensearch.py | 8 +++++-- .../6_\360\237\223\232_Index_Management.py" | 2 +- .../7_\360\237\223\232_Entity_Management.py" | 23 +++++++++++++++++-- application/utils/opensearch.py | 21 +++++++++++++++++ 5 files changed, 56 insertions(+), 7 deletions(-) diff --git a/application/nlq/business/vector_store.py b/application/nlq/business/vector_store.py index b0f9577..644ee06 100644 --- a/application/nlq/business/vector_store.py +++ b/application/nlq/business/vector_store.py @@ -86,7 +86,11 @@ def add_sample(cls, profile_name, question, answer): logger.info('Sample added') @classmethod - def add_entity_sample(cls, profile_name, entity, comment): + def add_entity_sample(cls, profile_name, entity, comment, entity_type="metrics", entity_info_dict=None): + if entity_type == "metrics" or entity_info_dict is None: + entity_table_info = [] + else: + entity_table_info = [entity_info_dict] logger.info(f'add sample entity: {entity} to profile {profile_name}') if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "": embedding = cls.create_vector_embedding_with_sagemaker(entity) @@ -95,7 +99,8 @@ def add_entity_sample(cls, profile_name, entity, comment): has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding) if has_same_sample: logger.info(f'delete sample sample entity: {entity} to profile {profile_name}') - if cls.opensearch_dao.add_entity_sample(opensearch_info['ner_index'], profile_name, entity, comment, embedding): + if cls.opensearch_dao.add_entity_sample(opensearch_info['ner_index'], profile_name, entity, comment, embedding, + entity_type, entity_table_info): logger.info('Sample added') @classmethod diff --git a/application/nlq/data_access/opensearch.py b/application/nlq/data_access/opensearch.py index 69cbe08..9481214 100644 --- a/application/nlq/data_access/opensearch.py +++ b/application/nlq/data_access/opensearch.py @@ -162,13 +162,17 @@ def add_sample(self, index_name, profile_name, question, answer, embedding): success, failed = put_bulk_in_opensearch([record], self.opensearch_client) return success == 1 - def add_entity_sample(self, index_name, profile_name, entity, comment, embedding): + def add_entity_sample(self, index_name, profile_name, entity, comment, embedding, entity_type="", entity_table_info=[]): + entity_count = len(entity_table_info) record = { '_index': index_name, 'entity': entity, 'comment': comment, 'profile': profile_name, - 'vector_field': embedding + 'vector_field': embedding, + 'entity_type': entity_type, + 'entity_count': entity_count, + 'entity_table_info': entity_table_info } success, failed = put_bulk_in_opensearch([record], self.opensearch_client) diff --git "a/application/pages/6_\360\237\223\232_Index_Management.py" "b/application/pages/6_\360\237\223\232_Index_Management.py" index c527410..4fdaa85 100644 --- "a/application/pages/6_\360\237\223\232_Index_Management.py" +++ "b/application/pages/6_\360\237\223\232_Index_Management.py" @@ -124,7 +124,7 @@ def main(): for j, item in enumerate(each_upload_data.itertuples(), 1): question = str(item.question) sql = str(item.sql) - VectorStore.add_entity_sample(current_profile, question, sql) + VectorStore.add_sample(current_profile, question, sql) progress = (j * 1.0) / total_rows progress_bar.progress(progress, text=progress_text) progress_bar.empty() diff --git "a/application/pages/7_\360\237\223\232_Entity_Management.py" "b/application/pages/7_\360\237\223\232_Entity_Management.py" index db1eef5..29d1766 100644 --- "a/application/pages/7_\360\237\223\232_Entity_Management.py" +++ "b/application/pages/7_\360\237\223\232_Entity_Management.py" @@ -62,8 +62,8 @@ def main(): index=None, placeholder="Please select data profile...", key='current_profile_name') - tab_view, tab_add, tab_search, batch_insert = st.tabs( - ['View Samples', 'Add New Sample', 'Sample Search', 'Batch Insert Samples']) + tab_view, tab_add, tab_dimension, tab_search, batch_insert = st.tabs( + ['View Entity Info', 'Add Metrics Entity', 'Add Dimension Entity', 'Entity Search', 'Batch Insert Entity']) if current_profile is not None: st.session_state['current_profile'] = current_profile with tab_view: @@ -91,6 +91,25 @@ def main(): st.rerun() else: st.error('please input valid question and answer') + with tab_dimension: + if current_profile is not None: + entity = st.text_input('Entity', key='index_question') + table = st.text_input('Table', key='index_table') + column = st.text_input('Column', key='index_column') + value = st.text_input('Dimension value', key='index_value') + if st.button('Submit', type='primary'): + if len(entity) > 0 and len(table) > 0 and len(column) > 0 and len(value) > 0: + entity_item_table_info = {} + entity_item_table_info["table_name"] = table + entity_item_table_info["column_name"] = column + entity_item_table_info["value"] = value + VectorStore.add_entity_sample(current_profile, entity, comment, "dimension", entity_item_table_info) + st.success('Sample added') + time.sleep(2) + st.rerun() + else: + st.error('please input valid question and answer') + with tab_search: if current_profile is not None: entity_search = st.text_input('Entity Search', key='index_entity_search') diff --git a/application/utils/opensearch.py b/application/utils/opensearch.py index 1e72342..90f02e7 100644 --- a/application/utils/opensearch.py +++ b/application/utils/opensearch.py @@ -112,6 +112,26 @@ def create_index_mapping(opensearch_client, index_name, dimension): }, "profile": { "type": "keyword" + }, + "entity_type": { + "type": "keyword" + }, + "entity_count": { + "type": "integer" + }, + "entity_table_info": { + "type": "nested", + "properties": { + "table_name": { + "type": "keyword" + }, + "column_name": { + "type": "keyword" + }, + "value": { + "type": "text" + } + } } } } @@ -333,6 +353,7 @@ def opensearch_index_init(): if index_name == AOS_INDEX_NER: check_flag = check_field_exists(opensearch_client, index_name, "ner_table_info") logger.info(f"check index flag: {check_flag}") + update_index_mapping(opensearch_client, index_name, dimension) return index_create_success except Exception as e: logger.error("create index error") From 93b5d1a280148df34644394b197de44616dbebec Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 12:19:44 +0800 Subject: [PATCH 106/130] add fastapi --- application/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/application/requirements.txt b/application/requirements.txt index 228112f..138eea9 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -16,4 +16,5 @@ pandas==2.0.3 openpyxl starrocks==1.0.6 clickhouse-sqlalchemy==0.2.6 -sagemaker \ No newline at end of file +sagemaker +fastapi~=0.110.1 \ No newline at end of file From 166d506cf41263767ccd8dcb886acf7a85ccb54b Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 12:38:41 +0800 Subject: [PATCH 107/130] update prompt.py --- .../7_\360\237\223\232_Entity_Management.py" | 2 +- application/utils/prompt.py | 37 ++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git "a/application/pages/7_\360\237\223\232_Entity_Management.py" "b/application/pages/7_\360\237\223\232_Entity_Management.py" index 29d1766..722d9bc 100644 --- "a/application/pages/7_\360\237\223\232_Entity_Management.py" +++ "b/application/pages/7_\360\237\223\232_Entity_Management.py" @@ -93,7 +93,7 @@ def main(): st.error('please input valid question and answer') with tab_dimension: if current_profile is not None: - entity = st.text_input('Entity', key='index_question') + entity = st.text_input('Entity', key='index_entity') table = st.text_input('Table', key='index_table') column = st.text_input('Column', key='index_column') value = st.text_input('Dimension value', key='index_value') diff --git a/application/utils/prompt.py b/application/utils/prompt.py index aaf16ad..0d507f3 100644 --- a/application/utils/prompt.py +++ b/application/utils/prompt.py @@ -30,7 +30,42 @@ AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3 = """You are a Amazon Redshift expert. Given an input question, first create a syntactically correct Redshift query to run, then look at the results of the query and return the answer to the input question.When generating SQL, do not add double quotes or single quotes around table names. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. -Never query for all columns from a table.""".format(top_k=TOP_K) +Never query for all columns from a table. +When generating SQL related to dates and times, please strictly use the Redshift SQL Functions listed in the following md tables contents in : + +| Function | Returns | +| --- | --- | +| + (Concatenation) operator | TIMESTAMP or TIMESTAMPZ | +| ADD_MONTHS | TIMESTAMP | +| AT TIME ZONE | TIMESTAMP or TIMESTAMPZ | +| CONVERT_TIMEZONE | TIMESTAMP | +| CURRENT_DATE | DATE | +| DATE_CMP | INTEGER | +| DATE_CMP_TIMESTAMP | INTEGER | +| DATE_CMP_TIMESTAMPTZ | INTEGER | +| DATE_PART_YEAR | INTEGER | +| DATEADD | TIMESTAMP or TIME or TIMETZ | +| DATEDIFF | BIGINT | +| DATE_PART | DOUBLE | +| DATE_TRUNC | TIMESTAMP | +| EXTRACT | INTEGER or DOUBLE | +| GETDATE | TIMESTAMP | +| INTERVAL_CMP | INTEGER | +| LAST_DAY | DATE | +| MONTHS_BETWEEN | FLOAT8 | +| NEXT_DAY | DATE | +| SYSDATE | TIMESTAMP | +| TIMEOFDAY | VARCHAR | +| TIMESTAMP_CMP | INTEGER | +| TIMESTAMP_CMP_DATE | INTEGER | +| TIMESTAMP_CMP_TIMESTAMPTZ | INTEGER | +| TIMESTAMPTZ_CMP | INTEGER | +| TIMESTAMPTZ_CMP_DATE | INTEGER | +| TIMESTAMPTZ_CMP_TIMESTAMP | INTEGER | +| TIMEZONE | TIMESTAMP or TIMESTAMPTZ | +| TO_TIMESTAMP | TIMESTAMPTZ | +| TRUNC | DATE | +""".format(top_k=TOP_K) SEARCH_INTENT_PROMPT_CLAUDE3 = """You are an intent classifier and entity extractor, and you need to perform intent classification and entity extraction on search queries. Background: I want to query data in the database, and you need to help me determine the user's relevant intent and extract the keywords from the query statement. Finally, return a JSON structure. From 9da914b39b81c2d88c89b4e4216bd36025f4ecd4 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 12:41:09 +0800 Subject: [PATCH 108/130] fix some error --- .../7_\360\237\223\232_Entity_Management.py" | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git "a/application/pages/7_\360\237\223\232_Entity_Management.py" "b/application/pages/7_\360\237\223\232_Entity_Management.py" index 722d9bc..c54738b 100644 --- "a/application/pages/7_\360\237\223\232_Entity_Management.py" +++ "b/application/pages/7_\360\237\223\232_Entity_Management.py" @@ -97,18 +97,18 @@ def main(): table = st.text_input('Table', key='index_table') column = st.text_input('Column', key='index_column') value = st.text_input('Dimension value', key='index_value') - if st.button('Submit', type='primary'): - if len(entity) > 0 and len(table) > 0 and len(column) > 0 and len(value) > 0: - entity_item_table_info = {} - entity_item_table_info["table_name"] = table - entity_item_table_info["column_name"] = column - entity_item_table_info["value"] = value - VectorStore.add_entity_sample(current_profile, entity, comment, "dimension", entity_item_table_info) - st.success('Sample added') - time.sleep(2) - st.rerun() - else: - st.error('please input valid question and answer') + if st.button('Submit', type='primary'): + if len(entity) > 0 and len(table) > 0 and len(column) > 0 and len(value) > 0: + entity_item_table_info = {} + entity_item_table_info["table_name"] = table + entity_item_table_info["column_name"] = column + entity_item_table_info["value"] = value + VectorStore.add_entity_sample(current_profile, entity, comment, "dimension", entity_item_table_info) + st.success('Sample added') + time.sleep(2) + st.rerun() + else: + st.error('please input valid question and answer') with tab_search: if current_profile is not None: From 1ec6a870f1dda43ea3822b0d61bbe5f1766686f6 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 13:03:15 +0800 Subject: [PATCH 109/130] fix some error --- "application/pages/7_\360\237\223\232_Entity_Management.py" | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git "a/application/pages/7_\360\237\223\232_Entity_Management.py" "b/application/pages/7_\360\237\223\232_Entity_Management.py" index c54738b..6bb5f4f 100644 --- "a/application/pages/7_\360\237\223\232_Entity_Management.py" +++ "b/application/pages/7_\360\237\223\232_Entity_Management.py" @@ -81,7 +81,7 @@ def main(): entity = st.text_input('Entity', key='index_question') comment = st.text_area('Comment', key='index_answer', height=300) - if st.button('Submit', type='primary'): + if st.button('Add Metrics Entity', type='primary'): if len(entity) > 0 and len(comment) > 0: VectorStore.add_entity_sample(current_profile, entity, comment) st.success('Sample added') @@ -97,7 +97,7 @@ def main(): table = st.text_input('Table', key='index_table') column = st.text_input('Column', key='index_column') value = st.text_input('Dimension value', key='index_value') - if st.button('Submit', type='primary'): + if st.button('Add Dimension Entity', type='primary'): if len(entity) > 0 and len(table) > 0 and len(column) > 0 and len(value) > 0: entity_item_table_info = {} entity_item_table_info["table_name"] = table From d5e5414f19f0e397425c614c528afcd4bd035cb4 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 13:33:35 +0800 Subject: [PATCH 110/130] fix some error --- application/nlq/business/vector_store.py | 30 ++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/application/nlq/business/vector_store.py b/application/nlq/business/vector_store.py index 644ee06..6779129 100644 --- a/application/nlq/business/vector_store.py +++ b/application/nlq/business/vector_store.py @@ -96,9 +96,18 @@ def add_entity_sample(cls, profile_name, entity, comment, entity_type="metrics", embedding = cls.create_vector_embedding_with_sagemaker(entity) else: embedding = cls.create_vector_embedding_with_bedrock(entity) - has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding) - if has_same_sample: - logger.info(f'delete sample sample entity: {entity} to profile {profile_name}') + if entity_type == "metrics": + has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding) + if has_same_sample: + logger.info(f'delete sample sample entity: {entity} to profile {profile_name}') + else: + same_dimension_value = cls.search_same_dimension_entity(profile_name, 1, opensearch_info['ner_index'], + embedding) + if len(same_dimension_value) > 0: + for item in same_dimension_value: + entity_table_info.append(item) + logger.info("entity_table_info: " + str(entity_table_info)) + has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding) if cls.opensearch_dao.add_entity_sample(opensearch_info['ner_index'], profile_name, entity, comment, embedding, entity_type, entity_table_info): logger.info('Sample added') @@ -113,7 +122,8 @@ def add_agent_cot_sample(cls, profile_name, entity, comment): has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['agent_index'], embedding) if has_same_sample: logger.info(f'delete agent sample sample query: {entity} to profile {profile_name}') - if cls.opensearch_dao.add_agent_cot_sample(opensearch_info['agent_index'], profile_name, entity, comment, embedding): + if cls.opensearch_dao.add_agent_cot_sample(opensearch_info['agent_index'], profile_name, entity, comment, + embedding): logger.info('Sample added') @classmethod @@ -197,3 +207,15 @@ def search_same_query(cls, profile_name, top_k, index_name, embedding): else: return False return False + + @classmethod + def search_same_dimension_entity(cls, profile_name, top_k, index_name, embedding): + search_res = cls.search_sample_with_embedding(profile_name, top_k, index_name, embedding) + same_dimension_value = [] + if len(search_res) > 0: + similarity_sample = search_res[0] + similarity_score = similarity_sample["_score"] + if similarity_score == 1.0: + if index_name == opensearch_info['ner_index']: + same_dimension_value = similarity_score["entity_table_info"] + return same_dimension_value From 214668df76598b53870ed0d37e0652a86b416331 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 13:40:06 +0800 Subject: [PATCH 111/130] fix some error --- application/nlq/business/vector_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/application/nlq/business/vector_store.py b/application/nlq/business/vector_store.py index 6779129..8d8c6de 100644 --- a/application/nlq/business/vector_store.py +++ b/application/nlq/business/vector_store.py @@ -217,5 +217,5 @@ def search_same_dimension_entity(cls, profile_name, top_k, index_name, embedding similarity_score = similarity_sample["_score"] if similarity_score == 1.0: if index_name == opensearch_info['ner_index']: - same_dimension_value = similarity_score["entity_table_info"] + same_dimension_value = similarity_sample["entity_table_info"] return same_dimension_value From 1c6bcc6f8c6bcd3fc53aec7bcb642a0669bdc30d Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 13:48:10 +0800 Subject: [PATCH 112/130] fix some error --- application/nlq/business/vector_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/application/nlq/business/vector_store.py b/application/nlq/business/vector_store.py index 8d8c6de..a0da865 100644 --- a/application/nlq/business/vector_store.py +++ b/application/nlq/business/vector_store.py @@ -217,5 +217,5 @@ def search_same_dimension_entity(cls, profile_name, top_k, index_name, embedding similarity_score = similarity_sample["_score"] if similarity_score == 1.0: if index_name == opensearch_info['ner_index']: - same_dimension_value = similarity_sample["entity_table_info"] + same_dimension_value = similarity_sample["_source"]["entity_table_info"] return same_dimension_value From 976f428f98660257a2c57067b37b542a165ab249 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Fri, 2 Aug 2024 14:08:08 +0800 Subject: [PATCH 113/130] fix some error --- application/nlq/data_access/opensearch.py | 13 +++++++++++++ .../pages/7_\360\237\223\232_Entity_Management.py" | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/application/nlq/data_access/opensearch.py b/application/nlq/data_access/opensearch.py index 9481214..1fcbf42 100644 --- a/application/nlq/data_access/opensearch.py +++ b/application/nlq/data_access/opensearch.py @@ -164,6 +164,19 @@ def add_sample(self, index_name, profile_name, question, answer, embedding): def add_entity_sample(self, index_name, profile_name, entity, comment, embedding, entity_type="", entity_table_info=[]): entity_count = len(entity_table_info) + comment_value = [] + item_comment_format = "{entity} is located in table {table_name}, column {column_name}, the dimension value is {value}." + if entity_type == "dimension": + if entity_count > 0: + for item in entity_table_info: + table_name = item["table_name"] + column_name = item["column_name"] + value = item["value"] + comment_format = item_comment_format.format(entity=entity, table_name=table_name, + column_name=column_name, value=value) + comment_value.append(comment_format) + comment = ";".join(comment_value) + record = { '_index': index_name, 'entity': entity, diff --git "a/application/pages/7_\360\237\223\232_Entity_Management.py" "b/application/pages/7_\360\237\223\232_Entity_Management.py" index 6bb5f4f..84f32f0 100644 --- "a/application/pages/7_\360\237\223\232_Entity_Management.py" +++ "b/application/pages/7_\360\237\223\232_Entity_Management.py" @@ -103,7 +103,7 @@ def main(): entity_item_table_info["table_name"] = table entity_item_table_info["column_name"] = column entity_item_table_info["value"] = value - VectorStore.add_entity_sample(current_profile, entity, comment, "dimension", entity_item_table_info) + VectorStore.add_entity_sample(current_profile, entity, "", "dimension", entity_item_table_info) st.success('Sample added') time.sleep(2) st.rerun() From 571b1dcbc9e9e737a197fe225f5385782f478be3 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Mon, 5 Aug 2024 09:13:20 +0800 Subject: [PATCH 114/130] add sql error Regenerating --- application/api/service.py | 66 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/application/api/service.py b/application/api/service.py index c945700..bda3f5e 100644 --- a/application/api/service.py +++ b/application/api/service.py @@ -570,6 +570,30 @@ async def ask_websocket(websocket: WebSocket, question: Question): await response_websocket(websocket, session_id, "Database SQL Execution", ContentEnum.STATE, "end", user_id) + if search_intent_result["status_code"] == 500: + await response_websocket(websocket, session_id, "Regenerating SQL ", ContentEnum.STATE, "start", user_id) + + additional_info = '''\n NOTE: when I try to write a SQL {sql_statement}, I got an error {error}. Please consider and avoid this problem. '''.format( + sql_statement=current_nlq_chain.get_generated_sql(), + error=search_intent_result["error_info"]) + normal_search_result = await normal_sql_regenerating_websocket(websocket=websocket, session_id=session_id, search_box=query_rewrite, + model_type=model_type, database_profile=database_profile, + entity_slot_retrieve=normal_search_result.entity_slot_retrieve, + retrieve_result=normal_search_result.retrieve_result, additional_info=additional_info) + + await response_websocket(websocket, session_id, "Regenerating SQL ", ContentEnum.STATE, "start", user_id) + if normal_search_result.sql != "": + current_nlq_chain.set_generated_sql(normal_search_result.sql) + sql_search_result.sql = normal_search_result.sql.strip() + current_nlq_chain.set_generated_sql_response(normal_search_result.response) + if explain_gen_process_flag: + sql_search_result.sql_gen_process = current_nlq_chain.get_generated_sql_explain().strip() + else: + sql_search_result.sql = "-1" + await response_websocket(websocket, session_id, "Database SQL Execution", ContentEnum.STATE, "start", user_id) + search_intent_result = get_sql_result_tool(database_profile, + current_nlq_chain.get_generated_sql()) + await response_websocket(websocket, session_id, "Database SQL Execution", ContentEnum.STATE, "end", user_id) if search_intent_result["status_code"] == 500: sql_search_result.data_analyse = "The query results are temporarily unavailable, please switch to debugging webpage to try the same query and check the log file for more information." else: @@ -587,12 +611,15 @@ async def ask_websocket(websocket: WebSocket, question: Question): sql_search_result.data_analyse = search_intent_analyse_result + await response_websocket(websocket, session_id, "Data Visualization", ContentEnum.STATE, "start", + user_id) model_select_type, show_select_data, select_chart_type, show_chart_data = data_visualization(model_type, query_rewrite, search_intent_result[ "data"], database_profile[ 'prompt_map']) + await response_websocket(websocket, session_id, "Data Visualization", ContentEnum.STATE, "end", user_id) if select_chart_type != "-1": sql_chart_data = ChartEntity(chart_type="", chart_data=[]) @@ -810,6 +837,45 @@ async def normal_text_search_websocket(websocket: WebSocket, session_id: str, se return search_result +async def normal_sql_regenerating_websocket(websocket: WebSocket, session_id: str, search_box, model_type, + database_profile, entity_slot_retrieve, retrieve_result, additional_info): + entity_slot_retrieve = entity_slot_retrieve + retrieve_result = retrieve_result + response = "" + sql = "" + search_result = SearchTextSqlResult(search_query=search_box, entity_slot_retrieve=entity_slot_retrieve, + retrieve_result=retrieve_result, response=response, sql=sql) + try: + if database_profile['db_url'] == '': + conn_name = database_profile['conn_name'] + db_url = ConnectionManagement.get_db_url_by_name(conn_name) + database_profile['db_url'] = db_url + database_profile['db_type'] = ConnectionManagement.get_db_type_by_name(conn_name) + + response = text_to_sql(database_profile['tables_info'], + database_profile['hints'], + database_profile['prompt_map'], + search_box, + model_id=model_type, + sql_examples=retrieve_result, + ner_example=entity_slot_retrieve, + dialect=database_profile['db_type'], + model_provider=None, + additional_info=additional_info) + logger.info("normal_sql_regenerating_websocket") + logger.info(f'{response=}') + sql = get_generated_sql(response) + search_result = SearchTextSqlResult(search_query=search_box, entity_slot_retrieve=entity_slot_retrieve, + retrieve_result=retrieve_result, response=response, sql="") + search_result.entity_slot_retrieve = entity_slot_retrieve + search_result.retrieve_result = retrieve_result + search_result.response = response + search_result.sql = sql + except Exception as e: + logger.error(e) + return search_result + + async def response_websocket(websocket: WebSocket, session_id: str, content, content_type: ContentEnum = ContentEnum.COMMON, status: str = "-1", user_id: str = "admin"): From 1484c696f26f317923819df93e789166d508ef01 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Mon, 5 Aug 2024 16:28:33 +0800 Subject: [PATCH 115/130] fix some SAGEMAKER_ENDPOINT_EMBEDDING error --- application/.env.template | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/application/.env.template b/application/.env.template index b000648..cb60672 100644 --- a/application/.env.template +++ b/application/.env.template @@ -32,4 +32,4 @@ BEDROCK_SECRETS_AK_SK= OPENSEARCH_SECRETS_URL_HOST=opensearch-host-url OPENSEARCH_SECRETS_USERNAME_PASSWORD=opensearch-master-user -SAGEMAKER_ENDPOINT_EMBEDDING= \ No newline at end of file +# SAGEMAKER_ENDPOINT_EMBEDDING= \ No newline at end of file From a160dd7237d263285839fe3ccb194555514346f1 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Tue, 6 Aug 2024 11:44:33 +0800 Subject: [PATCH 116/130] fix some issue and rollback related code --- application/nlq/business/vector_store.py | 41 ++----- .../7_\360\237\223\232_Entity_Management.py" | 25 +--- application/utils/opensearch.py | 110 +----------------- 3 files changed, 16 insertions(+), 160 deletions(-) diff --git a/application/nlq/business/vector_store.py b/application/nlq/business/vector_store.py index a0da865..bd5cd28 100644 --- a/application/nlq/business/vector_store.py +++ b/application/nlq/business/vector_store.py @@ -86,30 +86,16 @@ def add_sample(cls, profile_name, question, answer): logger.info('Sample added') @classmethod - def add_entity_sample(cls, profile_name, entity, comment, entity_type="metrics", entity_info_dict=None): - if entity_type == "metrics" or entity_info_dict is None: - entity_table_info = [] - else: - entity_table_info = [entity_info_dict] + def add_entity_sample(cls, profile_name, entity, comment): logger.info(f'add sample entity: {entity} to profile {profile_name}') if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "": embedding = cls.create_vector_embedding_with_sagemaker(entity) else: embedding = cls.create_vector_embedding_with_bedrock(entity) - if entity_type == "metrics": - has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding) - if has_same_sample: - logger.info(f'delete sample sample entity: {entity} to profile {profile_name}') - else: - same_dimension_value = cls.search_same_dimension_entity(profile_name, 1, opensearch_info['ner_index'], - embedding) - if len(same_dimension_value) > 0: - for item in same_dimension_value: - entity_table_info.append(item) - logger.info("entity_table_info: " + str(entity_table_info)) - has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding) - if cls.opensearch_dao.add_entity_sample(opensearch_info['ner_index'], profile_name, entity, comment, embedding, - entity_type, entity_table_info): + has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding) + if has_same_sample: + logger.info(f'delete sample sample entity: {entity} to profile {profile_name}') + if cls.opensearch_dao.add_entity_sample(opensearch_info['ner_index'], profile_name, entity, comment, embedding): logger.info('Sample added') @classmethod @@ -122,8 +108,7 @@ def add_agent_cot_sample(cls, profile_name, entity, comment): has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['agent_index'], embedding) if has_same_sample: logger.info(f'delete agent sample sample query: {entity} to profile {profile_name}') - if cls.opensearch_dao.add_agent_cot_sample(opensearch_info['agent_index'], profile_name, entity, comment, - embedding): + if cls.opensearch_dao.add_agent_cot_sample(opensearch_info['agent_index'], profile_name, entity, comment, embedding): logger.info('Sample added') @classmethod @@ -206,16 +191,4 @@ def search_same_query(cls, profile_name, top_k, index_name, embedding): return True else: return False - return False - - @classmethod - def search_same_dimension_entity(cls, profile_name, top_k, index_name, embedding): - search_res = cls.search_sample_with_embedding(profile_name, top_k, index_name, embedding) - same_dimension_value = [] - if len(search_res) > 0: - similarity_sample = search_res[0] - similarity_score = similarity_sample["_score"] - if similarity_score == 1.0: - if index_name == opensearch_info['ner_index']: - same_dimension_value = similarity_sample["_source"]["entity_table_info"] - return same_dimension_value + return False \ No newline at end of file diff --git "a/application/pages/7_\360\237\223\232_Entity_Management.py" "b/application/pages/7_\360\237\223\232_Entity_Management.py" index 84f32f0..db1eef5 100644 --- "a/application/pages/7_\360\237\223\232_Entity_Management.py" +++ "b/application/pages/7_\360\237\223\232_Entity_Management.py" @@ -62,8 +62,8 @@ def main(): index=None, placeholder="Please select data profile...", key='current_profile_name') - tab_view, tab_add, tab_dimension, tab_search, batch_insert = st.tabs( - ['View Entity Info', 'Add Metrics Entity', 'Add Dimension Entity', 'Entity Search', 'Batch Insert Entity']) + tab_view, tab_add, tab_search, batch_insert = st.tabs( + ['View Samples', 'Add New Sample', 'Sample Search', 'Batch Insert Samples']) if current_profile is not None: st.session_state['current_profile'] = current_profile with tab_view: @@ -81,7 +81,7 @@ def main(): entity = st.text_input('Entity', key='index_question') comment = st.text_area('Comment', key='index_answer', height=300) - if st.button('Add Metrics Entity', type='primary'): + if st.button('Submit', type='primary'): if len(entity) > 0 and len(comment) > 0: VectorStore.add_entity_sample(current_profile, entity, comment) st.success('Sample added') @@ -91,25 +91,6 @@ def main(): st.rerun() else: st.error('please input valid question and answer') - with tab_dimension: - if current_profile is not None: - entity = st.text_input('Entity', key='index_entity') - table = st.text_input('Table', key='index_table') - column = st.text_input('Column', key='index_column') - value = st.text_input('Dimension value', key='index_value') - if st.button('Add Dimension Entity', type='primary'): - if len(entity) > 0 and len(table) > 0 and len(column) > 0 and len(value) > 0: - entity_item_table_info = {} - entity_item_table_info["table_name"] = table - entity_item_table_info["column_name"] = column - entity_item_table_info["value"] = value - VectorStore.add_entity_sample(current_profile, entity, "", "dimension", entity_item_table_info) - st.success('Sample added') - time.sleep(2) - st.rerun() - else: - st.error('please input valid question and answer') - with tab_search: if current_profile is not None: entity_search = st.text_input('Entity Search', key='index_entity_search') diff --git a/application/utils/opensearch.py b/application/utils/opensearch.py index 90f02e7..4775a7f 100644 --- a/application/utils/opensearch.py +++ b/application/utils/opensearch.py @@ -3,7 +3,7 @@ from opensearchpy.helpers import bulk import logging from utils.llm import create_vector_embedding_with_bedrock, create_vector_embedding_with_sagemaker -from utils.env_var import opensearch_info, SAGEMAKER_ENDPOINT_EMBEDDING, AOS_INDEX_NER +from utils.env_var import opensearch_info, SAGEMAKER_ENDPOINT_EMBEDDING logger = logging.getLogger(__name__) @@ -112,97 +112,6 @@ def create_index_mapping(opensearch_client, index_name, dimension): }, "profile": { "type": "keyword" - }, - "entity_type": { - "type": "keyword" - }, - "entity_count": { - "type": "integer" - }, - "entity_table_info": { - "type": "nested", - "properties": { - "table_name": { - "type": "keyword" - }, - "column_name": { - "type": "keyword" - }, - "value": { - "type": "text" - } - } - } - } - } - ) - return bool(response['acknowledged']) - - -def check_field_exists(opensearch_client, index_name, field_name): - """ - Check if a field exists in the specified index - :param opensearch_client: OpenSearch client - :param index_name: Name of the index - :param field_name: Name of the field to check - :return: True if the field exists, False otherwise - """ - try: - # Get the mapping for the index - mapping = opensearch_client.indices.get_mapping(index=index_name) - - logger.info(mapping) - # Traverse the mapping to check if the field exists - if index_name in mapping: - properties = mapping[index_name]['mappings']['properties'] - if field_name in properties: - return True - except Exception as e: - logger.error(f"Error checking field {field_name}: {e}") - - return False - -def update_index_mapping(opensearch_client, index_name, dimension): - """ - Create index mapping - :param opensearch_client: - :param index_name: - :param dimension: - :return: - """ - response = opensearch_client.indices.put_mapping( - index=index_name, - body={ - "properties": { - "vector_field": { - "type": "knn_vector", - "dimension": dimension - }, - "text": { - "type": "keyword" - }, - "profile": { - "type": "keyword" - }, - "entity_type": { - "type": "keyword" - }, - "entity_count": { - "type": "integer" - }, - "entity_table_info": { - "type": "nested", - "properties": { - "table_name": { - "type": "keyword" - }, - "column_name": { - "type": "keyword" - }, - "value": { - "type": "text" - } - } } } } @@ -236,8 +145,7 @@ def get_retrieve_opensearch(opensearch_info, query, search_type, selected_profil index_name = opensearch_info['agent_index'] if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "": - records_with_embedding = create_vector_embedding_with_sagemaker(SAGEMAKER_ENDPOINT_EMBEDDING, query, - index_name=index_name) + records_with_embedding = create_vector_embedding_with_sagemaker(SAGEMAKER_ENDPOINT_EMBEDDING, query, index_name=index_name) else: records_with_embedding = create_vector_embedding_with_bedrock(query, index_name=index_name) retrieve_result = retrieve_results_from_opensearch( @@ -261,8 +169,7 @@ def get_retrieve_opensearch(opensearch_info, query, search_type, selected_profil def retrieve_results_from_opensearch(index_name, region_name, domain, opensearch_user, opensearch_password, query_embedding, top_k=3, host='', port=443, profile_name=None): - opensearch_client = get_opensearch_cluster_client(domain, host, port, opensearch_user, opensearch_password, - region_name) + opensearch_client = get_opensearch_cluster_client(domain, host, port, opensearch_user, opensearch_password, region_name) search_query = { "size": top_k, # Adjust the size as needed to retrieve more or fewer results "query": { @@ -298,8 +205,8 @@ def retrieve_results_from_opensearch(index_name, region_name, domain, opensearch def upload_results_to_opensearch(region_name, domain, opensearch_user, opensearch_password, index_name, query, sql, host='', port=443): - opensearch_client = get_opensearch_cluster_client(domain, host, port, opensearch_user, opensearch_password, - region_name) + + opensearch_client = get_opensearch_cluster_client(domain, host, port, opensearch_user, opensearch_password, region_name) # Vector embedding using Amazon Bedrock Titan text embedding logger.info(f"Creating embeddings for records") @@ -349,13 +256,8 @@ def opensearch_index_init(): logger.info(f"OpenSearch Index mapping created") else: index_create_success = False - else: - if index_name == AOS_INDEX_NER: - check_flag = check_field_exists(opensearch_client, index_name, "ner_table_info") - logger.info(f"check index flag: {check_flag}") - update_index_mapping(opensearch_client, index_name, dimension) return index_create_success except Exception as e: logger.error("create index error") logger.error(e) - return False + return False \ No newline at end of file From bd8ea4f24d646c453b6e081c1506861ee379d758 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Tue, 6 Aug 2024 13:51:54 +0800 Subject: [PATCH 117/130] change title name --- application/Index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/application/Index.py b/application/Index.py index dca3964..b27fa05 100644 --- a/application/Index.py +++ b/application/Index.py @@ -2,7 +2,7 @@ from utils.navigation import get_authenticator, force_set_cookie st.set_page_config( - page_title="Intelligent BI", + page_title="Generative BI", page_icon="👋", ) From c966261fa433f288847e340e39ce37a886265e8a Mon Sep 17 00:00:00 2001 From: wubinbin Date: Wed, 7 Aug 2024 13:12:03 +0800 Subject: [PATCH 118/130] feat: bundle issue --- report-front-end/.env | 2 +- report-front-end/Dockerfile | 8 +++++--- .../components/chatbot-panel/chat-message.tsx | 19 ------------------- 3 files changed, 6 insertions(+), 23 deletions(-) diff --git a/report-front-end/.env b/report-front-end/.env index 4eff791..05d174d 100644 --- a/report-front-end/.env +++ b/report-front-end/.env @@ -9,7 +9,7 @@ VITE_RIGHT_LOGO= # Login configuration, e.g. Cognito | None -VITE_LOGIN_TYPE=Cognito +VITE_LOGIN_TYPE=PLACEHOLDER_VITE_LOGIN_TYPE # KEEP the placeholder values if using CDK to deploy the backend! diff --git a/report-front-end/Dockerfile b/report-front-end/Dockerfile index 0ddb0fc..2511ab7 100644 --- a/report-front-end/Dockerfile +++ b/report-front-end/Dockerfile @@ -1,22 +1,24 @@ FROM public.ecr.aws/docker/library/node:18.17.0 AS builder WORKDIR /frontend COPY package*.json ./ +COPY . . +COPY .env /frontend/.env + ARG AWS_REGION RUN echo "Current AWS Region: $AWS_REGION" RUN if [ "$AWS_REGION" = "cn-north-1" ] || [ "$AWS_REGION" = "cn-northwest-1" ]; then \ + sed -i "s/PLACEHOLDER_VITE_LOGIN_TYPE/None/g" .env && \ npm config set registry https://registry.npmmirror.com && \ npm install; \ else \ + sed -i "s/PLACEHOLDER_VITE_LOGIN_TYPE/Cognito/g" .env && \ npm install; \ fi - -COPY . . RUN npm run build -COPY .env /frontend/.env FROM public.ecr.aws/docker/library/nginx:1.23-alpine COPY --from=builder /frontend/dist/ /usr/share/nginx/html/ diff --git a/report-front-end/src/components/chatbot-panel/chat-message.tsx b/report-front-end/src/components/chatbot-panel/chat-message.tsx index f250499..5ab15f1 100644 --- a/report-front-end/src/components/chatbot-panel/chat-message.tsx +++ b/report-front-end/src/components/chatbot-panel/chat-message.tsx @@ -355,25 +355,6 @@ const DataTable = (props: { distributions: []; header: [] }) => { filteringPlaceholder="Search" /> } -/* preferences={ - - }*/ /> setVisible(false)} From 05c1b78f69633e70b3e9ce75eab50eaa001dd6ae Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Thu, 8 Aug 2024 15:27:41 +0800 Subject: [PATCH 119/130] remove segamaker --- source/model/embedding/code/model.py | 104 ----------- source/model/embedding/code/requirements.txt | 1 - .../model/embedding/code/serving.properties | 5 - source/model/embedding/model/bge-m3_model.py | 104 ----------- source/model/embedding/model/model.sh | 81 --------- .../model/embedding/model/serving.properties | 5 - source/model/internlm/code/model.py | 168 ------------------ source/model/internlm/code/requirements.txt | 7 - source/model/internlm/code/serving.properties | 3 - .../internlm/model/internlm2-chat-7b_model.py | 168 ------------------ source/model/internlm/model/model.sh | 81 --------- .../model/internlm/model/serving.properties | 3 - source/model/prepare_model.sh | 68 ------- source/model/sqlcoder/code/model.py | 137 -------------- source/model/sqlcoder/code/requirements.txt | 2 - source/model/sqlcoder/code/serving.properties | 5 - source/model/sqlcoder/model/model.sh | 81 --------- .../model/sqlcoder/model/serving.properties | 5 - .../sqlcoder/model/sqlcoder-7b-2_model.py | 137 -------------- 19 files changed, 1165 deletions(-) delete mode 100644 source/model/embedding/code/model.py delete mode 100644 source/model/embedding/code/requirements.txt delete mode 100644 source/model/embedding/code/serving.properties delete mode 100644 source/model/embedding/model/bge-m3_model.py delete mode 100755 source/model/embedding/model/model.sh delete mode 100644 source/model/embedding/model/serving.properties delete mode 100644 source/model/internlm/code/model.py delete mode 100644 source/model/internlm/code/requirements.txt delete mode 100644 source/model/internlm/code/serving.properties delete mode 100644 source/model/internlm/model/internlm2-chat-7b_model.py delete mode 100755 source/model/internlm/model/model.sh delete mode 100644 source/model/internlm/model/serving.properties delete mode 100644 source/model/prepare_model.sh delete mode 100644 source/model/sqlcoder/code/model.py delete mode 100644 source/model/sqlcoder/code/requirements.txt delete mode 100644 source/model/sqlcoder/code/serving.properties delete mode 100755 source/model/sqlcoder/model/model.sh delete mode 100644 source/model/sqlcoder/model/serving.properties delete mode 100644 source/model/sqlcoder/model/sqlcoder-7b-2_model.py diff --git a/source/model/embedding/code/model.py b/source/model/embedding/code/model.py deleted file mode 100644 index eae95d6..0000000 --- a/source/model/embedding/code/model.py +++ /dev/null @@ -1,104 +0,0 @@ -import logging -import math -import os - -import torch -from djl_python import Input, Output -from FlagEmbedding import BGEM3FlagModel -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, pipeline - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -print(f"--device={device}") - - -def load_model(properties): - # tensor_parallel = properties["tensor_parallel_degree"] - model_location = properties["model_dir"] - if "model_id" in properties: - model_location = properties["model_id"] - logging.info(f"Loading model in {model_location}") - - # tokenizer = AutoTokenizer.from_pretrained(model_location, trust_remote_code=True) - # tokenizer.padding_side = 'right' - # model = AutoModel.from_pretrained( - # model_location, - # # device_map="balanced_low_0", - # trust_remote_code=True - # ).half() - # # load the model on GPU - # model.to(device) - # model.requires_grad_(False) - # model.eval() - - model = BGEM3FlagModel( - model_location, use_fp16=True - ) # Setting use_fp16 to True speeds up computation with a slight performance degradation - - return model - - -model = None -tokenizer = None -generator = None - - -def mean_pooling(model_output, attention_mask): - token_embeddings = model_output[0].to( - device - ) # First element of model_output contains all token embeddings - input_mask_expanded = ( - attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float().to(device) - ) - return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( - input_mask_expanded.sum(1), min=1e-9 - ) - - -def handle(inputs: Input): - global model - if not model: - model = load_model(inputs.get_properties()) - - if inputs.is_empty(): - return None - data = inputs.get_as_json() - - input_sentences = data["inputs"] - batch_size = data["batch_size"] - max_length = data["max_length"] - return_type = data["return_type"] - - logging.info(f"inputs: {input_sentences}") - - if return_type == "dense": - encoding_results = model.encode( - input_sentences, batch_size=batch_size, max_length=max_length - ) - elif return_type == "sparse": - encoding_results = model.encode( - input_sentences, - return_dense=False, - return_sparse=True, - return_colbert_vecs=False, - ) - elif return_type == "colbert": - encoding_results = model.encode( - input_sentences, - return_dense=False, - return_sparse=False, - return_colbert_vecs=True, - ) - elif return_type == "all": - encoding_results = model.encode( - input_sentences, - batch_size=batch_size, - max_length=max_length, - return_dense=True, - return_sparse=True, - return_colbert_vecs=True, - ) - - # encoding_results = [encoding_results] - - result = {"sentence_embeddings": encoding_results} - return Output().add_as_json(result) diff --git a/source/model/embedding/code/requirements.txt b/source/model/embedding/code/requirements.txt deleted file mode 100644 index c6bfd06..0000000 --- a/source/model/embedding/code/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -FlagEmbedding==1.2.5 diff --git a/source/model/embedding/code/serving.properties b/source/model/embedding/code/serving.properties deleted file mode 100644 index 59769a8..0000000 --- a/source/model/embedding/code/serving.properties +++ /dev/null @@ -1,5 +0,0 @@ -engine=Python -option.tensor_parallel_degree=1 -# update according to your own path -# option.s3url = s3://<_S3ModelAssets>/<_AssetsStack._embeddingModelPrefix> -option.s3url = s3://llm-bot-models-256374081253-cn-north-1/bge-m3/ \ No newline at end of file diff --git a/source/model/embedding/model/bge-m3_model.py b/source/model/embedding/model/bge-m3_model.py deleted file mode 100644 index eae95d6..0000000 --- a/source/model/embedding/model/bge-m3_model.py +++ /dev/null @@ -1,104 +0,0 @@ -import logging -import math -import os - -import torch -from djl_python import Input, Output -from FlagEmbedding import BGEM3FlagModel -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, pipeline - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -print(f"--device={device}") - - -def load_model(properties): - # tensor_parallel = properties["tensor_parallel_degree"] - model_location = properties["model_dir"] - if "model_id" in properties: - model_location = properties["model_id"] - logging.info(f"Loading model in {model_location}") - - # tokenizer = AutoTokenizer.from_pretrained(model_location, trust_remote_code=True) - # tokenizer.padding_side = 'right' - # model = AutoModel.from_pretrained( - # model_location, - # # device_map="balanced_low_0", - # trust_remote_code=True - # ).half() - # # load the model on GPU - # model.to(device) - # model.requires_grad_(False) - # model.eval() - - model = BGEM3FlagModel( - model_location, use_fp16=True - ) # Setting use_fp16 to True speeds up computation with a slight performance degradation - - return model - - -model = None -tokenizer = None -generator = None - - -def mean_pooling(model_output, attention_mask): - token_embeddings = model_output[0].to( - device - ) # First element of model_output contains all token embeddings - input_mask_expanded = ( - attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float().to(device) - ) - return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( - input_mask_expanded.sum(1), min=1e-9 - ) - - -def handle(inputs: Input): - global model - if not model: - model = load_model(inputs.get_properties()) - - if inputs.is_empty(): - return None - data = inputs.get_as_json() - - input_sentences = data["inputs"] - batch_size = data["batch_size"] - max_length = data["max_length"] - return_type = data["return_type"] - - logging.info(f"inputs: {input_sentences}") - - if return_type == "dense": - encoding_results = model.encode( - input_sentences, batch_size=batch_size, max_length=max_length - ) - elif return_type == "sparse": - encoding_results = model.encode( - input_sentences, - return_dense=False, - return_sparse=True, - return_colbert_vecs=False, - ) - elif return_type == "colbert": - encoding_results = model.encode( - input_sentences, - return_dense=False, - return_sparse=False, - return_colbert_vecs=True, - ) - elif return_type == "all": - encoding_results = model.encode( - input_sentences, - batch_size=batch_size, - max_length=max_length, - return_dense=True, - return_sparse=True, - return_colbert_vecs=True, - ) - - # encoding_results = [encoding_results] - - result = {"sentence_embeddings": encoding_results} - return Output().add_as_json(result) diff --git a/source/model/embedding/model/model.sh b/source/model/embedding/model/model.sh deleted file mode 100755 index 7ac66ed..0000000 --- a/source/model/embedding/model/model.sh +++ /dev/null @@ -1,81 +0,0 @@ -function usage { - echo "Make sure python3 installed properly. Usage: $0 -t TOKEN [-m MODEL_NAME] [-c COMMIT_HASH] [-s S3_BUCKET_NAME]" - echo " -t TOKEN Hugging Face token " - echo " -h Hugging Face Repo Name Hugging Face repo " - echo " -m MODEL_NAME Model name (default: csdc-atl/buffer-cross-001)" - echo " -c COMMIT_HASH Commit hash (default: 46d270928463db49b317e5ea469a8ac8152f4a13)" - echo " -p Tensor Parrallel degree Parameters in serving.properties " - echo " -s S3_BUCKET_NAME S3 bucket name to upload the model (default: llm-rag)" - exit 1 -} - -# Default values -model_name="csdc-atl/buffer-cross-001" -commit_hash="46d270928463db49b317e5ea469a8ac8152f4a13" -s3_bucket_name="llm-rag" # Default S3 bucket name - -# Parse command-line options -while getopts ":t:h:m:c:p:s:" opt; do - case $opt in - t) hf_token="$OPTARG" ;; - h) hf_name="$OPTARG" ;; - m) model_name="$OPTARG" ;; - c) commit_hash="$OPTARG" ;; - p) tensor_parallel_degree="$OPTARG" ;; - s) s3_bucket_name="$OPTARG" ;; - \?) echo "Invalid option: -$OPTARG" >&2; usage ;; - :) echo "Option -$OPTARG requires an argument." >&2; usage ;; - esac -done - - -# # Validate the hf_token and python3 interpreter exist -# if [ -z "$hf_token" ] || ! command -v python3 &> /dev/null; then -# usage -# fi - -# # Install necessary packages -pip install huggingface-hub -Uqq -pip install -U sagemaker - -# Define local model path -local_model_path="./${model_name}" - -# Uncomment the line below if you want to create a specific directory for the model -# mkdir -p $local_model_path - -# Download model snapshot in current folder without model prefix added -# python3 -c "from huggingface_hub import snapshot_download; from pathlib import Path; snapshot_download(repo_id='$model_name', revision='$commit_hash', cache_dir=Path('.'), token='$hf_token')" -python3 -c "from huggingface_hub import snapshot_download; from pathlib import Path; snapshot_download(repo_id='$hf_name', revision='$commit_hash', cache_dir='$local_model_path')" - -# Find model snapshot path with the first search result -model_snapshot_path=$(find $local_model_path -path '*/snapshots/*' -type d -print -quit) -echo "Model snapshot path: $model_snapshot_path" - -# s3:/// -aws s3 cp --recursive $model_snapshot_path s3://$s3_bucket_name/$model_name - -# Prepare model.py files according to model name -model_inference_file="./${model_name}_model.py" -cp $model_inference_file ../code/model.py - -# Modify the content of serving.properties and re-tar the model -cp serving.properties ../code/serving.properties -cd ../code -file_path="serving.properties" -os_type=$(uname -s) - -if [ "$os_type" == "Darwin" ]; then - sed -i "" "s|option.s3url = S3PATH|option.s3url = s3://$s3_bucket_name/$model_name/|g" $file_path - sed -i "" "s|option.tensor_parallel_degree=tpd|option.tensor_parallel_degree=$tensor_parallel_degree|g" $file_path -else - sed -i "s|option.s3url = S3PATH|option.s3url = s3://$s3_bucket_name/$model_name/|g" $file_path - sed -i "s|option.tensor_parallel_degree=tpd|option.tensor_parallel_degree=$tensor_parallel_degree|g" $file_path -fi - - -rm model.tar.gz -tar czvf model.tar.gz * - -code_path="${model_name}_deploy_code" -aws s3 cp model.tar.gz s3://$s3_bucket_name/$code_path/model.tar.gz diff --git a/source/model/embedding/model/serving.properties b/source/model/embedding/model/serving.properties deleted file mode 100644 index 09bf82f..0000000 --- a/source/model/embedding/model/serving.properties +++ /dev/null @@ -1,5 +0,0 @@ -engine=Python -option.tensor_parallel_degree=tpd -# update according to your own path -# option.s3url = s3://<_S3ModelAssets>/<_AssetsStack._embeddingModelPrefix> -option.s3url = S3PATH \ No newline at end of file diff --git a/source/model/internlm/code/model.py b/source/model/internlm/code/model.py deleted file mode 100644 index e946207..0000000 --- a/source/model/internlm/code/model.py +++ /dev/null @@ -1,168 +0,0 @@ -import time -import sys, os -os.environ['PYTHONUNBUFFERED'] = "1" -import traceback -import sys -import torch -import gc -from typing import List,Tuple -import logging -try: - from transformers.generation.streamers import BaseStreamer -except: # noqa # pylint: disable=bare-except - BaseStreamer = None -import queue -import threading -import time -from queue import Empty -from djl_python import Input, Output -import torch -import json -import types -import threading -from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer -# from transformers.generation.utils import GenerationConfig -import traceback -from transformers import AutoTokenizer,GPTQConfig,AutoModelForCausalLM - -from exllamav2 import ( - ExLlamaV2, - ExLlamaV2Config, - ExLlamaV2Cache, - ExLlamaV2Tokenizer, -) - -from exllamav2.generator import ( - ExLlamaV2StreamingGenerator, - ExLlamaV2Sampler -) -handle_lock = threading.Lock() -logger = logging.getLogger("sagemaker-inference") -logger.info(f'logger handlers: {logger.handlers}') - -generator = None -tokenizer = None - - -def new_decode(self, ids, decode_special_tokens = False): - ori_decode = tokenizer.decode - return ori_decode(ids, decode_special_tokens = True) - -def get_model(properties): - model_dir = properties['model_dir'] - model_path = os.path.join(model_dir, 'hf_model/') - if "model_id" in properties: - model_path = properties['model_id'] - logger.info(f'properties: {properties}') - logger.info(f'model_path: {model_path}') - # local_rank = int(os.getenv('LOCAL_RANK', '0')) - model_directory = model_path - - config = ExLlamaV2Config() - config.model_dir = model_directory - config.prepare() - - model = ExLlamaV2(config) - logger.info("Loading model: " + model_directory) - - cache = ExLlamaV2Cache(model, lazy = True) - model.load_autosplit(cache) - - tokenizer = ExLlamaV2Tokenizer(config) - - generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) - - return tokenizer,generator - -def _default_stream_output_formatter(token_texts): - if isinstance(token_texts,Exception): - token_texts = {'error_msg':str(token_texts)} - else: - token_texts = {"outputs": token_texts} - json_encoded_str = json.dumps(token_texts) + "\n" - return bytearray(json_encoded_str.encode("utf-8")) - -def generate(**body): - query = body.pop('query') - stream = body.pop('stream',False) - stop_words = body.pop('stop_tokens',None) - - stop_token_ids = [ - tokenizer.eos_token_id, - tokenizer.encode('<|im_end|>',encode_special_tokens=True).tolist()[0][0] - ] - - if stop_words: - assert isinstance(stop_words,list), stop_words - for stop_word in stop_words: - stop_token_ids.append(tokenizer.encode(stop_word,encode_special_tokens=True).tolist()[0][0]) - - # body.update({"do_preprocess": False}) - timeout = body.pop('timeout',60) - settings = ExLlamaV2Sampler.Settings() - settings.temperature = body.get('temperature',0.1) - settings.top_k = body.get('top_k',50) - settings.top_p = body.get('top_p',0.8) - settings.top_a = body.get('top_a',0.0) - settings.token_repetition_penalty = 1.0 - # tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0] - # settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id]) - max_new_tokens = body.get('max_new_tokens',500) - input_ids = tokenizer.encode(query,encode_special_tokens=True,add_bos=True) - prompt_tokens = input_ids.shape[-1] - generator.warmup() - generator.set_stop_conditions(stop_token_ids) - - generator.begin_stream(input_ids, settings) - - def _generator_helper(): - try: - generated_tokens = 0 - while True: - chunk, eos, _ = generator.stream() - generated_tokens += 1 - yield chunk - if eos or generated_tokens == max_new_tokens: break - - finally: - traceback.clear_frames(sys.exc_info()[2]) - gc.collect() - torch.cuda.empty_cache() - - stream_generator = _generator_helper() - if stream: - return stream_generator - r = "" - for i in stream_generator: - r += i - return r - - -def _handle(inputs) -> None: - if inputs.is_empty(): - logger.info('inputs is empty') - # Model server makes an empty call to warmup the model on startup - return None - torch.cuda.empty_cache() - body = inputs.get_as_json() - stream = body.get('stream',False) - logger.info(f'body: {body}') - response = generate(**body) - if stream: - return Output().add_stream_content(response,output_formatter=_default_stream_output_formatter) - else: - return Output().add_as_json(response) - - -def handle(inputs: Input) -> None: - task_request_time = time.time() - logger.info(f'recieve request task: {task_request_time},{inputs}') - with handle_lock: - global generator,tokenizer - if generator is None: - tokenizer, generator = get_model(inputs.get_properties()) - tokenizer.decode = types.MethodType(new_decode, tokenizer) - logger.info(f'executing request task, wait time: {time.time()-task_request_time}s') - return _handle(inputs) - - \ No newline at end of file diff --git a/source/model/internlm/code/requirements.txt b/source/model/internlm/code/requirements.txt deleted file mode 100644 index ebdb182..0000000 --- a/source/model/internlm/code/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -exllamav2==0.0.14 -torch==2.2.0 -sentencepiece==0.1.99 -accelerate==0.25.0 -bitsandbytes==0.41.1 -transformers==4.38.0 -einops==0.7.0 \ No newline at end of file diff --git a/source/model/internlm/code/serving.properties b/source/model/internlm/code/serving.properties deleted file mode 100644 index b3fb498..0000000 --- a/source/model/internlm/code/serving.properties +++ /dev/null @@ -1,3 +0,0 @@ -engine=Python -option.enable_streaming=true -option.s3url = s3://llm-bot-models-256374081253-cn-north-1/internlm2-chat-7b/ \ No newline at end of file diff --git a/source/model/internlm/model/internlm2-chat-7b_model.py b/source/model/internlm/model/internlm2-chat-7b_model.py deleted file mode 100644 index e946207..0000000 --- a/source/model/internlm/model/internlm2-chat-7b_model.py +++ /dev/null @@ -1,168 +0,0 @@ -import time -import sys, os -os.environ['PYTHONUNBUFFERED'] = "1" -import traceback -import sys -import torch -import gc -from typing import List,Tuple -import logging -try: - from transformers.generation.streamers import BaseStreamer -except: # noqa # pylint: disable=bare-except - BaseStreamer = None -import queue -import threading -import time -from queue import Empty -from djl_python import Input, Output -import torch -import json -import types -import threading -from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer -# from transformers.generation.utils import GenerationConfig -import traceback -from transformers import AutoTokenizer,GPTQConfig,AutoModelForCausalLM - -from exllamav2 import ( - ExLlamaV2, - ExLlamaV2Config, - ExLlamaV2Cache, - ExLlamaV2Tokenizer, -) - -from exllamav2.generator import ( - ExLlamaV2StreamingGenerator, - ExLlamaV2Sampler -) -handle_lock = threading.Lock() -logger = logging.getLogger("sagemaker-inference") -logger.info(f'logger handlers: {logger.handlers}') - -generator = None -tokenizer = None - - -def new_decode(self, ids, decode_special_tokens = False): - ori_decode = tokenizer.decode - return ori_decode(ids, decode_special_tokens = True) - -def get_model(properties): - model_dir = properties['model_dir'] - model_path = os.path.join(model_dir, 'hf_model/') - if "model_id" in properties: - model_path = properties['model_id'] - logger.info(f'properties: {properties}') - logger.info(f'model_path: {model_path}') - # local_rank = int(os.getenv('LOCAL_RANK', '0')) - model_directory = model_path - - config = ExLlamaV2Config() - config.model_dir = model_directory - config.prepare() - - model = ExLlamaV2(config) - logger.info("Loading model: " + model_directory) - - cache = ExLlamaV2Cache(model, lazy = True) - model.load_autosplit(cache) - - tokenizer = ExLlamaV2Tokenizer(config) - - generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) - - return tokenizer,generator - -def _default_stream_output_formatter(token_texts): - if isinstance(token_texts,Exception): - token_texts = {'error_msg':str(token_texts)} - else: - token_texts = {"outputs": token_texts} - json_encoded_str = json.dumps(token_texts) + "\n" - return bytearray(json_encoded_str.encode("utf-8")) - -def generate(**body): - query = body.pop('query') - stream = body.pop('stream',False) - stop_words = body.pop('stop_tokens',None) - - stop_token_ids = [ - tokenizer.eos_token_id, - tokenizer.encode('<|im_end|>',encode_special_tokens=True).tolist()[0][0] - ] - - if stop_words: - assert isinstance(stop_words,list), stop_words - for stop_word in stop_words: - stop_token_ids.append(tokenizer.encode(stop_word,encode_special_tokens=True).tolist()[0][0]) - - # body.update({"do_preprocess": False}) - timeout = body.pop('timeout',60) - settings = ExLlamaV2Sampler.Settings() - settings.temperature = body.get('temperature',0.1) - settings.top_k = body.get('top_k',50) - settings.top_p = body.get('top_p',0.8) - settings.top_a = body.get('top_a',0.0) - settings.token_repetition_penalty = 1.0 - # tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0] - # settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id]) - max_new_tokens = body.get('max_new_tokens',500) - input_ids = tokenizer.encode(query,encode_special_tokens=True,add_bos=True) - prompt_tokens = input_ids.shape[-1] - generator.warmup() - generator.set_stop_conditions(stop_token_ids) - - generator.begin_stream(input_ids, settings) - - def _generator_helper(): - try: - generated_tokens = 0 - while True: - chunk, eos, _ = generator.stream() - generated_tokens += 1 - yield chunk - if eos or generated_tokens == max_new_tokens: break - - finally: - traceback.clear_frames(sys.exc_info()[2]) - gc.collect() - torch.cuda.empty_cache() - - stream_generator = _generator_helper() - if stream: - return stream_generator - r = "" - for i in stream_generator: - r += i - return r - - -def _handle(inputs) -> None: - if inputs.is_empty(): - logger.info('inputs is empty') - # Model server makes an empty call to warmup the model on startup - return None - torch.cuda.empty_cache() - body = inputs.get_as_json() - stream = body.get('stream',False) - logger.info(f'body: {body}') - response = generate(**body) - if stream: - return Output().add_stream_content(response,output_formatter=_default_stream_output_formatter) - else: - return Output().add_as_json(response) - - -def handle(inputs: Input) -> None: - task_request_time = time.time() - logger.info(f'recieve request task: {task_request_time},{inputs}') - with handle_lock: - global generator,tokenizer - if generator is None: - tokenizer, generator = get_model(inputs.get_properties()) - tokenizer.decode = types.MethodType(new_decode, tokenizer) - logger.info(f'executing request task, wait time: {time.time()-task_request_time}s') - return _handle(inputs) - - \ No newline at end of file diff --git a/source/model/internlm/model/model.sh b/source/model/internlm/model/model.sh deleted file mode 100755 index d8d196b..0000000 --- a/source/model/internlm/model/model.sh +++ /dev/null @@ -1,81 +0,0 @@ -function usage { - echo "Make sure python3 installed properly. Usage: $0 -t TOKEN [-m MODEL_NAME] [-c COMMIT_HASH] [-s S3_BUCKET_NAME]" - echo " -t TOKEN Hugging Face token " - echo " -h Hugging Face Repo Name Hugging Face repo " - echo " -m MODEL_NAME Model name (default: csdc-atl/buffer-cross-001)" - echo " -c COMMIT_HASH Commit hash (default: 46d270928463db49b317e5ea469a8ac8152f4a13)" - echo " -p Tensor Parrallel degree Parameters in serving.properties " - echo " -s S3_BUCKET_NAME S3 bucket name to upload the model (default: llm-rag)" - exit 1 -} - -# Default values -model_name="" -commit_hash="" -s3_bucket_name="" # Default S3 bucket name - -# Parse command-line options -while getopts ":t:h:m:c:p:s:" opt; do - case $opt in - t) hf_token="$OPTARG" ;; - h) hf_name="$OPTARG" ;; - m) model_name="$OPTARG" ;; - c) commit_hash="$OPTARG" ;; - p) tensor_parallel_degree="$OPTARG" ;; - s) s3_bucket_name="$OPTARG" ;; - \?) echo "Invalid option: -$OPTARG" >&2; usage ;; - :) echo "Option -$OPTARG requires an argument." >&2; usage ;; - esac -done - - -# # Validate the hf_token and python3 interpreter exist -# if [ -z "$hf_token" ] || ! command -v python3 &> /dev/null; then -# usage -# fi - -# # Install necessary packages -pip install huggingface-hub -Uqq -pip install -U sagemaker - -# Define local model path -local_model_path="./${model_name}" - -# Uncomment the line below if you want to create a specific directory for the model -# mkdir -p $local_model_path - -# Download model snapshot in current folder without model prefix added -# python3 -c "from huggingface_hub import snapshot_download; from pathlib import Path; snapshot_download(repo_id='$model_name', revision='$commit_hash', cache_dir=Path('.'), token='$hf_token')" -python3 -c "from huggingface_hub import snapshot_download; from pathlib import Path; snapshot_download(repo_id='$hf_name', revision='$commit_hash', cache_dir='$local_model_path')" - -# Find model snapshot path with the first search result -model_snapshot_path=$(find $local_model_path -path '*/snapshots/*' -type d -print -quit) -echo "Model snapshot path: $model_snapshot_path" - -# s3:/// -aws s3 cp --recursive $model_snapshot_path s3://$s3_bucket_name/$model_name - -# Prepare model.py files according to model name -model_inference_file="./${model_name}_model.py" -cp $model_inference_file ../code/model.py - -# Modify the content of serving.properties and re-tar the model -cp serving.properties ../code/serving.properties -cd ../code -file_path="serving.properties" -os_type=$(uname -s) - -if [ "$os_type" == "Darwin" ]; then - sed -i "" "s|option.s3url = S3PATH|option.s3url = s3://$s3_bucket_name/$model_name/|g" $file_path - sed -i "" "s|option.tensor_parallel_degree=tpd|option.tensor_parallel_degree=$tensor_parallel_degree|g" $file_path -else - sed -i "s|option.s3url = S3PATH|option.s3url = s3://$s3_bucket_name/$model_name/|g" $file_path - sed -i "s|option.tensor_parallel_degree=tpd|option.tensor_parallel_degree=$tensor_parallel_degree|g" $file_path -fi - - -rm model.tar.gz -tar czvf model.tar.gz * - -code_path="${model_name}_deploy_code" -aws s3 cp model.tar.gz s3://$s3_bucket_name/$code_path/model.tar.gz diff --git a/source/model/internlm/model/serving.properties b/source/model/internlm/model/serving.properties deleted file mode 100644 index dfed737..0000000 --- a/source/model/internlm/model/serving.properties +++ /dev/null @@ -1,3 +0,0 @@ -engine=Python -option.enable_streaming=true -option.s3url = S3PATH \ No newline at end of file diff --git a/source/model/prepare_model.sh b/source/model/prepare_model.sh deleted file mode 100644 index c67be7b..0000000 --- a/source/model/prepare_model.sh +++ /dev/null @@ -1,68 +0,0 @@ -function usage { - echo "Make sure python3 installed properly. Usage: $0 -s S3_BUCKET_NAME" - echo " -s S3_BUCKET_NAME S3 bucket name to upload the model" - exit 1 -} - -# Parse command-line options -while getopts ":s:" opt; do - case $opt in - s) s3_bucket_name="$OPTARG" ;; - \?) echo "Invalid option: -$OPTARG" >&2; usage ;; - :) echo "Option -$OPTARG requires an argument." >&2; usage ;; - esac -done - -# Validate the hf_token and python3 interpreter exist -if [ -z "$s3_bucket_name" ] || ! command -v python3 &> /dev/null; then - usage -fi - -cd internlm/model -hf_names=("bartowski/internlm2-chat-7b-llama-exl2") -model_names=("internlm2-chat-7b") -commit_hashs=("54a594b0be43065e7b7674d0f236911cd7c465ab") -tensor_parallel_degree=(1) - -for index in "${!model_names[@]}"; do - hf_name="${hf_names[$index]}" - model_name="${model_names[$index]}" - commit_hash="${commit_hashs[$index]}" - tp="${tensor_parallel_degree[$index]}" - echo "model name $model_name" - echo "commit hash $commit_hash" - ./model.sh -h $hf_name -m $model_name -c $commit_hash -p $tp -s $s3_bucket_name -done - - -cd ../../sqlcoder/model -hf_names=("defog/sqlcoder-7b-2") -model_names=("sqlcoder-7b-2") -commit_hashs=("7e5b6f7981c0aa7d143f6bec6fa26625bdfcbe66") -tensor_parallel_degree=(1) - -for index in "${!model_names[@]}"; do - hf_name="${hf_names[$index]}" - model_name="${model_names[$index]}" - commit_hash="${commit_hashs[$index]}" - tp="${tensor_parallel_degree[$index]}" - echo "model name $model_name" - echo "commit hash $commit_hash" - ./model.sh -h $hf_name -m $model_name -c $commit_hash -p $tp -s $s3_bucket_name -done - -cd ../../embedding/model -hf_names=("BAAI/bge-m3") -model_names=("bge-m3") -commit_hashs=("3ab7155aa9b89ac532b2f2efcc3f136766b91025") -tensor_parallel_degree=(1) - -for index in "${!model_names[@]}"; do - hf_name="${hf_names[$index]}" - model_name="${model_names[$index]}" - commit_hash="${commit_hashs[$index]}" - tp="${tensor_parallel_degree[$index]}" - echo "model name $model_name" - echo "commit hash $commit_hash" - ./model.sh -h $hf_name -m $model_name -c $commit_hash -p $tp -s $s3_bucket_name -done \ No newline at end of file diff --git a/source/model/sqlcoder/code/model.py b/source/model/sqlcoder/code/model.py deleted file mode 100644 index a3f2342..0000000 --- a/source/model/sqlcoder/code/model.py +++ /dev/null @@ -1,137 +0,0 @@ -import logging -import math -import os -from threading import Thread - -import sqlparse -import torch -import transformers -from djl_python import Input, Output -from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline - -device = "cuda" - - -def load_model(properties): - model_location = properties["model_dir"] - if "model_id" in properties: - model_location = properties["model_id"] - logging.info(f"Loading model in {model_location}") - tokenizer = AutoTokenizer.from_pretrained(model_location, trust_remote_code=True) - model = AutoModelForCausalLM.from_pretrained( - model_location, - trust_remote_code=True, - load_in_8bit=True, - device_map="auto", - use_cache=True, - ) - return tokenizer, model - - -model = None -tokenizer = None - - -def generate_prompt(question): - prompt = """### Task -Generate a SQL query to answer [QUESTION]{question}[/QUESTION] - -### Instructions -- If you cannot answer the question with the available database schema, return 'I do not know' -- Remember that revenue is price multiplied by quantity -- Remember that cost is supply_price multiplied by quantity - -### Database Schema -This query will run on a database whose schema is represented in this string: -CREATE TABLE products ( - product_id INTEGER PRIMARY KEY, -- Unique ID for each product - name VARCHAR(50), -- Name of the product - price DECIMAL(10,2), -- Price of each unit of the product - quantity INTEGER -- Current quantity in stock -); - -CREATE TABLE customers ( - customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer - name VARCHAR(50), -- Name of the customer - address VARCHAR(100) -- Mailing address of the customer -); - -CREATE TABLE salespeople ( - salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson - name VARCHAR(50), -- Name of the salesperson - region VARCHAR(50) -- Geographic sales region -); - -CREATE TABLE sales ( - sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale - product_id INTEGER, -- ID of product sold - customer_id INTEGER, -- ID of customer who made purchase - salesperson_id INTEGER, -- ID of salesperson who made the sale - sale_date DATE, -- Date the sale occurred - quantity INTEGER -- Quantity of product sold -); - -CREATE TABLE product_suppliers ( - supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier - product_id INTEGER, -- Product ID supplied - supply_price DECIMAL(10,2) -- Unit price charged by supplier -); - --- sales.product_id can be joined with products.product_id --- sales.customer_id can be joined with customers.customer_id --- sales.salesperson_id can be joined with salespeople.salesperson_id --- product_suppliers.product_id can be joined with products.product_id - -### Answer -Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION] -[SQL] -""" - prompt = prompt.format(question=question) - return prompt - - -def stream_items(sql_query): - chunks = sql_query.split("\n") - for chunk in chunks: - stream_buffer = chunk + "\n" - logging.info(f"Stream buffer: {stream_buffer}") - yield stream_buffer - - -def handle(inputs: Input): - global tokenizer, model - if not model: - tokenizer, model = load_model(inputs.get_properties()) - - if inputs.is_empty(): - return None - data = inputs.get_as_json() - - prompt = data["prompt"] - stream = data.get("stream", False) - - # updated_prompt = generate_prompt(prompt) - updated_prompt = prompt - - inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda") - generated_ids = model.generate( - **inputs, - num_return_sequences=1, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.eos_token_id, - max_new_tokens=400, - do_sample=False, - num_beams=1, - ) - decoded_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - sql_query = sqlparse.format(decoded_outputs[0].split("[SQL]")[-1], reindent=True) - logging.info(f"SQL Query: {sql_query}") - - outputs = Output() - # split SQL query every into chunks containing 10 characters - if stream: - outputs.add_stream_content(stream_items(sql_query), output_formatter=None) - else: - outputs.add_as_json({"outputs": sql_query}) - - return outputs diff --git a/source/model/sqlcoder/code/requirements.txt b/source/model/sqlcoder/code/requirements.txt deleted file mode 100644 index f54d051..0000000 --- a/source/model/sqlcoder/code/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -accelerate==0.29.2 -sqlparse==0.5.0 \ No newline at end of file diff --git a/source/model/sqlcoder/code/serving.properties b/source/model/sqlcoder/code/serving.properties deleted file mode 100644 index a434cc0..0000000 --- a/source/model/sqlcoder/code/serving.properties +++ /dev/null @@ -1,5 +0,0 @@ -engine=Python -option.enable_streaming=true -# update according to your own path -# option.s3url = s3://<_S3ModelAssets>/<_AssetsStack._embeddingModelPrefix> -option.s3url = s3://llm-bot-models-256374081253-cn-north-1/sqlcoder-7b-2/ \ No newline at end of file diff --git a/source/model/sqlcoder/model/model.sh b/source/model/sqlcoder/model/model.sh deleted file mode 100755 index d8d196b..0000000 --- a/source/model/sqlcoder/model/model.sh +++ /dev/null @@ -1,81 +0,0 @@ -function usage { - echo "Make sure python3 installed properly. Usage: $0 -t TOKEN [-m MODEL_NAME] [-c COMMIT_HASH] [-s S3_BUCKET_NAME]" - echo " -t TOKEN Hugging Face token " - echo " -h Hugging Face Repo Name Hugging Face repo " - echo " -m MODEL_NAME Model name (default: csdc-atl/buffer-cross-001)" - echo " -c COMMIT_HASH Commit hash (default: 46d270928463db49b317e5ea469a8ac8152f4a13)" - echo " -p Tensor Parrallel degree Parameters in serving.properties " - echo " -s S3_BUCKET_NAME S3 bucket name to upload the model (default: llm-rag)" - exit 1 -} - -# Default values -model_name="" -commit_hash="" -s3_bucket_name="" # Default S3 bucket name - -# Parse command-line options -while getopts ":t:h:m:c:p:s:" opt; do - case $opt in - t) hf_token="$OPTARG" ;; - h) hf_name="$OPTARG" ;; - m) model_name="$OPTARG" ;; - c) commit_hash="$OPTARG" ;; - p) tensor_parallel_degree="$OPTARG" ;; - s) s3_bucket_name="$OPTARG" ;; - \?) echo "Invalid option: -$OPTARG" >&2; usage ;; - :) echo "Option -$OPTARG requires an argument." >&2; usage ;; - esac -done - - -# # Validate the hf_token and python3 interpreter exist -# if [ -z "$hf_token" ] || ! command -v python3 &> /dev/null; then -# usage -# fi - -# # Install necessary packages -pip install huggingface-hub -Uqq -pip install -U sagemaker - -# Define local model path -local_model_path="./${model_name}" - -# Uncomment the line below if you want to create a specific directory for the model -# mkdir -p $local_model_path - -# Download model snapshot in current folder without model prefix added -# python3 -c "from huggingface_hub import snapshot_download; from pathlib import Path; snapshot_download(repo_id='$model_name', revision='$commit_hash', cache_dir=Path('.'), token='$hf_token')" -python3 -c "from huggingface_hub import snapshot_download; from pathlib import Path; snapshot_download(repo_id='$hf_name', revision='$commit_hash', cache_dir='$local_model_path')" - -# Find model snapshot path with the first search result -model_snapshot_path=$(find $local_model_path -path '*/snapshots/*' -type d -print -quit) -echo "Model snapshot path: $model_snapshot_path" - -# s3:/// -aws s3 cp --recursive $model_snapshot_path s3://$s3_bucket_name/$model_name - -# Prepare model.py files according to model name -model_inference_file="./${model_name}_model.py" -cp $model_inference_file ../code/model.py - -# Modify the content of serving.properties and re-tar the model -cp serving.properties ../code/serving.properties -cd ../code -file_path="serving.properties" -os_type=$(uname -s) - -if [ "$os_type" == "Darwin" ]; then - sed -i "" "s|option.s3url = S3PATH|option.s3url = s3://$s3_bucket_name/$model_name/|g" $file_path - sed -i "" "s|option.tensor_parallel_degree=tpd|option.tensor_parallel_degree=$tensor_parallel_degree|g" $file_path -else - sed -i "s|option.s3url = S3PATH|option.s3url = s3://$s3_bucket_name/$model_name/|g" $file_path - sed -i "s|option.tensor_parallel_degree=tpd|option.tensor_parallel_degree=$tensor_parallel_degree|g" $file_path -fi - - -rm model.tar.gz -tar czvf model.tar.gz * - -code_path="${model_name}_deploy_code" -aws s3 cp model.tar.gz s3://$s3_bucket_name/$code_path/model.tar.gz diff --git a/source/model/sqlcoder/model/serving.properties b/source/model/sqlcoder/model/serving.properties deleted file mode 100644 index bb6b097..0000000 --- a/source/model/sqlcoder/model/serving.properties +++ /dev/null @@ -1,5 +0,0 @@ -engine=Python -option.enable_streaming=true -# update according to your own path -# option.s3url = s3://<_S3ModelAssets>/<_AssetsStack._embeddingModelPrefix> -option.s3url = S3PATH \ No newline at end of file diff --git a/source/model/sqlcoder/model/sqlcoder-7b-2_model.py b/source/model/sqlcoder/model/sqlcoder-7b-2_model.py deleted file mode 100644 index a3f2342..0000000 --- a/source/model/sqlcoder/model/sqlcoder-7b-2_model.py +++ /dev/null @@ -1,137 +0,0 @@ -import logging -import math -import os -from threading import Thread - -import sqlparse -import torch -import transformers -from djl_python import Input, Output -from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline - -device = "cuda" - - -def load_model(properties): - model_location = properties["model_dir"] - if "model_id" in properties: - model_location = properties["model_id"] - logging.info(f"Loading model in {model_location}") - tokenizer = AutoTokenizer.from_pretrained(model_location, trust_remote_code=True) - model = AutoModelForCausalLM.from_pretrained( - model_location, - trust_remote_code=True, - load_in_8bit=True, - device_map="auto", - use_cache=True, - ) - return tokenizer, model - - -model = None -tokenizer = None - - -def generate_prompt(question): - prompt = """### Task -Generate a SQL query to answer [QUESTION]{question}[/QUESTION] - -### Instructions -- If you cannot answer the question with the available database schema, return 'I do not know' -- Remember that revenue is price multiplied by quantity -- Remember that cost is supply_price multiplied by quantity - -### Database Schema -This query will run on a database whose schema is represented in this string: -CREATE TABLE products ( - product_id INTEGER PRIMARY KEY, -- Unique ID for each product - name VARCHAR(50), -- Name of the product - price DECIMAL(10,2), -- Price of each unit of the product - quantity INTEGER -- Current quantity in stock -); - -CREATE TABLE customers ( - customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer - name VARCHAR(50), -- Name of the customer - address VARCHAR(100) -- Mailing address of the customer -); - -CREATE TABLE salespeople ( - salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson - name VARCHAR(50), -- Name of the salesperson - region VARCHAR(50) -- Geographic sales region -); - -CREATE TABLE sales ( - sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale - product_id INTEGER, -- ID of product sold - customer_id INTEGER, -- ID of customer who made purchase - salesperson_id INTEGER, -- ID of salesperson who made the sale - sale_date DATE, -- Date the sale occurred - quantity INTEGER -- Quantity of product sold -); - -CREATE TABLE product_suppliers ( - supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier - product_id INTEGER, -- Product ID supplied - supply_price DECIMAL(10,2) -- Unit price charged by supplier -); - --- sales.product_id can be joined with products.product_id --- sales.customer_id can be joined with customers.customer_id --- sales.salesperson_id can be joined with salespeople.salesperson_id --- product_suppliers.product_id can be joined with products.product_id - -### Answer -Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION] -[SQL] -""" - prompt = prompt.format(question=question) - return prompt - - -def stream_items(sql_query): - chunks = sql_query.split("\n") - for chunk in chunks: - stream_buffer = chunk + "\n" - logging.info(f"Stream buffer: {stream_buffer}") - yield stream_buffer - - -def handle(inputs: Input): - global tokenizer, model - if not model: - tokenizer, model = load_model(inputs.get_properties()) - - if inputs.is_empty(): - return None - data = inputs.get_as_json() - - prompt = data["prompt"] - stream = data.get("stream", False) - - # updated_prompt = generate_prompt(prompt) - updated_prompt = prompt - - inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda") - generated_ids = model.generate( - **inputs, - num_return_sequences=1, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.eos_token_id, - max_new_tokens=400, - do_sample=False, - num_beams=1, - ) - decoded_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - sql_query = sqlparse.format(decoded_outputs[0].split("[SQL]")[-1], reindent=True) - logging.info(f"SQL Query: {sql_query}") - - outputs = Output() - # split SQL query every into chunks containing 10 characters - if stream: - outputs.add_stream_content(stream_items(sql_query), output_formatter=None) - else: - outputs.add_as_json({"outputs": sql_query}) - - return outputs From 808ea9661041e3252b35d1a0ce3d2a563dc76907 Mon Sep 17 00:00:00 2001 From: wubinbin Date: Wed, 7 Aug 2024 16:29:42 +0800 Subject: [PATCH 120/130] feat: switch session --- report-front-end/src/common/api/API.ts | 83 +------------------ report-front-end/src/common/api/WebSocket.ts | 51 ++++++++---- .../src/common/constant/global.tsx | 3 - .../chatbot-panel/chat-input-panel.tsx | 18 ++-- .../components/chatbot-panel/chat-message.tsx | 7 ++ .../src/components/chatbot-panel/chat.tsx | 40 ++++----- .../chatbot-panel/custom-questions.tsx | 8 +- .../chatbot-panel/suggested-questions.tsx | 7 +- .../src/components/session-panel/session.tsx | 14 ++-- .../src/components/session-panel/sessions.tsx | 18 ++-- .../src/components/session-panel/style.scss | 2 +- .../src/components/side-navigation/index.tsx | 8 +- .../src/pages/chatbot-page/playground.tsx | 8 +- 13 files changed, 112 insertions(+), 155 deletions(-) diff --git a/report-front-end/src/common/api/API.ts b/report-front-end/src/common/api/API.ts index e3279f8..6a93e32 100644 --- a/report-front-end/src/common/api/API.ts +++ b/report-front-end/src/common/api/API.ts @@ -1,11 +1,5 @@ -import { - ChatBotHistoryItem, - ChatBotMessageType, - FeedBackItem, - SessionItem, -} from "../../components/chatbot-panel/types"; -import { Dispatch, SetStateAction } from "react"; -import { BACKEND_URL, DEFAULT_QUERY_CONFIG } from "../constant/constants"; +import { FeedBackItem, SessionItem } from "../../components/chatbot-panel/types"; +import { BACKEND_URL } from "../constant/constants"; import { alertMsg } from "../helpers/tools"; export async function getSelectData() { @@ -30,79 +24,6 @@ export async function getSelectData() { } } -export async function query(props: { - query: string; - setLoading: Dispatch>; - configuration: any; - setMessageHistory: Dispatch>; -}) { - props.setMessageHistory((history: ChatBotHistoryItem[]) => { - return [...history, { - type: ChatBotMessageType.Human, - content: props.query - }]; - }); - props.setLoading(true); - try { - const param = { - query: props.query, - bedrock_model_id: props.configuration.selectedLLM || DEFAULT_QUERY_CONFIG.selectedLLM, - use_rag_flag: true, - visualize_results_flag: true, - intent_ner_recognition_flag: props.configuration.intentChecked, - agent_cot_flag: props.configuration.complexChecked, - profile_name: props.configuration.selectedDataPro || DEFAULT_QUERY_CONFIG.selectedDataPro, - explain_gen_process_flag: true, - gen_suggested_question_flag: props.configuration.modelSuggestChecked, - answer_with_insights: props.configuration.answerInsightChecked || DEFAULT_QUERY_CONFIG.answerInsightChecked, - top_k: props.configuration.topK, - top_p: props.configuration.topP, - max_tokens: props.configuration.maxLength, - temperature: props.configuration.temperature - }; - const url = `${BACKEND_URL}qa/ask`; - const response = await fetch(url, { - headers: { - "Content-Type": "application/json" - }, - method: "POST", - body: JSON.stringify(param) - } - ); - if (!response.ok) { - console.error('Query error, ', response); - return; - } - const result = await response.json(); - console.log("response: ", result); - props.setLoading(false); - props.setMessageHistory((history: ChatBotHistoryItem[]) => { - return [...history, { - type: ChatBotMessageType.AI, - content: result - }]; - }); - } catch (err) { - props.setLoading(false); - const result = { - query: props.query, - query_intent: "Error", - knowledge_search_result: {}, - sql_search_result: [], - agent_search_result: {}, - suggested_question: [] - }; - props.setLoading(false); - props.setMessageHistory((history: any) => { - return [...history, { - type: ChatBotMessageType.AI, - content: result - }]; - }); - console.error('Query error, ', err); - } -} - export async function addUserFeedback(feedbackData: FeedBackItem) { // call api try { diff --git a/report-front-end/src/common/api/WebSocket.ts b/report-front-end/src/common/api/WebSocket.ts index 722577f..842e790 100644 --- a/report-front-end/src/common/api/WebSocket.ts +++ b/report-front-end/src/common/api/WebSocket.ts @@ -2,12 +2,12 @@ import useWebSocket from "react-use-websocket"; import { DEFAULT_QUERY_CONFIG } from "../constant/constants"; import { SendJsonMessage } from "react-use-websocket/src/lib/types"; import { Dispatch, SetStateAction } from "react"; -import { ChatBotHistoryItem, ChatBotMessageItem, ChatBotMessageType } from "../../components/chatbot-panel/types"; -import { Global } from "../constant/global"; +import { ChatBotMessageItem, ChatBotMessageType } from "../../components/chatbot-panel/types"; +import { Session } from "../../components/session-panel/types"; export function createWssClient( setStatusMessage: Dispatch>, - setMessageHistory: Dispatch> + setSessions: Dispatch> ) { const socketUrl = process.env.VITE_WEBSOCKET_URL as string; const {sendJsonMessage} @@ -29,12 +29,22 @@ export function createWssClient( [...historyMessage, messageJson]); } else { setStatusMessage([]); - setMessageHistory((history: ChatBotHistoryItem[]) => { - return [...history, { - type: ChatBotMessageType.AI, - content: messageJson.content - }]; + setSessions((prevState) => { + return prevState.map((session) => { + if (messageJson.session_id !== session.session_id) { + return session; + } else { + return { + session_id: messageJson.session_id, + messages: [...session.messages, { + type: ChatBotMessageType.AI, + content: messageJson.content + }] + } + } + }) }); + } }; @@ -45,15 +55,24 @@ export function queryWithWS(props: { query: string; configuration: any; sendMessage: SendJsonMessage; - setMessageHistory: Dispatch>; + setSessions: Dispatch>; userId: string; + sessionId: string; }) { - - props.setMessageHistory((history: ChatBotHistoryItem[]) => { - return [...history, { - type: ChatBotMessageType.Human, - content: props.query - }]; + props.setSessions((prevState) => { + return prevState.map((session) => { + if (props.sessionId !== session.session_id) { + return session; + } else { + return { + session_id: session.session_id, + messages: [...session.messages, { + type: ChatBotMessageType.Human, + content: props.query + }] + } + } + }) }); const param = { query: props.query, @@ -71,7 +90,7 @@ export function queryWithWS(props: { max_tokens: props.configuration.maxLength, temperature: props.configuration.temperature, context_window: props.configuration.contextWindow, - session_id: Global.sessionId, + session_id: props.sessionId, user_id: props.userId }; console.log("Send WebSocketMessage: ", param); diff --git a/report-front-end/src/common/constant/global.tsx b/report-front-end/src/common/constant/global.tsx index 5dfc74a..0d39e21 100644 --- a/report-front-end/src/common/constant/global.tsx +++ b/report-front-end/src/common/constant/global.tsx @@ -1,7 +1,4 @@ -import { v4 as uuid } from 'uuid'; - export class Global { - public static sessionId = uuid(); } diff --git a/report-front-end/src/components/chatbot-panel/chat-input-panel.tsx b/report-front-end/src/components/chatbot-panel/chat-input-panel.tsx index 542855a..95b262b 100644 --- a/report-front-end/src/components/chatbot-panel/chat-input-panel.tsx +++ b/report-front-end/src/components/chatbot-panel/chat-input-panel.tsx @@ -18,15 +18,18 @@ import { } from "./types"; import styles from "./chat.module.scss"; import CustomQuestions from "./custom-questions"; +import { Session } from "../session-panel/types"; export interface ChatInputPanelProps { setToolsHide: Dispatch>; setLoading: Dispatch>; messageHistory: ChatBotHistoryItem[]; setMessageHistory: Dispatch>; + setSessions: Dispatch>; setStatusMessage: Dispatch>; sendMessage: SendJsonMessage; toolsHide: boolean; + currSessionId: string; } export abstract class ChatScrollState { @@ -43,22 +46,15 @@ export default function ChatInputPanel(props: ChatInputPanelProps) { const handleSendMessage = () => { setTextValue({ value: "" }); - // Call Fast API - /* query({ - query: state.value, - setLoading: props.setLoading, - configuration: userState.queryConfig, - setMessageHistory: props.setMessageHistory, - }).then();*/ - if (state.value !== "") { // Call WebSocket API queryWithWS({ query: state.value, configuration: userState.queryConfig, sendMessage: props.sendMessage, - setMessageHistory: props.setMessageHistory, - userId: userState.userInfo.userId + setSessions: props.setSessions, + userId: userState.userInfo.userId, + sessionId: props.currSessionId }); } }; @@ -118,7 +114,9 @@ export default function ChatInputPanel(props: ChatInputPanelProps) { setTextValue={setTextValue} setLoading={props.setLoading} setMessageHistory={props.setMessageHistory} + setSessions={props.setSessions} sendMessage={props.sendMessage} + sessionId={props.currSessionId} />
{/* diff --git a/report-front-end/src/components/chatbot-panel/chat-message.tsx b/report-front-end/src/components/chatbot-panel/chat-message.tsx index 5ab15f1..1f86812 100644 --- a/report-front-end/src/components/chatbot-panel/chat-message.tsx +++ b/report-front-end/src/components/chatbot-panel/chat-message.tsx @@ -35,6 +35,7 @@ import { FeedBackType, SQLSearchResult, } from "./types"; +import { Session } from "../session-panel/types"; export interface ChartTypeProps { data_show_type: string; @@ -463,7 +464,9 @@ function AIChatMessage(props: ChatMessageProps) { questions={content.suggested_question} setLoading={props.setLoading} setMessageHistory={props.setMessageHistory} + setSessions={props.setSessions} sendMessage={props.sendMessage} + sessionId={props.sessionId} /> ) : null} @@ -476,7 +479,9 @@ export interface ChatMessageProps { message: ChatBotHistoryItem; setLoading: Dispatch>; setMessageHistory: Dispatch>; + setSessions: Dispatch>; sendMessage: SendJsonMessage; + sessionId: string; } export default function ChatMessage(props: ChatMessageProps) { @@ -492,7 +497,9 @@ export default function ChatMessage(props: ChatMessageProps) { message={props.message} setLoading={props.setLoading} setMessageHistory={props.setMessageHistory} + setSessions={props.setSessions} sendMessage={props.sendMessage} + sessionId={props.sessionId} /> )} diff --git a/report-front-end/src/components/chatbot-panel/chat.tsx b/report-front-end/src/components/chatbot-panel/chat.tsx index b441a69..5fb0806 100644 --- a/report-front-end/src/components/chatbot-panel/chat.tsx +++ b/report-front-end/src/components/chatbot-panel/chat.tsx @@ -15,7 +15,7 @@ export default function Chat(props: { toolsHide: boolean; sessions: Session[]; setSessions: Dispatch>; - currentSession: number; + currentSessionId: string; }) { const [messageHistory, setMessageHistory] = useState( [], @@ -23,7 +23,7 @@ export default function Chat(props: { const [statusMessage, setStatusMessage] = useState([]); const [loading, setLoading] = useState(false); - const sendJsonMessage = createWssClient(setStatusMessage, setMessageHistory); + const sendJsonMessage = createWssClient(setStatusMessage, props.setSessions); const dispatch = useDispatch(); const userState = useSelector((state) => state) as UserState; @@ -54,24 +54,21 @@ export default function Chat(props: { }, [userState.queryConfig]); useEffect(() => { - // console.log("current session index: ", props.currentSession); - setMessageHistory(props.sessions[props.currentSession].messages); - }, [props.currentSession]); + props.sessions.forEach((session) => { + if (session.session_id === props.currentSessionId) { + setMessageHistory(session.messages); + } + }); + }, [props.currentSessionId]); + // update history message useEffect(() => { - props.setSessions((prevState) => { - return prevState.map((session: Session, idx: number) => { - if (idx === props.currentSession) { - return { - session_id: session.session_id, - messages: messageHistory - }; - } else { - return session; - } - }); + props.sessions.forEach((session) => { + if (session.session_id === props.currentSessionId) { + setMessageHistory(session.messages); + } }); - }, [messageHistory]); + }, [props.sessions]); return (
@@ -86,15 +83,18 @@ export default function Chat(props: { setMessageHistory={( history: SetStateAction, ) => setMessageHistory(history)} + setSessions={props.setSessions} sendMessage={sendJsonMessage} + sessionId={props.currentSessionId} />
); })} - {statusMessage.length === 0 ? null : ( + {statusMessage.filter((status) => status.session_id === props.currentSessionId).length === 0 ? null : (
- {statusMessage.map((message, idx) => { + {statusMessage.filter((status) => status.session_id === props.currentSessionId) + .map((message, idx) => { const displayMessage = idx % 2 === 1 ? true @@ -138,10 +138,12 @@ export default function Chat(props: { setMessageHistory={(history: SetStateAction) => setMessageHistory(history) } + setSessions={props.setSessions} setStatusMessage={(message: SetStateAction) => setStatusMessage(message) } sendMessage={sendJsonMessage} + currSessionId={props.currentSessionId} />
diff --git a/report-front-end/src/components/chatbot-panel/custom-questions.tsx b/report-front-end/src/components/chatbot-panel/custom-questions.tsx index 53f37f1..259f067 100644 --- a/report-front-end/src/components/chatbot-panel/custom-questions.tsx +++ b/report-front-end/src/components/chatbot-panel/custom-questions.tsx @@ -8,12 +8,15 @@ import styles from "./chat.module.scss"; import { queryWithWS } from "../../common/api/WebSocket"; import { SendJsonMessage } from "react-use-websocket/src/lib/types"; import { UserState } from "../../common/helpers/types"; +import { Session } from "../session-panel/types"; export interface RecommendQuestionsProps { setTextValue: Dispatch>; setLoading: Dispatch>; setMessageHistory: Dispatch>; + setSessions: Dispatch>; sendMessage: SendJsonMessage; + sessionId: string; } export default function CustomQuestions(props: RecommendQuestionsProps) { @@ -62,8 +65,9 @@ export default function CustomQuestions(props: RecommendQuestionsProps) { query: question, configuration: userState.queryConfig, sendMessage: props.sendMessage, - setMessageHistory: props.setMessageHistory, - userId: userState.userInfo.userId + setSessions: props.setSessions, + userId: userState.userInfo.userId, + sessionId: props.sessionId }); }; diff --git a/report-front-end/src/components/chatbot-panel/suggested-questions.tsx b/report-front-end/src/components/chatbot-panel/suggested-questions.tsx index da6edee..e350cc8 100644 --- a/report-front-end/src/components/chatbot-panel/suggested-questions.tsx +++ b/report-front-end/src/components/chatbot-panel/suggested-questions.tsx @@ -6,12 +6,15 @@ import { useSelector } from "react-redux"; import { queryWithWS } from "../../common/api/WebSocket"; import { SendJsonMessage } from "react-use-websocket/src/lib/types"; import { UserState } from "../../common/helpers/types"; +import { Session } from "../session-panel/types"; export interface SuggestedQuestionsProps { questions: string[]; setLoading: Dispatch>; setMessageHistory: Dispatch>; + setSessions: Dispatch>; sendMessage: SendJsonMessage; + sessionId: string; } export default function SuggestedQuestions(props: SuggestedQuestionsProps) { @@ -32,7 +35,9 @@ export default function SuggestedQuestions(props: SuggestedQuestionsProps) { configuration: userState.queryConfig, sendMessage: props.sendMessage, setMessageHistory: props.setMessageHistory, - userId: userState.userInfo.userId + setSessions: props.setSessions, + userId: userState.userInfo.userId, + sessionId: props.sessionId }); }; diff --git a/report-front-end/src/components/session-panel/session.tsx b/report-front-end/src/components/session-panel/session.tsx index 847c2c0..64fb12e 100644 --- a/report-front-end/src/components/session-panel/session.tsx +++ b/report-front-end/src/components/session-panel/session.tsx @@ -1,28 +1,30 @@ import { Button } from "@cloudscape-design/components"; import "./style.scss"; import { Dispatch, SetStateAction } from "react"; +import { Session } from "./types"; export const SessionPanel = (props: { - session: any, + session: Session, index: number, - currSession: number, - setCurrSession: Dispatch>, + currSessionId: string, + setCurrSessionId: Dispatch>, setSessions: Dispatch>, }) => { const onClick = () => { - props.setCurrSession(props.index); + console.log("onClick, sessionId: ", props.session); + props.setCurrSessionId(props.session.session_id); }; return (
); diff --git a/report-front-end/src/components/session-panel/sessions.tsx b/report-front-end/src/components/session-panel/sessions.tsx index 22d36cd..e12da21 100644 --- a/report-front-end/src/components/session-panel/sessions.tsx +++ b/report-front-end/src/components/session-panel/sessions.tsx @@ -12,8 +12,8 @@ export const Sessions = ( props: { sessions: Session[]; setSessions: Dispatch>; - currentSession: number; - setCurrentSession: Dispatch>; + currentSessionId: string; + setCurrentSessionId: Dispatch>; }, ) => { @@ -27,22 +27,24 @@ export const Sessions = ( getSessions(sessionItem).then( response => { console.log("sessions: ", response); + const sessionId = uuid(); props.setSessions([ { - session_id: uuid(), + session_id: sessionId, messages: [], }, ...(response)]); - props.setCurrentSession(0); + props.setCurrentSessionId(sessionId); }); }, [userInfo.queryConfig.selectedDataPro]); const addNewSession = () => { + const sessionId = uuid(); props.setSessions([ { - session_id: uuid(), + session_id: sessionId, messages: [], }, ...props.sessions]); - props.setCurrentSession(0); + props.setCurrentSessionId(sessionId); }; return ( @@ -60,8 +62,8 @@ export const Sessions = ( ))} diff --git a/report-front-end/src/components/session-panel/style.scss b/report-front-end/src/components/session-panel/style.scss index 8c9c0cf..64e8ebd 100644 --- a/report-front-end/src/components/session-panel/style.scss +++ b/report-front-end/src/components/session-panel/style.scss @@ -18,7 +18,7 @@ font-weight: 500 !important; padding: 12px 12px !important; color: black !important; - width: 250px; + width: 100%; height: 100%; white-space: nowrap !important; text-overflow: ellipsis !important; diff --git a/report-front-end/src/components/side-navigation/index.tsx b/report-front-end/src/components/side-navigation/index.tsx index 2b9415d..cf16182 100644 --- a/report-front-end/src/components/side-navigation/index.tsx +++ b/report-front-end/src/components/side-navigation/index.tsx @@ -7,8 +7,8 @@ export default function NavigationPanel( props: { sessions: Session[]; setSessions: Dispatch>; - currentSession: number; - setCurrentSession: Dispatch>; + currentSessionId: string; + setCurrentSessionId: Dispatch>; }) { return ( @@ -25,8 +25,8 @@ export default function NavigationPanel( ); diff --git a/report-front-end/src/pages/chatbot-page/playground.tsx b/report-front-end/src/pages/chatbot-page/playground.tsx index 9938dff..af1d6c1 100644 --- a/report-front-end/src/pages/chatbot-page/playground.tsx +++ b/report-front-end/src/pages/chatbot-page/playground.tsx @@ -13,7 +13,7 @@ export default function Playground() { session_id: uuid(), messages: [], }]); - const [currentSession, setCurrentSession] = useState(0); + const [currentSessionId, setCurrentSessionId] = useState(sessions[0].session_id); return ( } content={ @@ -32,7 +32,7 @@ export default function Playground() { setToolsHide={setToolsHide} sessions={sessions} setSessions={setSessions} - currentSession={currentSession} + currentSessionId={currentSessionId} /> } toolsHide={toolsHide} From ff9ef4cbf8019d6e5e1d6bdf4b566faae99a67f8 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Tue, 13 Aug 2024 09:07:11 +0800 Subject: [PATCH 121/130] add convert_timestamps_to_str --- application/utils/llm.py | 2 ++ application/utils/tool.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/application/utils/llm.py b/application/utils/llm.py index 9eb542c..23533eb 100644 --- a/application/utils/llm.py +++ b/application/utils/llm.py @@ -15,6 +15,7 @@ from utils.env_var import bedrock_ak_sk_info, BEDROCK_REGION, BEDROCK_EMBEDDING_MODEL, SAGEMAKER_EMBEDDING_REGION, \ SAGEMAKER_SQL_REGION, SAGEMAKER_ENDPOINT_EMBEDDING, SAGEMAKER_ENDPOINT_SQL +from utils.tool import convert_timestamps_to_str logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -520,6 +521,7 @@ def data_visualization(model_id, search_box, search_data, prompt_map): columns = list(search_data.columns) data_list = search_data.values.tolist() all_columns_data = [columns] + data_list + all_columns_data = convert_timestamps_to_str(all_columns_data) try: if len(all_columns_data) < 1: return "table", all_columns_data, "-1", [] diff --git a/application/utils/tool.py b/application/utils/tool.py index 03a17ac..949dfc8 100644 --- a/application/utils/tool.py +++ b/application/utils/tool.py @@ -5,6 +5,8 @@ from datetime import datetime from multiprocessing import Manager +import pandas as pd + from api.schemas import Message logger = logging.getLogger(__name__) @@ -54,6 +56,24 @@ def change_class_to_str(result): return "" +def convert_timestamps_to_str(data): + # Convert all Timestamp objects in the data to strings + try: + converted_data = [] + for row in data: + new_row = [] + for item in row: + if isinstance(item, pd.Timestamp): + # Convert Timestamp to string + new_row.append(item.strftime('%Y-%m-%d %H:%M:%S')) + else: + new_row.append(item) + converted_data.append(new_row) + return converted_data + except Exception as e: + logger.error(f"Error in converting timestamps to strings: {e}") + return data + def get_window_history(user_query_history): try: history_list = [] From 92b80bb82ca97ca84365679e39fb35a49a678bd1 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Tue, 13 Aug 2024 10:00:21 +0800 Subject: [PATCH 122/130] add sqlalchemy-redshift --- application/nlq/data_access/database.py | 4 ++-- application/requirements-api.txt | 3 ++- application/requirements.txt | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/application/nlq/data_access/database.py b/application/nlq/data_access/database.py index b3fd616..039add2 100644 --- a/application/nlq/data_access/database.py +++ b/application/nlq/data_access/database.py @@ -12,7 +12,7 @@ class RelationDatabase(): db_mapping = { 'mysql': 'mysql+pymysql', 'postgresql': 'postgresql+psycopg2', - 'redshift': 'postgresql+psycopg2', + 'redshift': 'redshift+psycopg2', 'starrocks': 'starrocks', 'clickhouse': 'clickhouse', # Add more mappings here for other databases @@ -75,7 +75,7 @@ def get_metadata_by_connection(cls, connection, schemas): metadata = db.MetaData() for s in schemas: metadata.reflect(bind=engine, schema=s) - metadata.reflect(bind=engine) + # metadata.reflect(bind=engine) return metadata @classmethod diff --git a/application/requirements-api.txt b/application/requirements-api.txt index 065f388..1d553d0 100644 --- a/application/requirements-api.txt +++ b/application/requirements-api.txt @@ -17,4 +17,5 @@ pandas==2.0.3 openpyxl starrocks==1.0.6 clickhouse-sqlalchemy==0.2.6 -sagemaker \ No newline at end of file +sagemaker +sqlalchemy-redshift~=0.8.14 \ No newline at end of file diff --git a/application/requirements.txt b/application/requirements.txt index 138eea9..3a8ccfc 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -17,4 +17,5 @@ openpyxl starrocks==1.0.6 clickhouse-sqlalchemy==0.2.6 sagemaker -fastapi~=0.110.1 \ No newline at end of file +fastapi~=0.110.1 +sqlalchemy-redshift~=0.8.14 \ No newline at end of file From b2b82770849460217c48a5039b1b99f911e37f8e Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Tue, 13 Aug 2024 10:04:44 +0800 Subject: [PATCH 123/130] fix Cognito --- report-front-end/.env | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/report-front-end/.env b/report-front-end/.env index 05d174d..4eff791 100644 --- a/report-front-end/.env +++ b/report-front-end/.env @@ -9,7 +9,7 @@ VITE_RIGHT_LOGO= # Login configuration, e.g. Cognito | None -VITE_LOGIN_TYPE=PLACEHOLDER_VITE_LOGIN_TYPE +VITE_LOGIN_TYPE=Cognito # KEEP the placeholder values if using CDK to deploy the backend! From fe8c7a9f346a9470a9517c22f9aa8bb3d6586c35 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Tue, 13 Aug 2024 10:37:18 +0800 Subject: [PATCH 124/130] change prompt --- application/utils/prompts/generate_prompt.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/application/utils/prompts/generate_prompt.py b/application/utils/prompts/generate_prompt.py index 58c4700..7d872b8 100644 --- a/application/utils/prompts/generate_prompt.py +++ b/application/utils/prompts/generate_prompt.py @@ -1633,7 +1633,7 @@ {{ "show_type" : "pie", - "format_data" : [['gender', 'num_users'], ['F', 1906], ['M', 1788]] + "format_data" : [["gender", "num_users"], ["F", 1906], ["M", 1788]] }} ``` @@ -1670,7 +1670,7 @@ {{ "show_type" : "pie", - "format_data" : [['gender', 'num_users'], ['F', 1906], ['M', 1788]] + "format_data" : [["gender", "num_users"], ["F", 1906], ["M", 1788]] }} ``` @@ -1707,7 +1707,7 @@ {{ "show_type" : "pie", - "format_data" : [['gender', 'num_users'], ['F', 1906], ['M', 1788]] + "format_data" : [["gender", "num_users"], ["F", 1906], ["M", 1788]] }} ``` @@ -1744,7 +1744,7 @@ {{ "show_type" : "pie", - "format_data" : [['gender', 'num_users'], ['F', 1906], ['M', 1788]] + "format_data" : [["gender", "num_users"], ["F", 1906], ["M", 1788]] }} ``` @@ -1781,7 +1781,7 @@ {{ "show_type" : "pie", - "format_data" : [['gender', 'num_users'], ['F', 1906], ['M', 1788]] + "format_data" : [["gender", "num_users"], ["F", 1906], ["M", 1788]] }} ``` From e46cc2e53d3a818dcdd4c923cc1ce5faa0d0777c Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Tue, 13 Aug 2024 10:46:47 +0800 Subject: [PATCH 125/130] change VITE_LOGIN_TYPE --- report-front-end/.env | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/report-front-end/.env b/report-front-end/.env index 4eff791..05d174d 100644 --- a/report-front-end/.env +++ b/report-front-end/.env @@ -9,7 +9,7 @@ VITE_RIGHT_LOGO= # Login configuration, e.g. Cognito | None -VITE_LOGIN_TYPE=Cognito +VITE_LOGIN_TYPE=PLACEHOLDER_VITE_LOGIN_TYPE # KEEP the placeholder values if using CDK to deploy the backend! From d55dc09ad82ec2dcc5754bf6dd562d40428e5e50 Mon Sep 17 00:00:00 2001 From: wubinbin Date: Tue, 13 Aug 2024 11:57:56 +0800 Subject: [PATCH 126/130] feat: compute history length based on session id --- report-front-end/src/components/chatbot-panel/chat.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/report-front-end/src/components/chatbot-panel/chat.tsx b/report-front-end/src/components/chatbot-panel/chat.tsx index 5fb0806..3e4a635 100644 --- a/report-front-end/src/components/chatbot-panel/chat.tsx +++ b/report-front-end/src/components/chatbot-panel/chat.tsx @@ -126,7 +126,7 @@ export default function Chat(props: {
{messageHistory.length === 0 && - statusMessage.length === 0 && + statusMessage.filter((status) => status.session_id === props.currentSessionId).length === 0 && !loading &&
{"GenBI Chatbot"}
}
From fbe92d7e236914ab1ef6d9fa1bd7696816eb2f59 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Tue, 13 Aug 2024 16:23:18 +0800 Subject: [PATCH 127/130] add model id check --- application/pages/mainpage.py | 3 ++ application/utils/prompts/check_prompt.py | 52 +++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/application/pages/mainpage.py b/application/pages/mainpage.py index fa5179f..0810b86 100644 --- a/application/pages/mainpage.py +++ b/application/pages/mainpage.py @@ -3,6 +3,7 @@ from nlq.business.vector_store import VectorStore from utils.navigation import make_sidebar from utils.opensearch import opensearch_index_init +from utils.prompts.check_prompt import check_model_id_prompt st.set_page_config( page_title="Generative BI", @@ -34,3 +35,5 @@ entity = "Month on month ratio" comment = "The month on month growth rate refers to the growth rate compared to the previous period, and the calculation formula is: month on month growth rate=(current period number - previous period number)/previous period number x 100%" VectorStore.add_entity_sample(current_profile, entity, comment) + +check_model_id_prompt() \ No newline at end of file diff --git a/application/utils/prompts/check_prompt.py b/application/utils/prompts/check_prompt.py index a6da356..fb82f48 100644 --- a/application/utils/prompts/check_prompt.py +++ b/application/utils/prompts/check_prompt.py @@ -1,5 +1,8 @@ import logging +from nlq.business.profile import ProfileManagement +from utils.prompts.generate_prompt import support_model_ids_map, prompt_map_dict + logger = logging.getLogger(__name__) required_syntax_map = { @@ -257,6 +260,14 @@ def check_prompt_syntax(system_prompt, user_prompt, prompt_type, model_id): def find_missing_prompt_syntax(system_prompt, user_prompt, prompt_type, model_id): + """ + find missing prompt syntax + :param system_prompt: + :param user_prompt: + :param prompt_type: + :param model_id: + :return: + """ system_prompt_required_syntax = required_syntax_map.get(prompt_type, {}).get('system_prompt', {}).get(model_id) user_prompt_required_syntax = required_syntax_map.get(prompt_type, {}).get('user_prompt', {}).get(model_id) @@ -272,3 +283,44 @@ def find_missing_prompt_syntax(system_prompt, user_prompt, prompt_type, model_id missing_user_prompt_syntax.append(f'{{{user_syntax}}}') return missing_system_prompt_syntax, missing_user_prompt_syntax + + +def check_model_id_prompt(): + """ + check model id in prompt in dynamoDB + :return: + """ + try: + model_ids = [] + for key, value in support_model_ids_map.items(): + model_ids.append(value) + all_profiles = ProfileManagement.get_all_profiles_with_info() + for profile_name, profile_value_dict in all_profiles.items(): + prompt_map = profile_value_dict.get('prompt_map') + prompt_map_flag = False + for prompt_type in prompt_map_dict: + if prompt_type not in prompt_map: + prompt_map[prompt_type] = prompt_map_dict[prompt_type] + prompt_map_flag = True + + for prompt_type in prompt_map_dict: + origin_system_prompt = prompt_map_dict[prompt_type].get('system_prompt') + origin_user_prompt = prompt_map_dict[prompt_type].get('user_prompt') + + db_system_prompt = prompt_map[prompt_type].get('system_prompt', {}) + db_user_prompt = prompt_map[prompt_type].get('user_prompt', {}) + + for model_id in model_ids: + if model_id not in db_system_prompt: + prompt_map[prompt_type]['system_prompt'][model_id] = origin_system_prompt[model_id] + prompt_map_flag = True + logger.warning(f"Model ID {model_id} is missing in system prompt of {prompt_type}") + if model_id not in db_user_prompt: + prompt_map[prompt_type]['user_prompt'][model_id] = origin_user_prompt[model_id] + prompt_map_flag = True + logger.warning(f"Model ID {model_id} is missing in user prompt of {prompt_type}") + + if prompt_map_flag: + ProfileManagement.update_table_prompt_map(profile_name, prompt_map) + except Exception as e: + logger.error("check prompt is error %s", e) From 9daf1cb1eca8c9abcc10a1927d6f1e42828fb6d9 Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Mon, 19 Aug 2024 15:08:09 +0800 Subject: [PATCH 128/130] change do_visualize_results --- ...0\237\214\215_Generative_BI_Playground.py" | 79 +++++++++++-------- 1 file changed, 46 insertions(+), 33 deletions(-) diff --git "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" index 1848a8b..6e0f6e0 100644 --- "a/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" +++ "b/application/pages/1_\360\237\214\215_Generative_BI_Playground.py" @@ -1,11 +1,9 @@ import json -import os import streamlit as st import pandas as pd import plotly.express as px from dotenv import load_dotenv import logging -import random from api.service import user_feedback_downvote from nlq.business.connection import ConnectionManagement @@ -65,6 +63,8 @@ def downvote_clicked(question, comment): def clean_st_history(selected_profile): st.session_state.messages[selected_profile] = [] st.session_state.query_rewrite_history[selected_profile] = [] + st.session_state.current_sql_result[selected_profile] = None + def get_user_history(selected_profile: str): @@ -80,41 +80,50 @@ def get_user_history(selected_profile: str): history_query.append(messages["role"] + ":" + messages["content"]) return history_query +def set_vision_change(): + st.session_state.vision_change = True + -def do_visualize_results(nlq_chain, sql_result): - sql_query_result = sql_result +def do_visualize_results(selected_profile): + sql_query_result = st.session_state.current_sql_result[selected_profile] if sql_query_result is not None: - nlq_chain.set_visualization_config_change(False) # Auto-detect columns - visualize_config_columns = st.columns(3) + available_columns = sql_query_result.columns.tolist() + + # Initialize session state for x_column and y_column if not already present + if 'x_column' not in st.session_state or st.session_state.x_column is None: + st.session_state.x_column = available_columns[0] if available_columns else None + if 'y_column' not in st.session_state or st.session_state.x_column is None: + st.session_state.y_column = available_columns[0] if available_columns else None - available_columns = sql_query_result.columns + # Layout configuration + col1, col2, col3 = st.columns([1, 1, 2]) + + # Chart type selection + chart_type = col1.selectbox('Choose the chart type', ['Table', 'Bar', 'Line', 'Pie'], + on_change=set_vision_change) - # hacky way to get around the issue of selectbox not updating when the options change - chart_type = visualize_config_columns[0].selectbox('Choose the chart type', - ['Table', 'Bar', 'Line', 'Pie'], - on_change=nlq_chain.set_visualization_config_change - ) if chart_type != 'Table': - x_column = visualize_config_columns[1].selectbox(f'Choose x-axis column', available_columns, - on_change=nlq_chain.set_visualization_config_change, - key=random.randint(0, 10000) - ) - y_column = visualize_config_columns[2].selectbox('Choose y-axis column', - reversed(available_columns.to_list()), - on_change=nlq_chain.set_visualization_config_change, - key=random.randint(0, 10000) - ) + # X-axis and Y-axis selection + st.session_state.x_column = col2.selectbox('Choose x-axis column', available_columns, + on_change=set_vision_change, + index=available_columns.index( + st.session_state.x_column) if st.session_state.x_column in available_columns else 0) + st.session_state.y_column = col3.selectbox('Choose y-axis column', available_columns, + on_change=set_vision_change, + index=available_columns.index( + st.session_state.y_column) if st.session_state.y_column in available_columns else 0) + + # Visualization if chart_type == 'Table': st.dataframe(sql_query_result, hide_index=True) elif chart_type == 'Bar': - st.plotly_chart(px.bar(sql_query_result, x=x_column, y=y_column)) + st.plotly_chart(px.bar(sql_query_result, x=st.session_state.x_column, y=st.session_state.y_column)) elif chart_type == 'Line': - st.plotly_chart(px.line(sql_query_result, x=x_column, y=y_column)) + st.plotly_chart(px.line(sql_query_result, x=st.session_state.x_column, y=st.session_state.y_column)) elif chart_type == 'Pie': - st.plotly_chart(px.pie(sql_query_result, names=x_column, values=y_column)) - else: - st.markdown('No visualization generated.') + st.plotly_chart(px.pie(sql_query_result, names=st.session_state.x_column, values=st.session_state.y_column)) + def recurrent_display(messages, i): @@ -247,6 +256,9 @@ def main(): all_profiles = ProfileManagement.get_all_profiles_with_info() st.session_state['profiles'] = all_profiles + if "vision_change" not in st.session_state: + st.session_state["vision_change"] = False + if 'selected_sample' not in st.session_state: st.session_state['selected_sample'] = '' @@ -313,6 +325,9 @@ def main(): st.session_state.query_rewrite_history[selected_profile] = [] st.session_state.nlq_chain = NLQChain(selected_profile) + if selected_profile not in st.session_state.current_sql_result: + st.session_state.current_sql_result[selected_profile] = None + if st.session_state.current_model_id != "" and st.session_state.current_model_id in model_ids: model_index = model_ids.index(st.session_state.current_model_id) model_type = st.selectbox("Choose your model", model_ids, index=model_index) @@ -384,9 +399,9 @@ def main(): knowledge_search_flag = False # add select box for which model to use - if search_box != "Type your query here..." or \ - current_nlq_chain.is_visualization_config_changed(): + if search_box != "Type your query here..." or st.session_state.vision_change: if search_box is not None and len(search_box) > 0: + st.session_state.current_sql_result[selected_profile] = None with st.chat_message("user"): current_nlq_chain.set_question(search_box) st.session_state.messages[selected_profile].append( @@ -642,7 +657,7 @@ def main(): st.session_state.messages[selected_profile].append( {"role": "assistant", "content": current_search_sql_result, "type": "pandas"}) - do_visualize_results(current_nlq_chain, st.session_state.current_sql_result[selected_profile]) + do_visualize_results(selected_profile) else: st.markdown("No relevant data found") @@ -666,10 +681,8 @@ def main(): on_click=sample_question_clicked, args=[gen_sq_list[2]]) else: - - if current_nlq_chain.is_visualization_config_changed(): - if visualize_results_flag: - do_visualize_results(current_nlq_chain, st.session_state.current_sql_result[selected_profile]) + if visualize_results_flag: + do_visualize_results(selected_profile) if __name__ == '__main__': From df343e72d5cdb455b4af3aead935b0cc1b1b1fdb Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Mon, 19 Aug 2024 17:55:50 +0800 Subject: [PATCH 129/130] add Platform.LINUX_AMD64 --- source/resources/lib/ecs/ecs-stack.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/source/resources/lib/ecs/ecs-stack.ts b/source/resources/lib/ecs/ecs-stack.ts index 70ce5e2..ac03b78 100644 --- a/source/resources/lib/ecs/ecs-stack.ts +++ b/source/resources/lib/ecs/ecs-stack.ts @@ -4,7 +4,7 @@ import * as ec2 from 'aws-cdk-lib/aws-ec2'; import * as ecs from 'aws-cdk-lib/aws-ecs'; import * as ecr from 'aws-cdk-lib/aws-ecr'; import * as iam from 'aws-cdk-lib/aws-iam'; -import {DockerImageAsset} from 'aws-cdk-lib/aws-ecr-assets'; +import {DockerImageAsset, Platform} from 'aws-cdk-lib/aws-ecr-assets'; import * as ecs_patterns from 'aws-cdk-lib/aws-ecs-patterns'; import * as path from 'path'; @@ -61,6 +61,7 @@ export class ECSStack extends cdk.Stack { 'dockerImageAsset': new DockerImageAsset(this, 'GenBiStreamlitDockerImage', { directory: services[0].dockerfileDirectory, file: services[0].dockerfile, + platform: Platform.LINUX_AMD64, buildArgs: { AWS_REGION: awsRegion, // Pass the AWS region as a build argument }, @@ -71,6 +72,7 @@ export class ECSStack extends cdk.Stack { 'dockerImageAsset': new DockerImageAsset(this, 'GenBiAPIDockerImage', { directory: services[1].dockerfileDirectory, file: services[1].dockerfile, + platform: Platform.LINUX_AMD64, buildArgs: { AWS_REGION: awsRegion, // Pass the AWS region as a build argument } @@ -275,6 +277,7 @@ export class ECSStack extends cdk.Stack { 'dockerImageAsset': new DockerImageAsset(this, 'GenBiFrontendDockerImage', { directory: services[2].dockerfileDirectory, file: services[2].dockerfile, + platform: Platform.LINUX_AMD64, buildArgs: { AWS_REGION: awsRegion, // Pass the AWS region as a build argument } From 6e38678226b3f2507c513353cc32cb6cdf6c3d2a Mon Sep 17 00:00:00 2001 From: Pinyu Su Date: Wed, 21 Aug 2024 16:42:54 +0800 Subject: [PATCH 130/130] change convert_timestamps_to_str --- application/utils/tool.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/application/utils/tool.py b/application/utils/tool.py index 949dfc8..380108b 100644 --- a/application/utils/tool.py +++ b/application/utils/tool.py @@ -2,7 +2,7 @@ import logging import time import random -from datetime import datetime +import datetime from multiprocessing import Manager import pandas as pd @@ -34,7 +34,7 @@ def generate_log_id(): def get_current_time(): - now = datetime.now() + now = datetime.datetime.now() formatted_time = now.strftime('%Y-%m-%d %H:%M:%S') return formatted_time @@ -66,6 +66,9 @@ def convert_timestamps_to_str(data): if isinstance(item, pd.Timestamp): # Convert Timestamp to string new_row.append(item.strftime('%Y-%m-%d %H:%M:%S')) + elif isinstance(item, datetime.date): + # Convert datetime.date to string + new_row.append(item.strftime('%Y-%m-%d %H:%M:%S')) else: new_row.append(item) converted_data.append(new_row)