Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

from .oauth2cli import Client, JwtAssertionCreator
from .oauth2cli.oidc import decode_part
from .authority import Authority, WORLD_WIDE
from .authority import (
Authority,
WORLD_WIDE,
_get_instance_discovery_endpoint,
_get_instance_discovery_host,
)
from .mex import send_request as mex_send_request
from .wstrust_request import send_request as wst_send_request
from .wstrust_response import *
Expand Down Expand Up @@ -671,7 +676,7 @@ def __init__(
self._region_detected = None
self.client, self._regional_client = self._build_client(
client_credential, self.authority)
self.authority_groups = None
self.authority_groups = {}
self._telemetry_buffer = {}
self._telemetry_lock = Lock()
_msal_extension_check()
Expand Down Expand Up @@ -1304,9 +1309,16 @@ def _find_msal_accounts(self, environment):
}
return list(grouped_accounts.values())

def _get_instance_metadata(self): # This exists so it can be mocked in unit test
def _get_instance_metadata(self, instance): # This exists so it can be mocked in unit test
instance_discovery_host = _get_instance_discovery_host(instance)
resp = self.http_client.get(
"https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", # TBD: We may extend this to use self._instance_discovery endpoint
_get_instance_discovery_endpoint(instance),
params={
'api-version': '1.1',
'authorization_endpoint': (
"https://{}/common/oauth2/authorize".format(instance_discovery_host)
),
},
headers={'Accept': 'application/json'})
resp.raise_for_status()
return json.loads(resp.text)['metadata']
Expand All @@ -1318,10 +1330,10 @@ def _get_authority_aliases(self, instance):
# Then it is an ADFS/B2C/known_authority_hosts situation
# which may not reach the central endpoint, so we skip it.
return []
if not self.authority_groups:
self.authority_groups = [
set(group['aliases']) for group in self._get_instance_metadata()]
for group in self.authority_groups:
if instance not in self.authority_groups:
self.authority_groups[instance] = [
set(group['aliases']) for group in self._get_instance_metadata(instance)]
for group in self.authority_groups[instance]:
if instance in group:
return [alias for alias in group if alias != instance]
return []
Expand Down
59 changes: 28 additions & 31 deletions msal/authority.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,28 @@
# Endpoints were copied from here
# https://docs.microsoft.com/en-us/azure/active-directory/develop/authentication-national-cloud#azure-ad-authentication-endpoints
AZURE_US_GOVERNMENT = "login.microsoftonline.us"
AZURE_CHINA = "login.chinacloudapi.cn"
DEPRECATED_AZURE_CHINA = "login.chinacloudapi.cn"
AZURE_PUBLIC = "login.microsoftonline.com"
AZURE_GOV_FR = "login.sovcloud-identity.fr"
AZURE_GOV_DE = "login.sovcloud-identity.de"
AZURE_GOV_SG = "login.sovcloud-identity.sg"

WORLD_WIDE = 'login.microsoftonline.com' # There was an alias login.windows.net
WELL_KNOWN_AUTHORITY_HOSTS = set([
WELL_KNOWN_AUTHORITY_HOSTS = frozenset([
WORLD_WIDE,
AZURE_CHINA,
'login-us.microsoftonline.com',
AZURE_US_GOVERNMENT,
])

# Trusted issuer hosts for OIDC issuer validation
# Includes all well-known Microsoft identity provider hosts and national clouds
TRUSTED_ISSUER_HOSTS = frozenset([
# Global/Public cloud
"login.microsoftonline.com",
"login.microsoft.com",
"login.windows.net",
"sts.windows.net",
# China cloud
"login.chinacloudapi.cn",
DEPRECATED_AZURE_CHINA,
"login.partner.microsoftonline.cn",
# Germany cloud (legacy)
"login.microsoftonline.de",
# US Government clouds
"login.microsoftonline.us",
"login.microsoftonline.de", # deprecated
'login-us.microsoftonline.com',
AZURE_US_GOVERNMENT,
"login.usgovcloudapi.net",
"login-us.microsoftonline.com",
"https://login.sovcloud-identity.fr", # AzureBleu
"https://login.sovcloud-identity.de", # AzureDelos
"https://login.sovcloud-identity.sg", # AzureGovSG
])
AZURE_GOV_FR,
AZURE_GOV_DE,
AZURE_GOV_SG,
])

WELL_KNOWN_B2C_HOSTS = [
"b2clogin.com",
Expand All @@ -54,6 +44,15 @@
_CIAM_DOMAIN_SUFFIX = ".ciamlogin.com"


def _get_instance_discovery_host(instance):
return instance if instance in WELL_KNOWN_AUTHORITY_HOSTS else WORLD_WIDE


def _get_instance_discovery_endpoint(instance):
return 'https://{}/common/discovery/instance'.format(
_get_instance_discovery_host(instance))


class AuthorityBuilder(object):
def __init__(self, instance, tenant):
"""A helper to save caller from doing string concatenation.
Expand Down Expand Up @@ -162,10 +161,8 @@ def _initialize_entra_authority(
) or (len(parts) == 3 and parts[2].lower().startswith("b2c_"))
self._is_known_to_developer = self.is_adfs or self._is_b2c or not validate_authority
is_known_to_microsoft = self.instance in WELL_KNOWN_AUTHORITY_HOSTS
instance_discovery_endpoint = 'https://{}/common/discovery/instance'.format( # Note: This URL seemingly returns V1 endpoint only
WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too
# See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103
# and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33
instance_discovery_endpoint = _get_instance_discovery_endpoint( # Note: This URL seemingly returns V1 endpoint only
self.instance
) if instance_discovery in (None, True) else instance_discovery
if instance_discovery_endpoint and not (
is_known_to_microsoft or self._is_known_to_developer):
Expand All @@ -177,8 +174,8 @@ def _initialize_entra_authority(
if payload.get("error") == "invalid_instance":
raise ValueError(
"invalid_instance: "
"The authority you provided, %s, is not whitelisted. "
"If it is indeed your legit customized domain name, "
"The authority you provided, %s, is not known. "
"If it is a valid domain name known to you, "
"you can turn off this check by passing in "
"instance_discovery=False"
% authority_url)
Expand Down Expand Up @@ -235,7 +232,7 @@ def has_valid_issuer(self):
return False

# Case 2: Issuer is from a trusted Microsoft host - O(1) lookup
if issuer_host in TRUSTED_ISSUER_HOSTS:
if issuer_host in WELL_KNOWN_AUTHORITY_HOSTS:
return True

# Case 3: Regional variant check - O(1) lookup
Expand All @@ -245,7 +242,7 @@ def has_valid_issuer(self):
potential_base = issuer_host[dot_index + 1:]
if "." not in issuer_host[:dot_index]:
# 3a: Base host is a trusted Microsoft host
if potential_base in TRUSTED_ISSUER_HOSTS:
if potential_base in WELL_KNOWN_AUTHORITY_HOSTS:
return True
# 3b: Issuer has a region prefix on the authority host
# e.g. issuer=us.someweb.com, authority=someweb.com
Expand Down
41 changes: 41 additions & 0 deletions tests/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,44 @@ def raise_for_status(self):
if self._raw_resp is not None: # Turns out `if requests.response` won't work
# cause it would be True when 200<=status<400
self._raw_resp.raise_for_status()


class RecordingHttpClient(object):
def __init__(self):
self.get_calls = []
self.post_calls = []
self._get_routes = []
self._post_routes = []

def add_get_route(self, matcher, responder):
self._get_routes.append((matcher, responder))

def add_post_route(self, matcher, responder):
self._post_routes.append((matcher, responder))

def get(self, url, params=None, headers=None, **kwargs):
call = {
"url": url,
"params": params,
"headers": headers,
"kwargs": kwargs,
}
self.get_calls.append(call)
for matcher, responder in self._get_routes:
if matcher(call):
return responder(call)
return MinimalResponse(status_code=404, text="")

def post(self, url, params=None, data=None, headers=None, **kwargs):
call = {
"url": url,
"params": params,
"data": data,
"headers": headers,
"kwargs": kwargs,
}
self.post_calls.append(call)
for matcher, responder in self._post_routes:
if matcher(call):
return responder(call)
return MinimalResponse(status_code=404, text="")
107 changes: 102 additions & 5 deletions tests/test_authority.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import msal
from msal.authority import *
from msal.authority import _CIAM_DOMAIN_SUFFIX, TRUSTED_ISSUER_HOSTS # Explicitly import private/new constants
from msal.authority import _CIAM_DOMAIN_SUFFIX
from tests import unittest
from tests.http_client import MinimalHttpClient

Expand Down Expand Up @@ -37,10 +37,90 @@ def _test_authority_builder(self, host, tenant):
c.close()

def test_wellknown_host_and_tenant(self):
# Assert all well known authority hosts are using their own "common" tenant
# This test makes real HTTP calls to authority endpoints.
# It is intentionally network-based to validate reachable hosts end-to-end.
excluded_hosts = {
DEPRECATED_AZURE_CHINA,
"login.microsoftonline.de", # deprecated
"login.microsoft.com", # issuer-only in this test context
"login.windows.net", # issuer-only in this test context
"sts.windows.net", # issuer-only in this test context
"login.partner.microsoftonline.cn", # issuer-only in this test context
"login.usgovcloudapi.net", # issuer-only in this test context
AZURE_GOV_FR, # currently unreachable in this environment
AZURE_GOV_DE, # currently unreachable in this environment
AZURE_GOV_SG, # currently unreachable in this environment
}
for host in WELL_KNOWN_AUTHORITY_HOSTS:
if host in excluded_hosts:
continue
self._test_given_host_and_tenant(host, "common")

@patch("msal.authority._instance_discovery")
@patch("msal.authority.tenant_discovery")
def test_new_sovereign_hosts_should_build_authority_endpoints(
self, tenant_discovery_mock, instance_discovery_mock):
for host in WELL_KNOWN_AUTHORITY_HOSTS:
tenant_discovery_mock.return_value = {
"authorization_endpoint": "https://{}/common/oauth2/v2.0/authorize".format(host),
"token_endpoint": "https://{}/common/oauth2/v2.0/token".format(host),
"issuer": "https://{}/common/v2.0".format(host),
}
instance_discovery_mock.return_value = {
"tenant_discovery_endpoint": (
"https://{}/common/v2.0/.well-known/openid-configuration".format(host)
),
}
c = MinimalHttpClient()
a = Authority(AuthorityBuilder(host, "common"), c)
self.assertEqual(
a.authorization_endpoint,
"https://{}/common/oauth2/v2.0/authorize".format(host))
self.assertEqual(
a.token_endpoint,
"https://{}/common/oauth2/v2.0/token".format(host))
c.close()

@patch("msal.authority._instance_discovery")
@patch("msal.authority.tenant_discovery")
def test_known_authority_should_use_same_host_and_skip_instance_discovery(
self, tenant_discovery_mock, instance_discovery_mock):
for host in WELL_KNOWN_AUTHORITY_HOSTS:
if host != AZURE_CHINA: # It is prone to ConnectionError
self._test_given_host_and_tenant(host, "common")
tenant_discovery_mock.return_value = {
"authorization_endpoint": "https://{}/common/oauth2/v2.0/authorize".format(host),
"token_endpoint": "https://{}/common/oauth2/v2.0/token".format(host),
"issuer": "https://{}/common/v2.0".format(host),
}
c = MinimalHttpClient()
Authority("https://{}/common".format(host), c)
c.close()

instance_discovery_mock.assert_not_called()
tenant_discovery_endpoint = tenant_discovery_mock.call_args[0][0]
self.assertTrue(
tenant_discovery_endpoint.startswith(
"https://{}/common/v2.0/.well-known/openid-configuration".format(host)))

@patch("msal.authority._instance_discovery")
@patch("msal.authority.tenant_discovery")
def test_unknown_authority_should_use_world_wide_instance_discovery_endpoint(
self, tenant_discovery_mock, instance_discovery_mock):
tenant_discovery_mock.return_value = {
"authorization_endpoint": "https://example.com/tenant/oauth2/v2.0/authorize",
"token_endpoint": "https://example.com/tenant/oauth2/v2.0/token",
"issuer": "https://example.com/tenant/v2.0",
}
instance_discovery_mock.return_value = {
"tenant_discovery_endpoint": "https://example.com/tenant/v2.0/.well-known/openid-configuration",
}

c = MinimalHttpClient()
Authority("https://example.com/tenant", c)
c.close()

self.assertEqual(
"https://{}/common/discovery/instance".format(WORLD_WIDE),
instance_discovery_mock.call_args[0][2])

def test_wellknown_host_and_tenant_using_new_authority_builder(self):
self._test_authority_builder(AZURE_PUBLIC, "consumers")
Expand Down Expand Up @@ -276,7 +356,24 @@ def test_by_default_a_known_to_microsoft_authority_should_skip_validation_but_st
app = msal.ClientApplication("id", authority="https://login.microsoftonline.com/common")
known_to_microsoft_validation.assert_not_called()
app.get_accounts() # This could make an instance metadata call for authority aliases
instance_metadata.assert_called_once_with()
instance_metadata.assert_called_once_with("login.microsoftonline.com")

def test_by_default_a_sovereign_known_authority_should_use_cloud_local_instance_metadata(
self, instance_metadata, known_to_microsoft_validation, _):
app = msal.ClientApplication("id", authority="https://login.microsoftonline.us/common")
known_to_microsoft_validation.assert_not_called()
app.get_accounts() # This could make an instance metadata call for authority aliases
instance_metadata.assert_called_once_with("login.microsoftonline.us")

def test_fr_known_authority_should_still_work_when_instance_metadata_has_no_alias_entry(
self, instance_metadata, known_to_microsoft_validation, _):
app = msal.ClientApplication("id", authority="https://{}/common".format(AZURE_GOV_FR))
known_to_microsoft_validation.assert_not_called()

accounts = app.get_accounts()

self.assertEqual([], accounts)
instance_metadata.assert_called_once_with(AZURE_GOV_FR)

def test_validate_authority_boolean_should_skip_validation_and_instance_metadata(
self, instance_metadata, known_to_microsoft_validation, _):
Expand Down
Loading
Loading