Skip to content

Commit 8158ecf

Browse files
author
Kyon Caldera
committed
refactor: move the hooks from sigv4_helper.py into a new folder and add tests
1 parent c5a0ddc commit 8158ecf

File tree

8 files changed

+440
-274
lines changed

8 files changed

+440
-274
lines changed

mcp_proxy_for_aws/cli.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919
from mcp_proxy_for_aws import __version__
2020
from mcp_proxy_for_aws.utils import within_range
21-
from typing import Dict, Optional, Sequence
21+
from typing import Any, Dict, Optional, Sequence
2222

2323

2424
class KeyValueAction(argparse.Action):
@@ -28,7 +28,7 @@ def __call__(
2828
self,
2929
parser: argparse.ArgumentParser,
3030
namespace: argparse.Namespace,
31-
values: str | Sequence[str],
31+
values: str | Sequence[Any] | None,
3232
option_string: Optional[str] = None,
3333
) -> None:
3434
"""Parse key=value pairs into a dictionary.
@@ -41,6 +41,11 @@ def __call__(
4141
"""
4242
metadata: Dict[str, str] = {}
4343
# Ensure values is a sequence
44+
if values is None:
45+
# No values provided, set empty dict
46+
setattr(namespace, self.dest, metadata)
47+
return
48+
4449
if isinstance(values, str):
4550
values = [values]
4651

mcp_proxy_for_aws/hooks.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""HTTPX event hooks for request/response processing."""
16+
17+
import httpx
18+
import json
19+
import logging
20+
from botocore.auth import SigV4Auth
21+
from botocore.awsrequest import AWSRequest
22+
from httpx._content import ByteStream
23+
from typing import Any, Dict, Optional
24+
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
async def _handle_error_response(response: httpx.Response) -> None:
30+
"""Event hook to handle HTTP error responses and extract details.
31+
32+
This function is called for every HTTP response to check for errors
33+
and provide more detailed error information when requests fail.
34+
35+
Args:
36+
response: The HTTP response object
37+
38+
Raises:
39+
No raises. let the mcp http client handle the errors.
40+
"""
41+
if response.is_error:
42+
# warning only because the SDK logs error
43+
log_level = logging.WARNING
44+
if (
45+
# The server MAY respond 405 to GET (SSE) and DELETE (session).
46+
response.status_code == 405 and response.request.method in ('GET', 'DELETE')
47+
) or (
48+
# The server MAY terminate the session at any time, after which it MUST
49+
# respond to requests containing that session ID with HTTP 404 Not Found.
50+
response.status_code == 404 and response.request.method == 'POST'
51+
):
52+
log_level = logging.DEBUG
53+
54+
try:
55+
# read the content and settle the response content. required to get body (.json(), .text)
56+
await response.aread()
57+
except Exception as e:
58+
logger.debug('Failed to read response: %s', e)
59+
# do nothing and let the client and SDK handle the error
60+
return
61+
62+
# Try to extract error details with fallbacks
63+
try:
64+
# Try to parse JSON error details
65+
error_details = response.json()
66+
logger.log(log_level, 'HTTP %d Error Details: %s', response.status_code, error_details)
67+
except Exception:
68+
# If JSON parsing fails, use response text or status code
69+
try:
70+
response_text = response.text
71+
logger.log(log_level, 'HTTP %d Error: %s', response.status_code, response_text)
72+
except Exception:
73+
# Fallback to just status code and URL
74+
logger.log(
75+
log_level, 'HTTP %d Error for url %s', response.status_code, response.url
76+
)
77+
78+
79+
def _resign_request_with_sigv4(
80+
request: httpx.Request,
81+
region: str,
82+
service: str,
83+
profile: Optional[str] = None,
84+
) -> None:
85+
"""Re-sign an HTTP request with AWS SigV4 after content modification.
86+
87+
This function removes old signature headers, creates a new signature based on
88+
the current request content, and updates the request headers with the new signature.
89+
90+
Args:
91+
request: The HTTP request object to re-sign (modified in-place)
92+
region: AWS region for SigV4 signing
93+
service: AWS service name for SigV4 signing
94+
profile: AWS profile to use (optional)
95+
"""
96+
# Import here to avoid circular dependency
97+
from mcp_proxy_for_aws.sigv4_helper import create_aws_session
98+
99+
# Remove old signature headers before re-signing
100+
headers_to_remove = ['Content-Length', 'x-amz-date', 'x-amz-security-token', 'authorization']
101+
for header in headers_to_remove:
102+
request.headers.pop(header, None)
103+
104+
# Set the new Content-Length
105+
request.headers['Content-Length'] = str(len(request.content))
106+
107+
logger.info('Headers after cleanup: %s', request.headers)
108+
109+
# Get AWS credentials
110+
session = create_aws_session(profile)
111+
credentials = session.get_credentials()
112+
logger.info('Re-signing request with credentials for access key: %s', credentials.access_key)
113+
114+
# Create headers dict for signing, removing connection header like in auth_flow
115+
headers_for_signing = dict(request.headers)
116+
headers_for_signing.pop('connection', None) # Remove connection header for signing
117+
118+
# Create SigV4 signer and AWS request
119+
signer = SigV4Auth(credentials, service, region)
120+
aws_request = AWSRequest(
121+
method=request.method,
122+
url=str(request.url),
123+
data=request.content,
124+
headers=headers_for_signing,
125+
)
126+
127+
# Sign the request
128+
logger.info('AWS request before signing: %s', aws_request.headers)
129+
signer.add_auth(aws_request)
130+
logger.info('AWS request after signing: %s', aws_request.headers)
131+
132+
# Update request headers with signed headers
133+
request.headers.update(dict(aws_request.headers))
134+
logger.info('Request headers after re-signing: %s', request.headers)
135+
136+
137+
async def _inject_metadata_hook(
138+
metadata: Dict[str, Any], region: str, service: str, request: httpx.Request
139+
) -> None:
140+
"""Request hook to inject metadata into MCP calls.
141+
142+
Args:
143+
metadata: Dictionary of metadata to inject into _meta field
144+
region: AWS region for SigV4 re-signing after metadata injection
145+
service: AWS service name for SigV4 re-signing after metadata injection
146+
request: The HTTP request object
147+
"""
148+
logger.info('=== Outgoing Request ===')
149+
logger.info('URL: %s', request.url)
150+
logger.info('Method: %s', request.method)
151+
152+
# Try to inject metadata if it's a JSON-RPC/MCP request
153+
if request.content and metadata:
154+
try:
155+
# Parse the request body
156+
body = json.loads(await request.aread())
157+
158+
# Check if it's a JSON-RPC request
159+
if isinstance(body, dict) and 'jsonrpc' in body:
160+
# Ensure _meta exists in params
161+
if '_meta' not in body['params']:
162+
body['params']['_meta'] = {}
163+
164+
# Get existing metadata
165+
existing_meta = body['params']['_meta']
166+
167+
# Merge metadata (existing takes precedence)
168+
if isinstance(existing_meta, dict):
169+
# Check for conflicting keys before merge
170+
conflicting_keys = set(metadata.keys()) & set(existing_meta.keys())
171+
if conflicting_keys:
172+
for key in conflicting_keys:
173+
logger.warning(
174+
'Metadata key "%s" already exists in _meta. '
175+
'Keeping existing value "%s", ignoring injected value "%s"',
176+
key,
177+
existing_meta[key],
178+
metadata[key],
179+
)
180+
body['params']['_meta'] = {**metadata, **existing_meta}
181+
else:
182+
logger.info('Replacing non-dict _meta value with injected metadata')
183+
body['params']['_meta'] = metadata
184+
185+
# Create new content with updated metadata
186+
new_content = json.dumps(body).encode('utf-8')
187+
188+
# Update the request with new content
189+
request.stream = ByteStream(new_content)
190+
request._content = new_content
191+
192+
# Re-sign the request with the new content
193+
_resign_request_with_sigv4(request, region, service)
194+
195+
logger.info('Injected metadata into _meta: %s', body['params']['_meta'])
196+
197+
except (json.JSONDecodeError, KeyError, TypeError) as e:
198+
# Not a JSON request or invalid format, skip metadata injection
199+
logger.error('Skipping metadata injection: %s', e)

0 commit comments

Comments
 (0)