diff --git a/python/semantic_kernel/connectors/mcp.py b/python/semantic_kernel/connectors/mcp.py index 6d7f8d2e182d..1ff15b722ff8 100644 --- a/python/semantic_kernel/connectors/mcp.py +++ b/python/semantic_kernel/connectors/mcp.py @@ -293,6 +293,8 @@ async def connect(self) -> None: try: self._current_task = asyncio.create_task(self._inner_connect(ready_event)) await ready_event.wait() + if self._current_task.done(): + await self._current_task except KernelPluginInvalidConfigurationError: ready_event.clear() raise @@ -314,55 +316,60 @@ async def close(self) -> None: self.session = None async def _inner_connect(self, ready_event: asyncio.Event) -> None: - if not self.session: - try: - transport = await self._exit_stack.enter_async_context(self.get_mcp_client()) - except Exception as ex: - await self._exit_stack.aclose() - ready_event.set() - raise KernelPluginInvalidConfigurationError( - "Failed to connect to the MCP server. Please check your configuration." - ) from ex - try: - session = await self._exit_stack.enter_async_context( - ClientSession( - read_stream=transport[0], - write_stream=transport[1], - read_timeout_seconds=timedelta(seconds=self.request_timeout) if self.request_timeout else None, - message_handler=self.message_handler, - logging_callback=self.logging_callback, - sampling_callback=self.sampling_callback, + try: + if not self.session: + try: + transport = await self._exit_stack.enter_async_context(self.get_mcp_client()) + except Exception as ex: + await self._exit_stack.aclose() + raise KernelPluginInvalidConfigurationError( + "Failed to connect to the MCP server. Please check your configuration." + ) from ex + try: + session = await self._exit_stack.enter_async_context( + ClientSession( + read_stream=transport[0], + write_stream=transport[1], + read_timeout_seconds=( + timedelta(seconds=self.request_timeout) if self.request_timeout else None + ), + message_handler=self.message_handler, + logging_callback=self.logging_callback, + sampling_callback=self.sampling_callback, + ) ) - ) - except Exception as ex: - await self._exit_stack.aclose() - raise KernelPluginInvalidConfigurationError( - "Failed to create a session. Please check your configuration." - ) from ex - try: - await session.initialize() - except Exception as ex: - await self._exit_stack.aclose() - raise KernelPluginInvalidConfigurationError( - "Failed to initialize session. Please check your configuration." - ) from ex - self.session = session - elif self.session._request_id == 0: - # If the session is not initialized, we need to reinitialize it - await self.session.initialize() - logger.debug("Connected to MCP server: %s", self.session) - if self.load_tools_flag: - await self.load_tools() - if self.load_prompts_flag: - await self.load_prompts() - - if logger.level != logging.NOTSET: - try: - await self.session.set_logging_level( - next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level) - ) - except Exception: - logger.warning("Failed to set log level to %s", logger.level) + except Exception as ex: + await self._exit_stack.aclose() + raise KernelPluginInvalidConfigurationError( + "Failed to create a session. Please check your configuration." + ) from ex + try: + await session.initialize() + except Exception as ex: + await self._exit_stack.aclose() + raise KernelPluginInvalidConfigurationError( + "Failed to initialize session. Please check your configuration." + ) from ex + self.session = session + elif self.session._request_id == 0: + # If the session is not initialized, we need to reinitialize it + await self.session.initialize() + logger.debug("Connected to MCP server: %s", self.session) + if self.load_tools_flag: + await self.load_tools() + if self.load_prompts_flag: + await self.load_prompts() + + if logger.level != logging.NOTSET: + try: + await self.session.set_logging_level( + next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level) + ) + except Exception: + logger.warning("Failed to set log level to %s", logger.level) + except Exception: + ready_event.set() + raise # Setting up is complete, will now signal the main loop that we are ready ready_event.set() # Create a stop event to signal the exit stack to close diff --git a/python/tests/unit/connectors/mcp/test_mcp.py b/python/tests/unit/connectors/mcp/test_mcp.py index dc8ea38330d3..fa8d4c0868db 100644 --- a/python/tests/unit/connectors/mcp/test_mcp.py +++ b/python/tests/unit/connectors/mcp/test_mcp.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import asyncio import logging import re from typing import TYPE_CHECKING @@ -228,6 +229,31 @@ async def test_mcp_plugin_failed_get_session(): pass +@patch("semantic_kernel.connectors.mcp.streamablehttp_client") +@patch("semantic_kernel.connectors.mcp.ClientSession") +async def test_streamable_http_initialization_failure_unblocks_connect(mock_session, mock_client): + mock_read = MagicMock() + mock_write = MagicMock() + mock_callback = MagicMock() + + mock_generator = MagicMock() + mock_generator.__aenter__.return_value = (mock_read, mock_write, mock_callback) + mock_generator.__aexit__.return_value = (mock_read, mock_write, mock_callback) + mock_client.return_value = mock_generator + + mock_session.return_value.__aenter__.return_value.initialize.side_effect = RuntimeError("Unauthorized") + + plugin = MCPStreamableHttpPlugin( + name="test", + url="http://localhost:8080/mcp", + load_tools=False, + load_prompts=False, + ) + + with pytest.raises(KernelPluginInvalidConfigurationError, match="Failed to initialize session"): + await asyncio.wait_for(plugin.connect(), timeout=1) + + @patch("semantic_kernel.connectors.mcp.stdio_client") @patch("semantic_kernel.connectors.mcp.ClientSession") async def test_with_kwargs_stdio(mock_session, mock_client, list_tool_calls, kernel: "Kernel"):