From 7e9bb65571256f3dc64e40c557eb34e6432e64f5 Mon Sep 17 00:00:00 2001 From: Yugang Wang Date: Wed, 22 Jun 2016 15:58:33 -0700 Subject: [PATCH] bug: fix #439 by validating & refreshing token before each rest call (#443) --- azure-cli.pyproj | 1 + src/azure/cli/_profile.py | 20 ++++++++++---------- src/azure/cli/adal_authentication.py | 22 ++++++++++++++++++++++ src/azure/cli/tests/test_profile.py | 13 ++++++++++--- src/azure/cli/utils/vcr_test_base.py | 2 +- 5 files changed, 44 insertions(+), 14 deletions(-) create mode 100644 src/azure/cli/adal_authentication.py diff --git a/azure-cli.pyproj b/azure-cli.pyproj index ed4375552..2299fee6b 100644 --- a/azure-cli.pyproj +++ b/azure-cli.pyproj @@ -25,6 +25,7 @@ $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\Python Tools\Microsoft.PythonTools.targets + Code diff --git a/src/azure/cli/_profile.py b/src/azure/cli/_profile.py index 88d0c2959..e55179caf 100644 --- a/src/azure/cli/_profile.py +++ b/src/azure/cli/_profile.py @@ -11,6 +11,7 @@ from .main import ACCOUNT from ._util import CLIError from ._azure_env import (get_authority_url, CLIENT_ID, get_management_endpoint_url, ENV_DEFAULT, COMMON_TENANT) +from .adal_authentication import AdalAuthentication import azure.cli._logging as _logging logger = _logging.get_az_logger(__name__) @@ -33,6 +34,7 @@ _SERVICE_PRINCIPAL = 'servicePrincipal' _SERVICE_PRINCIPAL_ID = 'servicePrincipalId' _SERVICE_PRINCIPAL_TENANT = 'servicePrincipalTenant' _TOKEN_ENTRY_USER_ID = 'userId' +_TOKEN_ENTRY_TOKEN_TYPE = 'tokenType' #This could mean either real access token, or client secret of a service principal #This naming is no good, but can't change because xplat-cli does so. _ACCESS_TOKEN = 'accessToken' @@ -203,17 +205,15 @@ class Profile(object): user_type = active_account[_USER_ENTITY][_USER_TYPE] username_or_sp_id = active_account[_USER_ENTITY][_USER_NAME] if user_type == _USER: - try: - access_token = self._creds_cache.retrieve_token_for_user(username_or_sp_id, - active_account[_TENANT_ID]) - except adal.AdalError as err: - raise CLIError(err) + token_retriever = lambda: self._creds_cache.retrieve_token_for_user( + username_or_sp_id, active_account[_TENANT_ID]) + auth_object = AdalAuthentication(token_retriever) else: - access_token = self._creds_cache.retrieve_token_for_service_principal( + token_retriever = lambda: self._creds_cache.retrieve_token_for_service_principal( username_or_sp_id) + auth_object = AdalAuthentication(token_retriever) - return BasicTokenAuthentication( - {'access_token': access_token}), str(active_account[_SUBSCRIPTION_ID]) + return auth_object, str(active_account[_SUBSCRIPTION_ID]) class SubscriptionFinder(object): @@ -323,7 +323,7 @@ class CredsCache(object): if self.adal_token_cache.has_state_changed: self.persist_cached_creds() - return token_entry[_ACCESS_TOKEN] + return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN]) def retrieve_token_for_service_principal(self, sp_id): matched = [x for x in self._service_principal_creds if sp_id == x[_SERVICE_PRINCIPAL_ID]] @@ -335,7 +335,7 @@ class CredsCache(object): token_entry = context.acquire_token_with_client_credentials(self._resource, sp_id, cred[_ACCESS_TOKEN]) - return token_entry[_ACCESS_TOKEN] + return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN]) def _load_creds(self): if self.adal_token_cache is not None: diff --git a/src/azure/cli/adal_authentication.py b/src/azure/cli/adal_authentication.py new file mode 100644 index 000000000..23030b7bb --- /dev/null +++ b/src/azure/cli/adal_authentication.py @@ -0,0 +1,22 @@ +import adal + +from msrest.authentication import Authentication + +from azure.cli._util import CLIError + +class AdalAuthentication(Authentication):#pylint: disable=too-few-public-methods + + def __init__(self, token_retriever): + self._token_retriever = token_retriever + + def signed_session(self): + session = super(AdalAuthentication, self).signed_session() + + try: + scheme, token = self._token_retriever() + except adal.AdalError as err: + raise CLIError(err) + + header = "{} {}".format(scheme, token) + session.headers['Authorization'] = header + return session diff --git a/src/azure/cli/tests/test_profile.py b/src/azure/cli/tests/test_profile.py index 955b91440..c8b989687 100644 --- a/src/azure/cli/tests/test_profile.py +++ b/src/azure/cli/tests/test_profile.py @@ -185,8 +185,9 @@ class Test_Profile(unittest.TestCase): @mock.patch('azure.cli._profile._read_file_content', autospec=True) @mock.patch('azure.cli._profile.CredsCache.retrieve_token_for_user', autospec=True) def test_get_login_credentials(self, mock_get_token, mock_read_cred_file): + some_token_type = 'Bearer' mock_read_cred_file.return_value = json.dumps([Test_Profile.token_entry1]) - mock_get_token.return_value = Test_Profile.raw_token1 + mock_get_token.return_value = (some_token_type, Test_Profile.raw_token1) #setup storage_mock = {'subscriptions': None} profile = Profile(storage_mock) @@ -200,7 +201,11 @@ class Test_Profile(unittest.TestCase): #verify self.assertEqual(subscription_id, '1') - self.assertEqual(cred.token['access_token'], self.raw_token1) + + #verify the cred._tokenRetriever is a working lambda + token_type, token = cred._token_retriever() + self.assertEqual(token, self.raw_token1) + self.assertEqual(some_token_type, token_type) self.assertEqual(mock_read_cred_file.call_count, 1) self.assertEqual(mock_get_token.call_count, 1) @@ -392,6 +397,7 @@ class Test_Profile(unittest.TestCase): def test_credscache_new_token_added_by_adal(self, mock_adal_auth_context, mock_open_for_write, mock_read_file): token_entry2 = { "accessToken": "new token", + "tokenType": "Bearer", "userId": self.user1 } def acquire_token_side_effect(*args): @@ -405,9 +411,9 @@ class Test_Profile(unittest.TestCase): mock_open_for_write.return_value = FileHandleStub() mock_read_file.return_value = json.dumps([self.token_entry1]) creds_cache = CredsCache(auth_ctx_factory=get_auth_context) - token = creds_cache.retrieve_token_for_user(self.user1, self.tenant_id) #action + token_type, token = creds_cache.retrieve_token_for_user(self.user1, self.tenant_id) mock_adal_auth_context.acquire_token.assert_called_once_with( 'https://management.core.windows.net/', self.user1, @@ -416,6 +422,7 @@ class Test_Profile(unittest.TestCase): #assert mock_open_for_write.assert_called_with(mock.ANY, 'w', encoding='ascii') self.assertEqual(token, 'new token') + self.assertEqual(token_type, token_entry2['tokenType']) class FileHandleStub: def write(self, content): diff --git a/src/azure/cli/utils/vcr_test_base.py b/src/azure/cli/utils/vcr_test_base.py index e868f5181..94c62db9e 100644 --- a/src/azure/cli/utils/vcr_test_base.py +++ b/src/azure/cli/utils/vcr_test_base.py @@ -44,7 +44,7 @@ def _mock_subscriptions(self): #pylint: disable=unused-argument "isDefault": True}] def _mock_user_access_token(_, _1, _2): #pylint: disable=unused-argument - return 'top-secret-token-for-you' + return ('Bearer', 'top-secret-token-for-you') def _mock_operation_delay(_): # don't run time.sleep()