bug: fix #439 by validating & refreshing token before each rest call (#443)

This commit is contained in:
Yugang Wang 2016-06-22 15:58:33 -07:00 коммит произвёл GitHub
Родитель d440d252f8
Коммит 7e9bb65571
5 изменённых файлов: 44 добавлений и 14 удалений

Просмотреть файл

@ -25,6 +25,7 @@
<PtvsTargetsFile>$(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\Python Tools\Microsoft.PythonTools.targets</PtvsTargetsFile>
</PropertyGroup>
<ItemGroup>
<Compile Include="azure\cli\adal_authentication.py" />
<Compile Include="azure\cli\application.py" />
<Compile Include="azure\cli\commands\azure_resource_id.py">
<SubType>Code</SubType>

Просмотреть файл

@ -11,6 +11,7 @@ from .main import ACCOUNT
from ._util import CLIError
from ._azure_env import (get_authority_url, CLIENT_ID, get_management_endpoint_url,
ENV_DEFAULT, COMMON_TENANT)
from .adal_authentication import AdalAuthentication
import azure.cli._logging as _logging
logger = _logging.get_az_logger(__name__)
@ -33,6 +34,7 @@ _SERVICE_PRINCIPAL = 'servicePrincipal'
_SERVICE_PRINCIPAL_ID = 'servicePrincipalId'
_SERVICE_PRINCIPAL_TENANT = 'servicePrincipalTenant'
_TOKEN_ENTRY_USER_ID = 'userId'
_TOKEN_ENTRY_TOKEN_TYPE = 'tokenType'
#This could mean either real access token, or client secret of a service principal
#This naming is no good, but can't change because xplat-cli does so.
_ACCESS_TOKEN = 'accessToken'
@ -203,17 +205,15 @@ class Profile(object):
user_type = active_account[_USER_ENTITY][_USER_TYPE]
username_or_sp_id = active_account[_USER_ENTITY][_USER_NAME]
if user_type == _USER:
try:
access_token = self._creds_cache.retrieve_token_for_user(username_or_sp_id,
active_account[_TENANT_ID])
except adal.AdalError as err:
raise CLIError(err)
token_retriever = lambda: self._creds_cache.retrieve_token_for_user(
username_or_sp_id, active_account[_TENANT_ID])
auth_object = AdalAuthentication(token_retriever)
else:
access_token = self._creds_cache.retrieve_token_for_service_principal(
token_retriever = lambda: self._creds_cache.retrieve_token_for_service_principal(
username_or_sp_id)
auth_object = AdalAuthentication(token_retriever)
return BasicTokenAuthentication(
{'access_token': access_token}), str(active_account[_SUBSCRIPTION_ID])
return auth_object, str(active_account[_SUBSCRIPTION_ID])
class SubscriptionFinder(object):
@ -323,7 +323,7 @@ class CredsCache(object):
if self.adal_token_cache.has_state_changed:
self.persist_cached_creds()
return token_entry[_ACCESS_TOKEN]
return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN])
def retrieve_token_for_service_principal(self, sp_id):
matched = [x for x in self._service_principal_creds if sp_id == x[_SERVICE_PRINCIPAL_ID]]
@ -335,7 +335,7 @@ class CredsCache(object):
token_entry = context.acquire_token_with_client_credentials(self._resource,
sp_id,
cred[_ACCESS_TOKEN])
return token_entry[_ACCESS_TOKEN]
return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN])
def _load_creds(self):
if self.adal_token_cache is not None:

Просмотреть файл

@ -0,0 +1,22 @@
import adal
from msrest.authentication import Authentication
from azure.cli._util import CLIError
class AdalAuthentication(Authentication):#pylint: disable=too-few-public-methods
def __init__(self, token_retriever):
self._token_retriever = token_retriever
def signed_session(self):
session = super(AdalAuthentication, self).signed_session()
try:
scheme, token = self._token_retriever()
except adal.AdalError as err:
raise CLIError(err)
header = "{} {}".format(scheme, token)
session.headers['Authorization'] = header
return session

Просмотреть файл

@ -185,8 +185,9 @@ class Test_Profile(unittest.TestCase):
@mock.patch('azure.cli._profile._read_file_content', autospec=True)
@mock.patch('azure.cli._profile.CredsCache.retrieve_token_for_user', autospec=True)
def test_get_login_credentials(self, mock_get_token, mock_read_cred_file):
some_token_type = 'Bearer'
mock_read_cred_file.return_value = json.dumps([Test_Profile.token_entry1])
mock_get_token.return_value = Test_Profile.raw_token1
mock_get_token.return_value = (some_token_type, Test_Profile.raw_token1)
#setup
storage_mock = {'subscriptions': None}
profile = Profile(storage_mock)
@ -200,7 +201,11 @@ class Test_Profile(unittest.TestCase):
#verify
self.assertEqual(subscription_id, '1')
self.assertEqual(cred.token['access_token'], self.raw_token1)
#verify the cred._tokenRetriever is a working lambda
token_type, token = cred._token_retriever()
self.assertEqual(token, self.raw_token1)
self.assertEqual(some_token_type, token_type)
self.assertEqual(mock_read_cred_file.call_count, 1)
self.assertEqual(mock_get_token.call_count, 1)
@ -392,6 +397,7 @@ class Test_Profile(unittest.TestCase):
def test_credscache_new_token_added_by_adal(self, mock_adal_auth_context, mock_open_for_write, mock_read_file):
token_entry2 = {
"accessToken": "new token",
"tokenType": "Bearer",
"userId": self.user1
}
def acquire_token_side_effect(*args):
@ -405,9 +411,9 @@ class Test_Profile(unittest.TestCase):
mock_open_for_write.return_value = FileHandleStub()
mock_read_file.return_value = json.dumps([self.token_entry1])
creds_cache = CredsCache(auth_ctx_factory=get_auth_context)
token = creds_cache.retrieve_token_for_user(self.user1, self.tenant_id)
#action
token_type, token = creds_cache.retrieve_token_for_user(self.user1, self.tenant_id)
mock_adal_auth_context.acquire_token.assert_called_once_with(
'https://management.core.windows.net/',
self.user1,
@ -416,6 +422,7 @@ class Test_Profile(unittest.TestCase):
#assert
mock_open_for_write.assert_called_with(mock.ANY, 'w', encoding='ascii')
self.assertEqual(token, 'new token')
self.assertEqual(token_type, token_entry2['tokenType'])
class FileHandleStub:
def write(self, content):

Просмотреть файл

@ -44,7 +44,7 @@ def _mock_subscriptions(self): #pylint: disable=unused-argument
"isDefault": True}]
def _mock_user_access_token(_, _1, _2): #pylint: disable=unused-argument
return 'top-secret-token-for-you'
return ('Bearer', 'top-secret-token-for-you')
def _mock_operation_delay(_):
# don't run time.sleep()