Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
*,
sampling_capabilities: types.SamplingCapability | None = None,
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
capability_extensions: dict[str, Any] | None = None,
) -> None:
super().__init__(
read_stream,
Expand All @@ -143,6 +144,10 @@ def __init__(
# Experimental: Task handlers (use defaults if not provided)
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()

# Capability extensions to include in initialize request
# These are passed as the 'extensions' field of ClientCapabilities
self._capability_extensions = capability_extensions

async def initialize(self) -> types.InitializeResult:
sampling = (
(self._sampling_capabilities or types.SamplingCapability())
Expand Down Expand Up @@ -177,6 +182,7 @@ async def initialize(self) -> types.InitializeResult:
experimental=None,
roots=roots,
tasks=self._task_handlers.build_capability(),
extensions=self._capability_extensions or None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's already dict[str, Any] | None, the or is not needed, is it?

),
client_info=self._client_info,
),
Expand Down
2 changes: 2 additions & 0 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ class ClientCapabilities(MCPModel):
"""Present if the client supports listing roots."""
tasks: ClientTasksCapability | None = None
"""Present if the client supports task-augmented requests."""
extensions: dict[str, Any] | None = None
"""Protocol extensions advertised by the client."""


class PromptsCapability(MCPModel):
Expand Down
71 changes: 71 additions & 0 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,3 +768,74 @@ async def mock_server():
await session.initialize()

await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta)


@pytest.mark.anyio
async def test_client_session_capability_extensions():
"""Test that capability_extensions are included in the initialize request."""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)

received_capabilities = None

# Define capability extensions (e.g., UI extension)
# These will be placed under the 'extensions' key of ClientCapabilities
capability_extensions = {"io.modelcontextprotocol/ui": {"mimeTypes": ["text/html;profile=mcp-app"]}}

async def mock_server():
nonlocal received_capabilities

session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request.root, InitializeRequest)
received_capabilities = request.root.params.capabilities

result = ServerResult(
InitializeResult(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
server_info=Implementation(name="mock-server", version="0.1.0"),
)
)

async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
capability_extensions=capability_extensions,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()

# Assert that the capability extensions were included in the request
assert received_capabilities is not None
# The extensions should be present under the 'extensions' key
caps_dict = received_capabilities.model_dump()
assert "extensions" in caps_dict
assert "io.modelcontextprotocol/ui" in caps_dict["extensions"]
assert caps_dict["extensions"]["io.modelcontextprotocol/ui"]["mimeTypes"] == ["text/html;profile=mcp-app"]