|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import boto3 |
| 16 | +import logging |
| 17 | + |
| 18 | +from collections.abc import AsyncGenerator |
| 19 | +from contextlib import asynccontextmanager |
| 20 | +from datetime import timedelta |
| 21 | +from typing import Optional |
| 22 | + |
| 23 | +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
| 24 | + |
| 25 | +from mcp.client.streamable_http import ( |
| 26 | + GetSessionIdCallback, |
| 27 | + create_mcp_http_client, |
| 28 | + streamablehttp_client |
| 29 | +) |
| 30 | +from mcp.shared._httpx_utils import McpHttpClientFactory |
| 31 | +from mcp.shared.message import SessionMessage |
| 32 | + |
| 33 | +from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth |
| 34 | + |
| 35 | + |
| 36 | +logger = logging.getLogger(__name__) |
| 37 | + |
| 38 | +@asynccontextmanager |
| 39 | +async def aws_iam_mcp_client( |
| 40 | + endpoint: str, |
| 41 | + aws_service: str, |
| 42 | + aws_region: Optional[str] = None, |
| 43 | + aws_profile: Optional[str] = None, |
| 44 | + headers: Optional[dict[str, str]] = None, |
| 45 | + timeout: Optional[float | timedelta] = 30, |
| 46 | + sse_read_timeout: Optional[float | timedelta] = 300, |
| 47 | + terminate_on_close: Optional[bool] = True, |
| 48 | + httpx_client_factory: Optional[McpHttpClientFactory] = create_mcp_http_client, |
| 49 | +) -> AsyncGenerator[ |
| 50 | + tuple[ |
| 51 | + MemoryObjectReceiveStream[SessionMessage | Exception], |
| 52 | + MemoryObjectSendStream[SessionMessage], |
| 53 | + GetSessionIdCallback, |
| 54 | + ], |
| 55 | + None, |
| 56 | +]: |
| 57 | + """ |
| 58 | + Create an AWS IAM-authenticated MCP streamable HTTP client. |
| 59 | +
|
| 60 | + This function establishes a connection to an MCP server using AWS IAM authentication |
| 61 | + via SigV4 signing. It returns the raw transport components for use with MCP client |
| 62 | + sessions or framework integrations. |
| 63 | +
|
| 64 | + Args: |
| 65 | + endpoint: The URL of the MCP server to connect to. Must be a valid HTTP/HTTPS URL. |
| 66 | + aws_service: The name of the AWS service the MCP server is hosted on, e.g. "bedrock-agentcore". |
| 67 | + aws_region: The AWS region name of the MCP server, e.g. "us-west-2". |
| 68 | + aws_profile: The AWS profile to use for authentication. |
| 69 | + headers: Optional additional HTTP headers to include in requests. |
| 70 | + timeout: Request timeout in seconds or timedelta object. Defaults to 30 seconds. |
| 71 | + sse_read_timeout: Server-sent events read timeout in seconds or timedelta object. |
| 72 | + terminate_on_close: Whether to terminate the connection on close. |
| 73 | + httpx_client_factory: Factory function for creating HTTPX clients. |
| 74 | +
|
| 75 | + Yields: |
| 76 | + tuple: Transport components for MCP communication: |
| 77 | + - read_stream: Async generator for reading server responses |
| 78 | + - write_stream: Async generator for sending requests to server |
| 79 | + - get_session_id: Function to retrieve the current session ID |
| 80 | + """ |
| 81 | + # Create a SigV4 authentication handler with AWS credentials |
| 82 | + logger.info("Preparing AWS IAM MCP client for endpoint: %s", endpoint) |
| 83 | + |
| 84 | + kwargs = {} |
| 85 | + if aws_region is not None: |
| 86 | + kwargs['region_name'] = aws_region |
| 87 | + if aws_profile is not None: |
| 88 | + kwargs['profile_name'] = aws_profile |
| 89 | + |
| 90 | + # Create a boto3 session with the provided arguments |
| 91 | + session = boto3.Session(**kwargs) |
| 92 | + |
| 93 | + profile = session.profile_name |
| 94 | + region = session.region_name |
| 95 | + |
| 96 | + logger.debug("AWS profile: %s", profile) |
| 97 | + logger.debug("AWS region: %s", region) |
| 98 | + logger.debug("AWS service: %s", aws_service) |
| 99 | + |
| 100 | + # Create a SigV4 authentication handler with AWS credentials |
| 101 | + auth = SigV4HTTPXAuth(session.get_credentials(), aws_service, region) |
| 102 | + |
| 103 | + # Establish connection using MCP SDK's streamable HTTP client |
| 104 | + async with streamablehttp_client( |
| 105 | + url=endpoint, |
| 106 | + auth=auth, |
| 107 | + headers=headers, |
| 108 | + timeout=timeout, |
| 109 | + sse_read_timeout=sse_read_timeout, |
| 110 | + terminate_on_close=terminate_on_close, |
| 111 | + httpx_client_factory=httpx_client_factory, |
| 112 | + ) as (read_stream, write_stream, get_session_id): |
| 113 | + # Return transport components for external session management |
| 114 | + logger.info("Successfully prepared AWS IAM MCP client for endpoint: %s", endpoint) |
| 115 | + yield (read_stream, write_stream, get_session_id) |
0 commit comments