Skip to content

Commit 9b3c93b

Browse files
committed
Add client
1 parent 692a291 commit 9b3c93b

File tree

3 files changed

+870
-1
lines changed

3 files changed

+870
-1
lines changed

mcp_proxy_for_aws/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
"""mcp-proxy-for-aws"""
1616

1717
from importlib.metadata import version as _metadata_version
18+
from .client import aws_iam_mcp_client
1819

1920

20-
__all__ = ['__version__']
2121
__version__ = _metadata_version('mcp-proxy-for-aws')
22+
23+
__all__ = [
24+
'__version__',
25+
'aws_iam_mcp_client',
26+
]

mcp_proxy_for_aws/client.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import boto3
16+
import logging
17+
18+
from collections.abc import AsyncGenerator
19+
from contextlib import asynccontextmanager
20+
from datetime import timedelta
21+
from typing import Optional
22+
23+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
24+
25+
from mcp.client.streamable_http import (
26+
GetSessionIdCallback,
27+
create_mcp_http_client,
28+
streamablehttp_client
29+
)
30+
from mcp.shared._httpx_utils import McpHttpClientFactory
31+
from mcp.shared.message import SessionMessage
32+
33+
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth
34+
35+
36+
logger = logging.getLogger(__name__)
37+
38+
@asynccontextmanager
39+
async def aws_iam_mcp_client(
40+
endpoint: str,
41+
aws_service: str,
42+
aws_region: Optional[str] = None,
43+
aws_profile: Optional[str] = None,
44+
headers: Optional[dict[str, str]] = None,
45+
timeout: Optional[float | timedelta] = 30,
46+
sse_read_timeout: Optional[float | timedelta] = 300,
47+
terminate_on_close: Optional[bool] = True,
48+
httpx_client_factory: Optional[McpHttpClientFactory] = create_mcp_http_client,
49+
) -> AsyncGenerator[
50+
tuple[
51+
MemoryObjectReceiveStream[SessionMessage | Exception],
52+
MemoryObjectSendStream[SessionMessage],
53+
GetSessionIdCallback,
54+
],
55+
None,
56+
]:
57+
"""
58+
Create an AWS IAM-authenticated MCP streamable HTTP client.
59+
60+
This function establishes a connection to an MCP server using AWS IAM authentication
61+
via SigV4 signing. It returns the raw transport components for use with MCP client
62+
sessions or framework integrations.
63+
64+
Args:
65+
endpoint: The URL of the MCP server to connect to. Must be a valid HTTP/HTTPS URL.
66+
aws_service: The name of the AWS service the MCP server is hosted on, e.g. "bedrock-agentcore".
67+
aws_region: The AWS region name of the MCP server, e.g. "us-west-2".
68+
aws_profile: The AWS profile to use for authentication.
69+
headers: Optional additional HTTP headers to include in requests.
70+
timeout: Request timeout in seconds or timedelta object. Defaults to 30 seconds.
71+
sse_read_timeout: Server-sent events read timeout in seconds or timedelta object.
72+
terminate_on_close: Whether to terminate the connection on close.
73+
httpx_client_factory: Factory function for creating HTTPX clients.
74+
75+
Yields:
76+
tuple: Transport components for MCP communication:
77+
- read_stream: Async generator for reading server responses
78+
- write_stream: Async generator for sending requests to server
79+
- get_session_id: Function to retrieve the current session ID
80+
"""
81+
# Create a SigV4 authentication handler with AWS credentials
82+
logger.info("Preparing AWS IAM MCP client for endpoint: %s", endpoint)
83+
84+
kwargs = {}
85+
if aws_region is not None:
86+
kwargs['region_name'] = aws_region
87+
if aws_profile is not None:
88+
kwargs['profile_name'] = aws_profile
89+
90+
# Create a boto3 session with the provided arguments
91+
session = boto3.Session(**kwargs)
92+
93+
profile = session.profile_name
94+
region = session.region_name
95+
96+
logger.debug("AWS profile: %s", profile)
97+
logger.debug("AWS region: %s", region)
98+
logger.debug("AWS service: %s", aws_service)
99+
100+
# Create a SigV4 authentication handler with AWS credentials
101+
auth = SigV4HTTPXAuth(session.get_credentials(), aws_service, region)
102+
103+
# Establish connection using MCP SDK's streamable HTTP client
104+
async with streamablehttp_client(
105+
url=endpoint,
106+
auth=auth,
107+
headers=headers,
108+
timeout=timeout,
109+
sse_read_timeout=sse_read_timeout,
110+
terminate_on_close=terminate_on_close,
111+
httpx_client_factory=httpx_client_factory,
112+
) as (read_stream, write_stream, get_session_id):
113+
# Return transport components for external session management
114+
logger.info("Successfully prepared AWS IAM MCP client for endpoint: %s", endpoint)
115+
yield (read_stream, write_stream, get_session_id)

0 commit comments

Comments
 (0)