Skip to content

Commit 40c4c74

Browse files
author
Kyon Caldera
committed
feat: replace forwarding region with metadata forwarding
1 parent 8de5aa4 commit 40c4c74

File tree

7 files changed

+197
-21
lines changed

7 files changed

+197
-21
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ docker build -t mcp-proxy-for-aws .
4747
| `--service` | AWS service name for SigV4 signing | Inferred from endpoint if not provided |No |
4848
| `--profile` | AWS profile for AWS credentials to use | Uses `AWS_PROFILE` environment variable if not set |No |
4949
| `--region` | AWS region to use | Uses `AWS_REGION` environment variable if not set, defaults to `us-east-1` |No |
50+
| `--metadata` | Metadata to inject into MCP requests as key=value pairs (e.g., `--metadata KEY1=value1 KEY2=value2`) | `AWS_REGION` is automatically injected based on `--region` if not provided |No |
5051
| `--read-only` | Disable tools which may require write permissions (tools which DO NOT require write permissions are annotated with [`readOnlyHint=true`](https://modelcontextprotocol.io/specification/2025-06-18/schema#toolannotations-readonlyhint)) | `False` |No |
5152
| `--retries` | Configures number of retries done when calling upstream services, setting this to 0 disables retries. | 0 |No |
5253
| `--log-level` | Set the logging level (`DEBUG/INFO/WARNING/ERROR/CRITICAL`) | `INFO` |No |

mcp_proxy_for_aws/cli.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,38 @@
1818
import os
1919
from mcp_proxy_for_aws import __version__
2020
from mcp_proxy_for_aws.utils import within_range
21+
from typing import Dict, Optional, Sequence
22+
23+
24+
class KeyValueAction(argparse.Action):
25+
"""Custom argparse action to parse key=value pairs into a dictionary."""
26+
27+
def __call__(
28+
self,
29+
parser: argparse.ArgumentParser,
30+
namespace: argparse.Namespace,
31+
values: str | Sequence[str],
32+
option_string: Optional[str] = None,
33+
) -> None:
34+
"""Parse key=value pairs into a dictionary.
35+
36+
Args:
37+
parser: The argument parser
38+
namespace: The namespace object to update
39+
values: The values to parse (list of key=value strings)
40+
option_string: The option string that triggered this action
41+
"""
42+
metadata: Dict[str, str] = {}
43+
# Ensure values is a sequence
44+
if isinstance(values, str):
45+
values = [values]
46+
47+
for item in values:
48+
if '=' not in item:
49+
parser.error(f'Metadata must be in key=value format, got: {item}')
50+
key, value = item.split('=', 1)
51+
metadata[key] = value
52+
setattr(namespace, self.dest, metadata)
2153

2254

2355
def parse_args():
@@ -65,9 +97,11 @@ def parse_args():
6597
)
6698

6799
parser.add_argument(
68-
'--forwarding-region',
69-
help='AWS region to forward to server (uses --region if not provided)',
100+
'--metadata',
101+
nargs='*',
102+
action=KeyValueAction,
70103
default=None,
104+
help='Metadata to inject into MCP requests as key=value pairs (e.g., --metadata AWS_REGION=us-west-2 KEY=VALUE)',
71105
)
72106

73107
parser.add_argument(

mcp_proxy_for_aws/server.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,13 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
5252

5353
# Validate and determine region
5454
region = determine_aws_region(args.endpoint, args.region)
55-
forwarding_region = args.forwarding_region or region
5655
logger.debug('Using region: %s', region)
5756

57+
# Build metadata dictionary - start with AWS_REGION, then merge user metadata
58+
metadata = {'AWS_REGION': region}
59+
if args.metadata:
60+
metadata.update(args.metadata)
61+
5862
# Get profile
5963
profile = args.profile
6064

@@ -63,7 +67,7 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
6367
'Using service: %s, region: %s, forwarding region: %s, profile: %s',
6468
service,
6569
region,
66-
forwarding_region,
70+
metadata.get('AWS_REGION'),
6771
profile,
6872
)
6973
logger.info('Running in MCP mode')
@@ -77,7 +81,7 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
7781

7882
# Create transport with SigV4 authentication
7983
transport = create_transport_with_sigv4(
80-
args.endpoint, service, region, forwarding_region, timeout, profile
84+
args.endpoint, service, region, metadata, timeout, profile
8185
)
8286

8387
# Create proxy with the transport

mcp_proxy_for_aws/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import re
2222
from fastmcp.client.transports import StreamableHttpTransport
2323
from mcp_proxy_for_aws.sigv4_helper import create_sigv4_client
24-
from typing import Dict, Optional
24+
from typing import Any, Dict, Optional
2525
from urllib.parse import urlparse
2626

2727

@@ -32,7 +32,7 @@ def create_transport_with_sigv4(
3232
url: str,
3333
service: str,
3434
region: str,
35-
forwarding_region: str,
35+
metadata: Dict[str, Any],
3636
custom_timeout: httpx.Timeout,
3737
profile: Optional[str] = None,
3838
) -> StreamableHttpTransport:
@@ -42,7 +42,7 @@ def create_transport_with_sigv4(
4242
url: The endpoint URL
4343
service: AWS service name for SigV4 signing
4444
region: AWS region to use
45-
forwarding_region: AWS region to forward to server
45+
metadata: Metadata dictionary to inject into MCP requests
4646
custom_timeout: httpx.Timeout used to connect to the endpoint
4747
profile: AWS profile to use (optional)
4848
@@ -62,7 +62,7 @@ def client_factory(
6262
region=region,
6363
headers=headers,
6464
timeout=custom_timeout,
65-
metadata={'AWS_REGION': forwarding_region},
65+
metadata=metadata,
6666
auth=auth,
6767
)
6868

tests/unit/test_cli.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,39 @@ def test_parse_args_negative_timeout(self):
140140
"""Test parsing fails with negative timeout (within_range validation)."""
141141
with pytest.raises(SystemExit):
142142
parse_args()
143+
144+
@patch(
145+
'sys.argv',
146+
['mcp-proxy-for-aws', 'https://example.com', '--metadata', 'KEY1=value1', 'KEY2=value2'],
147+
)
148+
def test_parse_metadata_argument(self):
149+
"""Test parsing metadata key=value pairs."""
150+
args = parse_args()
151+
assert args.metadata == {'KEY1': 'value1', 'KEY2': 'value2'}
152+
153+
@patch(
154+
'sys.argv',
155+
['mcp-proxy-for-aws', 'https://example.com', '--metadata', 'AWS_REGION=us-west-2'],
156+
)
157+
def test_parse_metadata_single_pair(self):
158+
"""Test parsing single metadata key=value pair."""
159+
args = parse_args()
160+
assert args.metadata == {'AWS_REGION': 'us-west-2'}
161+
162+
@patch(
163+
'sys.argv',
164+
['mcp-proxy-for-aws', 'https://example.com', '--metadata', 'KEY=value with spaces'],
165+
)
166+
def test_parse_metadata_with_spaces_in_value(self):
167+
"""Test parsing metadata with spaces in value."""
168+
args = parse_args()
169+
assert args.metadata == {'KEY': 'value with spaces'}
170+
171+
@patch('sys.argv', ['mcp-proxy-for-aws', 'https://example.com', '--metadata', 'INVALID'])
172+
def test_parse_metadata_invalid_format(self):
173+
"""Test that invalid metadata format raises an error."""
174+
import argparse
175+
176+
with pytest.raises((SystemExit, argparse.ArgumentTypeError)):
177+
# argparse may call sys.exit or raise ArgumentTypeError
178+
parse_args()

tests/unit/test_server.py

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def test_setup_mcp_mode(
5656
mock_args.profile = None
5757
mock_args.read_only = True
5858
mock_args.retries = 1
59-
mock_args.forwarding_region = None
59+
mock_args.metadata = None
6060
# Add timeout parameters
6161
mock_args.timeout = 180.0
6262
mock_args.connect_timeout = 60.0
@@ -87,7 +87,7 @@ async def test_setup_mcp_mode(
8787
assert call_args[0][0] == 'https://test.example.com'
8888
assert call_args[0][1] == 'test-service'
8989
assert call_args[0][2] == 'us-east-1'
90-
assert call_args[0][3] == 'us-east-1' # forwarding_region (defaults to region)
90+
assert call_args[0][3] == {'AWS_REGION': 'us-east-1'} # metadata
9191
# call_args[0][4] is the Timeout object
9292
assert call_args[0][5] is None # profile
9393
mock_as_proxy.assert_called_once_with(mock_transport)
@@ -118,7 +118,7 @@ async def test_setup_mcp_mode_no_retries(
118118
mock_args.profile = 'test-profile'
119119
mock_args.read_only = False
120120
mock_args.retries = 0 # No retries
121-
mock_args.forwarding_region = 'eu-west-1'
121+
mock_args.metadata = {'AWS_REGION': 'eu-west-1', 'CUSTOM_KEY': 'custom_value'}
122122
# Add timeout parameters
123123
mock_args.timeout = 180.0
124124
mock_args.connect_timeout = 60.0
@@ -149,13 +149,116 @@ async def test_setup_mcp_mode_no_retries(
149149
assert call_args[0][0] == 'https://test.example.com'
150150
assert call_args[0][1] == 'test-service'
151151
assert call_args[0][2] == 'us-east-1'
152-
assert call_args[0][3] == 'eu-west-1' # forwarding_region
152+
assert call_args[0][3] == {
153+
'AWS_REGION': 'eu-west-1',
154+
'CUSTOM_KEY': 'custom_value',
155+
} # metadata
153156
# call_args[0][4] is the Timeout object
154157
assert call_args[0][5] == 'test-profile' # profile
155158
mock_as_proxy.assert_called_once_with(mock_transport)
156159
mock_add_filtering.assert_called_once_with(mock_proxy, False)
157160
mock_proxy.run_async.assert_called_once()
158161

162+
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
163+
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
164+
@patch('mcp_proxy_for_aws.server.determine_aws_region')
165+
@patch('mcp_proxy_for_aws.server.determine_service_name')
166+
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
167+
async def test_setup_mcp_mode_no_metadata_injects_aws_region(
168+
self,
169+
mock_add_filtering,
170+
mock_determine_service,
171+
mock_determine_region,
172+
mock_as_proxy,
173+
mock_create_transport,
174+
):
175+
"""Test that AWS_REGION is automatically injected when no metadata is provided."""
176+
# Arrange
177+
local_mcp = Mock(spec=FastMCP)
178+
mock_args = Mock()
179+
mock_args.endpoint = 'https://test.example.com'
180+
mock_args.service = 'test-service'
181+
mock_args.region = 'ap-southeast-1'
182+
mock_args.profile = None
183+
mock_args.read_only = False
184+
mock_args.retries = 0
185+
mock_args.metadata = None # No metadata provided
186+
mock_args.timeout = 180.0
187+
mock_args.connect_timeout = 60.0
188+
mock_args.read_timeout = 120.0
189+
mock_args.write_timeout = 180.0
190+
mock_args.log_level = 'INFO'
191+
192+
mock_determine_service.return_value = 'test-service'
193+
mock_determine_region.return_value = 'ap-southeast-1'
194+
195+
mock_transport = Mock()
196+
mock_create_transport.return_value = mock_transport
197+
mock_proxy = Mock()
198+
mock_proxy.run_async = AsyncMock()
199+
mock_as_proxy.return_value = mock_proxy
200+
201+
# Act
202+
await setup_mcp_mode(local_mcp, mock_args)
203+
204+
# Assert - verify AWS_REGION was automatically injected
205+
assert mock_create_transport.call_count == 1
206+
call_args = mock_create_transport.call_args
207+
metadata = call_args[0][3]
208+
assert metadata == {'AWS_REGION': 'ap-southeast-1'}
209+
210+
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
211+
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
212+
@patch('mcp_proxy_for_aws.server.determine_aws_region')
213+
@patch('mcp_proxy_for_aws.server.determine_service_name')
214+
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
215+
async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
216+
self,
217+
mock_add_filtering,
218+
mock_determine_service,
219+
mock_determine_region,
220+
mock_as_proxy,
221+
mock_create_transport,
222+
):
223+
"""Test that AWS_REGION is injected even when other metadata is provided."""
224+
# Arrange
225+
local_mcp = Mock(spec=FastMCP)
226+
mock_args = Mock()
227+
mock_args.endpoint = 'https://test.example.com'
228+
mock_args.service = 'test-service'
229+
mock_args.region = 'us-west-1'
230+
mock_args.profile = None
231+
mock_args.read_only = False
232+
mock_args.retries = 0
233+
mock_args.metadata = {'CUSTOM_KEY': 'custom_value', 'ANOTHER_KEY': 'another_value'}
234+
mock_args.timeout = 180.0
235+
mock_args.connect_timeout = 60.0
236+
mock_args.read_timeout = 120.0
237+
mock_args.write_timeout = 180.0
238+
mock_args.log_level = 'INFO'
239+
240+
mock_determine_service.return_value = 'test-service'
241+
mock_determine_region.return_value = 'us-west-1'
242+
243+
mock_transport = Mock()
244+
mock_create_transport.return_value = mock_transport
245+
mock_proxy = Mock()
246+
mock_proxy.run_async = AsyncMock()
247+
mock_as_proxy.return_value = mock_proxy
248+
249+
# Act
250+
await setup_mcp_mode(local_mcp, mock_args)
251+
252+
# Assert - verify AWS_REGION was injected along with custom metadata
253+
assert mock_create_transport.call_count == 1
254+
call_args = mock_create_transport.call_args
255+
metadata = call_args[0][3]
256+
assert metadata == {
257+
'AWS_REGION': 'us-west-1',
258+
'CUSTOM_KEY': 'custom_value',
259+
'ANOTHER_KEY': 'another_value',
260+
}
261+
159262
def test_add_tool_filtering_middleware(self):
160263
"""Test that tool filtering middleware is added correctly."""
161264
# Arrange

tests/unit/test_utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ def test_create_transport_with_sigv4(self, mock_create_sigv4_client):
3939
service = 'test-service'
4040
profile = 'test-profile'
4141
region = 'us-east-1'
42-
forwarding_region = 'us-west-2'
42+
metadata = {'AWS_REGION': 'us-west-2', 'CUSTOM_KEY': 'custom_value'}
4343
custom_timeout = Timeout(30.0)
4444

4545
result = create_transport_with_sigv4(
46-
url, service, region, forwarding_region, custom_timeout, profile
46+
url, service, region, metadata, custom_timeout, profile
4747
)
4848

4949
# Verify result is StreamableHttpTransport
@@ -64,7 +64,7 @@ def test_create_transport_with_sigv4(self, mock_create_sigv4_client):
6464
headers={'test': 'header'},
6565
timeout=custom_timeout,
6666
auth=None,
67-
metadata={'AWS_REGION': forwarding_region},
67+
metadata=metadata,
6868
)
6969
else:
7070
# If we can't access the factory directly, just verify the transport was created
@@ -78,12 +78,10 @@ def test_create_transport_with_sigv4_no_profile(self, mock_create_sigv4_client):
7878
url = 'https://test-service.us-west-2.api.aws/mcp'
7979
service = 'test-service'
8080
region = 'test-region'
81-
forwarding_region = 'test-forwarding-region'
81+
metadata = {'AWS_REGION': 'test-forwarding-region'}
8282
custom_timeout = Timeout(60.0)
8383

84-
result = create_transport_with_sigv4(
85-
url, service, region, forwarding_region, custom_timeout
86-
)
84+
result = create_transport_with_sigv4(url, service, region, metadata, custom_timeout)
8785

8886
# Test that the httpx_client_factory calls create_sigv4_client correctly
8987
# We need to access the factory through the transport's internal structure
@@ -98,7 +96,7 @@ def test_create_transport_with_sigv4_no_profile(self, mock_create_sigv4_client):
9896
headers=None,
9997
timeout=custom_timeout,
10098
auth=None,
101-
metadata={'AWS_REGION': forwarding_region},
99+
metadata=metadata,
102100
)
103101
else:
104102
# If we can't access the factory directly, just verify the transport was created

0 commit comments

Comments
 (0)