Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions mcp_proxy_for_aws/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module-level storage for session-scoped data."""

from mcp.types import Implementation
from typing import Optional


_client_info: Optional[Implementation] = None


def get_client_info() -> Optional[Implementation]:
"""Get the stored client info."""
return _client_info


def set_client_info(info: Optional[Implementation]) -> None:
"""Set the client info."""
global _client_info
_client_info = info
39 changes: 39 additions & 0 deletions mcp_proxy_for_aws/middleware/client_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from collections.abc import Awaitable, Callable
from fastmcp.server.middleware import Middleware, MiddlewareContext
from mcp import types as mt
from mcp_proxy_for_aws.context import set_client_info


logger = logging.getLogger(__name__)


class ClientInfoMiddleware(Middleware):
"""Middleware to capture client_info from initialize method."""

async def on_initialize(
self,
context: MiddlewareContext[mt.InitializeRequest],
call_next: Callable[[MiddlewareContext[mt.InitializeRequest]], Awaitable[None]],
) -> None:
"""Capture client_info from initialize request."""
if context.message.params and context.message.params.clientInfo:
info = context.message.params.clientInfo
set_client_info(info)
logger.info('Captured client_info: name=%s, version=%s', info.name, info.version)

await call_next(context)
10 changes: 6 additions & 4 deletions mcp_proxy_for_aws/middleware/tool_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
from typing import Sequence


logger = logging.getLogger(__name__)


class ToolFilteringMiddleware(Middleware):
"""Middleware to filter tools based on read only flag."""

def __init__(self, read_only: bool, logger: logging.Logger | None = None):
def __init__(self, read_only: bool):
"""Initialize the middleware."""
self.read_only = read_only
self.logger = logger or logging.getLogger(__name__)

async def on_list_tools(
self,
Expand All @@ -35,7 +37,7 @@ async def on_list_tools(
"""Filter tools based on read only flag."""
# Get list of FastMCP Components
tools = await call_next(context)
self.logger.info('Filtering tools for read only: %s', self.read_only)
logger.info('Filtering tools for read only: %s', self.read_only)

# If not read only, return the list of tools as is
if not self.read_only:
Expand All @@ -50,7 +52,7 @@ async def on_list_tools(
read_only_hint = getattr(annotations, 'readOnlyHint', False)
if not read_only_hint:
# Skip tools that don't have readOnlyHint=True
self.logger.info('Skipping tool %s needing write permissions', tool.name)
logger.info('Skipping tool %s needing write permissions', tool.name)
continue

filtered_tools.append(tool)
Expand Down
18 changes: 13 additions & 5 deletions mcp_proxy_for_aws/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from mcp_proxy_for_aws.cli import parse_args
from mcp_proxy_for_aws.logging_config import configure_logging
from mcp_proxy_for_aws.middleware.client_info import ClientInfoMiddleware
from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware
from mcp_proxy_for_aws.utils import (
create_transport_with_sigv4,
Expand Down Expand Up @@ -167,6 +168,7 @@ async def client_factory():
'This proxy handles authentication and request routing to the appropriate backend services.'
),
)
add_client_info_middleware(proxy)
add_logging_middleware(proxy, args.log_level)
add_tool_filtering_middleware(proxy, args.read_only)

Expand All @@ -178,6 +180,16 @@ async def client_factory():
raise e


def add_client_info_middleware(mcp: FastMCP) -> None:
"""Add client info middleware to capture client_info from initialize.

Args:
mcp: The FastMCP instance to add client info middleware to
"""
logger.info('Adding client info middleware')
mcp.add_middleware(ClientInfoMiddleware())


def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None:
"""Add tool filtering middleware to target MCP server.

Expand All @@ -186,11 +198,7 @@ def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None
read_only: Whether or not to filter out tools that require write permissions
"""
logger.info('Adding tool filtering middleware')
mcp.add_middleware(
ToolFilteringMiddleware(
read_only=read_only,
)
)
mcp.add_middleware(ToolFilteringMiddleware(read_only))


def add_retry_middleware(mcp: FastMCP, retries: int) -> None:
Expand Down
9 changes: 9 additions & 0 deletions mcp_proxy_for_aws/sigv4_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
from functools import partial
from mcp_proxy_for_aws import __version__
from mcp_proxy_for_aws.context import get_client_info
from typing import Any, Dict, Generator, Optional


Expand Down Expand Up @@ -228,6 +230,13 @@ async def _sign_request_hook(
# Set Content-Length for signing
request.headers['Content-Length'] = str(len(request.content))

# Build User-Agent from client_info if available
info = get_client_info()
if info:
user_agent = f'{info.name}/{info.version} mcp-proxy-for-aws/{__version__}'
request.headers['User-Agent'] = user_agent
logger.info('Set User-Agent header: %s', user_agent)

# Get AWS credentials
session = create_aws_session(profile)
credentials = session.get_credentials()
Expand Down
116 changes: 116 additions & 0 deletions tests/unit/test_client_info_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from datetime import datetime
from fastmcp.server.middleware import MiddlewareContext
from mcp import types as mt
from mcp_proxy_for_aws.context import get_client_info, set_client_info
from mcp_proxy_for_aws.middleware.client_info import ClientInfoMiddleware


@pytest.fixture
def middleware():
"""Create a ClientInfoMiddleware instance."""
return ClientInfoMiddleware()


@pytest.fixture
def mock_context_with_client_info():
"""Create a mock context with client_info."""
params = mt.InitializeRequestParams(
protocolVersion='2024-11-05',
capabilities=mt.ClientCapabilities(),
clientInfo=mt.Implementation(name='test-client', version='1.0.0'),
)
message = mt.InitializeRequest(
method='initialize',
params=params,
)
return MiddlewareContext(
message=message,
fastmcp_context=None,
source='client',
type='request',
method='initialize',
timestamp=datetime.now(),
)


@pytest.mark.asyncio
async def test_captures_client_info(middleware, mock_context_with_client_info):
"""Test that middleware captures client_info from initialize request."""
# Reset context variable
set_client_info(None)

async def call_next(ctx):
pass

await middleware.on_initialize(mock_context_with_client_info, call_next)

# Verify client_info was captured
info = get_client_info()
assert info is not None
assert info.name == 'test-client'
assert info.version == '1.0.0'


@pytest.mark.asyncio
async def test_calls_next_middleware(middleware, mock_context_with_client_info):
"""Test that middleware calls the next middleware in chain."""
called = False

async def call_next(ctx):
nonlocal called
called = True

await middleware.on_initialize(mock_context_with_client_info, call_next)

assert called is True


@pytest.mark.asyncio
async def test_captures_different_client_info(middleware):
"""Test that middleware captures different client_info values."""
# Reset context variable
set_client_info(None)

params = mt.InitializeRequestParams(
protocolVersion='2024-11-05',
capabilities=mt.ClientCapabilities(),
clientInfo=mt.Implementation(name='another-client', version='2.5.3'),
)
message = mt.InitializeRequest(
method='initialize',
params=params,
)
context = MiddlewareContext(
message=message,
fastmcp_context=None,
source='client',
type='request',
method='initialize',
timestamp=datetime.now(),
)

async def call_next(ctx):
pass

await middleware.on_initialize(context, call_next)

# Verify client_info was captured with correct values
info = get_client_info()
assert info is not None
assert info.name == 'another-client'
assert info.version == '2.5.3'
41 changes: 41 additions & 0 deletions tests/unit/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,44 @@ async def test_sign_request_hook_with_partial_application(self, mock_create_sess
assert 'authorization' in request.headers
assert 'x-amz-date' in request.headers
mock_create_session.assert_called_once_with(profile)

@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
@pytest.mark.asyncio
async def test_sign_request_hook_sets_user_agent_from_client_info(self, mock_create_session):
"""Test that sign_request_hook sets User-Agent from client_info context."""
from mcp.types import Implementation
from mcp_proxy_for_aws import __version__
from mcp_proxy_for_aws.context import set_client_info

mock_create_session.return_value = create_mock_session()

# Set client_info in context
info = Implementation(name='test-client', version='2.5.0')
set_client_info(info)

request = httpx.Request('POST', 'https://example.com/mcp', content=b'test')
await _sign_request_hook('us-east-1', 'bedrock-agentcore', None, request)

assert (
request.headers['user-agent'] == f'test-client/2.5.0 mcp-proxy-for-aws/{__version__}'
)

# Clean up
set_client_info(None)

@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
@pytest.mark.asyncio
async def test_sign_request_hook_without_client_info(self, mock_create_session):
"""Test that sign_request_hook works without client_info."""
from mcp_proxy_for_aws.context import set_client_info

mock_create_session.return_value = create_mock_session()

# Ensure client_info is None
set_client_info(None)

request = httpx.Request('POST', 'https://example.com/mcp', content=b'test')
await _sign_request_hook('us-east-1', 'bedrock-agentcore', None, request)

assert 'user-agent' not in request.headers
assert 'authorization' in request.headers
2 changes: 0 additions & 2 deletions tests/unit/test_tool_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def test_constructor_read_only_false(self):

# Assert
assert middleware.read_only is False
assert middleware.logger is not None

def test_constructor_read_only_true(self):
"""Test constructor with read_only=True."""
Expand All @@ -49,7 +48,6 @@ def test_constructor_read_only_true(self):

# Assert
assert middleware.read_only is True
assert middleware.logger is not None


class TestOnListTools:
Expand Down
Loading