517 строки
19 KiB
Python
517 строки
19 KiB
Python
# --------------------------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License. See License.txt in the project root for license information.
|
|
# --------------------------------------------------------------------------------------------
|
|
|
|
# pylint: disable=wrong-import-order
|
|
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import json
|
|
import os
|
|
import re
|
|
import shlex
|
|
import sys
|
|
import tempfile
|
|
from random import choice
|
|
from string import digits, ascii_lowercase
|
|
|
|
from six.moves.urllib.parse import urlparse, parse_qs # pylint: disable=import-error
|
|
|
|
import unittest
|
|
try:
|
|
import unittest.mock as mock
|
|
except ImportError:
|
|
import mock
|
|
|
|
import vcr
|
|
import jmespath
|
|
|
|
from azure.cli.core import get_default_cli
|
|
from azure.cli.core import __version__ as core_version
|
|
import azure.cli.core._debug as _debug
|
|
from azure.cli.core._profile import Profile
|
|
from azure.cli.core.util import CLIError, random_string
|
|
|
|
LIVE_TEST_CONTROL_ENV = 'AZURE_CLI_TEST_RUN_LIVE'
|
|
COMMAND_COVERAGE_CONTROL_ENV = 'AZURE_CLI_TEST_COMMAND_COVERAGE'
|
|
MOCKED_SUBSCRIPTION_ID = '00000000-0000-0000-0000-000000000000'
|
|
MOCKED_TENANT_ID = '00000000-0000-0000-0000-000000000000'
|
|
MOCKED_STORAGE_ACCOUNT = 'dummystorage'
|
|
|
|
|
|
# MOCK METHODS
|
|
|
|
# Workaround until https://github.com/kevin1024/vcrpy/issues/293 is fixed.
|
|
vcr_connection_request = vcr.stubs.VCRConnection.request
|
|
|
|
|
|
def patch_vcr_connection_request(*args, **kwargs):
|
|
kwargs.pop('encode_chunked', None)
|
|
vcr_connection_request(*args, **kwargs)
|
|
|
|
|
|
vcr.stubs.VCRConnection.request = patch_vcr_connection_request
|
|
|
|
|
|
def _mock_get_mgmt_service_client(client_type, subscription_bound=True, subscription_id=None,
|
|
api_version=None):
|
|
# version of _get_mgmt_service_client to use when recording or playing tests
|
|
profile = Profile()
|
|
cred, subscription_id, _ = profile.get_login_credentials(subscription_id=subscription_id)
|
|
if subscription_bound:
|
|
client = client_type(cred, subscription_id, api_version=api_version) \
|
|
if api_version else client_type(cred, subscription_id)
|
|
else:
|
|
client = client_type(cred, api_version=api_version) \
|
|
if api_version else client_type(cred)
|
|
|
|
client = _debug.change_ssl_cert_verification(client)
|
|
|
|
client.config.add_user_agent("AZURECLI/TEST/{}".format(core_version))
|
|
|
|
return (client, subscription_id)
|
|
|
|
|
|
def _mock_generate_deployment_name(namespace):
|
|
if not namespace.deployment_name:
|
|
namespace.deployment_name = 'mock-deployment'
|
|
|
|
|
|
def _mock_handle_exceptions(ex):
|
|
raise ex
|
|
|
|
|
|
def _mock_subscriptions(self): # pylint: disable=unused-argument
|
|
return [{
|
|
"id": MOCKED_SUBSCRIPTION_ID,
|
|
"user": {
|
|
"name": "example@example.com",
|
|
"type": "user"
|
|
},
|
|
"state": "Enabled",
|
|
"name": "Example",
|
|
"tenantId": MOCKED_TENANT_ID,
|
|
"isDefault": True}]
|
|
|
|
|
|
def _mock_user_access_token(_, _1, _2, _3): # pylint: disable=unused-argument
|
|
return ('Bearer', 'top-secret-token-for-you', '_')
|
|
|
|
|
|
def _mock_operation_delay(_):
|
|
# don't run time.sleep()
|
|
return
|
|
|
|
|
|
# TEST CHECKS
|
|
|
|
|
|
class JMESPathCheckAssertionError(AssertionError):
|
|
def __init__(self, comparator, actual_result, json_data):
|
|
message = "Actual value '{}' != Expected value '{}'. ".format(
|
|
actual_result,
|
|
comparator.expected_result)
|
|
message += "Query '{}' used on json data '{}'".format(comparator.query, json_data)
|
|
super(JMESPathCheckAssertionError, self).__init__(message)
|
|
|
|
|
|
class JMESPathCheck(object): # pylint: disable=too-few-public-methods
|
|
|
|
def __init__(self, query, expected_result):
|
|
self.query = query
|
|
self.expected_result = expected_result
|
|
|
|
def compare(self, json_data):
|
|
actual_result = _search_result_by_jmespath(json_data, self.query)
|
|
if not actual_result == self.expected_result:
|
|
raise JMESPathCheckAssertionError(self, actual_result, json_data)
|
|
|
|
|
|
class JMESPathPatternCheck(object): # pylint: disable=too-few-public-methods
|
|
|
|
def __init__(self, query, expected_result):
|
|
self.query = query
|
|
self.expected_result = expected_result
|
|
|
|
def compare(self, json_data):
|
|
actual_result = _search_result_by_jmespath(json_data, self.query)
|
|
if not re.match(self.expected_result, str(actual_result), re.IGNORECASE):
|
|
raise JMESPathCheckAssertionError(self, actual_result, json_data)
|
|
|
|
|
|
class BooleanCheck(object): # pylint: disable=too-few-public-methods
|
|
|
|
def __init__(self, expected_result):
|
|
self.expected_result = expected_result
|
|
|
|
def compare(self, data):
|
|
result = str(str(data).lower() in ['yes', 'true', '1'])
|
|
try:
|
|
assert result == str(self.expected_result)
|
|
except AssertionError:
|
|
raise AssertionError("Actual value '{}' != Expected value {}".format(
|
|
result, self.expected_result))
|
|
|
|
|
|
class NoneCheck(object): # pylint: disable=too-few-public-methods
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def compare(self, data): # pylint: disable=no-self-use
|
|
none_strings = ['[]', '{}', 'false']
|
|
try:
|
|
assert not data or data in none_strings
|
|
except AssertionError:
|
|
raise AssertionError("Actual value '{}' != Expected value falsy (None, '', []) or "
|
|
"string in {}".format(data, none_strings))
|
|
|
|
|
|
class StringCheck(object): # pylint: disable=too-few-public-methods
|
|
|
|
def __init__(self, expected_result):
|
|
self.expected_result = expected_result
|
|
|
|
def compare(self, data):
|
|
try:
|
|
result = data.replace('"', '')
|
|
assert result == self.expected_result
|
|
except AssertionError:
|
|
raise AssertionError("Actual value '{}' != Expected value {}".format(
|
|
data, self.expected_result))
|
|
|
|
|
|
# HELPER METHODS
|
|
|
|
|
|
def _scrub_deployment_name(uri):
|
|
return re.sub('/deployments/([^/?]+)', '/deployments/mock-deployment', uri)
|
|
|
|
|
|
def _search_result_by_jmespath(json_data, query):
|
|
if not json_data:
|
|
json_data = '{}'
|
|
json_val = json.loads(json_data)
|
|
return jmespath.search(
|
|
query,
|
|
json_val,
|
|
jmespath.Options(collections.OrderedDict))
|
|
|
|
|
|
def _custom_request_matcher(r1, r2):
|
|
""" Ensure method, path, and query parameters match. """
|
|
if r1.method != r2.method:
|
|
return False
|
|
|
|
url1 = urlparse(r1.uri)
|
|
url2 = urlparse(r2.uri)
|
|
|
|
if url1.path != url2.path:
|
|
return False
|
|
|
|
q1 = parse_qs(url1.query)
|
|
q2 = parse_qs(url2.query)
|
|
shared_keys = set(q1.keys()).intersection(set(q2.keys()))
|
|
|
|
if len(shared_keys) != len(q1) or len(shared_keys) != len(q2):
|
|
return False
|
|
|
|
for key in shared_keys:
|
|
if q1[key][0].lower() != q2[key][0].lower():
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
# MAIN CLASS
|
|
|
|
|
|
class VCRTestBase(unittest.TestCase): # pylint: disable=too-many-instance-attributes
|
|
|
|
FILTER_HEADERS = [
|
|
'authorization',
|
|
'client-request-id',
|
|
'x-ms-client-request-id',
|
|
'x-ms-correlation-request-id',
|
|
'x-ms-ratelimit-remaining-subscription-reads',
|
|
'x-ms-request-id',
|
|
'x-ms-routing-request-id',
|
|
'x-ms-gateway-service-instanceid',
|
|
'x-ms-ratelimit-remaining-tenant-reads',
|
|
'x-ms-served-by',
|
|
]
|
|
|
|
# pylint: disable=too-many-arguments
|
|
def __init__(self, test_file, test_name, run_live=False, debug=False, debug_vcr=False,
|
|
skip_setup=False, skip_teardown=False):
|
|
super(VCRTestBase, self).__init__(test_name)
|
|
self.cli = get_default_cli()
|
|
self.test_name = test_name
|
|
self.recording_dir = os.path.join(os.path.dirname(test_file), 'recordings')
|
|
self.cassette_path = os.path.join(self.recording_dir, '{}.yaml'.format(test_name))
|
|
self.playback = os.path.isfile(self.cassette_path)
|
|
|
|
if os.environ.get(LIVE_TEST_CONTROL_ENV, None) == 'True':
|
|
self.run_live = True
|
|
else:
|
|
self.run_live = run_live
|
|
|
|
self.skip_setup = skip_setup
|
|
self.skip_teardown = skip_teardown
|
|
self.success = False
|
|
self.exception = None
|
|
self.track_commands = os.environ.get(COMMAND_COVERAGE_CONTROL_ENV, None)
|
|
self._debug = debug
|
|
|
|
if not self.playback and ('--buffer' in sys.argv) and not run_live:
|
|
self.exception = CLIError('No recorded result provided for {}.'.format(self.test_name))
|
|
|
|
if debug_vcr:
|
|
import logging
|
|
logging.basicConfig()
|
|
vcr_log = logging.getLogger('vcr')
|
|
vcr_log.setLevel(logging.INFO)
|
|
self.my_vcr = vcr.VCR(
|
|
cassette_library_dir=self.recording_dir,
|
|
before_record_request=self._before_record_request,
|
|
before_record_response=self._before_record_response,
|
|
decode_compressed_response=True,
|
|
serializer='json'
|
|
)
|
|
self.my_vcr.register_matcher('custom', _custom_request_matcher)
|
|
self.my_vcr.match_on = ['custom']
|
|
|
|
def _track_executed_commands(self, command):
|
|
if self.track_commands:
|
|
with open(self.track_commands, 'a+') as f:
|
|
f.write(' '.join(command))
|
|
f.write('\n')
|
|
|
|
def _before_record_request(self, request): # pylint: disable=no-self-use
|
|
# scrub subscription from the uri
|
|
request.uri = re.sub('/subscriptions/([^/]+)/',
|
|
'/subscriptions/{}/'.format(MOCKED_SUBSCRIPTION_ID), request.uri)
|
|
request.uri = re.sub('/graph.windows.net/([^/]+)/',
|
|
'/graph.windows.net/{}/'.format(MOCKED_TENANT_ID), request.uri)
|
|
request.uri = re.sub('/sig=([^/]+)&', '/sig=0000&', request.uri)
|
|
request.uri = _scrub_deployment_name(request.uri)
|
|
|
|
# replace random storage account name with dummy name
|
|
request.uri = re.sub(r'(vcrstorage[\d]+)', MOCKED_STORAGE_ACCOUNT, request.uri)
|
|
# prevents URI mismatch between Python 2 and 3 if request URI has extra / chars
|
|
request.uri = re.sub('//', '/', request.uri)
|
|
request.uri = re.sub('/', '//', request.uri, count=1)
|
|
# do not record requests sent for token refresh'
|
|
if (request.body and 'grant-type=refresh_token' in str(request.body)) or \
|
|
('/oauth2/token' in request.uri):
|
|
request = None
|
|
return request
|
|
|
|
def _before_record_response(self, response): # pylint: disable=no-self-use
|
|
for key in VCRTestBase.FILTER_HEADERS:
|
|
if key in response['headers']:
|
|
del response['headers'][key]
|
|
|
|
def _scrub_body_parameters(value):
|
|
value = re.sub('/subscriptions/([^/]+)/',
|
|
'/subscriptions/{}/'.format(MOCKED_SUBSCRIPTION_ID), value)
|
|
return value
|
|
|
|
for key in response['body']:
|
|
value = response['body'][key].decode('utf-8')
|
|
value = _scrub_body_parameters(value)
|
|
try:
|
|
response['body'][key] = bytes(value, 'utf-8')
|
|
except TypeError:
|
|
response['body'][key] = value.encode('utf-8')
|
|
return response
|
|
|
|
@mock.patch('azure.cli.core.util.handle_exception', _mock_handle_exceptions)
|
|
@mock.patch('azure.cli.core.commands.client_factory._get_mgmt_service_client',
|
|
_mock_get_mgmt_service_client) # pylint: disable=line-too-long
|
|
def _execute_live_or_recording(self):
|
|
# pylint: disable=no-member
|
|
try:
|
|
set_up = getattr(self, "set_up", None)
|
|
if callable(set_up) and not self.skip_setup:
|
|
self.set_up()
|
|
|
|
if self.run_live:
|
|
self.body()
|
|
else:
|
|
with self.my_vcr.use_cassette(self.cassette_path):
|
|
self.body()
|
|
self.success = True
|
|
except Exception as ex:
|
|
raise ex
|
|
finally:
|
|
tear_down = getattr(self, "tear_down", None)
|
|
if callable(tear_down) and not self.skip_teardown:
|
|
self.tear_down()
|
|
|
|
@mock.patch('azure.cli.core._profile.Profile.load_cached_subscriptions', _mock_subscriptions)
|
|
@mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user',
|
|
_mock_user_access_token) # pylint: disable=line-too-long
|
|
@mock.patch('azure.cli.core.util.handle_exception', _mock_handle_exceptions)
|
|
@mock.patch('azure.cli.core.commands.client_factory._get_mgmt_service_client',
|
|
_mock_get_mgmt_service_client) # pylint: disable=line-too-long
|
|
@mock.patch('msrestazure.azure_operation.AzureOperationPoller._delay', _mock_operation_delay)
|
|
@mock.patch('time.sleep', _mock_operation_delay)
|
|
@mock.patch('azure.cli.core.commands.LongRunningOperation._delay', _mock_operation_delay)
|
|
@mock.patch('azure.cli.core.commands.validators.generate_deployment_name',
|
|
_mock_generate_deployment_name)
|
|
def _execute_playback(self):
|
|
# pylint: disable=no-member
|
|
with self.my_vcr.use_cassette(self.cassette_path):
|
|
self.body()
|
|
self.success = True
|
|
|
|
def _post_recording_scrub(self):
|
|
""" Perform post-recording cleanup on the YAML file that can't be accomplished with the
|
|
VCR recording hooks. """
|
|
src_path = self.cassette_path
|
|
rg_name = getattr(self, 'resource_group', None)
|
|
rg_original = getattr(self, 'resource_group_original', None)
|
|
|
|
t = tempfile.NamedTemporaryFile('r+')
|
|
with open(src_path, 'r') as f:
|
|
for line in f:
|
|
# scrub resource group names
|
|
if rg_name != rg_original:
|
|
line = line.replace(rg_name, rg_original)
|
|
# omit bearer tokens
|
|
if 'authorization:' not in line.lower():
|
|
t.write(line)
|
|
t.seek(0)
|
|
with open(src_path, 'w') as f:
|
|
for line in t:
|
|
f.write(line)
|
|
t.close()
|
|
|
|
# COMMAND METHODS
|
|
|
|
def cmd(self, command, checks=None, allowed_exceptions=None,
|
|
debug=False): # pylint: disable=no-self-use
|
|
allowed_exceptions = allowed_exceptions or []
|
|
if not isinstance(allowed_exceptions, list):
|
|
allowed_exceptions = [allowed_exceptions]
|
|
|
|
if self._debug or debug:
|
|
print('\n\tRUNNING: {}'.format(command))
|
|
command_list = shlex.split(command)
|
|
try:
|
|
result = self.cli.invoke(command_list)
|
|
except Exception as ex: # pylint: disable=broad-except
|
|
ex_msg = str(ex)
|
|
if not next((x for x in allowed_exceptions if x in ex_msg), None):
|
|
raise ex
|
|
self._track_executed_commands(command_list)
|
|
|
|
if self._debug or debug:
|
|
print('\tRESULT: {}\n'.format(result))
|
|
|
|
if checks:
|
|
checks = [checks] if not isinstance(checks, list) else checks
|
|
for check in checks:
|
|
check.compare(result)
|
|
|
|
if '-o' in command_list and 'tsv' in command_list:
|
|
return result
|
|
else:
|
|
try:
|
|
result = result or '{}'
|
|
return json.loads(result)
|
|
except Exception: # pylint: disable=broad-except
|
|
return result
|
|
|
|
def set_env(self, key, val): # pylint: disable=no-self-use
|
|
os.environ[key] = val
|
|
|
|
def pop_env(self, key): # pylint: disable=no-self-use
|
|
return os.environ.pop(key, None)
|
|
|
|
def execute(self):
|
|
''' Method to actually start execution of the test. Must be called from the test_<name>
|
|
method of the test class. '''
|
|
try:
|
|
if self.run_live:
|
|
print('RUN LIVE: {}'.format(self.test_name))
|
|
self._execute_live_or_recording()
|
|
elif self.playback:
|
|
print('PLAYBACK: {}'.format(self.test_name))
|
|
self._execute_playback()
|
|
else:
|
|
print('RECORDING: {}'.format(self.test_name))
|
|
self._execute_live_or_recording()
|
|
except Exception as ex:
|
|
raise ex
|
|
finally:
|
|
if not self.success and not self.playback and os.path.isfile(self.cassette_path):
|
|
print('DISCARDING RECORDING: {}'.format(self.cassette_path))
|
|
os.remove(self.cassette_path)
|
|
elif self.success and not self.playback and os.path.isfile(self.cassette_path):
|
|
try:
|
|
self._post_recording_scrub()
|
|
except Exception: # pylint: disable=broad-except
|
|
os.remove(self.cassette_path)
|
|
|
|
|
|
class ResourceGroupVCRTestBase(VCRTestBase):
|
|
# pylint: disable=too-many-arguments
|
|
|
|
def __init__(self, test_file, test_name, resource_group='vcr_resource_group', run_live=False,
|
|
debug=False, debug_vcr=False, skip_setup=False, skip_teardown=False):
|
|
super(ResourceGroupVCRTestBase, self).__init__(test_file, test_name, run_live=run_live,
|
|
debug=debug, debug_vcr=debug_vcr,
|
|
skip_setup=skip_setup,
|
|
skip_teardown=skip_teardown)
|
|
self.resource_group_original = resource_group
|
|
random_tag = '_{}_'.format(''.join((choice(ascii_lowercase + digits) for _ in range(4))))
|
|
self.resource_group = '{}{}'.format(resource_group, '' if self.playback else random_tag)
|
|
self.location = 'westus'
|
|
|
|
def set_up(self):
|
|
self.cmd('group create --location {} --name {} --tags use=az-test'.format(
|
|
self.location, self.resource_group))
|
|
|
|
def tear_down(self):
|
|
self.cmd('group delete --name {} --no-wait --yes'.format(self.resource_group))
|
|
|
|
|
|
class StorageAccountVCRTestBase(VCRTestBase):
|
|
account_location = 'westus'
|
|
account_sku = 'Standard_LRS'
|
|
|
|
# pylint: disable=too-many-arguments
|
|
def __init__(self, test_file, test_name, resource_group='vcr_resource_group', run_live=False,
|
|
debug=False, debug_vcr=False, skip_setup=False, skip_teardown=False):
|
|
super(StorageAccountVCRTestBase, self).__init__(test_file, test_name, run_live=run_live,
|
|
debug=debug, debug_vcr=debug_vcr,
|
|
skip_setup=skip_setup,
|
|
skip_teardown=skip_teardown)
|
|
self.resource_group_original = resource_group
|
|
self.resource_group = '{}{}'.format(resource_group,
|
|
'' if self.playback else self.generate_random_tag())
|
|
self.account = MOCKED_STORAGE_ACCOUNT if self.playback else self.generate_account_name()
|
|
|
|
def set_up(self):
|
|
self.cmd('group create --location {} --name {} --tags use=az-test'.format(
|
|
self.account_location, self.resource_group))
|
|
self.cmd('storage account create --sku {} -l {} -n {} -g {}'.format(
|
|
self.account_sku, self.account_location, self.account, self.resource_group))
|
|
|
|
def tear_down(self):
|
|
self.cmd('storage account delete -g {} -n {} --yes'.format(
|
|
self.resource_group, self.account))
|
|
self.cmd('group delete --name {} --no-wait --yes'.format(self.resource_group))
|
|
|
|
@classmethod
|
|
def generate_account_name(cls):
|
|
return 'vcrstorage{}'.format(random_string(12, digits_only=True))
|
|
|
|
@classmethod
|
|
def generate_random_tag(cls):
|
|
return '_{}_'.format(random_string(4, force_lower=True))
|