Skip to content

Application initializer does not make tenant discovery calls #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 22, 2020
40 changes: 26 additions & 14 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ def __init__(
authority or "https://login.microsoftonline.com/common/",
self.http_client, validate_authority=validate_authority)
# Here the self.authority is not the same type as authority in input
self.client = None
self.token_cache = token_cache or TokenCache()
self.client = self._build_client(client_credential, self.authority)
self._client_credential = client_credential
self.authority_groups = None

def _build_client(self, client_credential, authority):
Expand Down Expand Up @@ -248,6 +249,12 @@ def _build_client(self, client_credential, authority):
on_removing_rt=self.token_cache.remove_rt,
on_updating_rt=self.token_cache.update_rt)

def _get_client(self):
if not self.client:
self.authority.initialize()
self.client = self._build_client(self._client_credential, self.authority)
return self.client

def get_authorization_request_url(
self,
scopes, # type: list[str]
Expand Down Expand Up @@ -307,6 +314,7 @@ def get_authorization_request_url(
authority,
self.http_client
) if authority else self.authority
the_authority.initialize()

client = Client(
{"authorization_endpoint": the_authority.authorization_endpoint},
Expand Down Expand Up @@ -367,7 +375,7 @@ def acquire_token_by_authorization_code(
# really empty.
assert isinstance(scopes, list), "Invalid parameter type"
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return self.client.obtain_token_by_authorization_code(
return self._get_client().obtain_token_by_authorization_code(
code, redirect_uri=redirect_uri,
scope=decorate_scope(scopes, self.client_id),
headers={
Expand All @@ -391,6 +399,7 @@ def get_accounts(self, username=None):
Your app can choose to display those information to end user,
and allow user to choose one of his/her accounts to proceed.
"""
self.authority.initialize()
accounts = self._find_msal_accounts(environment=self.authority.instance)
if not accounts: # Now try other aliases of this authority instance
for alias in self._get_authority_aliases(self.authority.instance):
Expand Down Expand Up @@ -543,6 +552,7 @@ def acquire_token_silent_with_error(
# authority,
# self.http_client,
# ) if authority else self.authority
self.authority.initialize()
result = self._acquire_token_silent_from_cache_and_possibly_refresh_it(
scopes, account, self.authority, force_refresh=force_refresh,
correlation_id=correlation_id,
Expand All @@ -555,6 +565,7 @@ def acquire_token_silent_with_error(
"https://" + alias + "/" + self.authority.tenant,
self.http_client,
validate_authority=False)
the_authority.initialize()
result = self._acquire_token_silent_from_cache_and_possibly_refresh_it(
scopes, account, the_authority, force_refresh=force_refresh,
correlation_id=correlation_id,
Expand Down Expand Up @@ -724,7 +735,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes):
* A dict contains "error" and some other keys, when error happened.
* A dict contains no "error" key means migration was successful.
"""
return self.client.obtain_token_by_refresh_token(
return self._get_client().obtain_token_by_refresh_token(
refresh_token,
decorate_scope(scopes, self.client_id),
rt_getter=lambda rt: rt,
Expand Down Expand Up @@ -754,7 +765,7 @@ def initiate_device_flow(self, scopes=None, **kwargs):
- an error response would contain some other readable key/value pairs.
"""
correlation_id = _get_new_correlation_id()
flow = self.client.initiate_device_flow(
flow = self._get_client().initiate_device_flow(
scope=decorate_scope(scopes or [], self.client_id),
headers={
CLIENT_REQUEST_ID: correlation_id,
Expand All @@ -778,7 +789,7 @@ def acquire_token_by_device_flow(self, flow, **kwargs):
- A successful response would contain "access_token" key,
- an error response would contain "error" and usually "error_description".
"""
return self.client.obtain_token_by_device_flow(
return self._get_client().obtain_token_by_device_flow(
flow,
data=dict(kwargs.pop("data", {}), code=flow["device_code"]),
# 2018-10-4 Hack:
Expand Down Expand Up @@ -815,14 +826,15 @@ def acquire_token_by_username_password(
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID),
}
self.authority.initialize()
if not self.authority.is_adfs:
user_realm_result = self.authority.user_realm_discovery(
username, correlation_id=headers[CLIENT_REQUEST_ID])
if user_realm_result.get("account_type") == "Federated":
return self._acquire_token_by_username_password_federated(
user_realm_result, username, password, scopes=scopes,
headers=headers, **kwargs)
return self.client.obtain_token_by_username_password(
return self._get_client().obtain_token_by_username_password(
username, password, scope=scopes,
headers=headers,
**kwargs)
Expand Down Expand Up @@ -851,16 +863,16 @@ def _acquire_token_by_username_password_federated(
GRANT_TYPE_SAML1_1 = 'urn:ietf:params:oauth:grant-type:saml1_1-bearer'
grant_type = {
SAML_TOKEN_TYPE_V1: GRANT_TYPE_SAML1_1,
SAML_TOKEN_TYPE_V2: self.client.GRANT_TYPE_SAML2,
SAML_TOKEN_TYPE_V2: Client.GRANT_TYPE_SAML2,
WSS_SAML_TOKEN_PROFILE_V1_1: GRANT_TYPE_SAML1_1,
WSS_SAML_TOKEN_PROFILE_V2: self.client.GRANT_TYPE_SAML2
WSS_SAML_TOKEN_PROFILE_V2: Client.GRANT_TYPE_SAML2
}.get(wstrust_result.get("type"))
if not grant_type:
raise RuntimeError(
"RSTR returned unknown token type: %s", wstrust_result.get("type"))
self.client.grant_assertion_encoders.setdefault( # Register a non-standard type
grant_type, self.client.encode_saml_assertion)
return self.client.obtain_token_by_assertion(
Client.grant_assertion_encoders.setdefault( # Register a non-standard type
grant_type, Client.encode_saml_assertion)
return self._get_client().obtain_token_by_assertion(
wstrust_result["token"], grant_type, scope=scopes, **kwargs)


Expand All @@ -878,7 +890,7 @@ def acquire_token_for_client(self, scopes, **kwargs):
- an error response would contain "error" and usually "error_description".
"""
# TBD: force_refresh behavior
return self.client.obtain_token_for_client(
return self._get_client().obtain_token_for_client(
scope=scopes, # This grant flow requires no scope decoration
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
Expand Down Expand Up @@ -910,9 +922,9 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, **kwargs):
"""
# The implementation is NOT based on Token Exchange
# https://tools.ietf.org/html/draft-ietf-oauth-token-exchange-16
return self.client.obtain_token_by_assertion( # bases on assertion RFC 7521
return self._get_client().obtain_token_by_assertion( # bases on assertion RFC 7521
user_assertion,
self.client.GRANT_TYPE_JWT, # IDTs and AAD ATs are all JWTs
Client.GRANT_TYPE_JWT, # IDTs and AAD ATs are all JWTs
scope=decorate_scope(scopes, self.client_id), # Decoration is used for:
# 1. Explicitly requesting an RT, without relying on AAD default
# behavior, even though it currently still issues an RT.
Expand Down
11 changes: 11 additions & 0 deletions msal/authority.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ def __init__(self, authority_url, http_client, validate_authority=True):
This parameter only controls whether an instance discovery will be
performed.
"""
self._http_client = http_client
self._authority_url = authority_url
self._validate_authority = validate_authority
self._is_initialized = False

def initialize(self):
if not self._is_initialized:
self.__initialize(self._authority_url, self._http_client, self._validate_authority)
self._is_initialized = True

def __initialize(self, authority_url, http_client, validate_authority):
self._http_client = http_client
authority, self.instance, tenant = canonicalize(authority_url)
parts = authority.path.split('/')
Expand Down
1 change: 1 addition & 0 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def setUp(self):
self.authority_url = "https://login.microsoftonline.com/common"
self.authority = msal.authority.Authority(
self.authority_url, MinimalHttpClient())
self.authority.initialize()
self.scopes = ["s1", "s2"]
self.uid = "my_uid"
self.utid = "my_utid"
Expand Down
5 changes: 3 additions & 2 deletions tests/test_authority.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_wellknown_host_and_tenant(self):
for host in WELL_KNOWN_AUTHORITY_HOSTS:
a = Authority(
'https://{}/common'.format(host), MinimalHttpClient())
a.initialize()
self.assertEqual(
a.authorization_endpoint,
'https://%s/common/oauth2/v2.0/authorize' % host)
Expand All @@ -34,7 +35,7 @@ def test_unknown_host_wont_pass_instance_discovery(self):
_assert = getattr(self, "assertRaisesRegex", self.assertRaisesRegexp) # Hack
with _assert(ValueError, "invalid_instance"):
Authority('https://example.com/tenant_doesnt_matter_in_this_case',
MinimalHttpClient())
MinimalHttpClient()).initialize()

def test_invalid_host_skipping_validation_can_be_turned_off(self):
try:
Expand Down Expand Up @@ -85,7 +86,7 @@ def test_memorize(self):
authority = "https://login.microsoftonline.com/common"
self.assertNotIn(authority, Authority._domains_without_user_realm_discovery)
a = Authority(authority, MinimalHttpClient(), validate_authority=False)

a.initialize()
# We now pretend this authority supports no User Realm Discovery
class MockResponse(object):
status_code = 404
Expand Down
2 changes: 2 additions & 0 deletions tests/test_authority_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_authority_honors_a_patched_requests(self):
# First, we test that the original, unmodified authority is working
a = msal.authority.Authority(
"https://login.microsoftonline.com/common", MinimalHttpClient())
a.initialize()
self.assertEqual(
a.authorization_endpoint,
'https://login.microsoftonline.com/common/oauth2/v2.0/authorize')
Expand All @@ -27,6 +28,7 @@ def test_authority_honors_a_patched_requests(self):
with self.assertRaises(RuntimeError):
a = msal.authority.Authority(
"https://login.microsoftonline.com/common", MinimalHttpClient())
a.initialize()
finally: # Tricky:
# Unpatch is necessary otherwise other test cases would be affected
msal.authority.requests = original