From 33dfe053f8ed220659b2f473b3deb8fcd211cb98 Mon Sep 17 00:00:00 2001 From: Weizhou Xing <169175349+wzxxing@users.noreply.github.com> Date: Thu, 27 Nov 2025 19:37:20 +0100 Subject: [PATCH] feat: pass client info such as `kiro/1.0.0` to user-agent --- mcp_proxy_for_aws/context.py | 32 ++++++ mcp_proxy_for_aws/middleware/client_info.py | 39 +++++++ mcp_proxy_for_aws/middleware/tool_filter.py | 10 +- mcp_proxy_for_aws/server.py | 18 ++- mcp_proxy_for_aws/sigv4_helper.py | 9 ++ tests/unit/test_client_info_middleware.py | 116 ++++++++++++++++++++ tests/unit/test_hooks.py | 41 +++++++ tests/unit/test_tool_filter.py | 2 - 8 files changed, 256 insertions(+), 11 deletions(-) create mode 100644 mcp_proxy_for_aws/context.py create mode 100644 mcp_proxy_for_aws/middleware/client_info.py create mode 100644 tests/unit/test_client_info_middleware.py diff --git a/mcp_proxy_for_aws/context.py b/mcp_proxy_for_aws/context.py new file mode 100644 index 0000000..3ebefc4 --- /dev/null +++ b/mcp_proxy_for_aws/context.py @@ -0,0 +1,32 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module-level storage for session-scoped data.""" + +from mcp.types import Implementation +from typing import Optional + + +_client_info: Optional[Implementation] = None + + +def get_client_info() -> Optional[Implementation]: + """Get the stored client info.""" + return _client_info + + +def set_client_info(info: Optional[Implementation]) -> None: + """Set the client info.""" + global _client_info + _client_info = info diff --git a/mcp_proxy_for_aws/middleware/client_info.py b/mcp_proxy_for_aws/middleware/client_info.py new file mode 100644 index 0000000..b08078b --- /dev/null +++ b/mcp_proxy_for_aws/middleware/client_info.py @@ -0,0 +1,39 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections.abc import Awaitable, Callable +from fastmcp.server.middleware import Middleware, MiddlewareContext +from mcp import types as mt +from mcp_proxy_for_aws.context import set_client_info + + +logger = logging.getLogger(__name__) + + +class ClientInfoMiddleware(Middleware): + """Middleware to capture client_info from initialize method.""" + + async def on_initialize( + self, + context: MiddlewareContext[mt.InitializeRequest], + call_next: Callable[[MiddlewareContext[mt.InitializeRequest]], Awaitable[None]], + ) -> None: + """Capture client_info from initialize request.""" + if context.message.params and context.message.params.clientInfo: + info = context.message.params.clientInfo + set_client_info(info) + logger.info('Captured client_info: name=%s, version=%s', info.name, info.version) + + await call_next(context) diff --git a/mcp_proxy_for_aws/middleware/tool_filter.py b/mcp_proxy_for_aws/middleware/tool_filter.py index 1cdbe87..ec9845f 100644 --- a/mcp_proxy_for_aws/middleware/tool_filter.py +++ b/mcp_proxy_for_aws/middleware/tool_filter.py @@ -19,13 +19,15 @@ from typing import Sequence +logger = logging.getLogger(__name__) + + class ToolFilteringMiddleware(Middleware): """Middleware to filter tools based on read only flag.""" - def __init__(self, read_only: bool, logger: logging.Logger | None = None): + def __init__(self, read_only: bool): """Initialize the middleware.""" self.read_only = read_only - self.logger = logger or logging.getLogger(__name__) async def on_list_tools( self, @@ -35,7 +37,7 @@ async def on_list_tools( """Filter tools based on read only flag.""" # Get list of FastMCP Components tools = await call_next(context) - self.logger.info('Filtering tools for read only: %s', self.read_only) + logger.info('Filtering tools for read only: %s', self.read_only) # If not read only, return the list of tools as is if not self.read_only: @@ -50,7 +52,7 @@ async def on_list_tools( read_only_hint = getattr(annotations, 'readOnlyHint', False) if not read_only_hint: # Skip tools that don't have readOnlyHint=True - self.logger.info('Skipping tool %s needing write permissions', tool.name) + logger.info('Skipping tool %s needing write permissions', tool.name) continue filtered_tools.append(tool) diff --git a/mcp_proxy_for_aws/server.py b/mcp_proxy_for_aws/server.py index 0438bdb..cbf1aa7 100644 --- a/mcp_proxy_for_aws/server.py +++ b/mcp_proxy_for_aws/server.py @@ -42,6 +42,7 @@ ) from mcp_proxy_for_aws.cli import parse_args from mcp_proxy_for_aws.logging_config import configure_logging +from mcp_proxy_for_aws.middleware.client_info import ClientInfoMiddleware from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware from mcp_proxy_for_aws.utils import ( create_transport_with_sigv4, @@ -167,6 +168,7 @@ async def client_factory(): 'This proxy handles authentication and request routing to the appropriate backend services.' ), ) + add_client_info_middleware(proxy) add_logging_middleware(proxy, args.log_level) add_tool_filtering_middleware(proxy, args.read_only) @@ -178,6 +180,16 @@ async def client_factory(): raise e +def add_client_info_middleware(mcp: FastMCP) -> None: + """Add client info middleware to capture client_info from initialize. + + Args: + mcp: The FastMCP instance to add client info middleware to + """ + logger.info('Adding client info middleware') + mcp.add_middleware(ClientInfoMiddleware()) + + def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None: """Add tool filtering middleware to target MCP server. @@ -186,11 +198,7 @@ def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None read_only: Whether or not to filter out tools that require write permissions """ logger.info('Adding tool filtering middleware') - mcp.add_middleware( - ToolFilteringMiddleware( - read_only=read_only, - ) - ) + mcp.add_middleware(ToolFilteringMiddleware(read_only)) def add_retry_middleware(mcp: FastMCP, retries: int) -> None: diff --git a/mcp_proxy_for_aws/sigv4_helper.py b/mcp_proxy_for_aws/sigv4_helper.py index 570df15..9df1c30 100644 --- a/mcp_proxy_for_aws/sigv4_helper.py +++ b/mcp_proxy_for_aws/sigv4_helper.py @@ -22,6 +22,8 @@ from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials from functools import partial +from mcp_proxy_for_aws import __version__ +from mcp_proxy_for_aws.context import get_client_info from typing import Any, Dict, Generator, Optional @@ -228,6 +230,13 @@ async def _sign_request_hook( # Set Content-Length for signing request.headers['Content-Length'] = str(len(request.content)) + # Build User-Agent from client_info if available + info = get_client_info() + if info: + user_agent = f'{info.name}/{info.version} mcp-proxy-for-aws/{__version__}' + request.headers['User-Agent'] = user_agent + logger.info('Set User-Agent header: %s', user_agent) + # Get AWS credentials session = create_aws_session(profile) credentials = session.get_credentials() diff --git a/tests/unit/test_client_info_middleware.py b/tests/unit/test_client_info_middleware.py new file mode 100644 index 0000000..79a9ae0 --- /dev/null +++ b/tests/unit/test_client_info_middleware.py @@ -0,0 +1,116 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from datetime import datetime +from fastmcp.server.middleware import MiddlewareContext +from mcp import types as mt +from mcp_proxy_for_aws.context import get_client_info, set_client_info +from mcp_proxy_for_aws.middleware.client_info import ClientInfoMiddleware + + +@pytest.fixture +def middleware(): + """Create a ClientInfoMiddleware instance.""" + return ClientInfoMiddleware() + + +@pytest.fixture +def mock_context_with_client_info(): + """Create a mock context with client_info.""" + params = mt.InitializeRequestParams( + protocolVersion='2024-11-05', + capabilities=mt.ClientCapabilities(), + clientInfo=mt.Implementation(name='test-client', version='1.0.0'), + ) + message = mt.InitializeRequest( + method='initialize', + params=params, + ) + return MiddlewareContext( + message=message, + fastmcp_context=None, + source='client', + type='request', + method='initialize', + timestamp=datetime.now(), + ) + + +@pytest.mark.asyncio +async def test_captures_client_info(middleware, mock_context_with_client_info): + """Test that middleware captures client_info from initialize request.""" + # Reset context variable + set_client_info(None) + + async def call_next(ctx): + pass + + await middleware.on_initialize(mock_context_with_client_info, call_next) + + # Verify client_info was captured + info = get_client_info() + assert info is not None + assert info.name == 'test-client' + assert info.version == '1.0.0' + + +@pytest.mark.asyncio +async def test_calls_next_middleware(middleware, mock_context_with_client_info): + """Test that middleware calls the next middleware in chain.""" + called = False + + async def call_next(ctx): + nonlocal called + called = True + + await middleware.on_initialize(mock_context_with_client_info, call_next) + + assert called is True + + +@pytest.mark.asyncio +async def test_captures_different_client_info(middleware): + """Test that middleware captures different client_info values.""" + # Reset context variable + set_client_info(None) + + params = mt.InitializeRequestParams( + protocolVersion='2024-11-05', + capabilities=mt.ClientCapabilities(), + clientInfo=mt.Implementation(name='another-client', version='2.5.3'), + ) + message = mt.InitializeRequest( + method='initialize', + params=params, + ) + context = MiddlewareContext( + message=message, + fastmcp_context=None, + source='client', + type='request', + method='initialize', + timestamp=datetime.now(), + ) + + async def call_next(ctx): + pass + + await middleware.on_initialize(context, call_next) + + # Verify client_info was captured with correct values + info = get_client_info() + assert info is not None + assert info.name == 'another-client' + assert info.version == '2.5.3' diff --git a/tests/unit/test_hooks.py b/tests/unit/test_hooks.py index 51b039a..4e7a5f0 100644 --- a/tests/unit/test_hooks.py +++ b/tests/unit/test_hooks.py @@ -435,3 +435,44 @@ async def test_sign_request_hook_with_partial_application(self, mock_create_sess assert 'authorization' in request.headers assert 'x-amz-date' in request.headers mock_create_session.assert_called_once_with(profile) + + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') + @pytest.mark.asyncio + async def test_sign_request_hook_sets_user_agent_from_client_info(self, mock_create_session): + """Test that sign_request_hook sets User-Agent from client_info context.""" + from mcp.types import Implementation + from mcp_proxy_for_aws import __version__ + from mcp_proxy_for_aws.context import set_client_info + + mock_create_session.return_value = create_mock_session() + + # Set client_info in context + info = Implementation(name='test-client', version='2.5.0') + set_client_info(info) + + request = httpx.Request('POST', 'https://example.com/mcp', content=b'test') + await _sign_request_hook('us-east-1', 'bedrock-agentcore', None, request) + + assert ( + request.headers['user-agent'] == f'test-client/2.5.0 mcp-proxy-for-aws/{__version__}' + ) + + # Clean up + set_client_info(None) + + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') + @pytest.mark.asyncio + async def test_sign_request_hook_without_client_info(self, mock_create_session): + """Test that sign_request_hook works without client_info.""" + from mcp_proxy_for_aws.context import set_client_info + + mock_create_session.return_value = create_mock_session() + + # Ensure client_info is None + set_client_info(None) + + request = httpx.Request('POST', 'https://example.com/mcp', content=b'test') + await _sign_request_hook('us-east-1', 'bedrock-agentcore', None, request) + + assert 'user-agent' not in request.headers + assert 'authorization' in request.headers diff --git a/tests/unit/test_tool_filter.py b/tests/unit/test_tool_filter.py index 7c3ab49..b418b01 100644 --- a/tests/unit/test_tool_filter.py +++ b/tests/unit/test_tool_filter.py @@ -40,7 +40,6 @@ def test_constructor_read_only_false(self): # Assert assert middleware.read_only is False - assert middleware.logger is not None def test_constructor_read_only_true(self): """Test constructor with read_only=True.""" @@ -49,7 +48,6 @@ def test_constructor_read_only_true(self): # Assert assert middleware.read_only is True - assert middleware.logger is not None class TestOnListTools: