diff --git a/msal/application.py b/msal/application.py index ba16df83..5d86927f 100644 --- a/msal/application.py +++ b/msal/application.py @@ -11,7 +11,12 @@ from .oauth2cli import Client, JwtAssertionCreator from .oauth2cli.oidc import decode_part -from .authority import Authority, WORLD_WIDE +from .authority import ( + Authority, + WORLD_WIDE, + _get_instance_discovery_endpoint, + _get_instance_discovery_host, +) from .mex import send_request as mex_send_request from .wstrust_request import send_request as wst_send_request from .wstrust_response import * @@ -671,7 +676,7 @@ def __init__( self._region_detected = None self.client, self._regional_client = self._build_client( client_credential, self.authority) - self.authority_groups = None + self.authority_groups = {} self._telemetry_buffer = {} self._telemetry_lock = Lock() _msal_extension_check() @@ -1304,9 +1309,16 @@ def _find_msal_accounts(self, environment): } return list(grouped_accounts.values()) - def _get_instance_metadata(self): # This exists so it can be mocked in unit test + def _get_instance_metadata(self, instance): # This exists so it can be mocked in unit test + instance_discovery_host = _get_instance_discovery_host(instance) resp = self.http_client.get( - "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", # TBD: We may extend this to use self._instance_discovery endpoint + _get_instance_discovery_endpoint(instance), + params={ + 'api-version': '1.1', + 'authorization_endpoint': ( + "https://{}/common/oauth2/authorize".format(instance_discovery_host) + ), + }, headers={'Accept': 'application/json'}) resp.raise_for_status() return json.loads(resp.text)['metadata'] @@ -1318,10 +1330,10 @@ def _get_authority_aliases(self, instance): # Then it is an ADFS/B2C/known_authority_hosts situation # which may not reach the central endpoint, so we skip it. return [] - if not self.authority_groups: - self.authority_groups = [ - set(group['aliases']) for group in self._get_instance_metadata()] - for group in self.authority_groups: + if instance not in self.authority_groups: + self.authority_groups[instance] = [ + set(group['aliases']) for group in self._get_instance_metadata(instance)] + for group in self.authority_groups[instance]: if instance in group: return [alias for alias in group if alias != instance] return [] diff --git a/msal/authority.py b/msal/authority.py index b114831f..6bc9e816 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -11,38 +11,28 @@ # Endpoints were copied from here # https://docs.microsoft.com/en-us/azure/active-directory/develop/authentication-national-cloud#azure-ad-authentication-endpoints AZURE_US_GOVERNMENT = "login.microsoftonline.us" -AZURE_CHINA = "login.chinacloudapi.cn" +DEPRECATED_AZURE_CHINA = "login.chinacloudapi.cn" AZURE_PUBLIC = "login.microsoftonline.com" +AZURE_GOV_FR = "login.sovcloud-identity.fr" +AZURE_GOV_DE = "login.sovcloud-identity.de" +AZURE_GOV_SG = "login.sovcloud-identity.sg" WORLD_WIDE = 'login.microsoftonline.com' # There was an alias login.windows.net -WELL_KNOWN_AUTHORITY_HOSTS = set([ +WELL_KNOWN_AUTHORITY_HOSTS = frozenset([ WORLD_WIDE, - AZURE_CHINA, - 'login-us.microsoftonline.com', - AZURE_US_GOVERNMENT, - ]) - -# Trusted issuer hosts for OIDC issuer validation -# Includes all well-known Microsoft identity provider hosts and national clouds -TRUSTED_ISSUER_HOSTS = frozenset([ - # Global/Public cloud - "login.microsoftonline.com", "login.microsoft.com", "login.windows.net", "sts.windows.net", - # China cloud - "login.chinacloudapi.cn", + DEPRECATED_AZURE_CHINA, "login.partner.microsoftonline.cn", - # Germany cloud (legacy) - "login.microsoftonline.de", - # US Government clouds - "login.microsoftonline.us", + "login.microsoftonline.de", # deprecated + 'login-us.microsoftonline.com', + AZURE_US_GOVERNMENT, "login.usgovcloudapi.net", - "login-us.microsoftonline.com", - "https://login.sovcloud-identity.fr", # AzureBleu - "https://login.sovcloud-identity.de", # AzureDelos - "https://login.sovcloud-identity.sg", # AzureGovSG -]) + AZURE_GOV_FR, + AZURE_GOV_DE, + AZURE_GOV_SG, + ]) WELL_KNOWN_B2C_HOSTS = [ "b2clogin.com", @@ -54,6 +44,15 @@ _CIAM_DOMAIN_SUFFIX = ".ciamlogin.com" +def _get_instance_discovery_host(instance): + return instance if instance in WELL_KNOWN_AUTHORITY_HOSTS else WORLD_WIDE + + +def _get_instance_discovery_endpoint(instance): + return 'https://{}/common/discovery/instance'.format( + _get_instance_discovery_host(instance)) + + class AuthorityBuilder(object): def __init__(self, instance, tenant): """A helper to save caller from doing string concatenation. @@ -162,10 +161,8 @@ def _initialize_entra_authority( ) or (len(parts) == 3 and parts[2].lower().startswith("b2c_")) self._is_known_to_developer = self.is_adfs or self._is_b2c or not validate_authority is_known_to_microsoft = self.instance in WELL_KNOWN_AUTHORITY_HOSTS - instance_discovery_endpoint = 'https://{}/common/discovery/instance'.format( # Note: This URL seemingly returns V1 endpoint only - WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too - # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 - # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 + instance_discovery_endpoint = _get_instance_discovery_endpoint( # Note: This URL seemingly returns V1 endpoint only + self.instance ) if instance_discovery in (None, True) else instance_discovery if instance_discovery_endpoint and not ( is_known_to_microsoft or self._is_known_to_developer): @@ -177,8 +174,8 @@ def _initialize_entra_authority( if payload.get("error") == "invalid_instance": raise ValueError( "invalid_instance: " - "The authority you provided, %s, is not whitelisted. " - "If it is indeed your legit customized domain name, " + "The authority you provided, %s, is not known. " + "If it is a valid domain name known to you, " "you can turn off this check by passing in " "instance_discovery=False" % authority_url) @@ -235,7 +232,7 @@ def has_valid_issuer(self): return False # Case 2: Issuer is from a trusted Microsoft host - O(1) lookup - if issuer_host in TRUSTED_ISSUER_HOSTS: + if issuer_host in WELL_KNOWN_AUTHORITY_HOSTS: return True # Case 3: Regional variant check - O(1) lookup @@ -245,7 +242,7 @@ def has_valid_issuer(self): potential_base = issuer_host[dot_index + 1:] if "." not in issuer_host[:dot_index]: # 3a: Base host is a trusted Microsoft host - if potential_base in TRUSTED_ISSUER_HOSTS: + if potential_base in WELL_KNOWN_AUTHORITY_HOSTS: return True # 3b: Issuer has a region prefix on the authority host # e.g. issuer=us.someweb.com, authority=someweb.com diff --git a/tests/http_client.py b/tests/http_client.py index 34d430f0..88c71180 100644 --- a/tests/http_client.py +++ b/tests/http_client.py @@ -40,3 +40,44 @@ def raise_for_status(self): if self._raw_resp is not None: # Turns out `if requests.response` won't work # cause it would be True when 200<=status<400 self._raw_resp.raise_for_status() + + +class RecordingHttpClient(object): + def __init__(self): + self.get_calls = [] + self.post_calls = [] + self._get_routes = [] + self._post_routes = [] + + def add_get_route(self, matcher, responder): + self._get_routes.append((matcher, responder)) + + def add_post_route(self, matcher, responder): + self._post_routes.append((matcher, responder)) + + def get(self, url, params=None, headers=None, **kwargs): + call = { + "url": url, + "params": params, + "headers": headers, + "kwargs": kwargs, + } + self.get_calls.append(call) + for matcher, responder in self._get_routes: + if matcher(call): + return responder(call) + return MinimalResponse(status_code=404, text="") + + def post(self, url, params=None, data=None, headers=None, **kwargs): + call = { + "url": url, + "params": params, + "data": data, + "headers": headers, + "kwargs": kwargs, + } + self.post_calls.append(call) + for matcher, responder in self._post_routes: + if matcher(call): + return responder(call) + return MinimalResponse(status_code=404, text="") diff --git a/tests/test_authority.py b/tests/test_authority.py index 481b03a7..d19a8b4b 100644 --- a/tests/test_authority.py +++ b/tests/test_authority.py @@ -6,7 +6,7 @@ import msal from msal.authority import * -from msal.authority import _CIAM_DOMAIN_SUFFIX, TRUSTED_ISSUER_HOSTS # Explicitly import private/new constants +from msal.authority import _CIAM_DOMAIN_SUFFIX from tests import unittest from tests.http_client import MinimalHttpClient @@ -37,10 +37,90 @@ def _test_authority_builder(self, host, tenant): c.close() def test_wellknown_host_and_tenant(self): - # Assert all well known authority hosts are using their own "common" tenant + # This test makes real HTTP calls to authority endpoints. + # It is intentionally network-based to validate reachable hosts end-to-end. + excluded_hosts = { + DEPRECATED_AZURE_CHINA, + "login.microsoftonline.de", # deprecated + "login.microsoft.com", # issuer-only in this test context + "login.windows.net", # issuer-only in this test context + "sts.windows.net", # issuer-only in this test context + "login.partner.microsoftonline.cn", # issuer-only in this test context + "login.usgovcloudapi.net", # issuer-only in this test context + AZURE_GOV_FR, # currently unreachable in this environment + AZURE_GOV_DE, # currently unreachable in this environment + AZURE_GOV_SG, # currently unreachable in this environment + } + for host in WELL_KNOWN_AUTHORITY_HOSTS: + if host in excluded_hosts: + continue + self._test_given_host_and_tenant(host, "common") + + @patch("msal.authority._instance_discovery") + @patch("msal.authority.tenant_discovery") + def test_new_sovereign_hosts_should_build_authority_endpoints( + self, tenant_discovery_mock, instance_discovery_mock): + for host in WELL_KNOWN_AUTHORITY_HOSTS: + tenant_discovery_mock.return_value = { + "authorization_endpoint": "https://{}/common/oauth2/v2.0/authorize".format(host), + "token_endpoint": "https://{}/common/oauth2/v2.0/token".format(host), + "issuer": "https://{}/common/v2.0".format(host), + } + instance_discovery_mock.return_value = { + "tenant_discovery_endpoint": ( + "https://{}/common/v2.0/.well-known/openid-configuration".format(host) + ), + } + c = MinimalHttpClient() + a = Authority(AuthorityBuilder(host, "common"), c) + self.assertEqual( + a.authorization_endpoint, + "https://{}/common/oauth2/v2.0/authorize".format(host)) + self.assertEqual( + a.token_endpoint, + "https://{}/common/oauth2/v2.0/token".format(host)) + c.close() + + @patch("msal.authority._instance_discovery") + @patch("msal.authority.tenant_discovery") + def test_known_authority_should_use_same_host_and_skip_instance_discovery( + self, tenant_discovery_mock, instance_discovery_mock): for host in WELL_KNOWN_AUTHORITY_HOSTS: - if host != AZURE_CHINA: # It is prone to ConnectionError - self._test_given_host_and_tenant(host, "common") + tenant_discovery_mock.return_value = { + "authorization_endpoint": "https://{}/common/oauth2/v2.0/authorize".format(host), + "token_endpoint": "https://{}/common/oauth2/v2.0/token".format(host), + "issuer": "https://{}/common/v2.0".format(host), + } + c = MinimalHttpClient() + Authority("https://{}/common".format(host), c) + c.close() + + instance_discovery_mock.assert_not_called() + tenant_discovery_endpoint = tenant_discovery_mock.call_args[0][0] + self.assertTrue( + tenant_discovery_endpoint.startswith( + "https://{}/common/v2.0/.well-known/openid-configuration".format(host))) + + @patch("msal.authority._instance_discovery") + @patch("msal.authority.tenant_discovery") + def test_unknown_authority_should_use_world_wide_instance_discovery_endpoint( + self, tenant_discovery_mock, instance_discovery_mock): + tenant_discovery_mock.return_value = { + "authorization_endpoint": "https://example.com/tenant/oauth2/v2.0/authorize", + "token_endpoint": "https://example.com/tenant/oauth2/v2.0/token", + "issuer": "https://example.com/tenant/v2.0", + } + instance_discovery_mock.return_value = { + "tenant_discovery_endpoint": "https://example.com/tenant/v2.0/.well-known/openid-configuration", + } + + c = MinimalHttpClient() + Authority("https://example.com/tenant", c) + c.close() + + self.assertEqual( + "https://{}/common/discovery/instance".format(WORLD_WIDE), + instance_discovery_mock.call_args[0][2]) def test_wellknown_host_and_tenant_using_new_authority_builder(self): self._test_authority_builder(AZURE_PUBLIC, "consumers") @@ -276,7 +356,24 @@ def test_by_default_a_known_to_microsoft_authority_should_skip_validation_but_st app = msal.ClientApplication("id", authority="https://login.microsoftonline.com/common") known_to_microsoft_validation.assert_not_called() app.get_accounts() # This could make an instance metadata call for authority aliases - instance_metadata.assert_called_once_with() + instance_metadata.assert_called_once_with("login.microsoftonline.com") + + def test_by_default_a_sovereign_known_authority_should_use_cloud_local_instance_metadata( + self, instance_metadata, known_to_microsoft_validation, _): + app = msal.ClientApplication("id", authority="https://login.microsoftonline.us/common") + known_to_microsoft_validation.assert_not_called() + app.get_accounts() # This could make an instance metadata call for authority aliases + instance_metadata.assert_called_once_with("login.microsoftonline.us") + + def test_fr_known_authority_should_still_work_when_instance_metadata_has_no_alias_entry( + self, instance_metadata, known_to_microsoft_validation, _): + app = msal.ClientApplication("id", authority="https://{}/common".format(AZURE_GOV_FR)) + known_to_microsoft_validation.assert_not_called() + + accounts = app.get_accounts() + + self.assertEqual([], accounts) + instance_metadata.assert_called_once_with(AZURE_GOV_FR) def test_validate_authority_boolean_should_skip_validation_and_instance_metadata( self, instance_metadata, known_to_microsoft_validation, _): diff --git a/tests/test_recording_http_client.py b/tests/test_recording_http_client.py new file mode 100644 index 00000000..c6288deb --- /dev/null +++ b/tests/test_recording_http_client.py @@ -0,0 +1,155 @@ +import json + +from tests import unittest +from tests.http_client import RecordingHttpClient, MinimalResponse +from msal.application import ConfidentialClientApplication + + +class TestSovereignAuthorityForClientCredentialWithRecordingHttpClient(unittest.TestCase): + def test_acquire_token_for_client_on_gov_fr_should_keep_calls_on_same_host(self): + host = "login.sovcloud-identity.fr" + expected_instance_discovery_url = "https://{}/common/discovery/instance".format(host) + expected_instance_discovery_params = { + "api-version": "1.1", + "authorization_endpoint": ( + "https://{}/common/oauth2/authorize".format(host) + ), + } + + http_client = RecordingHttpClient() + + def is_oidc_discovery(call): + return call["url"].startswith( + "https://{}/common/v2.0/.well-known/openid-configuration".format(host)) + + def oidc_discovery_response(_call): + return MinimalResponse(status_code=200, text=json.dumps({ + "authorization_endpoint": "https://{}/common/oauth2/v2.0/authorize".format(host), + "token_endpoint": "https://{}/common/oauth2/v2.0/token".format(host), + "issuer": "https://{}/common/v2.0".format(host), + })) + + def is_instance_discovery(call): + return ( + call["url"] == expected_instance_discovery_url + and call["params"] == expected_instance_discovery_params + ) + + def instance_discovery_response(_call): + return MinimalResponse(status_code=200, text=json.dumps({ + "tenant_discovery_endpoint": ( + "https://login.microsoftonline.us/" + "cab8a31a-1906-4287-a0d8-4eef66b95f6e/" + "v2.0/.well-known/openid-configuration" + ), + "api-version": "1.1", + "metadata": [ + { + "preferred_network": "login.microsoftonline.com", + "preferred_cache": "login.windows.net", + "aliases": [ + "login.microsoftonline.com", + "login.windows.net", + "login.microsoft.com", + "sts.windows.net", + ], + }, + { + "preferred_network": "login.partner.microsoftonline.cn", + "preferred_cache": "login.partner.microsoftonline.cn", + "aliases": [ + "login.partner.microsoftonline.cn", + "login.chinacloudapi.cn", + ], + }, + { + "preferred_network": "login.microsoftonline.de", + "preferred_cache": "login.microsoftonline.de", + "aliases": ["login.microsoftonline.de"], + }, + { + "preferred_network": "login.microsoftonline.us", + "preferred_cache": "login.microsoftonline.us", + "aliases": [ + "login.microsoftonline.us", + "login.usgovcloudapi.net", + ], + }, + { + "preferred_network": "login-us.microsoftonline.com", + "preferred_cache": "login-us.microsoftonline.com", + "aliases": ["login-us.microsoftonline.com"], + }, + ], + })) + + token_counter = {"value": 0} + + def is_token_call(call): + return call["url"].startswith("https://{}/common/oauth2/v2.0/token".format(host)) + + def token_response(_call): + token_counter["value"] += 1 + return MinimalResponse(status_code=200, text=json.dumps({ + "access_token": "AT_{}".format(token_counter["value"]), + "expires_in": 3600, + })) + + http_client.add_get_route(is_oidc_discovery, oidc_discovery_response) + http_client.add_get_route(is_instance_discovery, instance_discovery_response) + http_client.add_post_route(is_token_call, token_response) + + app = ConfidentialClientApplication( + "client_id", + client_credential="secret", + authority="https://{}/common".format(host), + http_client=http_client, + ) + + result1 = app.acquire_token_for_client(["scope1"]) + self.assertEqual("AT_1", result1.get("access_token")) + + get_calls_after_first = list(http_client.get_calls) + post_calls_after_first = list(http_client.post_calls) + + result2 = app.acquire_token_for_client(["scope2"]) + self.assertEqual("AT_2", result2.get("access_token")) + + post_count_after_scope2 = len(http_client.post_calls) + get_count_after_scope2 = len(http_client.get_calls) + + cached_result1 = app.acquire_token_for_client(["scope1"]) + self.assertEqual("AT_1", cached_result1.get("access_token")) + + cached_result2 = app.acquire_token_for_client(["scope2"]) + self.assertEqual("AT_2", cached_result2.get("access_token")) + + cached_result3 = app.acquire_token_for_client(["scope1"]) + self.assertEqual("AT_1", cached_result3.get("access_token")) + + self.assertEqual( + post_count_after_scope2, + len(http_client.post_calls), + "Subsequent same-scope calls should be served from cache without token POST") + self.assertEqual( + get_count_after_scope2, + len(http_client.get_calls), + "Subsequent same-authority calls should not trigger additional discovery GET") + + self.assertEqual(1, len(get_calls_after_first), "First acquire should trigger one discovery GET") + self.assertTrue( + get_calls_after_first[0]["url"].startswith( + "https://{}/common/v2.0/.well-known/openid-configuration".format(host))) + + self.assertEqual(1, len(post_calls_after_first), "First acquire should trigger one token POST") + self.assertTrue( + post_calls_after_first[0]["url"].startswith("https://{}/common/oauth2/v2.0/token".format(host))) + + self.assertEqual(1, len(http_client.get_calls), "Second acquire on same authority should not re-discover") + self.assertEqual(2, len(http_client.post_calls), "Second acquire with a different scope should request another token") + self.assertTrue( + http_client.post_calls[1]["url"].startswith("https://{}/common/oauth2/v2.0/token".format(host))) + + all_urls = [c["url"] for c in http_client.get_calls + http_client.post_calls] + self.assertTrue(all("login.microsoftonline.com" not in url for url in all_urls)) + self.assertTrue(all("https://{}/".format(host) in url for url in all_urls))