From a04c6a638a0a322a3e180491f2ceb92499990d66 Mon Sep 17 00:00:00 2001 From: Ishu Kumar Date: Thu, 7 Aug 2025 20:35:37 +0530 Subject: [PATCH 1/4] Added MCP functionality with tool calls and tool call display --- src/client/content/chatbot.py | 250 +++++++++---- src/client/content/config/mcp_servers.py | 24 ++ src/client/content/config/settings.py | 47 ++- src/client/mcp/client.py | 446 +++++++++++++++++++++++ src/client/mcp/frontend.py | 94 +++++ src/client/utils/st_common.py | 18 + src/common/schema.py | 39 +- src/launch_client.py | 4 + src/launch_server.py | 120 +++++- src/server/api/core/bootstrap.py | 3 +- src/server/api/core/mcp.py | 31 ++ src/server/api/core/settings.py | 10 + src/server/api/v1/__init__.py | 2 +- src/server/api/v1/mcp.py | 147 ++++++++ src/server/api/v1/settings.py | 13 +- src/server/bootstrap/mcp.py | 89 +++++ src/server/mcp/server/archive_mcp.py | 182 +++++++++ src/server/mcp/server_config.json | 20 + 18 files changed, 1448 insertions(+), 91 deletions(-) create mode 100644 src/client/content/config/mcp_servers.py create mode 100644 src/client/mcp/client.py create mode 100644 src/client/mcp/frontend.py create mode 100644 src/server/api/core/mcp.py create mode 100644 src/server/api/v1/mcp.py create mode 100644 src/server/bootstrap/mcp.py create mode 100644 src/server/mcp/server/archive_mcp.py create mode 100644 src/server/mcp/server_config.json diff --git a/src/client/content/chatbot.py b/src/client/content/chatbot.py index d8382ecc..7ba3fe05 100644 --- a/src/client/content/chatbot.py +++ b/src/client/content/chatbot.py @@ -2,11 +2,11 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -Session States Set: -- user_client: Stores the Client +This file merges the Streamlit Chatbot GUI with the MCPClient for a complete, +runnable example demonstrating their integration. """ -# spell-checker:ignore streamlit, oraclevs, selectai +# spell-checker:ignore streamlit, oraclevs, selectai, langgraph, prebuilt import asyncio import inspect import json @@ -21,8 +21,9 @@ import client.utils.api_call as api_call from client.utils.st_footer import render_chat_footer -import client.utils.client as client import common.logging_config as logging_config +from client.mcp.client import MCPClient +from pathlib import Path logger = logging_config.logging.getLogger("client.content.chatbot") @@ -67,95 +68,220 @@ async def main() -> None: ######################################################################### # Sidebar Settings ######################################################################### - # Get a list of available language models, if none, then stop ll_models_enabled = st_common.enabled_models_lookup("ll") if not ll_models_enabled: st.error("No language models are configured and/or enabled. Disabling Client.", icon="🛑") st.stop() - # the sidebars will set this to False if not everything is configured. state.enable_client = True st_common.tools_sidebar() st_common.history_sidebar() st_common.ll_sidebar() st_common.selectai_sidebar() st_common.vector_search_sidebar() - # Stop when sidebar configurations not set if not state.enable_client: st.stop() ######################################################################### # Chatty-Bot Centre ######################################################################### - # Establish the Client - if "user_client" not in state: - state.user_client = client.Client( - server=state.server, - settings=state.client_settings, - timeout=1200, - ) - user_client: client.Client = state.user_client - - history = await user_client.get_history() + + if "messages" not in state: + state.messages = [] + st.chat_message("ai").write("Hello, how can I help you?") - vector_search_refs = [] - for message in history or []: - if not message["content"]: + + for message in state.messages: + role = message.get("role") + display_role = "" + if role in ("human", "user"): + display_role = "human" + elif role in ("ai", "assistant"): + if not message.get("content") and not message.get("tool_trace"): + continue + display_role = "assistant" + else: continue - if message["role"] == "tool" and message["name"] == "oraclevs_tool": - vector_search_refs = json.loads(message["content"]) - if message["role"] in ("ai", "assistant"): - with st.chat_message("ai"): - st.markdown(message["content"]) - if vector_search_refs: - show_vector_search_refs(vector_search_refs) - vector_search_refs = [] - elif message["role"] in ("human", "user"): - with st.chat_message("human"): - content = message["content"] + + with st.chat_message(display_role): + if "tool_trace" in message and message["tool_trace"]: + for tool_call in message["tool_trace"]: + with st.expander(f"🛠️ **Tool Call:** `{tool_call['name']}`", expanded=False): + st.text("Arguments:") + st.code(json.dumps(tool_call.get('args', {}), indent=2), language="json") + if "error" in tool_call: + st.text("Error:") + st.error(tool_call['error']) + else: + st.text("Result:") + st.code(tool_call.get('result', ''), language="json") + if message.get("content"): + # Display file attachments if present + if "attachments" in message and message["attachments"]: + for file in message["attachments"]: + # Show appropriate icon based on file type + if file["type"].startswith("image/"): + st.image(file["preview"], use_container_width=True) + st.markdown(f"🖼️ **{file['name']}** ({file['size']//1024} KB)") + elif file["type"] == "application/pdf": + st.markdown(f"📄 **{file['name']}** ({file['size']//1024} KB)") + elif file["type"] in ("text/plain", "text/markdown"): + st.markdown(f"📝 **{file['name']}** ({file['size']//1024} KB)") + else: + st.markdown(f"📎 **{file['name']}** ({file['size']//1024} KB)") + + # Display message content - handle both string and list formats + content = message.get("content") if isinstance(content, list): - for part in content: - if part["type"] == "text": - st.write(part["text"]) - elif part["type"] == "image_url" and part["image_url"]["url"].startswith("data:image"): - st.image(part["image_url"]["url"]) + # Extract and display only text parts + text_parts = [part["text"] for part in content if part["type"] == "text"] + st.markdown("\n".join(text_parts)) else: - st.write(content) + st.markdown(content) sys_prompt = state.client_settings["prompts"]["sys"] render_chat_footer() + if human_request := st.chat_input( f"Ask your question here... (current prompt: {sys_prompt})", accept_file=True, - file_type=["jpg", "jpeg", "png"], + file_type=["jpg", "jpeg", "png", "pdf", "txt", "docx"], + key=f"chat_input_{len(state.messages)}", ): - st.chat_message("human").write(human_request.text) - file_b64 = None - if human_request["files"]: - file = human_request["files"][0] - file_bytes = file.read() - file_b64 = base64.b64encode(file_bytes).decode("utf-8") + # Process message with potential file attachments + message = {"role": "user", "content": human_request.text} + + # Handle file attachments + if hasattr(human_request, "files") and human_request.files: + # Store file information separately from content + message["attachments"] = [] + for file in human_request.files: + file_bytes = file.read() + file_b64 = base64.b64encode(file_bytes).decode("utf-8") + message["attachments"].append({ + "name": file.name, + "type": file.type, + "size": len(file_bytes), + "data": file_b64, + "preview": f"data:{file.type};base64,{file_b64}" if file.type.startswith("image/") else None + }) + + state.messages.append(message) + st.rerun() + if state.messages and state.messages[-1]["role"] == "user": try: - message_placeholder = st.chat_message("ai").empty() - full_answer = "" - async for chunk in user_client.stream(message=human_request.text, image_b64=file_b64): - full_answer += chunk - message_placeholder.markdown(full_answer) - # Stream until we hit the end then refresh to replace with history - st.rerun() - except Exception: - logger.error("Exception:", exc_info=1) - st.chat_message("ai").write( - """ - I'm sorry, something's gone wrong. Please try again. - If the problem persists, please raise an issue. - """ - ) - if st.button("Retry", key="reload_chatbot"): - st_common.clear_state_key("user_client") + with st.chat_message("ai"): + with st.spinner("Thinking..."): + client_settings_for_request = state.client_settings.copy() + model_id = client_settings_for_request.get('ll_model', {}).get('model') + if model_id: + all_model_configs = st_common.enabled_models_lookup("ll") + model_config = all_model_configs.get(model_id, {}) + if 'api_key' in model_config: + if 'll_model' not in client_settings_for_request: + client_settings_for_request['ll_model'] = {} + client_settings_for_request['ll_model']['api_key'] = model_config['api_key'] + + # Prepare message history for backend + message_history = [] + for msg in state.messages: + # Create a copy of the message + processed_msg = msg.copy() + + # If there are attachments, include them in the content + if "attachments" in msg and msg["attachments"]: + # Start with the text content + text_content = msg["content"] + + # Handle list content format (from OpenAI API) + if isinstance(text_content, list): + text_parts = [part["text"] for part in text_content if part["type"] == "text"] + text_content = "\n".join(text_parts) + + # Create a list to hold structured content parts + content_list = [{"type": "text", "text": text_content}] + + non_image_references = [] + for attachment in msg["attachments"]: + if attachment["type"].startswith("image/"): + # Only add image URLs for user messages + if msg["role"] in ("human", "user"): + # Normalize image MIME types for compatibility + mime_type = attachment["type"] + if mime_type == "image/jpg": + mime_type = "image/jpeg" + + content_list.append({ + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{attachment['data']}", + "detail": "low" + } + }) + else: + # Handle non-image files as text references + non_image_references.append(f"\n[File: {attachment['name']} ({attachment['size']//1024} KB)]") + + # If there were non-image files, append their references to the main text part + if non_image_references: + content_list[0]['text'] += "".join(non_image_references) + + processed_msg["content"] = content_list + # Convert list content to string format + elif isinstance(msg.get("content"), list): + text_parts = [part["text"] for part in msg["content"] if part["type"] == "text"] + processed_msg["content"] = str("\n".join(text_parts)) + # Otherwise, ensure content is a string + else: + processed_msg["content"] = str(msg.get("content", "")) + + message_history.append(processed_msg) + + async with MCPClient(client_settings=client_settings_for_request) as mcp_client: + final_text, tool_trace, new_history = await mcp_client.invoke( + message_history=message_history + ) + + # Update the history for display. + # Keep the original message structure with attachments + for i in range(len(new_history) - 1, -1, -1): + if new_history[i].get("role") == "assistant": + # Preserve any attachments from the user message + user_message = state.messages[-1] + if "attachments" in user_message: + new_history[-1]["attachments"] = user_message["attachments"] + + new_history[i]["content"] = final_text + new_history[i]["tool_trace"] = tool_trace + break + + state.messages = new_history + st.rerun() + + except Exception as e: + logger.error("Exception during invoke call:", exc_info=True) + # Extract just the error message + error_msg = str(e) + + # Check if it's a file-related error + if "file" in error_msg.lower() or "image" in error_msg.lower() or "content" in error_msg.lower(): + st.error(f"Error: {error_msg}") + + # Add a button to remove files and retry + if st.button("Remove files and retry", key="remove_files_retry"): + # Remove attachments from the latest message + if state.messages and "attachments" in state.messages[-1]: + del state.messages[-1]["attachments"] + st.rerun() + else: + st.error(f"Error: {error_msg}") + + if st.button("Retry", key="reload_chatbot_error"): + if state.messages and state.messages[-1]["role"] == "user": + state.messages.pop() st.rerun() -if __name__ == "__main__" or "page.py" in inspect.stack()[1].filename: +if __name__ == "__main__" or ("page" in inspect.stack()[1].filename if inspect.stack() else False): try: asyncio.run(main()) except ValueError as ex: diff --git a/src/client/content/config/mcp_servers.py b/src/client/content/config/mcp_servers.py new file mode 100644 index 00000000..5535227d --- /dev/null +++ b/src/client/content/config/mcp_servers.py @@ -0,0 +1,24 @@ +import inspect + +from client.mcp.frontend import display_commands_tab, display_ide_tab, get_fastapi_base_url, get_server_capabilities + +import streamlit as st + +def main(): + fastapi_base_url = get_fastapi_base_url() + tools, resources, prompts = get_server_capabilities(fastapi_base_url) + if "chat_history" not in st.session_state: + st.session_state.chat_history = [] + ide, commands = st.tabs(["🛠️ IDE", "📚 Available Commands"]) + + with ide: + # Display the IDE tab using the original AI Optimizer logic. + display_ide_tab() + with commands: + # Display the commands tab using the original AI Optimizer logic. + display_commands_tab(tools, resources, prompts) + + + +if __name__ == "__main__" or "page.py" in inspect.stack()[1].filename: + main() diff --git a/src/client/content/config/settings.py b/src/client/content/config/settings.py index 399a141e..cdebe9dc 100644 --- a/src/client/content/config/settings.py +++ b/src/client/content/config/settings.py @@ -38,15 +38,32 @@ ############################################################################# def get_settings(include_sensitive: bool = False): """Get Server-Side Settings""" - settings = api_call.get( - endpoint="v1/settings", - params={ - "client": state.client_settings["client"], - "full_config": True, - "incl_sensitive": include_sensitive, - }, - ) - return settings + try: + settings = api_call.get( + endpoint="v1/settings", + params={ + "client": state.client_settings["client"], + "full_config": True, + "incl_sensitive": include_sensitive, + }, + ) + return settings + except api_call.ApiError as e: + if "not found" in str(e): + # If client settings not found, create them + logger.info("Client settings not found, creating new ones") + api_call.post(endpoint="v1/settings", params={"client": state.client_settings["client"]}) + settings = api_call.get( + endpoint="v1/settings", + params={ + "client": state.client_settings["client"], + "full_config": True, + "incl_sensitive": include_sensitive, + }, + ) + return settings + else: + raise def save_settings(settings): @@ -141,11 +158,11 @@ def apply_uploaded_settings(uploaded): def spring_ai_conf_check(ll_model, embed_model) -> str: """Check if configuration is valid for SpringAI package""" - if ll_model is None or embed_model is None: + if not ll_model or not embed_model: return "hybrid" - ll_api = ll_model["api"] - embed_api = embed_model["api"] + ll_api = ll_model.get("api", "") + embed_api = embed_model.get("api", "") if "OpenAI" in ll_api and "OpenAI" in embed_api: return "openai" @@ -287,9 +304,11 @@ def main(): st.header("SpringAI Settings", divider="red") # Merge the User Settings into the Model Config model_lookup = st_common.state_configs_lookup("model_configs", "id") - ll_config = model_lookup[state.client_settings["ll_model"]["model"]] | state.client_settings["ll_model"] + ll_model_id = state.client_settings["ll_model"].get("model") + ll_config = model_lookup.get(ll_model_id, {}) | state.client_settings["ll_model"] + embed_model_id = state.client_settings["vector_search"].get("model") embed_config = ( - model_lookup[state.client_settings["vector_search"]["model"]] | state.client_settings["vector_search"] + model_lookup.get(embed_model_id, {}) | state.client_settings["vector_search"] ) spring_ai_conf = spring_ai_conf_check(ll_config, embed_config) diff --git a/src/client/mcp/client.py b/src/client/mcp/client.py new file mode 100644 index 00000000..d4282828 --- /dev/null +++ b/src/client/mcp/client.py @@ -0,0 +1,446 @@ +import json +import os +import time +import asyncio +from dotenv import load_dotenv +from mcp import ClientSession, StdioServerParameters, types +from mcp.client.stdio import stdio_client +from typing import List, Dict, Optional, Tuple, Type, Any +from contextlib import AsyncExitStack + +# --- MODIFICATION: Import LangChain components --- +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage, BaseMessage +from langchain_core.language_models.base import BaseLanguageModel +from pydantic import create_model, BaseModel, Field +# Import the specific chat models you want to support +from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_cohere import ChatCohere +from langchain_ollama import ChatOllama +from langchain_groq import ChatGroq +from langchain_mistralai import ChatMistralAI + +load_dotenv() + +if os.getenv("IS_STREAMLIT_CONTEXT"): + import nest_asyncio + nest_asyncio.apply() + +class MCPClient: + # MODIFICATION: Changed the constructor to accept client_settings + def __init__(self, client_settings: Dict): + """ + Initialize MCP Client using a settings dictionary from the Streamlit client. + + Args: + client_settings: The state.client_settings object. + """ + # 1. Validate the incoming settings dictionary + if not client_settings or 'll_model' not in client_settings: + raise ValueError("Client settings are incomplete. 'll_model' is required.") + + # 2. Store the settings and extract the model ID + self.model_settings = client_settings['ll_model'] + + # This is our new "Service Factory" using LangChain classes + # If no model is specified, we'll initialize with a default one + if 'model' not in self.model_settings or not self.model_settings['model']: + # Set a default model if none is specified + self.model_settings['model'] = 'llama3.1' + # Remove any OpenAI-specific parameters that might cause issues + self.model_settings.pop('openai_api_key', None) + + self.langchain_model = self._create_langchain_model(**self.model_settings) + + self.exit_stack = AsyncExitStack() + self.sessions: Dict[str, ClientSession] = {} + self.tool_to_session: Dict[str, Tuple[ClientSession, types.Tool]] = {} + self.available_prompts: Dict[str, types.Prompt] = {} + self.static_resources: Dict[str, str] = {} + self.dynamic_resources: List[str] = [] + self.resource_to_session: Dict[str, str] = {} + self.prompt_to_session: Dict[str, str] = {} + self.available_tools: List[Dict] = [] + self._stdio_generators: Dict[str, Any] = {} # To store stdio generators for cleanup + print(f"Initialized MCPClient with LangChain model: {self.langchain_model.__class__.__name__}") + + # --- FIX: Add __aenter__ and __aexit__ to make this a context manager --- + async def __aenter__(self): + """Enter the async context, connecting to all servers.""" + await self.connect_to_servers() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit the async context, ensuring all connections are cleaned up.""" + await self.cleanup() + + def _create_langchain_model(self, model: str, **kwargs) -> BaseLanguageModel: + """Factory to create and return a LangChain ChatModel instance.""" + # If no model is specified, default to llama3.1 which works with Ollama + if not model: + model = "llama3.1" + # Remove any OpenAI-specific parameters that might cause issues + kwargs.pop('openai_api_key', None) + + model_lower = model.lower() + + # Handle OpenAI models + if model_lower.startswith('gpt-'): + # Check if api_key is in kwargs and rename it to openai_api_key for ChatOpenAI + if 'api_key' in kwargs: + kwargs['openai_api_key'] = kwargs.pop('api_key') + # Remove parameters that shouldn't be passed to ChatOpenAI + kwargs.pop('context_length', None) + kwargs.pop('chat_history', None) + return ChatOpenAI(model=model, **kwargs) + + # Handle Anthropic models + elif model_lower.startswith('claude-'): + kwargs.pop('openai_api_key', None) + return ChatAnthropic(model=model, **kwargs) + + # Handle Google models + elif model_lower.startswith('gemini-'): + kwargs.pop('openai_api_key', None) + return ChatGoogleGenerativeAI(model=model, **kwargs) + + # Handle Mistral models + elif model_lower.startswith('mistral-'): + kwargs.pop('openai_api_key', None) + return ChatMistralAI(model=model, **kwargs) + + # Handle Cohere models + elif model_lower.startswith('cohere-'): + kwargs.pop('openai_api_key', None) + return ChatCohere(model=model, **kwargs) + + # Handle Groq models + elif model_lower.startswith('groq-'): + kwargs.pop('openai_api_key', None) + return ChatGroq(model=model, **kwargs) + + # Default to Ollama for any other model name + else: + return ChatOllama(model=model, **kwargs) + + def _convert_dict_to_langchain_messages(self, message_history: List[Dict]) -> List[BaseMessage]: + """Converts a list of message dictionaries to a list of LangChain message objects.""" + messages: List[BaseMessage] = [] + for msg in message_history: + role = msg.get("role") + content = msg.get("content", "") + if role == "user": + messages.append(HumanMessage(content=content)) # type: ignore + elif role == "assistant": + # AIMessage can handle tool calls directly from the dictionary format + tool_calls = msg.get("tool_calls") + messages.append(AIMessage(content=content, tool_calls=tool_calls or [])) # type: ignore + elif role == "system": + messages.append(SystemMessage(content=content)) # type: ignore + elif role == "tool": + messages.append(ToolMessage(content=content, tool_call_id=msg.get("tool_call_id", ""))) # type: ignore + return messages # type: ignore + + def _convert_langchain_messages_to_dict(self, langchain_messages: List[BaseMessage]) -> List[Dict]: + """Converts a list of LangChain message objects back to a list of dictionaries for session state.""" + dict_messages = [] + for msg in langchain_messages: + if isinstance(msg, HumanMessage): + dict_messages.append({"role": "user", "content": msg.content}) + elif isinstance(msg, AIMessage): + # Preserve tool calls in the dictionary format + dict_messages.append({"role": "assistant", "content": msg.content, "tool_calls": msg.tool_calls}) + elif isinstance(msg, SystemMessage): + dict_messages.append({"role": "system", "content": msg.content}) + elif isinstance(msg, ToolMessage): + dict_messages.append({"role": "tool", "content": msg.content, "tool_call_id": msg.tool_call_id}) + return dict_messages + + def _prepare_messages_for_service(self, message_history: List[Dict]) -> List[Dict]: + """ + FIX: Translates the rich message history from the GUI into a simple, + text-only format that AI services can understand. + """ + prepared_messages = [] + for msg in message_history: + content = msg.get("content") + # If content is a list (multimodal), extract only the text. + if isinstance(content, list): + text_content = " ".join( + part["text"] for part in content if part.get("type") == "text" + ) + prepared_messages.append({"role": msg["role"], "content": text_content}) + # Otherwise, use the content as is (assuming it's a string). + else: + prepared_messages.append(msg) + return prepared_messages + + async def connect_to_servers(self): + try: + config_paths = ["server/mcp/server_config.json", os.path.join(os.path.dirname(__file__), "..", "..", "server", "mcp", "server_config.json")] + servers = {} + for config_path in config_paths: + try: + with open(config_path, "r") as file: + servers = json.load(file).get("mcpServers", {}) + print(f"Loaded MCP server configuration from: {config_path}") + print(f"Found servers: {list(servers.keys())}") + break + except FileNotFoundError: + print(f"MCP server config not found at: {config_path}") + continue + except Exception as e: + print(f"Error reading MCP server config from {config_path}: {e}") + continue + if not servers: + print("No MCP server configuration found!") + for name, config in servers.items(): + print(f"Connecting to MCP server: {name}") + await self.connect_to_server(name, config) + except Exception as e: print(f"Error loading server configuration: {e}") + + async def connect_to_server(self, server_name: str, server_config: dict): + try: + print(f"Connecting to server '{server_name}' with config: {server_config}") + server_params = StdioServerParameters(**server_config) + + # Create the stdio client connection using the exit stack for proper cleanup + try: + read, write = await self.exit_stack.enter_async_context(stdio_client(server_params)) + + # Create the client session using the exit stack for proper cleanup + session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + + await session.initialize() + self.sessions[server_name] = session + + # Load tools, resources, and prompts from this server + await self._load_server_capabilities(session, server_name) + except RuntimeError as e: + # Handle runtime errors related to task context + if "cancel scope" not in str(e).lower(): + raise + print(f"Warning: Connection to '{server_name}' had context issues: {e}") + except Exception as e: + raise + except Exception as e: + print(f"Failed to connect to '{server_name}': {e}") + import traceback + traceback.print_exc() + + async def _run_async_generator(self, generator): + """Helper method to run an async generator in the current task context.""" + return await generator.__anext__() + + async def _load_server_capabilities(self, session: ClientSession, server_name: str): + """Load tools, resources, and prompts from a connected server.""" + try: + # List tools + tools_list = await session.list_tools() + print(f"Found {len(tools_list.tools)} tools from server '{server_name}'") + for tool in tools_list.tools: + self.tool_to_session[tool.name] = (session, tool) + print(f"Loaded tool '{tool.name}' from server '{server_name}'") + + # List resources + try: + resp = await session.list_resources() + if resp.resources: print(f" - Found Static Resources: {[r.name for r in resp.resources]}") + for resource in resp.resources: + uri = resource.uri.encoded_string() + self.resource_to_session[uri] = server_name + user_shortcut = uri.split('//')[-1] + self.static_resources[user_shortcut] = uri + if resource.name and resource.name != user_shortcut: + self.static_resources[resource.name] = uri + except Exception as e: + print(f"Failed to load resources from server '{server_name}': {e}") + + # Discover DYNAMIC resource templates + try: + # The response object for templates has a `.templates` attribute + resp = await session.list_resource_templates() + if resp.resourceTemplates: print(f" - Found Dynamic Resource Templates: {[t.name for t in resp.resourceTemplates]}") + for template in resp.resourceTemplates: + uri = template.uriTemplate + # The key for the session map MUST be the pattern itself. + self.resource_to_session[uri] = server_name + if uri not in self.dynamic_resources: + self.dynamic_resources.append(uri) + except Exception as e: + # This is also okay, some servers don't have dynamic resources. + print(f"Failed to load dynamic resources from server '{server_name}': {e}") + + + # List prompts + try: + prompts_list = await session.list_prompts() + print(f"Found {len(prompts_list.prompts)} prompts from server '{server_name}'") + for prompt in prompts_list.prompts: + self.available_prompts[prompt.name] = prompt + self.prompt_to_session[prompt.name] = server_name + print(f"Loaded prompt '{prompt.name}' from server '{server_name}'") + except Exception as e: + print(f"Failed to load prompts from server '{server_name}': {e}") + + except Exception as e: + print(f"Failed to load capabilities from server '{server_name}': {e}") + + async def _rebuild_mcp_tool_schemas(self): + """Rebuilds the list of tools from connected MCP servers in a LangChain-compatible format.""" + self.available_tools = [] + for _, (_, tool_object) in self.tool_to_session.items(): + # LangChain's .bind_tools can often work directly with this MCP schema + tool_schema = { + "name": tool_object.name, + "description": tool_object.description, + "args_schema": self.create_pydantic_model_from_schema(tool_object.name, tool_object.inputSchema) + } + self.available_tools.append(tool_schema) + print(f"Available tools after rebuild: {len(self.available_tools)}") + + def create_pydantic_model_from_schema(self, name: str, schema: dict) -> Type[BaseModel]: + """Dynamically creates a Pydantic model from a JSON schema for LangChain tool binding.""" + fields = {} + if schema and 'properties' in schema: + for prop_name, prop_details in schema['properties'].items(): + field_type = str # Default to string + # A more robust implementation would map JSON schema types to Python types + if prop_details.get('type') == 'integer': field_type = int + elif prop_details.get('type') == 'number': field_type = float + elif prop_details.get('type') == 'boolean': field_type = bool + + fields[prop_name] = (field_type, Field(..., description=prop_details.get('description'))) + + return create_model(name, **fields) # type: ignore + + async def execute_mcp_tool(self, tool_name: str, tool_args: Dict) -> str: + try: + session, _ = self.tool_to_session[tool_name] + result = await session.call_tool(tool_name, arguments=tool_args) + if not result.content: return "Tool executed successfully." + + # Handle different content types properly + if isinstance(result.content, list): + text_parts = [] + for item in result.content: + # Check if item has a text attribute + if hasattr(item, 'text'): + text_parts.append(str(item.text)) + else: + # Handle other content types + text_parts.append(str(item)) + return " | ".join(text_parts) + else: + return str(result.content) + except Exception as e: + # Check if it's a closed resource error + if "ClosedResourceError" in str(type(e)) or "closed" in str(e).lower(): + raise Exception("MCP session is closed. Please try again.") from e + else: + raise + + async def invoke(self, message_history: List[Dict]) -> Tuple[str, List[Dict], List[Dict]]: + """ + Main entry point. Now returns a tuple of: + (final_text_response, tool_calls_trace, new_full_history) + """ + max_retries = 3 + for attempt in range(max_retries): + try: + langchain_messages = self._convert_dict_to_langchain_messages(message_history) + + # Separate the final text response from the tool trace + final_text_response = "" + tool_calls_trace = [] + + max_iterations = 10 + tool_execution_failed = False + for iteration in range(max_iterations): + await self._rebuild_mcp_tool_schemas() + model_with_tools = self.langchain_model.bind_tools(self.available_tools) + response_message: AIMessage = await model_with_tools.ainvoke(langchain_messages) + langchain_messages.append(response_message) + + # Capture the final text response from the last message + if response_message.content: + final_text_response = response_message.content + + if not response_message.tool_calls: + break + + for tool_call in response_message.tool_calls: + tool_name = tool_call['name'] + tool_args = tool_call['args'] + + try: + result_content = await self.execute_mcp_tool(tool_name, tool_args) + tool_calls_trace.append({ + "name": tool_name, + "args": tool_args, + "result": result_content + }) + except Exception as e: + if "MCP session is closed" in str(e) and attempt < max_retries - 1: + print(f"MCP session closed, reinitializing (attempt {attempt + 1})") + await self.cleanup(); await self.connect_to_servers() + await asyncio.sleep(0.1); tool_execution_failed = True; break + else: + result_content = f"Error executing tool {tool_name}: {e}" + tool_calls_trace.append({ + "name": tool_name, + "args": tool_args, + "error": result_content + }) + + langchain_messages.append(ToolMessage(content=result_content, tool_call_id=tool_call['id'])) + + if tool_execution_failed: break + + if tool_execution_failed and attempt < max_retries - 1: continue + + final_history_dict = self._convert_langchain_messages_to_dict(langchain_messages) + + return final_text_response, tool_calls_trace, final_history_dict + + except RuntimeError as e: + if "Event loop is closed" in str(e) and attempt < max_retries - 1: + print(f"Event loop closed, reinitializing model (attempt {attempt + 1})") + self.langchain_model = self._create_langchain_model(**self.model_settings) + await asyncio.sleep(0.1); continue + else: raise Exception("Event loop closed. Please try again.") from e + except Exception as e: + if attempt >= max_retries - 1: raise + print(f"Invoke attempt {attempt + 1} failed, retrying: {e}") + await asyncio.sleep(0.1) + + raise Exception("Failed to invoke MCP client after all retries") + + async def cleanup(self): + """Clean up all resources properly.""" + try: + # Close all sessions using the exit stack to avoid context issues + await self.exit_stack.aclose() + except Exception as e: + # Suppress errors related to async context management as they don't affect functionality + if "cancel scope" not in str(e).lower() and "asyncio" not in str(e).lower(): + print(f"Error during cleanup: {e}") + + try: + # Clear sessions + self.sessions.clear() + + # Clear other data structures + self.tool_to_session.clear() + self.available_prompts.clear() + self.static_resources.clear() + self.dynamic_resources.clear() + self.resource_to_session.clear() + self.prompt_to_session.clear() + self.available_tools.clear() + + # Recreate the exit stack for future use + self.exit_stack = AsyncExitStack() + except Exception as e: + print(f"Error during cleanup: {e}") diff --git a/src/client/mcp/frontend.py b/src/client/mcp/frontend.py new file mode 100644 index 00000000..383bb07f --- /dev/null +++ b/src/client/mcp/frontend.py @@ -0,0 +1,94 @@ +import streamlit as st +import os +import requests +import json + +def set_page(): + st.set_page_config( + page_title="MCP Universal Chatbot", + page_icon="🤖", + layout="wide" + ) + +def get_fastapi_base_url(): + return os.getenv("FASTAPI_BASE_URL", "http://127.0.0.1:8000") + +@st.cache_data(show_spinner="Connecting to MCP Backend...", ttl=60) +def get_server_capabilities(fastapi_base_url): + """Fetches the lists of tools and resources from the FastAPI backend.""" + try: + # Get API key from environment or generate one + api_key = os.getenv("API_SERVER_KEY") + headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} + + # First check if MCP is enabled and initialized + status_response = requests.get(f"{fastapi_base_url}/v1/mcp/status", headers=headers) + if status_response.status_code == 200: + status = status_response.json() + if not status.get("enabled", False): + st.warning("MCP is not enabled. Please enable it in the configuration.") + return {"error": "MCP not enabled"}, {"error": "MCP not enabled"}, {"error": "MCP not enabled"} + if not status.get("initialized", False): + st.info("MCP is enabled but not yet initialized. Please select a model first.") + return {"tools": []}, {"static": [], "dynamic": []}, {"prompts": []} + + tools_response = requests.get(f"{fastapi_base_url}/v1/mcp/tools", headers=headers) + tools_response.raise_for_status() + tools = tools_response.json() + + resources_response = requests.get(f"{fastapi_base_url}/v1/mcp/resources", headers=headers) + resources_response.raise_for_status() + resources = resources_response.json() + + prompts_response = requests.get(f"{fastapi_base_url}/v1/mcp/prompts", headers=headers) + prompts_response.raise_for_status() + prompts = prompts_response.json() + + return tools, resources, prompts + except requests.exceptions.RequestException as e: + st.error(f"Could not connect to the MCP backend at {fastapi_base_url}. Is it running? Error: {e}") + return {"tools": []}, {"static": [], "dynamic": []}, {"prompts": []} + +def get_server_files(): + files = ["server/mcp/server_config.json"] + try: + with open("server/mcp/server_config.json", "r") as f: config = json.load(f) + for server in config.get("mcpServers", {}).values(): + script_path = server.get("args", [None])[0] + if script_path and os.path.exists(script_path): files.append(script_path) + except FileNotFoundError: st.sidebar.error("server_config.json not found!") + return list(set(files)) + +def display_ide_tab(): + st.header("🔧 Integrated MCP Server IDE") + st.info("Edit your server configuration or scripts. Restart the launcher for changes to take effect.") + server_files = get_server_files() + selected_file = st.selectbox("Select a file to edit", options=server_files) + if selected_file: + with open(selected_file, "r") as f: file_content = f.read() + from streamlit_ace import st_ace + new_content = st_ace(value=file_content, language="python" if selected_file.endswith(".py") else "json", theme="monokai", keybinding="vscode", height=500, auto_update=True) + if st.button("Save Changes"): + with open(selected_file, "w") as f: f.write(new_content) + st.success(f"Successfully saved {selected_file}!") + +def display_commands_tab(tools, resources, prompts): + st.header("📖 Discovered MCP Commands") + st.info("These commands were discovered from the MCP backend.") + + if tools: + with st.expander("🛠️ Available Tools (Used automatically by the AI)", expanded=True): + # Extract just the tool names from the tools response + if "tools" in tools and isinstance(tools["tools"], list): + tool_names = [tool.get("name", tool) if isinstance(tool, dict) else tool for tool in tools["tools"]] + st.write(tool_names) + else: + st.json(tools) + + if resources: + with st.expander("📦 Available Resources (Use with `@` or just ``)"): + st.json(resources) + + if prompts: + with st.expander("📝 Available Prompts (Use with `/prompt ` or select in chat)"): + st.json(prompts) diff --git a/src/client/utils/st_common.py b/src/client/utils/st_common.py index b1ec6d98..b386215d 100644 --- a/src/client/utils/st_common.py +++ b/src/client/utils/st_common.py @@ -17,6 +17,12 @@ import common.logging_config as logging_config from common.schema import PromptPromptType, PromptNameType, SelectAISettings, ClientIdType +# Import the MCP initialization function +try: + from launch_server import initialize_mcp_engine_with_model +except ImportError: + initialize_mcp_engine_with_model = None + logger = logging_config.logging.getLogger("client.utils.st_common") @@ -161,6 +167,8 @@ def ll_sidebar() -> None: selected_model = state.client_settings["ll_model"]["model"] ll_idx = list(ll_models_enabled.keys()).index(selected_model) if not state.client_settings["selectai"]["enabled"]: + # Store the previous model to detect changes + previous_model = selected_model selected_model = st.sidebar.selectbox( "Chat model:", options=list(ll_models_enabled.keys()), @@ -169,6 +177,16 @@ def ll_sidebar() -> None: on_change=update_client_settings("ll_model"), disabled=state.client_settings["selectai"]["enabled"], ) + + # If the model has changed, reinitialize the MCP engine + if selected_model != previous_model and initialize_mcp_engine_with_model: + try: + # Instead of creating a new event loop, we'll set a flag to indicate + # that the MCP engine needs to be reinitialized + state.mcp_needs_reinit = selected_model + logger.info(f"MCP engine marked for reinitialization with model: {selected_model}") + except Exception as e: + logger.error(f"Failed to mark MCP engine for reinitialization with model {selected_model}: {e}") # Temperature temperature = ll_models_enabled[selected_model]["temperature"] diff --git a/src/common/schema.py b/src/common/schema.py index ecd2fb98..da2d5ca3 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -4,8 +4,10 @@ """ # spell-checker:ignore ollama, hnsw, mult, ocid, testset, selectai, explainsql, showsql, vector_search, aioptimizer +from __future__ import annotations + import time -from typing import Optional, Literal, Union, get_args, Any +from typing import Optional, Literal, Union, get_args, Any, Dict, List from pydantic import BaseModel, Field, PrivateAttr, model_validator from langchain_core.messages import ChatMessage @@ -301,6 +303,7 @@ class Configuration(BaseModel): model_configs: Optional[list[Model]] = None oci_configs: Optional[list[OracleCloudSettings]] = None prompt_configs: Optional[list[Prompt]] = None + mcp_configs: Optional[list[MCPModelConfig]] = Field(default=None, description="List of MCP configurations") def model_dump_public(self, incl_sensitive: bool = False, incl_readonly: bool = False) -> dict: """Remove marked fields for FastAPI Response""" @@ -452,6 +455,37 @@ class EvaluationReport(Evaluation): html_report: str = Field(description="HTML Report") +##################################################### +# MCP +##################################################### +class MCPModelConfig(BaseModel): + """MCP Model Configuration""" + model_id: str = Field(..., description="Model identifier") + service_type: Literal["ollama", "openai"] = Field(..., description="AI service type") + base_url: str = Field(default="http://localhost:11434", description="Base URL for API") + api_key: Optional[str] = Field(default=None, description="API key", json_schema_extra={"sensitive": True}) + enabled: bool = Field(default=True, description="Model availability status") + streaming: bool = Field(default=False, description="Enable streaming responses") + temperature: float = Field(default=1.0, description="Model temperature") + max_tokens: int = Field(default=2048, description="Maximum tokens per response") + + +class MCPToolConfig(BaseModel): + """MCP Tool Configuration""" + name: str = Field(..., description="Tool name") + description: str = Field(..., description="Tool description") + parameters: Dict[str, Any] = Field(..., description="Tool parameters") + enabled: bool = Field(default=True, description="Tool availability status") + + +class MCPSettings(BaseModel): + """MCP Global Settings""" + models: List[MCPModelConfig] = Field(default_factory=list, description="Available MCP models") + tools: List[MCPToolConfig] = Field(default_factory=list, description="Available MCP tools") + default_model: Optional[str] = Field(default=None, description="Default model identifier") + enabled: bool = Field(default=True, description="Enable or disable MCP functionality") + + ##################################################### # Types ##################################################### @@ -469,3 +503,6 @@ class EvaluationReport(Evaluation): TestSetsIdType = TestSets.__annotations__["tid"] TestSetsNameType = TestSets.__annotations__["name"] TestSetDateType = TestSets.__annotations__["created"] +MCPModelIdType = MCPModelConfig.__annotations__["model_id"] +MCPServiceType = MCPModelConfig.__annotations__["service_type"] +MCPToolNameType = MCPToolConfig.__annotations__["name"] diff --git a/src/launch_client.py b/src/launch_client.py index 4e5e4797..39330c3d 100644 --- a/src/launch_client.py +++ b/src/launch_client.py @@ -128,6 +128,7 @@ def main() -> None: state.disabled["model_cfg"] = os.environ.get("DISABLE_MODEL_CFG", "false").lower() == "true" state.disabled["oci_cfg"] = os.environ.get("DISABLE_OCI_CFG", "false").lower() == "true" state.disabled["settings"] = os.environ.get("DISABLE_SETTINGS", "false").lower() == "true" + state.disabled["mcp_cfg"] = os.environ.get("DISABLE_MCP_CFG", "false").lower() == "true" # Left Hand Side - Navigation chatbot = st.Page("client/content/chatbot.py", title="ChatBot", icon="💬", default=True) @@ -166,6 +167,9 @@ def main() -> None: # When we get here, if there's nothing in "Configuration" delete it if not navigation["Configuration"]: del navigation["Configuration"] + if not state.disabled["mcp_cfg"]: + mcp_config = st.Page("client/content/config/mcp_servers.py", title="MCP Servers", icon="💾") + navigation["Configuration"].append(mcp_config) pg = st.navigation(navigation, position="sidebar", expanded=False) pg.run() diff --git a/src/launch_server.py b/src/launch_server.py index b9c02194..fceea5c0 100644 --- a/src/launch_server.py +++ b/src/launch_server.py @@ -25,19 +25,25 @@ import server.patches.litellm_patch # pylint: disable=unused-import import argparse +import json import queue import secrets import socket import subprocess import threading -from typing import Annotated +from typing import Annotated, Any, Dict, Optional from pathlib import Path import uvicorn +from contextlib import asynccontextmanager import psutil -from fastapi import FastAPI, HTTPException, Depends, status, APIRouter +from client.mcp.client import MCPClient +from fastapi import APIRouter, Depends, FastAPI, HTTPException, status +from fastapi.openapi.utils import get_openapi +from fastapi.routing import APIRoute from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from pydantic import BaseModel # Logging import common.logging_config as logging_config @@ -45,9 +51,50 @@ # Configuration import server.bootstrap.configfile as configfile +from server.bootstrap import mcp as mcp_bootstrap logger = logging_config.logging.getLogger("launch_server") +mcp_engine: Optional[MCPClient] = None + +def get_mcp_engine() -> Optional[MCPClient]: + """Get the current MCP engine instance.""" + global mcp_engine + logger.debug(f"get_mcp_engine() called, returning: {mcp_engine}") + # Additional debugging to check if the variable exists + if 'mcp_engine' in globals(): + print(f"DEBUG: mcp_engine in globals: {globals().get('mcp_engine')}") + else: + print("DEBUG: mcp_engine not in globals") + # Print the module name to see which module this is + print(f"DEBUG: This is module: {__name__}") + return mcp_engine + +async def initialize_mcp_engine_with_model(model_name: str) -> Optional[MCPClient]: + """Initialize or reinitialize the MCP engine with a specific model.""" + global mcp_engine + + # Clean up existing engine if it exists + if mcp_engine: + try: + await mcp_engine.cleanup() + except Exception as e: + logger.error(f"Error cleaning up existing MCP engine: {e}") + + # Initialize new engine with the specified model + try: + mcp_engine = MCPClient(client_settings={'ll_model': {'model': model_name}}) + logger.info("MCP Client created with model %s, connecting to servers...", model_name) + await mcp_engine.connect_to_servers() + logger.info("MCP Engine initialized successfully with model %s", model_name) + return mcp_engine + except Exception as e: + logger.error(f"Failed to initialize MCP Engine with model {model_name}: {e}", exc_info=True) + mcp_engine = None + return None +class McpToolCallRequest(BaseModel): + tool_name: str + tool_args: Dict[str, Any] ########################################## # Process Control @@ -97,8 +144,7 @@ def start_subprocess(port: int, logfile: bool) -> subprocess.Popen: return process port = port or find_available_port() - existing_pid = get_pid_using_port(port) - if existing_pid: + if existing_pid := get_pid_using_port(port): logger.info("API server already running on port: %i (PID: %i)", port, existing_pid) return existing_pid @@ -118,11 +164,10 @@ def stop_server(pid: int) -> None: proc = psutil.Process(pid) proc.terminate() proc.wait() + logger.info("API server stopped.") except (psutil.NoSuchProcess, psutil.AccessDenied) as ex: logger.error("Failed to terminate process with PID: %i - %s", pid, ex) - logger.info("API server stopped.") - ########################################## # Server App and API Key @@ -170,12 +215,72 @@ def register_endpoints(noauth: APIRouter, auth: APIRouter): auth.include_router(api_v1.selectai.auth, prefix="/v1/selectai", tags=["SelectAI"]) auth.include_router(api_v1.settings.auth, prefix="/v1/settings", tags=["Tools - Settings"]) auth.include_router(api_v1.testbed.auth, prefix="/v1/testbed", tags=["Tools - Testbed"]) + auth.include_router(api_v1.mcp.auth, prefix="/v1/mcp", tags=["Config - MCP Servers"]) ############################################################################# # APP FACTORY ############################################################################# -def create_app(config: str = None) -> FastAPI: +@asynccontextmanager +async def lifespan(app: FastAPI): + """FastAPI startup/shutdown lifecycle for the MCP Engine.""" + logger.info("Starting API Server...") + global mcp_engine + + # Define a single, authoritative path for the configuration file. + config_path = Path("server/etc/mcp_config.json") + + # 1. Handle the missing configuration file as a critical error. + if not config_path.exists(): + logger.error( + f"CRITICAL: MCP configuration file not found at '{config_path}'. " + "MCP Engine cannot be initialized." + ) + # Yield control to allow the server to run, but without the MCP engine. + yield + return + + # 2. Load the configuration and initialize the engine. + try: + logger.info(f"Loading MCP configuration from '{config_path}'...") + with open(config_path, encoding='utf-8') as f: + mcp_config = json.load(f) + + mcp_bootstrap.load_mcp_settings(mcp_config) + + # 3. Check if MCP is enabled in the loaded configuration. + if mcp_bootstrap.MCP_SETTINGS and mcp_bootstrap.MCP_SETTINGS.enabled: + logger.info("MCP is enabled. Initializing MCP Engine...") + + # This structure assumes MCPClient can be initialized with just the default model. + client_init_settings = { + 'll_model': {'model': mcp_bootstrap.MCP_SETTINGS.default_model} + } + mcp_engine = MCPClient(client_settings=client_init_settings) + + await mcp_engine.connect_to_servers() + logger.info("MCP Engine initialized successfully.") + else: + logger.warning("MCP is disabled in the configuration file. Skipping initialization.") + + except Exception as e: + logger.error(f"Failed to initialize MCP Engine from configuration: {e}", exc_info=True) + # Ensure the engine is not set if initialization fails. + mcp_engine = None + + # Yield control to the running application. + yield + + # Shutdown the engine if it was successfully initialized. + if mcp_engine: + logger.info("Shutting down MCP Engine...") + try: + await mcp_engine.cleanup() + logger.info("MCP Engine cleanup completed.") + except Exception as e: + logger.error(f"Error during MCP Engine cleanup: {e}") + +def create_app(config: str = "") -> FastAPI: """Create and configure the FastAPI app.""" if not config: config = configfile.config_file_path() @@ -187,6 +292,7 @@ def create_app(config: str = None) -> FastAPI: version=__version__, docs_url="/v1/docs", openapi_url="/v1/openapi.json", + lifespan=lifespan, license_info={ "name": "Universal Permissive License", "url": "http://oss.oracle.com/licenses/upl", diff --git a/src/server/api/core/bootstrap.py b/src/server/api/core/bootstrap.py index f4087446..db95de41 100644 --- a/src/server/api/core/bootstrap.py +++ b/src/server/api/core/bootstrap.py @@ -3,10 +3,11 @@ Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ -from server.bootstrap import databases, models, oci, prompts, settings +from server.bootstrap import databases, models, oci, prompts, settings, mcp DATABASE_OBJECTS = databases.main() MODEL_OBJECTS = models.main() OCI_OBJECTS = oci.main() PROMPT_OBJECTS = prompts.main() SETTINGS_OBJECTS = settings.main() +MCP_OBJECTS = mcp.main() diff --git a/src/server/api/core/mcp.py b/src/server/api/core/mcp.py new file mode 100644 index 00000000..751a8fc0 --- /dev/null +++ b/src/server/api/core/mcp.py @@ -0,0 +1,31 @@ +from typing import Optional, List, Dict, Any +from common.schema import MCPModelConfig, MCPToolConfig, MCPSettings +from server.bootstrap import mcp as mcp_bootstrap +import common.logging_config as logging_config + +logger = logging_config.logging.getLogger("api.core.mcp") + +def get_mcp_model(model_id: str) -> Optional[MCPModelConfig]: + """Get MCP model configuration by ID""" + for model in mcp_bootstrap.MCP_MODELS: + if model.model_id == model_id: + return model + return None + +def get_mcp_tool(tool_name: str) -> Optional[MCPToolConfig]: + """Get MCP tool configuration by name""" + for tool in mcp_bootstrap.MCP_TOOLS: + if tool.name == tool_name: + return tool + return None + +def update_mcp_settings(settings: Dict[str, Any]) -> MCPSettings: + """Update MCP settings""" + if not mcp_bootstrap.MCP_SETTINGS: + raise ValueError("MCP settings not initialized") + + for key, value in settings.items(): + if hasattr(mcp_bootstrap.MCP_SETTINGS, key): + setattr(mcp_bootstrap.MCP_SETTINGS, key, value) + + return mcp_bootstrap.MCP_SETTINGS \ No newline at end of file diff --git a/src/server/api/core/settings.py b/src/server/api/core/settings.py index 7013eb48..6121a52e 100644 --- a/src/server/api/core/settings.py +++ b/src/server/api/core/settings.py @@ -54,11 +54,16 @@ def get_server_config() -> schema.Configuration: prompt_objects = bootstrap.PROMPT_OBJECTS prompt_configs = [prompt for prompt in prompt_objects] + # Add MCP configs as a list (similar to other configs) + mcp_objects = bootstrap.mcp.MCP_MODELS # Get list of models from bootstrap + mcp_configs = [model for model in mcp_objects] # Convert to list like other configs + full_config = { "database_configs": database_configs, "model_configs": model_configs, "oci_configs": oci_configs, "prompt_configs": prompt_configs, + "mcp_configs": mcp_configs, # Now it's a list like other configs } return full_config @@ -91,6 +96,11 @@ def update_server_config(config_data: dict) -> None: if "prompt_configs" in config_data: bootstrap.PROMPT_OBJECTS = config.prompt_configs or [] + + # Add MCP config handling (similar to other configs) + if "mcp_configs" in config_data: + from server.bootstrap import mcp + mcp.MCP_MODELS = config.mcp_configs or [] # Store as list like other configs def load_config_from_json_data(config_data: dict, client: schema.ClientIdType = None) -> None: diff --git a/src/server/api/v1/__init__.py b/src/server/api/v1/__init__.py index fcd6743f..873ce855 100644 --- a/src/server/api/v1/__init__.py +++ b/src/server/api/v1/__init__.py @@ -3,4 +3,4 @@ Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ -from . import chat, databases, embed, models, oci, probes, prompts, testbed, settings, selectai +from . import chat, databases, embed, models, oci, probes, prompts, testbed, settings, selectai, mcp diff --git a/src/server/api/v1/mcp.py b/src/server/api/v1/mcp.py new file mode 100644 index 00000000..c1e008a4 --- /dev/null +++ b/src/server/api/v1/mcp.py @@ -0,0 +1,147 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +This file is being used in APIs, and not the backend.py file. +""" + +from typing import Optional, Dict, Any +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel +from datetime import datetime + +import common.logging_config as logging_config + +logger = logging_config.logging.getLogger("endpoints.v1.mcp") + +auth = APIRouter() + +def mcp_engine_obj(): + """Check if the MCP engine is initialized.""" + try: + from launch_server import get_mcp_engine + mcp_engine = get_mcp_engine() + except ImportError: + return None + return mcp_engine + +class McpToolCallRequest(BaseModel): + tool_name: str + tool_args: Dict[str, Any] + +class ChatRequest(BaseModel): + query: str + prompt_name: Optional[str] = None + resource_uri: Optional[str] = None + message_history: Optional[list] = None + +@auth.get( + "/tools", + description="List available MCP tools", + response_model=dict +) +async def list_mcp_tools(): + # Import here to avoid circular imports + mcp_engine = mcp_engine_obj() + if not mcp_engine: + raise HTTPException(status_code=503, detail="MCP Engine not initialized.") + try: + await mcp_engine._rebuild_mcp_tool_schemas() + except Exception as e: + logger.error(f"Error rebuilding tool schemas: {e}") + + tools_info = [] + for tool_name, (session, tool_object) in mcp_engine.tool_to_session.items(): + tools_info.append({ + "name": tool_object.name, + "description": tool_object.description, + "input_schema": tool_object.inputSchema + }) + return {"tools": tools_info} + +@auth.post( + "/execute", + description="Execute an MCP tool", + response_model=dict +) +async def execute_mcp_tool(request: McpToolCallRequest): + # Import here to avoid circular imports + mcp_engine = mcp_engine_obj() + if not mcp_engine: + raise HTTPException(status_code=503, detail="MCP Engine not initialized.") + try: + result = await mcp_engine.execute_mcp_tool(request.tool_name, request.tool_args) + return {"result": result} + except Exception as e: + logger.error(f"Error executing MCP tool: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@auth.post( + "/chat", + description="Chat with MCP engine", + response_model=dict +) +async def chat_endpoint(request: ChatRequest): + mcp_engine = mcp_engine_obj() + if not mcp_engine: + raise HTTPException(status_code=503, detail="MCP Engine not initialized.") + try: + message_history = request.message_history or [{"role": "user", "content": request.query}] + response_text, _ = await mcp_engine.invoke( + message_history=message_history + ) + return {"response": response_text} + except Exception as e: + logger.error(f"Error in MCP chat: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@auth.get( + "/resources", + description="List MCP resources", + response_model=dict +) +async def list_resources(): + # Import here to avoid circular imports + mcp_engine = mcp_engine_obj() + if not mcp_engine: + raise HTTPException(status_code=503, detail="MCP Engine not initialized.") + + try: + # This will trigger loading if not already loaded + _ = await mcp_engine._rebuild_mcp_tool_schemas() + except Exception as e: + logger.error(f"Error loading resources: {e}") + + return { + "static": list(getattr(mcp_engine, "static_resources", {}).keys()), + "dynamic": getattr(mcp_engine, "dynamic_resources", []) + } + +@auth.get( + "/prompts", + description="List MCP prompts", + response_model=dict +) +async def list_prompts(): + mcp_engine = mcp_engine_obj() + if not mcp_engine: + raise HTTPException(status_code=503, detail="MCP Engine not initialized.") + try: + # This will trigger loading if not already loaded + _ = await mcp_engine._rebuild_mcp_tool_schemas() + except Exception as e: + logger.error(f"Error loading prompts: {e}") + + return { + "prompts": list(getattr(mcp_engine, "available_prompts", {}).keys()) + } + +@auth.get("/health", response_model=dict) +async def health_check(): + """Check MCP engine health status""" + actual_mcp_engine = mcp_engine_obj() + return { + "status": "initialized" if actual_mcp_engine else "not_initialized", + "engine_type": str(type(actual_mcp_engine)) if actual_mcp_engine else None, + "available_tools": len(getattr(actual_mcp_engine, "available_tools", [])) if actual_mcp_engine else 0, + "timestamp": datetime.now().isoformat() + } diff --git a/src/server/api/v1/settings.py b/src/server/api/v1/settings.py index 4bcd9817..528560d5 100644 --- a/src/server/api/v1/settings.py +++ b/src/server/api/v1/settings.py @@ -38,7 +38,7 @@ async def settings_get( full_config: bool = False, incl_sensitive: bool = Depends(_incl_sensitive_param), incl_readonly: bool = Depends(_incl_readonly_param), -) -> Union[schema.Configuration, schema.Settings]: +) -> Union[schema.Configuration, schema.Settings, JSONResponse]: """Get settings for a specific client by name""" try: client_settings = settings.get_client_settings(client) @@ -55,8 +55,11 @@ async def settings_get( model_configs=config.get("model_configs"), oci_configs=config.get("oci_configs"), prompt_configs=config.get("prompt_configs"), + mcp_configs=config.get("mcp_configs", None) ) - return JSONResponse(content=response.model_dump_public(incl_sensitive=incl_sensitive, incl_readonly=incl_readonly)) + if incl_sensitive or incl_readonly: + return JSONResponse(content=response.model_dump_public(incl_sensitive=incl_sensitive, incl_readonly=incl_readonly)) + return response @auth.patch( @@ -114,12 +117,12 @@ async def load_settings_from_file( pass try: - if not file.filename.endswith(".json"): + if not file.filename or not file.filename.endswith(".json"): raise HTTPException(status_code=400, detail="Settings: Only JSON files are supported.") contents = await file.read() config_data = json.loads(contents) settings.load_config_from_json_data(config_data, client) - return {"message": "Configuration loaded successfully."} + return JSONResponse(content={"message": "Configuration loaded successfully."}) except json.JSONDecodeError as ex: raise HTTPException(status_code=400, detail="Settings: Invalid JSON file.") from ex except KeyError as ex: @@ -148,7 +151,7 @@ async def load_settings_from_json( try: settings.load_config_from_json_data(payload.model_dump(), client) - return {"message": "Configuration loaded successfully."} + return JSONResponse(content={"message": "Configuration loaded successfully."}) except json.JSONDecodeError as ex: raise HTTPException(status_code=400, detail="Settings: Invalid JSON file.") from ex except KeyError as ex: diff --git a/src/server/bootstrap/mcp.py b/src/server/bootstrap/mcp.py new file mode 100644 index 00000000..95e2e34a --- /dev/null +++ b/src/server/bootstrap/mcp.py @@ -0,0 +1,89 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +from typing import List, Optional +import os + +from server.bootstrap.configfile import ConfigStore +from common.schema import MCPSettings, MCPModelConfig, MCPToolConfig +import common.logging_config as logging_config + +logger = logging_config.logging.getLogger("bootstrap.mcp") + +# Global configuration holders +MCP_SETTINGS: Optional[MCPSettings] = None +MCP_MODELS: List[MCPModelConfig] = [] +MCP_TOOLS: List[MCPToolConfig] = [] + +def load_mcp_settings(config: dict) -> None: + """Load MCP configuration from config file""" + global MCP_SETTINGS, MCP_MODELS, MCP_TOOLS + + # Convert to settings object first + mcp_settings = MCPSettings( + models=[MCPModelConfig(**model) for model in config.get("models", [])], + tools=[MCPToolConfig(**tool) for tool in config.get("tools", [])], + default_model=config.get("default_model"), + enabled=config.get("enabled", True) + ) + + # Set globals + MCP_SETTINGS = mcp_settings + MCP_MODELS = mcp_settings.models + MCP_TOOLS = mcp_settings.tools + + logger.info("Loaded %i MCP Models and %i Tools", len(MCP_MODELS), len(MCP_TOOLS)) + +def main() -> MCPSettings: + """Bootstrap MCP Configuration""" + logger.debug("*** Bootstrapping MCP - Start") + + # Load from ConfigStore if available + configuration = ConfigStore.get() + if configuration and configuration.mcp_configs: + logger.debug("Using MCP configs from ConfigStore") + # Convert list of MCPModelConfig objects to MCPSettings + mcp_settings = MCPSettings( + models=configuration.mcp_configs, + tools=[], # No tools in the current schema + default_model=configuration.mcp_configs[0].model_id if configuration.mcp_configs else None, + enabled=True + ) + else: + # Default MCP configuration + mcp_settings = MCPSettings( + models=[ + MCPModelConfig( + model_id="llama3.1", + service_type="ollama", + base_url=os.environ.get("ON_PREM_OLLAMA_URL", "http://localhost:11434"), + enabled=True, + streaming=False, + temperature=1.0, + max_tokens=2048 + ) + ], + tools=[ + MCPToolConfig( + name="file_reader", + description="Read contents of files", + parameters={ + "path": "string", + "encoding": "string" + }, + enabled=True + ) + ], + default_model=None, + enabled=True + ) + + logger.info("Loaded %i MCP Models and %i Tools", len(mcp_settings.models), len(mcp_settings.tools)) + logger.debug("*** Bootstrapping MCP - End") + logger.info("MCP Settings: %s", mcp_settings.model_dump_json()) + return mcp_settings + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/server/mcp/server/archive_mcp.py b/src/server/mcp/server/archive_mcp.py new file mode 100644 index 00000000..d38a091f --- /dev/null +++ b/src/server/mcp/server/archive_mcp.py @@ -0,0 +1,182 @@ +import json +import os +from dotenv import load_dotenv +import arxiv +from typing import List +from mcp.server.fastmcp import FastMCP +import textwrap + +# --- Configuration and Setup --- +load_dotenv() +PAPER_DIR = "papers" +# Initialize FastMCP server with a name +mcp = FastMCP("research") +_paper_cache = {} + +# --- Tool Definitions --- + +@mcp.tool() +def search_papers(topic: str, max_results: int = 5) -> List[str]: + """ + Searches for papers on arXiv based on a topic and saves their metadata. + + Args: + topic (str): The topic to search for. + max_results (int): Maximum number of results to retrieve. + + Returns: + List[str]: A list of the paper IDs found and saved. + """ + client_arxiv = arxiv.Client() + search = arxiv.Search( + query=topic, + max_results=max_results, + sort_by=arxiv.SortCriterion.Relevance + ) + papers = list(client_arxiv.results(search)) + + if not papers: + # It's good practice to print feedback on the server side + print(f"Server: No papers found for topic '{topic}'") + return [] + + path = os.path.join(PAPER_DIR, topic.lower().replace(" ", "_")) + os.makedirs(path, exist_ok=True) + file_path = os.path.join(path, "papers_info.json") + + try: + with open(file_path, "r") as json_file: + papers_info = json.load(json_file) + except (FileNotFoundError, json.JSONDecodeError): + papers_info = {} + + paper_ids = [] + for paper in papers: + paper_id = paper.get_short_id() + paper_ids.append(paper_id) + papers_info[paper_id] = { + 'title': paper.title, + 'authors': [author.name for author in paper.authors], + 'summary': paper.summary, + 'pdf_url': paper.pdf_url, + 'published': str(paper.published.date()) + } + + with open(file_path, "w") as json_file: + json.dump(papers_info, json_file, indent=2) + + print(f"Server: Saved {len(paper_ids)} papers to {file_path}") + return paper_ids + +@mcp.tool() +def extract_info(paper_id: str) -> str: + """ + Retrieves saved information for a specific paper ID from all topics. + Uses an in-memory cache for performance. + + Args: + paper_id (str): The ID of the paper to look for. + + Returns: + str: JSON string with paper information if found, else an error message. + """ + # 1. First, check the cache for an exact match + if paper_id in _paper_cache: + return json.dumps(_paper_cache[paper_id], indent=2) + + # 2. If not in cache, perform the expensive file search (your original logic) + for item in os.listdir(PAPER_DIR): + item_path = os.path.join(PAPER_DIR, item) + if os.path.isdir(item_path): + file_path = os.path.join(item_path, "papers_info.json") + if os.path.isfile(file_path): + try: + with open(file_path, "r") as json_file: + papers_info = json.load(json_file) + + # Search logic (can be simplified if we populate cache first) + for key, value in papers_info.items(): + # Add every paper from this file to the cache to avoid re-reading this file + if key not in _paper_cache: + _paper_cache[key] = value + + except (FileNotFoundError, json.JSONDecodeError): + continue + + # 3. Now that the cache is populated from relevant files, check again. + # This handles version differences as well. + if paper_id in _paper_cache: + return json.dumps(_paper_cache[paper_id], indent=2) + + base_id = paper_id.split('v')[0] + for key, value in _paper_cache.items(): + if key.startswith(base_id): + return json.dumps(value, indent=2) + + return f"Error: No saved information found for paper ID {paper_id}." + +# --- Resource Definitions --- + +@mcp.resource("papers://folders") +def get_available_folders() -> str: + """Lists all available topic folders that contain saved paper information.""" + print(f"Server: Listing available topic folders in {PAPER_DIR}") + folders = [] + if os.path.exists(PAPER_DIR): + for topic_dir in os.listdir(PAPER_DIR): + if os.path.isdir(os.path.join(PAPER_DIR, topic_dir)): + folders.append(topic_dir) + + content = "# Available Research Topics\n\n" + if folders: + content += "You can retrieve info for any of these topics using `@`.\n\n" + for folder in folders: + content += f"- `{folder}`\n" + else: + content += "No topic folders found. Use `search_papers` to create one." + print(f"Server: Found {len(folders)} topic folders.") + return content + +@mcp.resource("papers://{topic}") +def get_topic_papers(topic: str) -> str: + """Gets detailed information about all saved papers for a specific topic.""" + print(f"Server: Retrieving papers for topic '{topic}'") + topic_dir = topic.lower().replace(" ", "_") + papers_file = os.path.join(PAPER_DIR, topic_dir, "papers_info.json") + + if not os.path.exists(papers_file): + return f"# No papers found for topic: {topic}" + + with open(papers_file, 'r') as f: + papers_data = json.load(f) + + content = f"# Papers on {topic.replace('_', ' ').title()}\n\n" + for paper_id, info in papers_data.items(): + content += f"## {info['title']} (`{paper_id}`)\n" + content += f"- **Authors**: {', '.join(info['authors'])}\n" + content += f"- **Summary**: {info['summary'][:200]}...\n---\n" + print(f"Server: Found {len(papers_data)} papers for topic '{topic}'") + return content + +# --- Prompt Definition --- + +@mcp.prompt() +def generate_search_prompt(topic: str) -> str: + """Generates a system prompt to guide an AI in researching a topic.""" + return textwrap.dedent(f""" + You are a research assistant. Your goal is to provide a comprehensive overview of a topic. + When asked about '{topic}', follow these steps: + 1. Use the `search_papers` tool to find relevant papers. + 2. For each paper ID returned, use the `extract_info` tool to get its details. + 3. Synthesize the information from all papers into a cohesive summary. + 4. Present the key findings, common themes, and any differing conclusions. + Do not present the raw JSON. Format the final output for readability. + """) + +# --- Main Execution Block --- + +if __name__ == "__main__": + # This is the original, simple, and correct way to run the server. + # It will not crash. + print("Research MCP Server running on stdio...") + mcp.run(transport='stdio') diff --git a/src/server/mcp/server_config.json b/src/server/mcp/server_config.json new file mode 100644 index 00000000..3d8b0321 --- /dev/null +++ b/src/server/mcp/server_config.json @@ -0,0 +1,20 @@ +{ + "mcpServers": { + "filesystem": { + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + "." + ] + }, + "research": { + "command": "python3", + "args": ["server/mcp/server/archive_mcp.py"] + }, + "fetch": { + "command": "python3", + "args": ["-m", "mcp_server_fetch"] + } + } +} From 16ef70aba387dcb4a2bcd045ce9cc97e20d09bbd Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 11 Aug 2025 14:46:25 +0100 Subject: [PATCH 2/4] Include additional langchain and MCP modules --- src/pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/pyproject.toml b/src/pyproject.toml index 7650f9da..b31fe996 100644 --- a/src/pyproject.toml +++ b/src/pyproject.toml @@ -28,9 +28,13 @@ server = [ "fastapi==0.116.1", "faiss-cpu==1.11.0.post1", "giskard==2.17.0", + "langchain-anthropic==0.3.18", "langchain-cohere==0.4.5", "langchain-community==0.3.27", + "langchain-google-genai==2.1.9", + "langchain-groq==0.3.7", "langchain-huggingface==0.3.1", + "langchain-mistralai==0.2.11", "langchain-ollama==0.3.6", "langchain-openai==0.3.29", "langgraph==0.6.4", @@ -38,6 +42,7 @@ server = [ "llama-index==0.13.1", "lxml==6.0.0", "matplotlib==3.10.5", + "mcp==1.12.4", "oci~=2.0", "psutil==7.0.0", "python-multipart==0.0.20", From 0c3dfd0a4c57dd37848f93994cc21f18f31085eb Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 11 Aug 2025 15:09:07 +0100 Subject: [PATCH 3/4] Re-Org Schema --- src/common/schema.py | 69 ++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/src/common/schema.py b/src/common/schema.py index ea7bbdcf..b5ffebd6 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -4,10 +4,8 @@ """ # spell-checker:ignore ollama hnsw mult ocid testset selectai explainsql showsql vector_search aioptimizer genai -from __future__ import annotations - import time -from typing import Optional, Literal, Union, get_args, Any, Dict, List +from typing import Optional, Literal, Union, get_args, Any from pydantic import BaseModel, Field, PrivateAttr, model_validator from langchain_core.messages import ChatMessage @@ -101,6 +99,40 @@ def set_connection(self, connection: oracledb.Connection) -> None: self._connection = connection +##################################################### +# MCP +##################################################### +class MCPModelConfig(BaseModel): + """MCP Model Configuration""" + + model_id: str = Field(..., description="Model identifier") + service_type: Literal["ollama", "openai"] = Field(..., description="AI service type") + base_url: str = Field(default="http://localhost:11434", description="Base URL for API") + api_key: Optional[str] = Field(default=None, description="API key", json_schema_extra={"sensitive": True}) + enabled: bool = Field(default=True, description="Model availability status") + streaming: bool = Field(default=False, description="Enable streaming responses") + temperature: float = Field(default=1.0, description="Model temperature") + max_tokens: int = Field(default=2048, description="Maximum tokens per response") + + +class MCPToolConfig(BaseModel): + """MCP Tool Configuration""" + + name: str = Field(..., description="Tool name") + description: str = Field(..., description="Tool description") + parameters: dict[str, Any] = Field(..., description="Tool parameters") + enabled: bool = Field(default=True, description="Tool availability status") + + +class MCPSettings(BaseModel): + """MCP Global Settings""" + + models: list[MCPModelConfig] = Field(default_factory=list, description="Available MCP models") + tools: list[MCPToolConfig] = Field(default_factory=list, description="Available MCP tools") + default_model: Optional[str] = Field(default=None, description="Default model identifier") + enabled: bool = Field(default=True, description="Enable or disable MCP functionality") + + ##################################################### # Models ##################################################### @@ -474,37 +506,6 @@ class EvaluationReport(Evaluation): html_report: str = Field(description="HTML Report") -##################################################### -# MCP -##################################################### -class MCPModelConfig(BaseModel): - """MCP Model Configuration""" - model_id: str = Field(..., description="Model identifier") - service_type: Literal["ollama", "openai"] = Field(..., description="AI service type") - base_url: str = Field(default="http://localhost:11434", description="Base URL for API") - api_key: Optional[str] = Field(default=None, description="API key", json_schema_extra={"sensitive": True}) - enabled: bool = Field(default=True, description="Model availability status") - streaming: bool = Field(default=False, description="Enable streaming responses") - temperature: float = Field(default=1.0, description="Model temperature") - max_tokens: int = Field(default=2048, description="Maximum tokens per response") - - -class MCPToolConfig(BaseModel): - """MCP Tool Configuration""" - name: str = Field(..., description="Tool name") - description: str = Field(..., description="Tool description") - parameters: Dict[str, Any] = Field(..., description="Tool parameters") - enabled: bool = Field(default=True, description="Tool availability status") - - -class MCPSettings(BaseModel): - """MCP Global Settings""" - models: List[MCPModelConfig] = Field(default_factory=list, description="Available MCP models") - tools: List[MCPToolConfig] = Field(default_factory=list, description="Available MCP tools") - default_model: Optional[str] = Field(default=None, description="Default model identifier") - enabled: bool = Field(default=True, description="Enable or disable MCP functionality") - - ##################################################### # Types ##################################################### From 1c4e662553cc287df102f020fcb9d936afe4bcd1 Mon Sep 17 00:00:00 2001 From: Ishu Kumar Date: Wed, 13 Aug 2025 10:07:58 +0530 Subject: [PATCH 4/4] preserving tool call history and refactored main method in the chatbot --- src/client/content/chatbot.py | 353 ++++++++++++++++++++-------------- 1 file changed, 208 insertions(+), 145 deletions(-) diff --git a/src/client/content/chatbot.py b/src/client/content/chatbot.py index 7ba3fe05..0db4e016 100644 --- a/src/client/content/chatbot.py +++ b/src/client/content/chatbot.py @@ -57,17 +57,10 @@ def show_vector_search_refs(context): ############################################################################# -# MAIN +# Helper Functions ############################################################################# -async def main() -> None: - """Streamlit GUI""" - try: - get_models() - except api_call.ApiError: - st.stop() - ######################################################################### - # Sidebar Settings - ######################################################################### +def setup_sidebar(): + """Initialize and validate sidebar components""" ll_models_enabled = st_common.enabled_models_lookup("ll") if not ll_models_enabled: st.error("No language models are configured and/or enabled. Disabling Client.", icon="🛑") @@ -81,15 +74,11 @@ async def main() -> None: if not state.enable_client: st.stop() - ######################################################################### - # Chatty-Bot Centre - ######################################################################### +def display_messages(): + """Render chat message history""" + if not state.messages: + st.chat_message("ai").write("Hello, how can I help you?") - if "messages" not in state: - state.messages = [] - - st.chat_message("ai").write("Hello, how can I help you?") - for message in state.messages: role = message.get("role") display_role = "" @@ -120,8 +109,9 @@ async def main() -> None: for file in message["attachments"]: # Show appropriate icon based on file type if file["type"].startswith("image/"): - st.image(file["preview"], use_container_width=True) - st.markdown(f"🖼️ **{file['name']}** ({file['size']//1024} KB)") + cols = st.columns([1, 3]) + with cols[0]: + st.image(file["preview"], use_container_width=True) elif file["type"] == "application/pdf": st.markdown(f"📄 **{file['name']}** ({file['size']//1024} KB)") elif file["type"] in ("text/plain", "text/markdown"): @@ -132,12 +122,182 @@ async def main() -> None: # Display message content - handle both string and list formats content = message.get("content") if isinstance(content, list): - # Extract and display only text parts - text_parts = [part["text"] for part in content if part["type"] == "text"] + text_parts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + if "text" in part: + text_parts.append(part["text"]) + elif isinstance(part, str): + text_parts.append(part) st.markdown("\n".join(text_parts)) else: st.markdown(content) +def process_user_input(human_request): + """Process user input including file attachments""" + message = {"role": "user", "content": human_request.text} + + # Handle file attachments with base64 for display + if hasattr(human_request, "files") and human_request.files: + message["attachments"] = [] + for file in human_request.files: + file_bytes = file.read() + file_b64 = base64.b64encode(file_bytes).decode("utf-8") + message["attachments"].append({ + "name": file.name, + "type": file.type, + "size": len(file_bytes), + "data": file_b64, + "preview": f"data:{file.type};base64,{file_b64}" if file.type.startswith("image/") else None + }) + + state.messages.append(message) + +def prepare_client_settings(): + """Prepare client settings for MCPClient invocation""" + client_settings_for_request = state.client_settings.copy() + model_id = client_settings_for_request.get('ll_model', {}).get('model') + if model_id: + all_model_configs = st_common.enabled_models_lookup("ll") + model_config = all_model_configs.get(model_id, {}) + if 'api_key' in model_config: + if 'll_model' not in client_settings_for_request: + client_settings_for_request['ll_model'] = {} + client_settings_for_request['ll_model']['api_key'] = model_config['api_key'] + return client_settings_for_request + +def prepare_message_history(): + """Process message history for backend""" + message_history = [] + for msg in state.messages: + processed_msg = msg.copy() + + if "attachments" in msg and msg["attachments"]: + text_content = msg["content"] + if isinstance(text_content, list): + text_parts = [] + for part in text_content: + if isinstance(part, dict) and part.get("type") == "text": + if "text" in part: + text_parts.append(part["text"]) + elif isinstance(part, str): + text_parts.append(part) + text_content = "\n".join(text_parts) + + content_list = [{"type": "text", "text": text_content}] + for attachment in msg["attachments"]: + if attachment["type"].startswith("image/"): + mime_type = attachment["type"] + if mime_type == "image/jpg": + mime_type = "image/jpeg" + content_list.append({ + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{attachment['data']}", + "detail": "low" + } + }) + else: + # Handle non-image files as text references + content_list.append({ + "type": "text", + "text": f"\n[File: {attachment['name']} ({attachment['size']//1024} KB)]" + }) + + processed_msg["content"] = content_list + # Convert list content to string format + elif isinstance(msg.get("content"), list): + text_parts = [] + for part in msg["content"]: + if isinstance(part, dict) and part.get("type") == "text": + if "text" in part: + text_parts.append(part["text"]) + elif isinstance(part, str): + text_parts.append(part) + processed_msg["content"] = "\n".join(text_parts) + # Otherwise, ensure content is a string + else: + processed_msg["content"] = str(msg.get("content", "")) + + message_history.append(processed_msg) + return message_history + +def process_final_text(final_text): + """Convert final response text to string format""" + if isinstance(final_text, list): + text_parts = [] + for part in final_text: + if isinstance(part, dict): + part_type = part.get("type") + part_text = part.get("text") + if part_type == "text" and isinstance(part_text, str): + text_parts.append(part_text) + elif isinstance(part, str): + text_parts.append(part) + return "\n".join(text_parts) + return final_text + +def find_last_user_message_index(): + """Find index of last user message in history""" + last_user_idx = -1 + for i, msg in enumerate(state.messages): + if msg.get("role") in ("human", "user"): + last_user_idx = i + return last_user_idx + +def handle_invoke_error(e): + """Handle exceptions during MCPClient invocation""" + logger.error("Exception during invoke call:", exc_info=True) + error_msg = str(e) + + if "file" in error_msg.lower() or "image" in error_msg.lower() or "content" in error_msg.lower(): + st.error(f"Error: {error_msg}") + if st.button("Remove files and retry", key="remove_files_retry"): + # Remove attachments from the latest message + if state.messages and "attachments" in state.messages[-1]: + del state.messages[-1]["attachments"] + st.rerun() + else: + st.error(f"Error: {error_msg}") + + if st.button("Retry", key="reload_chatbot_error"): + if state.messages and state.messages[-1]["role"] == "user": + state.messages.pop() + st.rerun() + + +############################################################################# +# MAIN +############################################################################# +async def main() -> None: + """Streamlit GUI""" + # Initialize critical session state variables + if 'enable_client' not in state: + state.enable_client = True + if 'messages' not in state: + state.messages = [] + # Add initial greeting message + state.messages.append({ + "role": "assistant", + "content": "Hello, how can I help you?" + }) + + try: + get_models() + except api_call.ApiError: + st.stop() + + setup_sidebar() + + # Final safety check + if not state.enable_client: + st.stop() + + ######################################################################### + # Chatty-Bot Centre + ######################################################################### + + display_messages() sys_prompt = state.client_settings["prompts"]["sys"] render_chat_footer() @@ -147,138 +307,41 @@ async def main() -> None: file_type=["jpg", "jpeg", "png", "pdf", "txt", "docx"], key=f"chat_input_{len(state.messages)}", ): - # Process message with potential file attachments - message = {"role": "user", "content": human_request.text} - - # Handle file attachments - if hasattr(human_request, "files") and human_request.files: - # Store file information separately from content - message["attachments"] = [] - for file in human_request.files: - file_bytes = file.read() - file_b64 = base64.b64encode(file_bytes).decode("utf-8") - message["attachments"].append({ - "name": file.name, - "type": file.type, - "size": len(file_bytes), - "data": file_b64, - "preview": f"data:{file.type};base64,{file_b64}" if file.type.startswith("image/") else None - }) - - state.messages.append(message) + process_user_input(human_request) st.rerun() + if state.messages and state.messages[-1]["role"] == "user": try: with st.chat_message("ai"): with st.spinner("Thinking..."): - client_settings_for_request = state.client_settings.copy() - model_id = client_settings_for_request.get('ll_model', {}).get('model') - if model_id: - all_model_configs = st_common.enabled_models_lookup("ll") - model_config = all_model_configs.get(model_id, {}) - if 'api_key' in model_config: - if 'll_model' not in client_settings_for_request: - client_settings_for_request['ll_model'] = {} - client_settings_for_request['ll_model']['api_key'] = model_config['api_key'] - - # Prepare message history for backend - message_history = [] - for msg in state.messages: - # Create a copy of the message - processed_msg = msg.copy() - - # If there are attachments, include them in the content - if "attachments" in msg and msg["attachments"]: - # Start with the text content - text_content = msg["content"] - - # Handle list content format (from OpenAI API) - if isinstance(text_content, list): - text_parts = [part["text"] for part in text_content if part["type"] == "text"] - text_content = "\n".join(text_parts) - - # Create a list to hold structured content parts - content_list = [{"type": "text", "text": text_content}] - - non_image_references = [] - for attachment in msg["attachments"]: - if attachment["type"].startswith("image/"): - # Only add image URLs for user messages - if msg["role"] in ("human", "user"): - # Normalize image MIME types for compatibility - mime_type = attachment["type"] - if mime_type == "image/jpg": - mime_type = "image/jpeg" - - content_list.append({ - "type": "image_url", - "image_url": { - "url": f"data:{mime_type};base64,{attachment['data']}", - "detail": "low" - } - }) - else: - # Handle non-image files as text references - non_image_references.append(f"\n[File: {attachment['name']} ({attachment['size']//1024} KB)]") - - # If there were non-image files, append their references to the main text part - if non_image_references: - content_list[0]['text'] += "".join(non_image_references) - - processed_msg["content"] = content_list - # Convert list content to string format - elif isinstance(msg.get("content"), list): - text_parts = [part["text"] for part in msg["content"] if part["type"] == "text"] - processed_msg["content"] = str("\n".join(text_parts)) - # Otherwise, ensure content is a string - else: - processed_msg["content"] = str(msg.get("content", "")) - - message_history.append(processed_msg) - - async with MCPClient(client_settings=client_settings_for_request) as mcp_client: + client_settings = prepare_client_settings() + message_history = prepare_message_history() + async with MCPClient(client_settings=client_settings) as mcp_client: final_text, tool_trace, new_history = await mcp_client.invoke( message_history=message_history ) - - # Update the history for display. - # Keep the original message structure with attachments - for i in range(len(new_history) - 1, -1, -1): - if new_history[i].get("role") == "assistant": - # Preserve any attachments from the user message - user_message = state.messages[-1] - if "attachments" in user_message: - new_history[-1]["attachments"] = user_message["attachments"] - - new_history[i]["content"] = final_text - new_history[i]["tool_trace"] = tool_trace - break - - state.messages = new_history - st.rerun() - - except Exception as e: - logger.error("Exception during invoke call:", exc_info=True) - # Extract just the error message - error_msg = str(e) - - # Check if it's a file-related error - if "file" in error_msg.lower() or "image" in error_msg.lower() or "content" in error_msg.lower(): - st.error(f"Error: {error_msg}") - - # Add a button to remove files and retry - if st.button("Remove files and retry", key="remove_files_retry"): - # Remove attachments from the latest message - if state.messages and "attachments" in state.messages[-1]: - del state.messages[-1]["attachments"] + + final_text_str = process_final_text(final_text) + assistant_msg = { + "role": "assistant", + "content": final_text_str, + "tool_trace": tool_trace + } + + # Preserve attachments from user message + if "attachments" in state.messages[-1]: + assistant_msg["attachments"] = state.messages[-1]["attachments"] + + # Update or add assistant message + last_user_idx = find_last_user_message_index() + if last_user_idx + 1 < len(state.messages) and state.messages[last_user_idx + 1].get("role") == "assistant": + state.messages[last_user_idx + 1] = assistant_msg + else: + state.messages.append(assistant_msg) + st.rerun() - else: - st.error(f"Error: {error_msg}") - - if st.button("Retry", key="reload_chatbot_error"): - if state.messages and state.messages[-1]["role"] == "user": - state.messages.pop() - st.rerun() + except Exception as e: + handle_invoke_error(e) if __name__ == "__main__" or ("page" in inspect.stack()[1].filename if inspect.stack() else False):