Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 55 additions & 48 deletions python/semantic_kernel/connectors/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ async def connect(self) -> None:
try:
self._current_task = asyncio.create_task(self._inner_connect(ready_event))
await ready_event.wait()
if self._current_task.done():
await self._current_task
except KernelPluginInvalidConfigurationError:
ready_event.clear()
raise
Expand All @@ -314,55 +316,60 @@ async def close(self) -> None:
self.session = None

async def _inner_connect(self, ready_event: asyncio.Event) -> None:
if not self.session:
try:
transport = await self._exit_stack.enter_async_context(self.get_mcp_client())
except Exception as ex:
await self._exit_stack.aclose()
ready_event.set()
raise KernelPluginInvalidConfigurationError(
"Failed to connect to the MCP server. Please check your configuration."
) from ex
try:
session = await self._exit_stack.enter_async_context(
ClientSession(
read_stream=transport[0],
write_stream=transport[1],
read_timeout_seconds=timedelta(seconds=self.request_timeout) if self.request_timeout else None,
message_handler=self.message_handler,
logging_callback=self.logging_callback,
sampling_callback=self.sampling_callback,
try:
if not self.session:
try:
transport = await self._exit_stack.enter_async_context(self.get_mcp_client())
except Exception as ex:
await self._exit_stack.aclose()
raise KernelPluginInvalidConfigurationError(
"Failed to connect to the MCP server. Please check your configuration."
) from ex
try:
session = await self._exit_stack.enter_async_context(
ClientSession(
read_stream=transport[0],
write_stream=transport[1],
read_timeout_seconds=(
timedelta(seconds=self.request_timeout) if self.request_timeout else None
),
message_handler=self.message_handler,
logging_callback=self.logging_callback,
sampling_callback=self.sampling_callback,
)
)
)
except Exception as ex:
await self._exit_stack.aclose()
raise KernelPluginInvalidConfigurationError(
"Failed to create a session. Please check your configuration."
) from ex
try:
await session.initialize()
except Exception as ex:
await self._exit_stack.aclose()
raise KernelPluginInvalidConfigurationError(
"Failed to initialize session. Please check your configuration."
) from ex
self.session = session
elif self.session._request_id == 0:
# If the session is not initialized, we need to reinitialize it
await self.session.initialize()
logger.debug("Connected to MCP server: %s", self.session)
if self.load_tools_flag:
await self.load_tools()
if self.load_prompts_flag:
await self.load_prompts()

if logger.level != logging.NOTSET:
try:
await self.session.set_logging_level(
next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level)
)
except Exception:
logger.warning("Failed to set log level to %s", logger.level)
except Exception as ex:
await self._exit_stack.aclose()
raise KernelPluginInvalidConfigurationError(
"Failed to create a session. Please check your configuration."
) from ex
try:
await session.initialize()
except Exception as ex:
await self._exit_stack.aclose()
raise KernelPluginInvalidConfigurationError(
"Failed to initialize session. Please check your configuration."
) from ex
self.session = session
elif self.session._request_id == 0:
# If the session is not initialized, we need to reinitialize it
await self.session.initialize()
logger.debug("Connected to MCP server: %s", self.session)
if self.load_tools_flag:
await self.load_tools()
if self.load_prompts_flag:
await self.load_prompts()

if logger.level != logging.NOTSET:
try:
await self.session.set_logging_level(
next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level)
)
except Exception:
logger.warning("Failed to set log level to %s", logger.level)
except Exception:
ready_event.set()
raise
# Setting up is complete, will now signal the main loop that we are ready
ready_event.set()
# Create a stop event to signal the exit stack to close
Expand Down
26 changes: 26 additions & 0 deletions python/tests/unit/connectors/mcp/test_mcp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
import logging
import re
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -228,6 +229,31 @@ async def test_mcp_plugin_failed_get_session():
pass


@patch("semantic_kernel.connectors.mcp.streamablehttp_client")
@patch("semantic_kernel.connectors.mcp.ClientSession")
async def test_streamable_http_initialization_failure_unblocks_connect(mock_session, mock_client):
mock_read = MagicMock()
mock_write = MagicMock()
mock_callback = MagicMock()

mock_generator = MagicMock()
mock_generator.__aenter__.return_value = (mock_read, mock_write, mock_callback)
mock_generator.__aexit__.return_value = (mock_read, mock_write, mock_callback)
mock_client.return_value = mock_generator

mock_session.return_value.__aenter__.return_value.initialize.side_effect = RuntimeError("Unauthorized")

plugin = MCPStreamableHttpPlugin(
name="test",
url="http://localhost:8080/mcp",
load_tools=False,
load_prompts=False,
)

with pytest.raises(KernelPluginInvalidConfigurationError, match="Failed to initialize session"):
await asyncio.wait_for(plugin.connect(), timeout=1)


@patch("semantic_kernel.connectors.mcp.stdio_client")
@patch("semantic_kernel.connectors.mcp.ClientSession")
async def test_with_kwargs_stdio(mock_session, mock_client, list_tool_calls, kernel: "Kernel"):
Expand Down
Loading