diff --git a/azure-cli.pyproj b/azure-cli.pyproj
index ed4375552..2299fee6b 100644
--- a/azure-cli.pyproj
+++ b/azure-cli.pyproj
@@ -25,6 +25,7 @@
$(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\Python Tools\Microsoft.PythonTools.targets
+
Code
diff --git a/src/azure/cli/_profile.py b/src/azure/cli/_profile.py
index 88d0c2959..e55179caf 100644
--- a/src/azure/cli/_profile.py
+++ b/src/azure/cli/_profile.py
@@ -11,6 +11,7 @@ from .main import ACCOUNT
from ._util import CLIError
from ._azure_env import (get_authority_url, CLIENT_ID, get_management_endpoint_url,
ENV_DEFAULT, COMMON_TENANT)
+from .adal_authentication import AdalAuthentication
import azure.cli._logging as _logging
logger = _logging.get_az_logger(__name__)
@@ -33,6 +34,7 @@ _SERVICE_PRINCIPAL = 'servicePrincipal'
_SERVICE_PRINCIPAL_ID = 'servicePrincipalId'
_SERVICE_PRINCIPAL_TENANT = 'servicePrincipalTenant'
_TOKEN_ENTRY_USER_ID = 'userId'
+_TOKEN_ENTRY_TOKEN_TYPE = 'tokenType'
#This could mean either real access token, or client secret of a service principal
#This naming is no good, but can't change because xplat-cli does so.
_ACCESS_TOKEN = 'accessToken'
@@ -203,17 +205,15 @@ class Profile(object):
user_type = active_account[_USER_ENTITY][_USER_TYPE]
username_or_sp_id = active_account[_USER_ENTITY][_USER_NAME]
if user_type == _USER:
- try:
- access_token = self._creds_cache.retrieve_token_for_user(username_or_sp_id,
- active_account[_TENANT_ID])
- except adal.AdalError as err:
- raise CLIError(err)
+ token_retriever = lambda: self._creds_cache.retrieve_token_for_user(
+ username_or_sp_id, active_account[_TENANT_ID])
+ auth_object = AdalAuthentication(token_retriever)
else:
- access_token = self._creds_cache.retrieve_token_for_service_principal(
+ token_retriever = lambda: self._creds_cache.retrieve_token_for_service_principal(
username_or_sp_id)
+ auth_object = AdalAuthentication(token_retriever)
- return BasicTokenAuthentication(
- {'access_token': access_token}), str(active_account[_SUBSCRIPTION_ID])
+ return auth_object, str(active_account[_SUBSCRIPTION_ID])
class SubscriptionFinder(object):
@@ -323,7 +323,7 @@ class CredsCache(object):
if self.adal_token_cache.has_state_changed:
self.persist_cached_creds()
- return token_entry[_ACCESS_TOKEN]
+ return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN])
def retrieve_token_for_service_principal(self, sp_id):
matched = [x for x in self._service_principal_creds if sp_id == x[_SERVICE_PRINCIPAL_ID]]
@@ -335,7 +335,7 @@ class CredsCache(object):
token_entry = context.acquire_token_with_client_credentials(self._resource,
sp_id,
cred[_ACCESS_TOKEN])
- return token_entry[_ACCESS_TOKEN]
+ return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN])
def _load_creds(self):
if self.adal_token_cache is not None:
diff --git a/src/azure/cli/adal_authentication.py b/src/azure/cli/adal_authentication.py
new file mode 100644
index 000000000..23030b7bb
--- /dev/null
+++ b/src/azure/cli/adal_authentication.py
@@ -0,0 +1,22 @@
+import adal
+
+from msrest.authentication import Authentication
+
+from azure.cli._util import CLIError
+
+class AdalAuthentication(Authentication):#pylint: disable=too-few-public-methods
+
+ def __init__(self, token_retriever):
+ self._token_retriever = token_retriever
+
+ def signed_session(self):
+ session = super(AdalAuthentication, self).signed_session()
+
+ try:
+ scheme, token = self._token_retriever()
+ except adal.AdalError as err:
+ raise CLIError(err)
+
+ header = "{} {}".format(scheme, token)
+ session.headers['Authorization'] = header
+ return session
diff --git a/src/azure/cli/tests/test_profile.py b/src/azure/cli/tests/test_profile.py
index 955b91440..c8b989687 100644
--- a/src/azure/cli/tests/test_profile.py
+++ b/src/azure/cli/tests/test_profile.py
@@ -185,8 +185,9 @@ class Test_Profile(unittest.TestCase):
@mock.patch('azure.cli._profile._read_file_content', autospec=True)
@mock.patch('azure.cli._profile.CredsCache.retrieve_token_for_user', autospec=True)
def test_get_login_credentials(self, mock_get_token, mock_read_cred_file):
+ some_token_type = 'Bearer'
mock_read_cred_file.return_value = json.dumps([Test_Profile.token_entry1])
- mock_get_token.return_value = Test_Profile.raw_token1
+ mock_get_token.return_value = (some_token_type, Test_Profile.raw_token1)
#setup
storage_mock = {'subscriptions': None}
profile = Profile(storage_mock)
@@ -200,7 +201,11 @@ class Test_Profile(unittest.TestCase):
#verify
self.assertEqual(subscription_id, '1')
- self.assertEqual(cred.token['access_token'], self.raw_token1)
+
+ #verify the cred._tokenRetriever is a working lambda
+ token_type, token = cred._token_retriever()
+ self.assertEqual(token, self.raw_token1)
+ self.assertEqual(some_token_type, token_type)
self.assertEqual(mock_read_cred_file.call_count, 1)
self.assertEqual(mock_get_token.call_count, 1)
@@ -392,6 +397,7 @@ class Test_Profile(unittest.TestCase):
def test_credscache_new_token_added_by_adal(self, mock_adal_auth_context, mock_open_for_write, mock_read_file):
token_entry2 = {
"accessToken": "new token",
+ "tokenType": "Bearer",
"userId": self.user1
}
def acquire_token_side_effect(*args):
@@ -405,9 +411,9 @@ class Test_Profile(unittest.TestCase):
mock_open_for_write.return_value = FileHandleStub()
mock_read_file.return_value = json.dumps([self.token_entry1])
creds_cache = CredsCache(auth_ctx_factory=get_auth_context)
- token = creds_cache.retrieve_token_for_user(self.user1, self.tenant_id)
#action
+ token_type, token = creds_cache.retrieve_token_for_user(self.user1, self.tenant_id)
mock_adal_auth_context.acquire_token.assert_called_once_with(
'https://management.core.windows.net/',
self.user1,
@@ -416,6 +422,7 @@ class Test_Profile(unittest.TestCase):
#assert
mock_open_for_write.assert_called_with(mock.ANY, 'w', encoding='ascii')
self.assertEqual(token, 'new token')
+ self.assertEqual(token_type, token_entry2['tokenType'])
class FileHandleStub:
def write(self, content):
diff --git a/src/azure/cli/utils/vcr_test_base.py b/src/azure/cli/utils/vcr_test_base.py
index e868f5181..94c62db9e 100644
--- a/src/azure/cli/utils/vcr_test_base.py
+++ b/src/azure/cli/utils/vcr_test_base.py
@@ -44,7 +44,7 @@ def _mock_subscriptions(self): #pylint: disable=unused-argument
"isDefault": True}]
def _mock_user_access_token(_, _1, _2): #pylint: disable=unused-argument
- return 'top-secret-token-for-you'
+ return ('Bearer', 'top-secret-token-for-you')
def _mock_operation_delay(_):
# don't run time.sleep()