diff --git a/python/semantic_kernel/connectors/openapi_plugin/__init__.py b/python/semantic_kernel/connectors/openapi_plugin/__init__.py index 875c5155d301..a8501d4b6ee5 100644 --- a/python/semantic_kernel/connectors/openapi_plugin/__init__.py +++ b/python/semantic_kernel/connectors/openapi_plugin/__init__.py @@ -7,5 +7,11 @@ from semantic_kernel.connectors.openapi_plugin.operation_selection_predicate_context import ( OperationSelectionPredicateContext, ) +from semantic_kernel.connectors.openapi_plugin.server_url_validator import ServerUrlValidationOptions -__all__ = ["OpenAPIFunctionExecutionParameters", "OpenApiParser", "OperationSelectionPredicateContext"] +__all__ = [ + "OpenAPIFunctionExecutionParameters", + "OpenApiParser", + "OperationSelectionPredicateContext", + "ServerUrlValidationOptions", +] diff --git a/python/semantic_kernel/connectors/openapi_plugin/models/rest_api_operation.py b/python/semantic_kernel/connectors/openapi_plugin/models/rest_api_operation.py index 52e1c8aac59f..94480a81ae43 100644 --- a/python/semantic_kernel/connectors/openapi_plugin/models/rest_api_operation.py +++ b/python/semantic_kernel/connectors/openapi_plugin/models/rest_api_operation.py @@ -221,7 +221,7 @@ def build_headers(self, arguments: dict[str, Any]) -> dict[str, str]: def build_operation_url(self, arguments, server_url_override=None, api_host_url=None): """Build the URL for the operation.""" - server_url = self.get_server_url(server_url_override, api_host_url) + server_url = self.get_server_url(server_url_override, api_host_url, arguments) path = self.build_path(self.path, arguments) try: return urljoin(server_url, path.lstrip("/")) @@ -253,11 +253,11 @@ def get_server_url(self, server_url_override=None, api_host_url=None, arguments= argument_name = variable_def.get("argument_name", variable_name) if argument_name in arguments: value = arguments[argument_name] - server_url_string = server_url_string.replace(f"{{{variable_name}}}", value) + server_url_string = server_url_string.replace(f"{{{variable_name}}}", str(value)) elif "default" in variable_def and variable_def["default"] is not None: # Use the default value if no argument is provided value = variable_def["default"] - server_url_string = server_url_string.replace(f"{{{variable_name}}}", value) + server_url_string = server_url_string.replace(f"{{{variable_name}}}", str(value)) else: # Raise an exception if no value is available raise FunctionExecutionException( diff --git a/python/semantic_kernel/connectors/openapi_plugin/openapi_function_execution_parameters.py b/python/semantic_kernel/connectors/openapi_plugin/openapi_function_execution_parameters.py index 2d1ac19df68b..d22a0fd44faa 100644 --- a/python/semantic_kernel/connectors/openapi_plugin/openapi_function_execution_parameters.py +++ b/python/semantic_kernel/connectors/openapi_plugin/openapi_function_execution_parameters.py @@ -16,7 +16,13 @@ class OpenAPIFunctionExecutionParameters(KernelBaseModel): - """OpenAPI function execution parameters.""" + """OpenAPI function execution parameters. + + OpenAPI operation request URLs are validated by default to reduce SSRF risk. Requests must use HTTPS + and must not resolve to private, loopback, link-local, or otherwise non-public IP addresses unless the + target is explicitly trusted through `server_url_validation_allowed_base_urls` or + `allow_private_network_access`. + """ http_client: httpx.AsyncClient | None = None auth_callback: AuthCallbackType | None = None @@ -47,9 +53,24 @@ class OpenAPIFunctionExecutionParameters(KernelBaseModel): "and need external HTTP $ref resolution." ), ) + server_url_validation_allowed_base_urls: list[str] = Field( + default_factory=list, + description=( + "Base URLs that are explicitly allowed for OpenAPI operation requests. Matching URLs bypass " + "the default HTTPS-only and private-network validation gates. Set only for trusted endpoints." + ), + ) + allow_private_network_access: bool = Field( + False, + description=( + "Whether OpenAPI operation requests may target private, loopback, link-local, or otherwise " + "non-public IP addresses. Disabled by default to prevent SSRF." + ), + ) def model_post_init(self, __context: Any) -> None: """Post initialization method for the model.""" + from semantic_kernel.connectors.openapi_plugin.server_url_validator import ServerUrlValidationOptions from semantic_kernel.utils.telemetry.user_agent import HTTP_USER_AGENT if self.server_url_override: @@ -57,5 +78,7 @@ def model_post_init(self, __context: Any) -> None: if not parsed_url.scheme or not parsed_url.netloc: raise ValueError(f"Invalid server_url_override: {self.server_url_override}") + ServerUrlValidationOptions(allowed_base_urls=self.server_url_validation_allowed_base_urls) + if not self.user_agent: self.user_agent = HTTP_USER_AGENT diff --git a/python/semantic_kernel/connectors/openapi_plugin/openapi_manager.py b/python/semantic_kernel/connectors/openapi_plugin/openapi_manager.py index b825a1635cae..a715dac6b9a4 100644 --- a/python/semantic_kernel/connectors/openapi_plugin/openapi_manager.py +++ b/python/semantic_kernel/connectors/openapi_plugin/openapi_manager.py @@ -11,6 +11,7 @@ from semantic_kernel.connectors.openapi_plugin.models.rest_api_uri import Uri from semantic_kernel.connectors.openapi_plugin.openapi_parser import OpenApiParser from semantic_kernel.connectors.openapi_plugin.openapi_runner import OpenApiRunner +from semantic_kernel.connectors.openapi_plugin.server_url_validator import ServerUrlValidationOptions from semantic_kernel.exceptions.function_exceptions import FunctionExecutionException from semantic_kernel.functions.kernel_arguments import KernelArguments from semantic_kernel.functions.kernel_function_decorator import kernel_function @@ -79,6 +80,12 @@ def create_functions_from_openapi( http_client=execution_settings.http_client if execution_settings else None, enable_dynamic_payload=execution_settings.enable_dynamic_payload if execution_settings else True, enable_payload_namespacing=execution_settings.enable_payload_namespacing if execution_settings else False, + server_url_validation_options=ServerUrlValidationOptions( + allowed_base_urls=execution_settings.server_url_validation_allowed_base_urls, + allow_private_network_access=execution_settings.allow_private_network_access, + ) + if execution_settings + else None, ) functions = [] diff --git a/python/semantic_kernel/connectors/openapi_plugin/openapi_runner.py b/python/semantic_kernel/connectors/openapi_plugin/openapi_runner.py index d1d2db141b2b..ee88d4111ffb 100644 --- a/python/semantic_kernel/connectors/openapi_plugin/openapi_runner.py +++ b/python/semantic_kernel/connectors/openapi_plugin/openapi_runner.py @@ -17,6 +17,10 @@ from semantic_kernel.connectors.openapi_plugin.models.rest_api_operation import RestApiOperation from semantic_kernel.connectors.openapi_plugin.models.rest_api_payload import RestApiPayload from semantic_kernel.connectors.openapi_plugin.models.rest_api_run_options import RestApiRunOptions +from semantic_kernel.connectors.openapi_plugin.server_url_validator import ( + ServerUrlValidationOptions, + validate_server_url, +) from semantic_kernel.exceptions.function_exceptions import FunctionExecutionException from semantic_kernel.functions.kernel_arguments import KernelArguments from semantic_kernel.utils.feature_stage_decorator import experimental @@ -39,6 +43,7 @@ def __init__( http_client: httpx.AsyncClient | None = None, enable_dynamic_payload: bool = True, enable_payload_namespacing: bool = False, + server_url_validation_options: ServerUrlValidationOptions | None = None, ): """Initialize the OpenApiRunner.""" self.spec = Spec.from_dict(parsed_openapi_document) # type: ignore @@ -46,6 +51,7 @@ def __init__( self.http_client = http_client self.enable_dynamic_payload = enable_dynamic_payload self.enable_payload_namespacing = enable_payload_namespacing + self.server_url_validation_options = server_url_validation_options or ServerUrlValidationOptions() def build_full_url(self, base_url, query_string): """Build the full URL.""" @@ -137,6 +143,7 @@ async def run_operation( server_url_override=options.server_url_override if options else None, api_host_url=options.api_host_url if options else None, ) + await validate_server_url(url, self.server_url_validation_options) headers = operation.build_headers(arguments=arguments) payload, _ = self.build_operation_payload(operation=operation, arguments=arguments) diff --git a/python/semantic_kernel/connectors/openapi_plugin/server_url_validator.py b/python/semantic_kernel/connectors/openapi_plugin/server_url_validator.py new file mode 100644 index 000000000000..a3ede15f0f3c --- /dev/null +++ b/python/semantic_kernel/connectors/openapi_plugin/server_url_validator.py @@ -0,0 +1,241 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import ipaddress +import socket +from collections.abc import Awaitable, Callable, Sequence +from typing import Any +from urllib.parse import ParseResult, urlparse + +from pydantic import Field + +from semantic_kernel.exceptions.function_exceptions import FunctionExecutionException +from semantic_kernel.kernel_pydantic import KernelBaseModel + +DnsResolver = Callable[[str], Awaitable[Sequence[str | ipaddress.IPv4Address | ipaddress.IPv6Address]]] + +DEFAULT_ALLOWED_SCHEME = "https" + + +class ServerUrlValidationOptions(KernelBaseModel): + """Options for validating OpenAPI operation request URLs.""" + + allowed_base_urls: list[str] = Field(default_factory=list) + allow_private_network_access: bool = False + + def model_post_init(self, __context: Any) -> None: + """Validate configured allowed base URLs.""" + for allowed_base_url in self.allowed_base_urls: + _parse_absolute_url(allowed_base_url, option_name="allowed_base_urls") + + +async def validate_server_url( + url: str, + options: ServerUrlValidationOptions | None = None, + dns_resolver: DnsResolver | None = None, +) -> None: + """Validate a fully resolved OpenAPI operation URL against the supplied policy.""" + options = options or ServerUrlValidationOptions() + try: + parsed_url = _parse_absolute_url(url) + except ValueError as exc: + raise FunctionExecutionException( + f"The request URI '{url}' is not allowed because it is not a valid absolute URI." + ) from exc + + if _matches_allowed_base_url(parsed_url, options.allowed_base_urls): + return + + if options.allowed_base_urls: + raise FunctionExecutionException( + f"The request URI '{url}' is not allowed. It does not match any of the allowed base URLs." + ) + + if parsed_url.scheme.lower() != DEFAULT_ALLOWED_SCHEME: + raise FunctionExecutionException( + f"The request URI scheme '{parsed_url.scheme}' is not allowed. " + f"Only '{DEFAULT_ALLOWED_SCHEME}' is permitted by default. " + "To allow this URL, add it to server_url_validation_allowed_base_urls." + ) + + if options.allow_private_network_access: + return + + await _ensure_public_host(parsed_url, dns_resolver) + + +def try_categorize_non_public_address( + address: str | ipaddress.IPv4Address | ipaddress.IPv6Address, +) -> tuple[bool, str]: + """Return whether an IP address is non-public and the category when blocked.""" + ip_address = ipaddress.ip_address(address) + + if isinstance(ip_address, ipaddress.IPv6Address) and ip_address.ipv4_mapped: + ip_address = ip_address.ipv4_mapped + + if isinstance(ip_address, ipaddress.IPv4Address): + return _try_classify_ipv4(ip_address) + + return _try_classify_ipv6(ip_address) + + +def _parse_absolute_url(url: str, option_name: str = "url") -> ParseResult: + parsed_url = urlparse(url) + try: + parsed_url.port + except ValueError as exc: + raise ValueError(f"Invalid {option_name}: {url}") from exc + + if not parsed_url.scheme or not parsed_url.netloc or not parsed_url.hostname: + raise ValueError(f"Invalid {option_name}: {url}") + return parsed_url + + +def _matches_allowed_base_url(url: ParseResult, allowed_base_urls: list[str]) -> bool: + for allowed_base_url in allowed_base_urls: + base_url = _parse_absolute_url(allowed_base_url, option_name="allowed_base_urls") + if url.scheme.lower() != base_url.scheme.lower(): + continue + if (url.hostname or "").lower() != (base_url.hostname or "").lower(): + continue + if _effective_port(url) != _effective_port(base_url): + continue + if _matches_path_prefix(url.path, base_url.path): + return True + + return False + + +def _effective_port(url: ParseResult) -> int | None: + if url.port is not None: + return url.port + if url.scheme.lower() == "https": + return 443 + if url.scheme.lower() == "http": + return 80 + return None + + +def _matches_path_prefix(url_path: str, base_path: str) -> bool: + url_path = url_path or "/" + base_path = base_path or "/" + + if url_path.lower() == base_path.lower(): + return True + + base_path_with_slash = base_path if base_path.endswith("/") else f"{base_path}/" + return url_path.lower().startswith(base_path_with_slash.lower()) + + +async def _ensure_public_host(parsed_url: ParseResult, dns_resolver: DnsResolver | None) -> None: + host = parsed_url.hostname + if host is None: + raise FunctionExecutionException(f"The request URI '{parsed_url.geturl()}' does not contain a valid host.") + + try: + ip_address = ipaddress.ip_address(host) + except ValueError: + addresses = await _resolve_host(host, dns_resolver) + else: + _ensure_public_address(parsed_url.geturl(), ip_address) + return + + if not addresses: + raise FunctionExecutionException( + f"The request URI '{parsed_url.geturl()}' is not allowed: DNS resolution for host " + f"'{host}' returned no addresses. The request is blocked as a precaution." + ) + + for address in addresses: + _ensure_public_address(parsed_url.geturl(), address) + + +async def _resolve_host( + host: str, + dns_resolver: DnsResolver | None, +) -> list[ipaddress.IPv4Address | ipaddress.IPv6Address]: + try: + if dns_resolver: + resolved_addresses = await dns_resolver(host) + return [ipaddress.ip_address(address) for address in resolved_addresses] + + loop = asyncio.get_running_loop() + addr_info = await loop.getaddrinfo(host, None, type=socket.SOCK_STREAM) + except (OSError, ValueError) as exc: + raise FunctionExecutionException( + f"The request URI host '{host}' is not allowed: DNS resolution failed. " + "The request is blocked as a precaution to prevent potential access to private network addresses." + ) from exc + + addresses: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = [] + seen_addresses: set[str] = set() + for family, _, _, _, sockaddr in addr_info: + if family not in (socket.AF_INET, socket.AF_INET6): + continue + address = ipaddress.ip_address(sockaddr[0]) + address_string = str(address) + if address_string not in seen_addresses: + addresses.append(address) + seen_addresses.add(address_string) + return addresses + + +def _ensure_public_address(url: str, address: ipaddress.IPv4Address | ipaddress.IPv6Address) -> None: + blocked, category = try_categorize_non_public_address(address) + if blocked: + raise FunctionExecutionException( + f"The request URI '{url}' is not allowed: host resolves to a {category} address ({address}), " + "which is blocked by default to prevent Server-Side Request Forgery (SSRF). " + "To allow this URL, add it to server_url_validation_allowed_base_urls or set " + "allow_private_network_access=True." + ) + + +def _try_classify_ipv4(address: ipaddress.IPv4Address) -> tuple[bool, str]: + b0, b1, b2, _ = address.packed + + if b0 == 0: + return True, "unspecified" + if b0 == 10: + return True, "private (RFC1918)" + if b0 == 127: + return True, "loopback" + if b0 == 169 and b1 == 254: + return True, "link-local" + if b0 == 172 and 16 <= b1 <= 31: + return True, "private (RFC1918)" + if b0 == 192 and b1 == 168: + return True, "private (RFC1918)" + if b0 == 100 and 64 <= b1 <= 127: + return True, "carrier-grade NAT" + if b0 == 198 and b1 in (18, 19): + return True, "benchmarking" + if b0 == 192 and b1 == 0 and b2 in (0, 2): + return True, "reserved" + if b0 == 198 and b1 == 51 and b2 == 100: + return True, "reserved" + if b0 == 203 and b1 == 0 and b2 == 113: + return True, "reserved" + if 224 <= b0 <= 239: + return True, "multicast" + if b0 >= 240: + return True, "reserved" + + return False, "" + + +def _try_classify_ipv6(address: ipaddress.IPv6Address) -> tuple[bool, str]: + if address.is_loopback: + return True, "loopback" + if address.is_unspecified: + return True, "unspecified" + if address.is_link_local: + return True, "link-local" + if address in ipaddress.ip_network("fc00::/7"): + return True, "private (IPv6 ULA)" + if address.is_multicast: + return True, "multicast" + if address in ipaddress.ip_network("2001:db8::/32"): + return True, "reserved" + + return False, "" diff --git a/python/tests/integration/cross_language/test_cross_language.py b/python/tests/integration/cross_language/test_cross_language.py index ee86e8888d64..72aa0a16bac0 100644 --- a/python/tests/integration/cross_language/test_cross_language.py +++ b/python/tests/integration/cross_language/test_cross_language.py @@ -762,6 +762,7 @@ async def mock_request(request: httpx.Request): openapi_document_path=openapi_spec_file, execution_settings=OpenAPIFunctionExecutionParameters( http_client=client, + server_url_validation_allowed_base_urls=["https://127.0.0.1"], ), ) diff --git a/python/tests/unit/connectors/openapi_plugin/test_openapi_manager.py b/python/tests/unit/connectors/openapi_plugin/test_openapi_manager.py index 37bd6b324d77..dd32f70998dd 100644 --- a/python/tests/unit/connectors/openapi_plugin/test_openapi_manager.py +++ b/python/tests/unit/connectors/openapi_plugin/test_openapi_manager.py @@ -9,11 +9,15 @@ RestApiParameterLocation, ) from semantic_kernel.connectors.openapi_plugin.models.rest_api_run_options import RestApiRunOptions +from semantic_kernel.connectors.openapi_plugin.openapi_function_execution_parameters import ( + OpenAPIFunctionExecutionParameters, +) from semantic_kernel.connectors.openapi_plugin.openapi_manager import ( _create_function_from_operation, create_functions_from_openapi, ) from semantic_kernel.connectors.openapi_plugin.openapi_runner import OpenApiRunner +from semantic_kernel.connectors.openapi_plugin.server_url_validator import ServerUrlValidationOptions from semantic_kernel.exceptions import FunctionExecutionException from semantic_kernel.functions.kernel_function_decorator import kernel_function from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata @@ -236,7 +240,10 @@ async def test_run_operation_uses_timeout_from_run_options(): "info": {"title": "Test", "version": "1.0.0"}, "paths": {}, } - runner = OpenApiRunner(parsed_openapi_document=minimal_openapi_spec) + runner = OpenApiRunner( + parsed_openapi_document=minimal_openapi_spec, + server_url_validation_options=ServerUrlValidationOptions(allowed_base_urls=["https://api.example.com"]), + ) operation = MagicMock() operation.method = "GET" operation.build_headers.return_value = {} @@ -274,3 +281,31 @@ async def test_run_operation_uses_timeout_from_run_options(): found = True break assert found, f"httpx.AsyncClient was not called with timeout={desired_timeout}" + + +@patch("semantic_kernel.connectors.openapi_plugin.openapi_manager.OpenApiRunner") +@patch("semantic_kernel.connectors.openapi_plugin.openapi_manager.OpenApiParser") +def test_create_functions_from_openapi_propagates_server_url_validation_settings(mock_parser_class, mock_runner_class): + parsed_doc = { + "openapi": "3.0.0", + "info": {"title": "Test", "version": "1.0.0"}, + "paths": {}, + } + mock_parser = MagicMock() + mock_parser.parse.return_value = parsed_doc + mock_parser.create_rest_api_operations.return_value = {} + mock_parser_class.return_value = mock_parser + execution_settings = OpenAPIFunctionExecutionParameters( + server_url_validation_allowed_base_urls=["http://192.168.1.100/v1"], + allow_private_network_access=True, + ) + + create_functions_from_openapi( + plugin_name="test_plugin", + openapi_document_path="test_openapi_document_path", + execution_settings=execution_settings, + ) + + validation_options = mock_runner_class.call_args.kwargs["server_url_validation_options"] + assert validation_options.allowed_base_urls == ["http://192.168.1.100/v1"] + assert validation_options.allow_private_network_access is True diff --git a/python/tests/unit/connectors/openapi_plugin/test_openapi_runner.py b/python/tests/unit/connectors/openapi_plugin/test_openapi_runner.py index 1665314a903a..b990ebfb834e 100644 --- a/python/tests/unit/connectors/openapi_plugin/test_openapi_runner.py +++ b/python/tests/unit/connectors/openapi_plugin/test_openapi_runner.py @@ -8,6 +8,7 @@ from semantic_kernel.connectors.openapi_plugin.models.rest_api_operation import RestApiOperation from semantic_kernel.connectors.openapi_plugin.models.rest_api_payload import RestApiPayload from semantic_kernel.connectors.openapi_plugin.openapi_manager import OpenApiRunner +from semantic_kernel.connectors.openapi_plugin.server_url_validator import ServerUrlValidationOptions from semantic_kernel.exceptions import FunctionExecutionException @@ -294,7 +295,9 @@ def test_get_first_response_media_type_default(): async def test_run_operation(): - runner = OpenApiRunner({}) + runner = OpenApiRunner( + {}, server_url_validation_options=ServerUrlValidationOptions(allowed_base_urls=["http://example.com"]) + ) operation = MagicMock() arguments = {} options = MagicMock() @@ -323,3 +326,40 @@ async def mock_request(*args, **kwargs): result = await runner.run_operation(operation, arguments, options) assert result == "response text" + + +async def test_run_operation_blocks_disallowed_url_before_request(): + runner = OpenApiRunner({}) + operation = MagicMock() + operation.method = "GET" + runner.build_operation_url = MagicMock(return_value="https://127.0.0.1/latest/meta-data/") + runner.http_client = AsyncMock() + runner.http_client.request = AsyncMock() + + with pytest.raises(FunctionExecutionException, match="loopback"): + await runner.run_operation(operation, {}, None) + + operation.build_headers.assert_not_called() + runner.http_client.request.assert_not_called() + + +async def test_run_operation_blocks_server_variable_ssrf_before_request(): + runner = OpenApiRunner({}) + operation = RestApiOperation( + id="getCloudMetadata", + method="GET", + servers=[ + { + "url": "https://{api_server}/", + "variables": {"api_server": {"default": "api.example.com"}}, + } + ], + path="latest/meta-data/", + ) + runner.http_client = AsyncMock() + runner.http_client.request = AsyncMock() + + with pytest.raises(FunctionExecutionException, match="link-local"): + await runner.run_operation(operation, {"api_server": "169.254.169.254"}, None) + + runner.http_client.request.assert_not_called() diff --git a/python/tests/unit/connectors/openapi_plugin/test_server_url_validator.py b/python/tests/unit/connectors/openapi_plugin/test_server_url_validator.py new file mode 100644 index 000000000000..177ff15cb2c6 --- /dev/null +++ b/python/tests/unit/connectors/openapi_plugin/test_server_url_validator.py @@ -0,0 +1,170 @@ +# Copyright (c) Microsoft. All rights reserved. + +import socket + +import pytest + +from semantic_kernel.connectors.openapi_plugin.server_url_validator import ( + ServerUrlValidationOptions, + try_categorize_non_public_address, + validate_server_url, +) +from semantic_kernel.exceptions import FunctionExecutionException + + +@pytest.mark.parametrize( + ("address", "expected_category"), + [ + ("127.0.0.1", "loopback"), + ("127.255.255.254", "loopback"), + ("169.254.169.254", "link-local"), + ("169.254.0.1", "link-local"), + ("10.0.0.1", "private (RFC1918)"), + ("172.16.0.1", "private (RFC1918)"), + ("172.31.255.255", "private (RFC1918)"), + ("192.168.0.1", "private (RFC1918)"), + ("100.64.0.1", "carrier-grade NAT"), + ("100.127.255.254", "carrier-grade NAT"), + ("0.0.0.0", "unspecified"), + ("224.0.0.1", "multicast"), + ("239.255.255.255", "multicast"), + ("240.0.0.1", "reserved"), + ("255.255.255.255", "reserved"), + ("198.18.0.1", "benchmarking"), + ("192.0.2.1", "reserved"), + ("198.51.100.1", "reserved"), + ("203.0.113.1", "reserved"), + ("::1", "loopback"), + ("::", "unspecified"), + ("fe80::1", "link-local"), + ("fc00::1", "private (IPv6 ULA)"), + ("fd00::1", "private (IPv6 ULA)"), + ("ff02::1", "multicast"), + ("2001:db8::1", "reserved"), + ("::ffff:127.0.0.1", "loopback"), + ("::ffff:169.254.169.254", "link-local"), + ], +) +def test_try_categorize_non_public_address(address: str, expected_category: str): + blocked, category = try_categorize_non_public_address(address) + + assert blocked is True + assert category == expected_category + + +@pytest.mark.parametrize( + "address", + [ + "8.8.8.8", + "1.1.1.1", + "93.184.216.34", + "172.15.255.255", + "172.32.0.1", + "11.0.0.1", + "192.169.0.1", + "100.63.255.255", + "100.128.0.1", + "2606:4700:4700::1111", + ], +) +def test_try_categorize_non_public_address_allows_public_addresses(address: str): + blocked, category = try_categorize_non_public_address(address) + + assert blocked is False + assert category == "" + + +async def test_validate_server_url_rejects_literal_link_local_ipv4(): + with pytest.raises(FunctionExecutionException, match="link-local"): + await validate_server_url("https://169.254.169.254/latest/meta-data/") + + +async def test_validate_server_url_rejects_literal_loopback_ipv6(): + with pytest.raises(FunctionExecutionException, match="loopback"): + await validate_server_url("https://[::1]/") + + +async def test_validate_server_url_rejects_http_scheme_by_default(): + with pytest.raises(FunctionExecutionException, match="scheme"): + await validate_server_url("http://api.example.com/") + + +async def test_validate_server_url_allows_public_https_literal_by_default(): + await validate_server_url("https://1.1.1.1/") + + +async def test_validate_server_url_rejects_invalid_uri_with_function_execution_exception(): + with pytest.raises(FunctionExecutionException, match="not a valid absolute URI"): + await validate_server_url("invalid_url") + + +async def test_validate_server_url_allows_explicit_base_url_for_private_http_address(): + options = ServerUrlValidationOptions(allowed_base_urls=["http://192.168.1.100/v1"]) + + await validate_server_url("http://192.168.1.100/v1/orders", options) + + +async def test_validate_server_url_rejects_when_allowed_base_urls_do_not_match(): + options = ServerUrlValidationOptions(allowed_base_urls=["https://api.example.com/v1"]) + + with pytest.raises(FunctionExecutionException, match="allowed base URLs"): + await validate_server_url("https://api.example.com/v2/orders", options) + + +async def test_validate_server_url_allows_private_network_access_after_scheme_gate(): + options = ServerUrlValidationOptions(allow_private_network_access=True) + + await validate_server_url("https://10.0.0.5/", options) + + +async def test_validate_server_url_blocks_hostname_resolving_to_link_local(): + async def fake_resolver(host: str): + assert host == "evil.example.com" + return ["169.254.169.254"] + + with pytest.raises(FunctionExecutionException, match="link-local"): + await validate_server_url("https://evil.example.com/latest/meta-data/", dns_resolver=fake_resolver) + + +async def test_validate_server_url_blocks_hostname_resolving_to_loopback(): + async def fake_resolver(host: str): + assert host == "attacker.example.com" + return ["127.0.0.1"] + + with pytest.raises(FunctionExecutionException, match="loopback"): + await validate_server_url("https://attacker.example.com/api", dns_resolver=fake_resolver) + + +async def test_validate_server_url_blocks_when_any_resolved_address_is_private(): + async def fake_resolver(host: str): + assert host == "rebind.example.com" + return ["93.184.216.34", "10.0.0.1"] + + with pytest.raises(FunctionExecutionException, match="private"): + await validate_server_url("https://rebind.example.com/", dns_resolver=fake_resolver) + + +async def test_validate_server_url_allows_hostname_resolving_to_public_ip(): + async def fake_resolver(host: str): + assert host == "api.example.com" + return ["93.184.216.34"] + + await validate_server_url("https://api.example.com/", dns_resolver=fake_resolver) + + +async def test_validate_server_url_blocks_dns_resolution_failure(): + async def fake_resolver(host: str): + assert host == "unreachable.example.com" + raise socket.gaierror() + + with pytest.raises(FunctionExecutionException, match="DNS resolution"): + await validate_server_url("https://unreachable.example.com/", dns_resolver=fake_resolver) + + +async def test_validate_server_url_blocks_empty_dns_response(): + async def fake_resolver(host: str): + assert host == "empty-dns.example.com" + return [] + + with pytest.raises(FunctionExecutionException, match="returned no addresses"): + await validate_server_url("https://empty-dns.example.com/", dns_resolver=fake_resolver) diff --git a/python/tests/unit/connectors/openapi_plugin/test_sk_openapi.py b/python/tests/unit/connectors/openapi_plugin/test_sk_openapi.py index c59be28e9c55..1b44733562c9 100644 --- a/python/tests/unit/connectors/openapi_plugin/test_sk_openapi.py +++ b/python/tests/unit/connectors/openapi_plugin/test_sk_openapi.py @@ -342,6 +342,23 @@ def test_get_server_url_with_servers_and_variables(): assert operation.get_server_url(arguments=arguments) == expected_url +def test_get_server_url_with_servers_coerces_variable_argument_to_string(): + operation = RestApiOperation( + id="test", + method="GET", + servers=[ + { + "url": "https://example.com/{version}", + "variables": {"version": {"default": "v1", "argument_name": "api_version"}}, + } + ], + path="/resource/{id}", + ) + arguments = {"api_version": 2} + expected_url = "https://example.com/2/" + assert operation.get_server_url(arguments=arguments) == expected_url + + def test_get_server_url_with_servers_and_default_variable(): operation = RestApiOperation( id="test", @@ -353,6 +370,17 @@ def test_get_server_url_with_servers_and_default_variable(): assert operation.get_server_url() == expected_url +def test_get_server_url_with_servers_coerces_default_variable_to_string(): + operation = RestApiOperation( + id="test", + method="GET", + servers=[{"url": "https://example.com/{version}", "variables": {"version": {"default": 1}}}], + path="/resource/{id}", + ) + expected_url = "https://example.com/1/" + assert operation.get_server_url() == expected_url + + def test_get_server_url_with_override(): operation = RestApiOperation( id="test", @@ -896,3 +924,8 @@ def test_invalid_server_url_override(): with pytest.raises(ValueError, match="Invalid server_url_override: invalid_url"): params = OpenAPIFunctionExecutionParameters(server_url_override="invalid_url") params.model_post_init(None) + + +def test_invalid_server_url_validation_allowed_base_url(): + with pytest.raises(ValueError, match="Invalid allowed_base_urls: invalid_url"): + OpenAPIFunctionExecutionParameters(server_url_validation_allowed_base_urls=["invalid_url"])