Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ docker build -t mcp-proxy-for-aws .
| `--service` | AWS service name for SigV4 signing | Inferred from endpoint if not provided |No |
| `--profile` | AWS profile for AWS credentials to use | Uses `AWS_PROFILE` environment variable if not set |No |
| `--region` | AWS region to use | Uses `AWS_REGION` environment variable if not set, defaults to `us-east-1` |No |
| `--metadata` | Metadata to inject into MCP requests as key=value pairs (e.g., `--metadata KEY1=value1 KEY2=value2`) | `AWS_REGION` is automatically injected based on `--region` if not provided |No |
| `--read-only` | Disable tools which may require write permissions (tools which DO NOT require write permissions are annotated with [`readOnlyHint=true`](https://modelcontextprotocol.io/specification/2025-06-18/schema#toolannotations-readonlyhint)) | `False` |No |
| `--retries` | Configures number of retries done when calling upstream services, setting this to 0 disables retries. | 0 |No |
| `--log-level` | Set the logging level (`DEBUG/INFO/WARNING/ERROR/CRITICAL`) | `INFO` |No |
Expand Down
47 changes: 46 additions & 1 deletion mcp_proxy_for_aws/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,43 @@
import os
from mcp_proxy_for_aws import __version__
from mcp_proxy_for_aws.utils import within_range
from typing import Any, Dict, Optional, Sequence


class KeyValueAction(argparse.Action):
"""Custom argparse action to parse key=value pairs into a dictionary."""

def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: str | Sequence[Any] | None,
option_string: Optional[str] = None,
) -> None:
"""Parse key=value pairs into a dictionary.

Args:
parser: The argument parser
namespace: The namespace object to update
values: The values to parse (list of key=value strings)
option_string: The option string that triggered this action
"""
metadata: Dict[str, str] = {}
# Ensure values is a sequence
if values is None:
# No values provided, set empty dict
setattr(namespace, self.dest, metadata)
return

if isinstance(values, str):
values = [values]

for item in values:
if '=' not in item:
parser.error(f'Metadata must be in key=value format, got: {item}')
key, value = item.split('=', 1)
metadata[key] = value
setattr(namespace, self.dest, metadata)


def parse_args():
Expand Down Expand Up @@ -60,10 +97,18 @@ def parse_args():

parser.add_argument(
'--region',
help='AWS region to use (uses AWS_REGION environment variable if not provided, with final fallback to us-east-1)',
help='AWS region to sign (uses AWS_REGION environment variable if not provided, with final fallback to us-east-1)',
default=None,
)

parser.add_argument(
'--metadata',
nargs='*',
action=KeyValueAction,
default=None,
help='Metadata to inject into MCP requests as key=value pairs (e.g., --metadata AWS_REGION=us-west-2 KEY=VALUE)',
)

parser.add_argument(
'--read-only',
action='store_true',
Expand Down
171 changes: 171 additions & 0 deletions mcp_proxy_for_aws/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""HTTPX event hooks for request/response processing."""

import httpx
import json
import logging
from httpx._content import ByteStream
from typing import Any, Dict, Optional


logger = logging.getLogger(__name__)


async def _handle_error_response(response: httpx.Response) -> None:
"""Event hook to handle HTTP error responses and extract details.

This function is called for every HTTP response to check for errors
and provide more detailed error information when requests fail.

Args:
response: The HTTP response object

Raises:
No raises. let the mcp http client handle the errors.
"""
if response.is_error:
# warning only because the SDK logs error
log_level = logging.WARNING
if (
# The server MAY respond 405 to GET (SSE) and DELETE (session).
response.status_code == 405 and response.request.method in ('GET', 'DELETE')
) or (
# The server MAY terminate the session at any time, after which it MUST
# respond to requests containing that session ID with HTTP 404 Not Found.
response.status_code == 404 and response.request.method == 'POST'
):
log_level = logging.DEBUG

try:
# read the content and settle the response content. required to get body (.json(), .text)
await response.aread()
except Exception as e:
logger.debug('Failed to read response: %s', e)
# do nothing and let the client and SDK handle the error
return

# Try to extract error details with fallbacks
try:
# Try to parse JSON error details
error_details = response.json()
logger.log(log_level, 'HTTP %d Error Details: %s', response.status_code, error_details)
except Exception:
# If JSON parsing fails, use response text or status code
try:
response_text = response.text
logger.log(log_level, 'HTTP %d Error: %s', response.status_code, response_text)
except Exception:
# Fallback to just status code and URL
logger.log(
log_level, 'HTTP %d Error for url %s', response.status_code, response.url
)


async def _sign_request_hook(
region: str,
service: str,
profile: Optional[str],
request: httpx.Request,
) -> None:
"""Request hook to sign HTTP requests with AWS SigV4.

This hook signs the request with AWS SigV4 credentials and adds signature headers.

This should be the last hook called to ensure the signature includes any modifications.

Args:
region: AWS region for SigV4 signing
service: AWS service name for SigV4 signing
profile: AWS profile to use (optional)
request: The HTTP request object to sign (modified in-place)
"""
# Import here to avoid circular dependency and for compatibility
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session

# Set Content-Length for signing
request.headers['Content-Length'] = str(len(request.content))

# Get AWS credentials
session = create_aws_session(profile)
credentials = session.get_credentials()
logger.info('Signing request with credentials for access key: %s', credentials.access_key)

# Create SigV4 auth and use its signing logic
auth = SigV4HTTPXAuth(credentials, service, region)

# Call auth_flow to sign the request (it modifies request in-place)
auth_flow = auth.auth_flow(request)
next(auth_flow) # Execute the generator to perform signing

logger.debug('Request headers after signing: %s', request.headers)


async def _inject_metadata_hook(metadata: Dict[str, Any], request: httpx.Request) -> None:
"""Request hook to inject metadata into MCP calls.

Args:
metadata: Dictionary of metadata to inject into _meta field
request: The HTTP request object
"""
logger.info('=== Outgoing Request ===')
logger.info('URL: %s', request.url)
logger.info('Method: %s', request.method)

# Try to inject metadata if it's a JSON-RPC/MCP request
if request.content and metadata:
try:
# Parse the request body
body = json.loads(await request.aread())

# Check if it's a JSON-RPC request
if isinstance(body, dict) and 'jsonrpc' in body:
# Ensure _meta exists in params
if '_meta' not in body['params']:
body['params']['_meta'] = {}

# Get existing metadata
existing_meta = body['params']['_meta']

# Merge metadata (existing takes precedence)
if isinstance(existing_meta, dict):
# Check for conflicting keys before merge
conflicting_keys = set(metadata.keys()) & set(existing_meta.keys())
if conflicting_keys:
for key in conflicting_keys:
logger.warning(
'Metadata key "%s" already exists in _meta. '
'Keeping existing value "%s", ignoring injected value "%s"',
key,
existing_meta[key],
metadata[key],
)
body['params']['_meta'] = {**metadata, **existing_meta}
else:
logger.info('Replacing non-dict _meta value with injected metadata')
body['params']['_meta'] = metadata

# Create new content with updated metadata
new_content = json.dumps(body).encode('utf-8')

# Update the request with new content
request.stream = ByteStream(new_content)
request._content = new_content

logger.info('Injected metadata into _meta: %s', body['params']['_meta'])

except (json.JSONDecodeError, KeyError, TypeError) as e:
# Not a JSON request or invalid format, skip metadata injection
logger.error('Skipping metadata injection: %s', e)
17 changes: 15 additions & 2 deletions mcp_proxy_for_aws/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,22 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
region = determine_aws_region(args.endpoint, args.region)
logger.debug('Using region: %s', region)

# Build metadata dictionary - start with AWS_REGION, then merge user metadata
metadata = {'AWS_REGION': region}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we are always setting this metadata?

I'm just wondering if we shouldn't set nothing in the default case, and users must specify the param if they want to use this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding of the intended workflow, there is no variation wherein an AWS_REGION is not necessary. If this is not the case, I'm happy to make it entirely optional.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe differentiate between "feature-request" and "how other projects will use it". The "feature-request" here is to add the ability to add metadata.

With that in mind, I think let's make it completely optional

if args.metadata:
metadata.update(args.metadata)

# Get profile
profile = args.profile

# Log server configuration
logger.info('Using service: %s, region: %s, profile: %s', service, region, profile)
logger.info(
'Using service: %s, region: %s, metadata: %s, profile: %s',
service,
region,
metadata,
profile,
)
logger.info('Running in MCP mode')

timeout = httpx.Timeout(
Expand All @@ -69,7 +80,9 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
)

# Create transport with SigV4 authentication
transport = create_transport_with_sigv4(args.endpoint, service, region, timeout, profile)
transport = create_transport_with_sigv4(
args.endpoint, service, region, metadata, timeout, profile
)

# Create proxy with the transport
proxy = FastMCP.as_proxy(transport)
Expand Down
Loading