Skip to content

Commit c3909e2

Browse files
author
Kyon Caldera
committed
feat(server.py): add forwarding region as optional argument
1 parent 80e2441 commit c3909e2

File tree

6 files changed

+43
-12
lines changed

6 files changed

+43
-12
lines changed

mcp_proxy_for_aws/cli.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,16 @@ def parse_args():
6060

6161
parser.add_argument(
6262
'--region',
63-
help='AWS region to use (uses AWS_REGION environment variable if not provided, with final fallback to us-east-1)',
63+
help='AWS region to sign (uses AWS_REGION environment variable if not provided, with final fallback to us-east-1)',
6464
default=None,
6565
)
6666

67+
parser.add_argument(
68+
'--forwarding-region',
69+
help='AWS region to forward to server (uses --region if not provided)',
70+
default=None,
71+
)
72+
6773
parser.add_argument(
6874
'--read-only',
6975
action='store_true',

mcp_proxy_for_aws/server.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,20 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
5252

5353
# Validate and determine region
5454
region = determine_aws_region(args.endpoint, args.region)
55+
forwarding_region = args.forwarding_region or region
5556
logger.debug('Using region: %s', region)
5657

5758
# Get profile
5859
profile = args.profile
5960

6061
# Log server configuration
61-
logger.info('Using service: %s, region: %s, profile: %s', service, region, profile)
62+
logger.info(
63+
'Using service: %s, region: %s, forwarding region: %s, profile: %s',
64+
service,
65+
region,
66+
forwarding_region,
67+
profile,
68+
)
6269
logger.info('Running in MCP mode')
6370

6471
timeout = httpx.Timeout(
@@ -69,7 +76,9 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
6976
)
7077

7178
# Create transport with SigV4 authentication
72-
transport = create_transport_with_sigv4(args.endpoint, service, region, timeout, profile)
79+
transport = create_transport_with_sigv4(
80+
args.endpoint, service, region, forwarding_region, timeout, profile
81+
)
7382

7483
# Create proxy with the transport
7584
proxy = FastMCP.as_proxy(transport)

mcp_proxy_for_aws/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def create_transport_with_sigv4(
3232
url: str,
3333
service: str,
3434
region: str,
35+
forwarding_region: str,
3536
custom_timeout: httpx.Timeout,
3637
profile: Optional[str] = None,
3738
) -> StreamableHttpTransport:
@@ -41,6 +42,7 @@ def create_transport_with_sigv4(
4142
url: The endpoint URL
4243
service: AWS service name for SigV4 signing
4344
region: AWS region to use
45+
forwarding_region: AWS region to forward to server
4446
custom_timeout: httpx.Timeout used to connect to the endpoint
4547
profile: AWS profile to use (optional)
4648
@@ -60,7 +62,7 @@ def client_factory(
6062
region=region,
6163
headers=headers,
6264
timeout=custom_timeout,
63-
metadata={'AWS_REGION': region},
65+
metadata={'AWS_REGION': forwarding_region},
6466
auth=auth,
6567
)
6668

tests/integ/test_proxy_simple_mcp_server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,7 @@ async def test_metadata_injection_aws_region(
124124
assert 'AWS_REGION' in response_data['received_meta'], (
125125
f'Metadata should contain AWS_REGION: {response_data["received_meta"]}'
126126
)
127+
assert (
128+
response_data['received_meta']['AWS_REGION']
129+
== remote_mcp_server_configuration['region_name']
130+
), f'AWS_REGION should be {remote_mcp_server_configuration["region_name"]}'

tests/unit/test_server.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ async def test_setup_mcp_mode(
5656
mock_args.profile = None
5757
mock_args.read_only = True
5858
mock_args.retries = 1
59+
mock_args.forwarding_region = None
5960
# Add timeout parameters
6061
mock_args.timeout = 180.0
6162
mock_args.connect_timeout = 60.0
@@ -86,8 +87,9 @@ async def test_setup_mcp_mode(
8687
assert call_args[0][0] == 'https://test.example.com'
8788
assert call_args[0][1] == 'test-service'
8889
assert call_args[0][2] == 'us-east-1'
89-
# call_args[0][3] is the Timeout object
90-
assert call_args[0][4] is None # profile
90+
assert call_args[0][3] == 'us-east-1' # forwarding_region (defaults to region)
91+
# call_args[0][4] is the Timeout object
92+
assert call_args[0][5] is None # profile
9193
mock_as_proxy.assert_called_once_with(mock_transport)
9294
mock_add_filtering.assert_called_once_with(mock_proxy, True)
9395
mock_add_retry.assert_called_once_with(mock_proxy, 1)
@@ -116,6 +118,7 @@ async def test_setup_mcp_mode_no_retries(
116118
mock_args.profile = 'test-profile'
117119
mock_args.read_only = False
118120
mock_args.retries = 0 # No retries
121+
mock_args.forwarding_region = 'eu-west-1'
119122
# Add timeout parameters
120123
mock_args.timeout = 180.0
121124
mock_args.connect_timeout = 60.0
@@ -146,8 +149,9 @@ async def test_setup_mcp_mode_no_retries(
146149
assert call_args[0][0] == 'https://test.example.com'
147150
assert call_args[0][1] == 'test-service'
148151
assert call_args[0][2] == 'us-east-1'
149-
# call_args[0][3] is the Timeout object
150-
assert call_args[0][4] == 'test-profile' # profile
152+
assert call_args[0][3] == 'eu-west-1' # forwarding_region
153+
# call_args[0][4] is the Timeout object
154+
assert call_args[0][5] == 'test-profile' # profile
151155
mock_as_proxy.assert_called_once_with(mock_transport)
152156
mock_add_filtering.assert_called_once_with(mock_proxy, False)
153157
mock_proxy.run_async.assert_called_once()

tests/unit/test_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,12 @@ def test_create_transport_with_sigv4(self, mock_create_sigv4_client):
3939
service = 'test-service'
4040
profile = 'test-profile'
4141
region = 'us-east-1'
42+
forwarding_region = 'us-west-2'
4243
custom_timeout = Timeout(30.0)
4344

44-
result = create_transport_with_sigv4(url, service, region, custom_timeout, profile)
45+
result = create_transport_with_sigv4(
46+
url, service, region, forwarding_region, custom_timeout, profile
47+
)
4548

4649
# Verify result is StreamableHttpTransport
4750
assert isinstance(result, StreamableHttpTransport)
@@ -61,7 +64,7 @@ def test_create_transport_with_sigv4(self, mock_create_sigv4_client):
6164
headers={'test': 'header'},
6265
timeout=custom_timeout,
6366
auth=None,
64-
metadata={'AWS_REGION': region},
67+
metadata={'AWS_REGION': forwarding_region},
6568
)
6669
else:
6770
# If we can't access the factory directly, just verify the transport was created
@@ -75,9 +78,12 @@ def test_create_transport_with_sigv4_no_profile(self, mock_create_sigv4_client):
7578
url = 'https://test-service.us-west-2.api.aws/mcp'
7679
service = 'test-service'
7780
region = 'test-region'
81+
forwarding_region = 'test-forwarding-region'
7882
custom_timeout = Timeout(60.0)
7983

80-
result = create_transport_with_sigv4(url, service, region, custom_timeout)
84+
result = create_transport_with_sigv4(
85+
url, service, region, forwarding_region, custom_timeout
86+
)
8187

8288
# Test that the httpx_client_factory calls create_sigv4_client correctly
8389
# We need to access the factory through the transport's internal structure
@@ -92,7 +98,7 @@ def test_create_transport_with_sigv4_no_profile(self, mock_create_sigv4_client):
9298
headers=None,
9399
timeout=custom_timeout,
94100
auth=None,
95-
metadata={'AWS_REGION': region},
101+
metadata={'AWS_REGION': forwarding_region},
96102
)
97103
else:
98104
# If we can't access the factory directly, just verify the transport was created

0 commit comments

Comments
 (0)