Skip to content

Commit 4d50207

Browse files
author
Kyon Caldera
committed
refactor: move the hooks from sigv4_helper.py into a new folder and add tests
1 parent 40c4c74 commit 4d50207

File tree

8 files changed

+426
-260
lines changed

8 files changed

+426
-260
lines changed

mcp_proxy_for_aws/cli.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919
from mcp_proxy_for_aws import __version__
2020
from mcp_proxy_for_aws.utils import within_range
21-
from typing import Dict, Optional, Sequence
21+
from typing import Any, Dict, Optional, Sequence
2222

2323

2424
class KeyValueAction(argparse.Action):
@@ -28,7 +28,7 @@ def __call__(
2828
self,
2929
parser: argparse.ArgumentParser,
3030
namespace: argparse.Namespace,
31-
values: str | Sequence[str],
31+
values: str | Sequence[Any] | None,
3232
option_string: Optional[str] = None,
3333
) -> None:
3434
"""Parse key=value pairs into a dictionary.
@@ -41,6 +41,11 @@ def __call__(
4141
"""
4242
metadata: Dict[str, str] = {}
4343
# Ensure values is a sequence
44+
if values is None:
45+
# No values provided, set empty dict
46+
setattr(namespace, self.dest, metadata)
47+
return
48+
4449
if isinstance(values, str):
4550
values = [values]
4651

mcp_proxy_for_aws/hooks.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""HTTPX event hooks for request/response processing."""
16+
17+
import httpx
18+
import json
19+
import logging
20+
from botocore.auth import SigV4Auth
21+
from botocore.awsrequest import AWSRequest
22+
from httpx._content import ByteStream
23+
from typing import Any, Dict, Optional
24+
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
async def _handle_error_response(response: httpx.Response) -> None:
30+
"""Event hook to handle HTTP error responses and extract details.
31+
32+
This function is called for every HTTP response to check for errors
33+
and provide more detailed error information when requests fail.
34+
35+
Args:
36+
response: The HTTP response object
37+
38+
Raises:
39+
No raises. let the mcp http client handle the errors.
40+
"""
41+
if response.is_error:
42+
try:
43+
# read the content and settle the response content. required to get body (.json(), .text)
44+
await response.aread()
45+
except Exception as e:
46+
logger.error('Failed to read response: %s', e)
47+
# do nothing and let the client handle the error
48+
return
49+
50+
# Try to extract error details with fallbacks
51+
try:
52+
# Try to parse JSON error details
53+
error_details = response.json()
54+
logger.error('HTTP %d Error Details: %s', response.status_code, error_details)
55+
except Exception:
56+
# If JSON parsing fails, use response text or status code
57+
try:
58+
response_text = response.text
59+
logger.error('HTTP %d Error: %s', response.status_code, response_text)
60+
except Exception:
61+
# Fallback to just status code and URL
62+
logger.error('HTTP %d Error for url %s', response.status_code, response.url)
63+
64+
65+
def _resign_request_with_sigv4(
66+
request: httpx.Request,
67+
region: str,
68+
service: str,
69+
profile: Optional[str] = None,
70+
) -> None:
71+
"""Re-sign an HTTP request with AWS SigV4 after content modification.
72+
73+
This function removes old signature headers, creates a new signature based on
74+
the current request content, and updates the request headers with the new signature.
75+
76+
Args:
77+
request: The HTTP request object to re-sign (modified in-place)
78+
region: AWS region for SigV4 signing
79+
service: AWS service name for SigV4 signing
80+
profile: AWS profile to use (optional)
81+
"""
82+
# Import here to avoid circular dependency
83+
from mcp_proxy_for_aws.sigv4_helper import create_aws_session
84+
85+
# Remove old signature headers before re-signing
86+
headers_to_remove = ['Content-Length', 'x-amz-date', 'x-amz-security-token', 'authorization']
87+
for header in headers_to_remove:
88+
request.headers.pop(header, None)
89+
90+
# Set the new Content-Length
91+
request.headers['Content-Length'] = str(len(request.content))
92+
93+
logger.info('Headers after cleanup: %s', request.headers)
94+
95+
# Get AWS credentials
96+
session = create_aws_session(profile)
97+
credentials = session.get_credentials()
98+
logger.info('Re-signing request with credentials for access key: %s', credentials.access_key)
99+
100+
# Create headers dict for signing, removing connection header like in auth_flow
101+
headers_for_signing = dict(request.headers)
102+
headers_for_signing.pop('connection', None) # Remove connection header for signing
103+
104+
# Create SigV4 signer and AWS request
105+
signer = SigV4Auth(credentials, service, region)
106+
aws_request = AWSRequest(
107+
method=request.method,
108+
url=str(request.url),
109+
data=request.content,
110+
headers=headers_for_signing,
111+
)
112+
113+
# Sign the request
114+
logger.info('AWS request before signing: %s', aws_request.headers)
115+
signer.add_auth(aws_request)
116+
logger.info('AWS request after signing: %s', aws_request.headers)
117+
118+
# Update request headers with signed headers
119+
request.headers.update(dict(aws_request.headers))
120+
logger.info('Request headers after re-signing: %s', request.headers)
121+
122+
123+
async def _inject_metadata_hook(
124+
metadata: Dict[str, Any], region: str, service: str, request: httpx.Request
125+
) -> None:
126+
"""Request hook to inject metadata into MCP calls.
127+
128+
Args:
129+
metadata: Dictionary of metadata to inject into _meta field
130+
region: AWS region for SigV4 re-signing after metadata injection
131+
service: AWS service name for SigV4 re-signing after metadata injection
132+
request: The HTTP request object
133+
"""
134+
logger.info('=== Outgoing Request ===')
135+
logger.info('URL: %s', request.url)
136+
logger.info('Method: %s', request.method)
137+
138+
# Try to inject metadata if it's a JSON-RPC/MCP request
139+
if request.content and metadata:
140+
try:
141+
# Parse the request body
142+
body = json.loads(await request.aread())
143+
144+
# Check if it's a JSON-RPC request
145+
if isinstance(body, dict) and 'jsonrpc' in body:
146+
# Ensure _meta exists in params
147+
if '_meta' not in body['params']:
148+
body['params']['_meta'] = {}
149+
150+
# Get existing metadata
151+
existing_meta = body['params']['_meta']
152+
153+
# Merge metadata (existing takes precedence)
154+
if isinstance(existing_meta, dict):
155+
# Check for conflicting keys before merge
156+
conflicting_keys = set(metadata.keys()) & set(existing_meta.keys())
157+
if conflicting_keys:
158+
for key in conflicting_keys:
159+
logger.warning(
160+
'Metadata key "%s" already exists in _meta. '
161+
'Keeping existing value "%s", ignoring injected value "%s"',
162+
key,
163+
existing_meta[key],
164+
metadata[key],
165+
)
166+
body['params']['_meta'] = {**metadata, **existing_meta}
167+
else:
168+
logger.info('Replacing non-dict _meta value with injected metadata')
169+
body['params']['_meta'] = metadata
170+
171+
# Create new content with updated metadata
172+
new_content = json.dumps(body).encode('utf-8')
173+
174+
# Update the request with new content
175+
request.stream = ByteStream(new_content)
176+
request._content = new_content
177+
178+
# Re-sign the request with the new content
179+
_resign_request_with_sigv4(request, region, service)
180+
181+
logger.info('Injected metadata into _meta: %s', body['params']['_meta'])
182+
183+
except (json.JSONDecodeError, KeyError, TypeError) as e:
184+
# Not a JSON request or invalid format, skip metadata injection
185+
logger.error('Skipping metadata injection: %s', e)

mcp_proxy_for_aws/sigv4_helper.py

Lines changed: 1 addition & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616

1717
import boto3
1818
import httpx
19-
import json
2019
import logging
2120
from botocore.auth import SigV4Auth
2221
from botocore.awsrequest import AWSRequest
2322
from botocore.credentials import Credentials
2423
from functools import partial
25-
from httpx._content import ByteStream
24+
from mcp_proxy_for_aws.hooks import _handle_error_response, _inject_metadata_hook
2625
from typing import Any, Dict, Generator, Optional
2726

2827

@@ -74,162 +73,6 @@ def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Re
7473
yield request
7574

7675

77-
async def _handle_error_response(response: httpx.Response) -> None:
78-
"""Event hook to handle HTTP error responses and extract details.
79-
80-
This function is called for every HTTP response to check for errors
81-
and provide more detailed error information when requests fail.
82-
83-
Args:
84-
response: The HTTP response object
85-
86-
Raises:
87-
No raises. let the mcp http client handle the errors.
88-
"""
89-
if response.is_error:
90-
try:
91-
# read the content and settle the response content. required to get body (.json(), .text)
92-
await response.aread()
93-
except Exception as e:
94-
logger.error('Failed to read response: %s', e)
95-
# do nothing and let the client handle the error
96-
return
97-
98-
# Try to extract error details with fallbacks
99-
try:
100-
# Try to parse JSON error details
101-
error_details = response.json()
102-
logger.error('HTTP %d Error Details: %s', response.status_code, error_details)
103-
except Exception:
104-
# If JSON parsing fails, use response text or status code
105-
try:
106-
response_text = response.text
107-
logger.error('HTTP %d Error: %s', response.status_code, response_text)
108-
except Exception:
109-
# Fallback to just status code and URL
110-
logger.error('HTTP %d Error for url %s', response.status_code, response.url)
111-
112-
113-
def _resign_request_with_sigv4(
114-
request: httpx.Request,
115-
region: str,
116-
service: str,
117-
profile: Optional[str] = None,
118-
) -> None:
119-
"""Re-sign an HTTP request with AWS SigV4 after content modification.
120-
121-
This function removes old signature headers, creates a new signature based on
122-
the current request content, and updates the request headers with the new signature.
123-
124-
Args:
125-
request: The HTTP request object to re-sign (modified in-place)
126-
region: AWS region for SigV4 signing
127-
service: AWS service name for SigV4 signing
128-
profile: AWS profile to use (optional)
129-
"""
130-
# Remove old signature headers before re-signing
131-
headers_to_remove = ['Content-Length', 'x-amz-date', 'x-amz-security-token', 'authorization']
132-
for header in headers_to_remove:
133-
request.headers.pop(header, None)
134-
135-
# Set the new Content-Length
136-
request.headers['Content-Length'] = str(len(request.content))
137-
138-
logger.info('Headers after cleanup: %s', request.headers)
139-
140-
# Get AWS credentials
141-
session = create_aws_session(profile)
142-
credentials = session.get_credentials()
143-
logger.info('Re-signing request with credentials for access key: %s', credentials.access_key)
144-
145-
# Create headers dict for signing, removing connection header like in auth_flow
146-
headers_for_signing = dict(request.headers)
147-
headers_for_signing.pop('connection', None) # Remove connection header for signing
148-
149-
# Create SigV4 signer and AWS request
150-
signer = SigV4Auth(credentials, service, region)
151-
aws_request = AWSRequest(
152-
method=request.method,
153-
url=str(request.url),
154-
data=request.content,
155-
headers=headers_for_signing,
156-
)
157-
158-
# Sign the request
159-
logger.info('AWS request before signing: %s', aws_request.headers)
160-
signer.add_auth(aws_request)
161-
logger.info('AWS request after signing: %s', aws_request.headers)
162-
163-
# Update request headers with signed headers
164-
request.headers.update(dict(aws_request.headers))
165-
logger.info('Request headers after re-signing: %s', request.headers)
166-
167-
168-
async def _inject_metadata_hook(
169-
metadata: Dict[str, Any], region: str, service: str, request: httpx.Request
170-
) -> None:
171-
"""Request hook to inject metadata into MCP calls.
172-
173-
Args:
174-
metadata: Dictionary of metadata to inject into _meta field
175-
region: AWS region for SigV4 re-signing after metadata injection
176-
service: AWS service name for SigV4 re-signing after metadata injection
177-
request: The HTTP request object
178-
"""
179-
logger.info('=== Outgoing Request ===')
180-
logger.info('URL: %s', request.url)
181-
logger.info('Method: %s', request.method)
182-
183-
# Try to inject metadata if it's a JSON-RPC/MCP request
184-
if request.content and metadata:
185-
try:
186-
# Parse the request body
187-
body = json.loads(await request.aread())
188-
189-
# Check if it's a JSON-RPC request
190-
if isinstance(body, dict) and 'jsonrpc' in body:
191-
# Ensure _meta exists in params
192-
if '_meta' not in body['params']:
193-
body['params']['_meta'] = {}
194-
195-
# Get existing metadata
196-
existing_meta = body['params']['_meta']
197-
198-
# Merge metadata (existing takes precedence)
199-
if isinstance(existing_meta, dict):
200-
# Check for conflicting keys before merge
201-
conflicting_keys = set(metadata.keys()) & set(existing_meta.keys())
202-
if conflicting_keys:
203-
for key in conflicting_keys:
204-
logger.warning(
205-
'Metadata key "%s" already exists in _meta. '
206-
'Keeping existing value "%s", ignoring injected value "%s"',
207-
key,
208-
existing_meta[key],
209-
metadata[key],
210-
)
211-
body['params']['_meta'] = {**metadata, **existing_meta}
212-
else:
213-
logger.info('Replacing non-dict _meta value with injected metadata')
214-
body['params']['_meta'] = metadata
215-
216-
# Create new content with updated metadata
217-
new_content = json.dumps(body).encode('utf-8')
218-
219-
# Update the request with new content
220-
request.stream = ByteStream(new_content)
221-
request._content = new_content
222-
223-
# Re-sign the request with the new content
224-
_resign_request_with_sigv4(request, region, service)
225-
226-
logger.info('Injected metadata into _meta: %s', body['params']['_meta'])
227-
228-
except (json.JSONDecodeError, KeyError, TypeError) as e:
229-
# Not a JSON request or invalid format, skip metadata injection
230-
logger.error('Skipping metadata injection: %s', e)
231-
232-
23376
def create_aws_session(profile: Optional[str] = None) -> boto3.Session:
23477
"""Create an AWS session with optional profile.
23578

0 commit comments

Comments
 (0)