[V3] SAS Auth Revisions (#1103)
* Restructured SAS-related objects and their responsibilities * Removed references to "renewing" a SAS token due to bad semantics * Keeping a valid SAS token is no longer the domain of the IoTHubMQTTClient * Made signing data with a SigningMechanism asynchronous (in signature) - actual async implementation will come later.
This commit is contained in:
Родитель
bb4cb90295
Коммит
a9c20135dc
|
@ -10,7 +10,9 @@ repos:
|
|||
- id: flake8
|
||||
args: ['--config=.flake8']
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v0.991
|
||||
rev: v1.0.0
|
||||
# NOTE: Azure SDK guidelines say to use 0.931 as a pin, but it seems bugged.
|
||||
# Getting some clarity on this, but for now, we're going to just use the most recent
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: v3_async_wip/ # for now only the new v3 files have typings
|
||||
|
|
|
@ -19,8 +19,11 @@ class HangingAsyncMock(mock.AsyncMock):
|
|||
self._stop_hanging = asyncio.Event()
|
||||
|
||||
async def _do_hang(self, *args, **kwargs):
|
||||
self._is_hanging.set()
|
||||
if not self._is_hanging.is_set():
|
||||
self._stop_hanging.clear()
|
||||
self._is_hanging.set()
|
||||
await self._stop_hanging.wait()
|
||||
return self.return_value
|
||||
|
||||
async def wait_for_hang(self):
|
||||
await self._is_hanging.wait()
|
||||
|
@ -29,4 +32,8 @@ class HangingAsyncMock(mock.AsyncMock):
|
|||
return self._is_hanging.is_set()
|
||||
|
||||
def stop_hanging(self):
|
||||
self._stop_hanging.set()
|
||||
if self._is_hanging.is_set():
|
||||
self._stop_hanging.set()
|
||||
self._is_hanging.clear()
|
||||
else:
|
||||
raise RuntimeError("Not hanging")
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
[mypy]
|
||||
show_error_codes = True
|
|
@ -14,5 +14,5 @@ def arbitrary_exception():
|
|||
class ArbitraryException(Exception):
|
||||
pass
|
||||
|
||||
e = ArbitraryException()
|
||||
e = ArbitraryException("arbitrary description")
|
||||
return e
|
||||
|
|
|
@ -109,14 +109,14 @@ class TestIoTEdgeHsmGetCertificate(object):
|
|||
return mocker.patch.object(requests, "get")
|
||||
|
||||
@pytest.mark.it("Sends an HTTP GET request to retrieve the trust bundle from Edge")
|
||||
def test_requests_trust_bundle(self, mocker, edge_hsm, mock_requests_get):
|
||||
async def test_requests_trust_bundle(self, mocker, edge_hsm, mock_requests_get):
|
||||
expected_url = edge_hsm.workload_uri + "trust-bundle"
|
||||
expected_params = {"api-version": edge_hsm.api_version}
|
||||
expected_headers = {
|
||||
"User-Agent": urllib.parse.quote_plus(user_agent.get_iothub_user_agent())
|
||||
}
|
||||
|
||||
edge_hsm.get_certificate()
|
||||
await edge_hsm.get_certificate()
|
||||
|
||||
assert mock_requests_get.call_count == 1
|
||||
assert mock_requests_get.call_args == mocker.call(
|
||||
|
@ -124,43 +124,43 @@ class TestIoTEdgeHsmGetCertificate(object):
|
|||
)
|
||||
|
||||
@pytest.mark.it("Returns the certificate from the trust bundle received from Edge")
|
||||
def test_returns_certificate(self, edge_hsm, mock_requests_get):
|
||||
async def test_returns_certificate(self, edge_hsm, mock_requests_get):
|
||||
mock_response = mock_requests_get.return_value
|
||||
certificate = "my certificate"
|
||||
mock_response.json.return_value = {"certificate": certificate}
|
||||
|
||||
returned_cert = edge_hsm.get_certificate()
|
||||
returned_cert = await edge_hsm.get_certificate()
|
||||
|
||||
assert returned_cert is certificate
|
||||
|
||||
@pytest.mark.it("Raises IoTEdgeError if a bad request is made to Edge")
|
||||
def test_bad_request(self, edge_hsm, mock_requests_get):
|
||||
async def test_bad_request(self, edge_hsm, mock_requests_get):
|
||||
mock_response = mock_requests_get.return_value
|
||||
error = requests.exceptions.HTTPError()
|
||||
mock_response.raise_for_status.side_effect = error
|
||||
|
||||
with pytest.raises(IoTEdgeError) as e_info:
|
||||
edge_hsm.get_certificate()
|
||||
await edge_hsm.get_certificate()
|
||||
assert e_info.value.__cause__ is error
|
||||
|
||||
@pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the trust bundle")
|
||||
def test_bad_json(self, edge_hsm, mock_requests_get):
|
||||
async def test_bad_json(self, edge_hsm, mock_requests_get):
|
||||
mock_response = mock_requests_get.return_value
|
||||
error = ValueError()
|
||||
mock_response.json.side_effect = error
|
||||
|
||||
with pytest.raises(IoTEdgeError) as e_info:
|
||||
edge_hsm.get_certificate()
|
||||
await edge_hsm.get_certificate()
|
||||
assert e_info.value.__cause__ is error
|
||||
|
||||
@pytest.mark.it("Raises IoTEdgeError if the certificate is missing from the trust bundle")
|
||||
def test_bad_trust_bundle(self, edge_hsm, mock_requests_get):
|
||||
async def test_bad_trust_bundle(self, edge_hsm, mock_requests_get):
|
||||
mock_response = mock_requests_get.return_value
|
||||
# Return an empty json dict with no 'certificate' key
|
||||
mock_response.json.return_value = {}
|
||||
|
||||
with pytest.raises(IoTEdgeError):
|
||||
edge_hsm.get_certificate()
|
||||
await edge_hsm.get_certificate()
|
||||
|
||||
|
||||
@pytest.mark.describe("IoTEdgeHsm - .sign()")
|
||||
|
@ -172,7 +172,7 @@ class TestIoTEdgeHsmSign(object):
|
|||
@pytest.mark.it(
|
||||
"Makes an HTTP request to Edge to sign a piece of string data using the HMAC-SHA256 algorithm"
|
||||
)
|
||||
def test_requests_data_signing(self, mocker, edge_hsm, mock_requests_post):
|
||||
async def test_requests_data_signing(self, mocker, edge_hsm, mock_requests_post):
|
||||
data_str = "somedata"
|
||||
data_str_b64 = "c29tZWRhdGE="
|
||||
mock_requests_post.return_value.json.return_value = {"digest": "somedigest"}
|
||||
|
@ -187,7 +187,7 @@ class TestIoTEdgeHsmSign(object):
|
|||
}
|
||||
expected_json = json.dumps({"keyId": "primary", "algo": "HMACSHA256", "data": data_str_b64})
|
||||
|
||||
edge_hsm.sign(data_str)
|
||||
await edge_hsm.sign(data_str)
|
||||
|
||||
assert mock_requests_post.call_count == 1
|
||||
assert mock_requests_post.call_args == mocker.call(
|
||||
|
@ -195,14 +195,14 @@ class TestIoTEdgeHsmSign(object):
|
|||
)
|
||||
|
||||
@pytest.mark.it("Base64 encodes the string data in the request")
|
||||
def test_b64_encodes_data(self, edge_hsm, mock_requests_post):
|
||||
async def test_b64_encodes_data(self, edge_hsm, mock_requests_post):
|
||||
# This test is actually implicitly tested in the first test, but it's
|
||||
# important to have an explicit test for it since it's a requirement
|
||||
data_str = "somedata"
|
||||
data_str_b64 = base64.b64encode(data_str.encode("utf-8")).decode()
|
||||
mock_requests_post.return_value.json.return_value = {"digest": "somedigest"}
|
||||
|
||||
edge_hsm.sign(data_str)
|
||||
await edge_hsm.sign(data_str)
|
||||
|
||||
sent_data = json.loads(mock_requests_post.call_args[1]["data"])["data"]
|
||||
|
||||
|
@ -210,11 +210,11 @@ class TestIoTEdgeHsmSign(object):
|
|||
assert sent_data == data_str_b64
|
||||
|
||||
@pytest.mark.it("Returns the signed data received from Edge")
|
||||
def test_returns_signed_data(self, edge_hsm, mock_requests_post):
|
||||
async def test_returns_signed_data(self, edge_hsm, mock_requests_post):
|
||||
expected_digest = "somedigest"
|
||||
mock_requests_post.return_value.json.return_value = {"digest": expected_digest}
|
||||
|
||||
signed_data = edge_hsm.sign("somedata")
|
||||
signed_data = await edge_hsm.sign("somedata")
|
||||
|
||||
assert signed_data == expected_digest
|
||||
|
||||
|
@ -226,38 +226,38 @@ class TestIoTEdgeHsmSign(object):
|
|||
pytest.param(b"sign this message", "c2lnbiB0aGlzIG1lc3NhZ2U=", id="Bytes"),
|
||||
],
|
||||
)
|
||||
def test_supported_types(
|
||||
async def test_supported_types(
|
||||
self, edge_hsm, data_string, expected_request_data, mock_requests_post
|
||||
):
|
||||
mock_requests_post.return_value.json.return_value = {"digest": "somedigest"}
|
||||
edge_hsm.sign(data_string)
|
||||
await edge_hsm.sign(data_string)
|
||||
sent_data = json.loads(mock_requests_post.call_args[1]["data"])["data"]
|
||||
|
||||
assert sent_data == expected_request_data
|
||||
|
||||
@pytest.mark.it("Raises IoTEdgeError if a bad request is made to EdgeHub")
|
||||
def test_bad_request(self, edge_hsm, mock_requests_post):
|
||||
async def test_bad_request(self, edge_hsm, mock_requests_post):
|
||||
mock_response = mock_requests_post.return_value
|
||||
error = requests.exceptions.HTTPError()
|
||||
mock_response.raise_for_status.side_effect = error
|
||||
|
||||
with pytest.raises(IoTEdgeError) as e_info:
|
||||
edge_hsm.sign("somedata")
|
||||
await edge_hsm.sign("somedata")
|
||||
assert e_info.value.__cause__ is error
|
||||
|
||||
@pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the signed response")
|
||||
def test_bad_json(self, edge_hsm, mock_requests_post):
|
||||
async def test_bad_json(self, edge_hsm, mock_requests_post):
|
||||
mock_response = mock_requests_post.return_value
|
||||
error = ValueError()
|
||||
mock_response.json.side_effect = error
|
||||
with pytest.raises(IoTEdgeError) as e_info:
|
||||
edge_hsm.sign("somedata")
|
||||
await edge_hsm.sign("somedata")
|
||||
assert e_info.value.__cause__ is error
|
||||
|
||||
@pytest.mark.it("Raises IoTEdgeError if the signed data is missing from the response")
|
||||
def test_bad_response(self, edge_hsm, mock_requests_post):
|
||||
async def test_bad_response(self, edge_hsm, mock_requests_post):
|
||||
mock_response = mock_requests_post.return_value
|
||||
mock_response.json.return_value = {}
|
||||
|
||||
with pytest.raises(IoTEdgeError):
|
||||
edge_hsm.sign("somedata")
|
||||
await edge_hsm.sign("somedata")
|
||||
|
|
|
@ -8,6 +8,7 @@ import asyncio
|
|||
import json
|
||||
import pytest
|
||||
import ssl
|
||||
import sys
|
||||
import time
|
||||
import typing
|
||||
import urllib
|
||||
|
@ -22,12 +23,13 @@ from v3_async_wip import config, constant, models, user_agent
|
|||
from v3_async_wip import mqtt_client as mqtt
|
||||
from v3_async_wip import request_response as rr
|
||||
from v3_async_wip import mqtt_topic_iothub as mqtt_topic
|
||||
from azure.iot.device.common.auth import sastoken as st
|
||||
from azure.iot.device.common import alarm
|
||||
from v3_async_wip import sastoken as st
|
||||
|
||||
|
||||
FAKE_DEVICE_ID = "fake_device_id"
|
||||
FAKE_MODULE_ID = "fake_module_id"
|
||||
FAKE_DEVICE_CLIENT_ID = "fake_device_id"
|
||||
FAKE_MODULE_CLIENT_ID = "fake_device_id/fake_module_id"
|
||||
FAKE_HOSTNAME = "fake.hostname"
|
||||
FAKE_GATEWAY_HOSTNAME = "fake.gateway.hostname"
|
||||
FAKE_SIGNATURE = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI="
|
||||
|
@ -71,21 +73,24 @@ mqtt_unsubscribe_exceptions = [
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def renewable_sastoken(mocker):
|
||||
mock_signing_mechanism = mocker.MagicMock()
|
||||
mock_signing_mechanism.sign.return_value = FAKE_SIGNATURE
|
||||
sastoken = st.RenewableSasToken(uri=FAKE_URI, signing_mechanism=mock_signing_mechanism)
|
||||
# sastoken.refresh = mocker.MagicMock()
|
||||
return sastoken
|
||||
def sastoken():
|
||||
sastoken_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format(
|
||||
resource=FAKE_URI, signature=FAKE_SIGNATURE, expiry=FAKE_EXPIRY
|
||||
)
|
||||
return st.SasToken(sastoken_str)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nonrenewable_sastoken():
|
||||
token_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format(
|
||||
resource=FAKE_URI, signature=FAKE_SIGNATURE, expiry=FAKE_EXPIRY
|
||||
)
|
||||
sastoken = st.NonRenewableSasToken(token_str)
|
||||
return sastoken
|
||||
def mock_sastoken_provider(mocker, sastoken):
|
||||
provider = mocker.MagicMock(spec=st.SasTokenProvider)
|
||||
provider.get_current_sastoken.return_value = sastoken
|
||||
# Use a HangingAsyncMock so that it isn't constantly returning
|
||||
provider.wait_for_new_sastoken = custom_mock.HangingAsyncMock()
|
||||
provider.wait_for_new_sastoken.return_value = sastoken
|
||||
# NOTE: Technically, this mock just always returns the same SasToken,
|
||||
# even after an "update", but for the purposes of testing at this level,
|
||||
# it doesn't matter
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -109,6 +114,8 @@ async def client(mocker, client_config):
|
|||
client._mqtt_client.subscribe = mocker.AsyncMock()
|
||||
client._mqtt_client.unsubscribe = mocker.AsyncMock()
|
||||
client._mqtt_client.publish = mocker.AsyncMock()
|
||||
# Also mock the set credentials method since we test that
|
||||
client._mqtt_client.set_credentials = mocker.MagicMock()
|
||||
|
||||
yield client
|
||||
await client.shutdown()
|
||||
|
@ -130,7 +137,7 @@ class TestIoTHubMQTTClientInstantiation:
|
|||
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
|
||||
],
|
||||
)
|
||||
async def test_ids(self, client_config, device_id, module_id):
|
||||
async def test_simple_ids(self, client_config, device_id, module_id):
|
||||
client_config.device_id = device_id
|
||||
client_config.module_id = module_id
|
||||
|
||||
|
@ -139,69 +146,141 @@ class TestIoTHubMQTTClientInstantiation:
|
|||
assert client._module_id == client_config.module_id
|
||||
await client.shutdown()
|
||||
|
||||
@pytest.mark.it("Stores the SasToken value from the IoTHubClientConfig as an attribute")
|
||||
@pytest.mark.parametrize(
|
||||
"sastoken",
|
||||
[
|
||||
pytest.param(lazy_fixture("renewable_sastoken"), id="Renewable SAS Authentication"),
|
||||
pytest.param(
|
||||
lazy_fixture("nonrenewable_sastoken"), id="Non-Renewable SAS Authentication"
|
||||
),
|
||||
pytest.param(None, id="Non-SAS Authentication"),
|
||||
],
|
||||
)
|
||||
async def test_sastoken(self, client_config, sastoken):
|
||||
client_config.sastoken = sastoken
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
assert client._sastoken is client_config.sastoken
|
||||
await client.shutdown()
|
||||
|
||||
# NOTE: For testing the functionality of this update Alarm, see the corresponding test suite
|
||||
@pytest.mark.it("Starts the SasToken update Alarm if the IoTHubClientConfig has a SasToken")
|
||||
@pytest.mark.parametrize(
|
||||
"sastoken",
|
||||
[
|
||||
pytest.param(lazy_fixture("renewable_sastoken"), id="Renewable SAS Authentication"),
|
||||
pytest.param(
|
||||
lazy_fixture("nonrenewable_sastoken"), id="Non-Renewable SAS Authentication"
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_sastoken_update_task(self, client_config, sastoken):
|
||||
client_config.sastoken = sastoken
|
||||
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
assert isinstance(client._sastoken_update_alarm, alarm.Alarm)
|
||||
assert client._sastoken_update_alarm.is_alive()
|
||||
assert client._sastoken_update_alarm.daemon is True
|
||||
await client.shutdown()
|
||||
|
||||
@pytest.mark.it("Does not start an Alarm if the IoTHubClientConfig does not have a SasToken")
|
||||
async def test_no_sastoken_update_task(self, client_config):
|
||||
assert client_config.sastoken is None
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
assert client._sastoken_update_alarm is None
|
||||
await client.shutdown()
|
||||
|
||||
@pytest.mark.it("Creates an empty RequestLedger")
|
||||
async def test_request_ledger(self, client_config):
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
assert isinstance(client._request_ledger, rr.RequestLedger)
|
||||
assert len(client._request_ledger) == 0
|
||||
await client.shutdown()
|
||||
|
||||
@pytest.mark.it(
|
||||
"Creates an MQTTClient instance based on the configuration of IoTHubClientConfig"
|
||||
"Derives the `client_id` from the `device_id` and `module_id` and stores it as an attribute"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"device_id, module_id, expected_client_id",
|
||||
[
|
||||
pytest.param(FAKE_DEVICE_ID, None, FAKE_DEVICE_ID, id="Device Configuration"),
|
||||
pytest.param(FAKE_DEVICE_ID, None, FAKE_DEVICE_CLIENT_ID, id="Device Configuration"),
|
||||
pytest.param(
|
||||
FAKE_DEVICE_ID,
|
||||
FAKE_MODULE_ID,
|
||||
"{}/{}".format(FAKE_DEVICE_ID, FAKE_MODULE_ID),
|
||||
id="Module Configuration",
|
||||
FAKE_DEVICE_ID, FAKE_MODULE_ID, FAKE_MODULE_CLIENT_ID, id="Module Configuration"
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_client_id(self, client_config, device_id, module_id, expected_client_id):
|
||||
client_config.device_id = device_id
|
||||
client_config.module_id = module_id
|
||||
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
assert client._client_id == expected_client_id
|
||||
await client.shutdown()
|
||||
|
||||
@pytest.mark.it("Derives the `username` and stores the result as an attribute")
|
||||
@pytest.mark.parametrize(
|
||||
"device_id, module_id, client_id",
|
||||
[
|
||||
pytest.param(FAKE_DEVICE_ID, None, FAKE_DEVICE_CLIENT_ID, id="Device Configuration"),
|
||||
pytest.param(
|
||||
FAKE_DEVICE_ID, FAKE_MODULE_ID, FAKE_MODULE_CLIENT_ID, id="Module Configuration"
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"hostname, gateway_hostname",
|
||||
[
|
||||
pytest.param(FAKE_HOSTNAME, None, id="No Gateway Hostname"),
|
||||
pytest.param(FAKE_HOSTNAME, FAKE_GATEWAY_HOSTNAME, id="Gateway Hostname"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"product_info",
|
||||
[
|
||||
pytest.param("", id="No Product Info"),
|
||||
pytest.param("my-product-info", id="Custom Product Info"),
|
||||
pytest.param("my$product$info", id="Custom Product Info (URL encoding required)"),
|
||||
pytest.param(
|
||||
constant.DIGITAL_TWIN_PREFIX + ":com:example:ClimateSensor;1",
|
||||
id="Digital Twin Product Info",
|
||||
),
|
||||
pytest.param(
|
||||
constant.DIGITAL_TWIN_PREFIX + ":com:example:$Climate$ensor;1",
|
||||
id="Digital Twin Product Info (URL encoding required)",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_username(
|
||||
self,
|
||||
client_config,
|
||||
device_id,
|
||||
module_id,
|
||||
client_id,
|
||||
hostname,
|
||||
gateway_hostname,
|
||||
product_info,
|
||||
):
|
||||
client_config.device_id = device_id
|
||||
client_config.module_id = module_id
|
||||
client_config.hostname = hostname
|
||||
client_config.gateway_hostname = gateway_hostname
|
||||
client_config.product_info = product_info
|
||||
|
||||
ua = user_agent.get_iothub_user_agent()
|
||||
url_encoded_user_agent = urllib.parse.quote(ua, safe="")
|
||||
# NOTE: This assertion shows the URL encoding was meaningful
|
||||
assert user_agent != url_encoded_user_agent
|
||||
|
||||
url_encoded_product_info = urllib.parse.quote(product_info, safe="")
|
||||
# NOTE: We can't really make the same assertion here, because this isn't always meaningful
|
||||
|
||||
# Determine expected username based on config
|
||||
if product_info.startswith(constant.DIGITAL_TWIN_PREFIX):
|
||||
expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}&{digital_twin_prefix}={custom_product_info}".format(
|
||||
hostname=hostname,
|
||||
client_id=client_id,
|
||||
api_version=constant.IOTHUB_API_VERSION,
|
||||
user_agent=url_encoded_user_agent,
|
||||
digital_twin_prefix=constant.DIGITAL_TWIN_QUERY_HEADER,
|
||||
custom_product_info=url_encoded_product_info,
|
||||
)
|
||||
else:
|
||||
expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}{custom_product_info}".format(
|
||||
hostname=hostname,
|
||||
client_id=client_id,
|
||||
api_version=constant.IOTHUB_API_VERSION,
|
||||
user_agent=url_encoded_user_agent,
|
||||
custom_product_info=url_encoded_product_info,
|
||||
)
|
||||
# NOTE: Regarding the above, no matter if we have a gateway hostname set or not, it is the hostname that is always used.
|
||||
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
# The expected username was derived
|
||||
assert client._username == expected_username
|
||||
await client.shutdown()
|
||||
|
||||
@pytest.mark.it("Stores the `sastoken_provider` from the IoTHubClientConfig as an attribute")
|
||||
@pytest.mark.parametrize(
|
||||
"sastoken_provider",
|
||||
[
|
||||
pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"),
|
||||
pytest.param(None, id="No SasTokenProvider present"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"device_id, module_id",
|
||||
[
|
||||
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
|
||||
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
|
||||
],
|
||||
)
|
||||
async def test_sastoken_provider(self, client_config, sastoken_provider, device_id, module_id):
|
||||
client_config.device_id = device_id
|
||||
client_config.module_id = module_id
|
||||
client_config.sastoken_provider = sastoken_provider
|
||||
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
assert client._sastoken_provider is sastoken_provider
|
||||
await client.shutdown()
|
||||
|
||||
@pytest.mark.it(
|
||||
"Creates an MQTTClient instance based on the configuration of IoTHubClientConfig and stores it as an attribute"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"device_id, module_id, expected_client_id",
|
||||
[
|
||||
pytest.param(FAKE_DEVICE_ID, None, FAKE_DEVICE_CLIENT_ID, id="Device Configuration"),
|
||||
pytest.param(
|
||||
FAKE_DEVICE_ID, FAKE_MODULE_ID, FAKE_MODULE_CLIENT_ID, id="Module Configuration"
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -268,108 +347,58 @@ class TestIoTHubMQTTClientInstantiation:
|
|||
# Graceful exit
|
||||
await client.shutdown()
|
||||
|
||||
@pytest.mark.it("Sets credentials on the newly created MQTTClient instance")
|
||||
@pytest.mark.parametrize(
|
||||
"device_id, module_id, expected_client_id",
|
||||
[
|
||||
pytest.param(FAKE_DEVICE_ID, None, FAKE_DEVICE_ID, id="Device Configuration"),
|
||||
pytest.param(
|
||||
FAKE_DEVICE_ID,
|
||||
FAKE_MODULE_ID,
|
||||
FAKE_DEVICE_ID + "/" + FAKE_MODULE_ID,
|
||||
id="Module Configuration",
|
||||
),
|
||||
],
|
||||
@pytest.mark.it(
|
||||
"Uses the derived `username` as the username, with no password as the credentials for the newly created MQTTClient instance when not using SAS authentication"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"hostname, gateway_hostname, expected_hostname",
|
||||
"device_id, module_id",
|
||||
[
|
||||
pytest.param(FAKE_HOSTNAME, None, FAKE_HOSTNAME, id="No Gateway Hostname"),
|
||||
pytest.param(FAKE_HOSTNAME, FAKE_GATEWAY_HOSTNAME, FAKE_HOSTNAME, id="Gateway Hostname")
|
||||
# NOTE: Yes, that's right, we expect to always use the hostname, never the gateway hostname
|
||||
# at least, when it comes to credentials
|
||||
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
|
||||
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"sastoken",
|
||||
[
|
||||
pytest.param(lazy_fixture("renewable_sastoken"), id="Renewable SAS Authentication"),
|
||||
pytest.param(
|
||||
lazy_fixture("nonrenewable_sastoken"), id="Non-Renewable SAS Authentication"
|
||||
),
|
||||
pytest.param(None, id="Non-SAS Authentication"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"product_info",
|
||||
[
|
||||
pytest.param("", id="No Product Info"),
|
||||
pytest.param("my-product-info", id="Custom Product Info"),
|
||||
pytest.param("my$product$info", id="Custom Product Info (URL encoding required)"),
|
||||
pytest.param(
|
||||
constant.DIGITAL_TWIN_PREFIX + ":com:example:ClimateSensor;1",
|
||||
id="Digital Twin Product Info",
|
||||
),
|
||||
pytest.param(
|
||||
constant.DIGITAL_TWIN_PREFIX + ":com:example:$Climate$ensor;1",
|
||||
id="Digital Twin Product Info (URL encoding required)",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_mqtt_client_credentials(
|
||||
self,
|
||||
mocker,
|
||||
client_config,
|
||||
sastoken,
|
||||
device_id,
|
||||
module_id,
|
||||
expected_client_id,
|
||||
hostname,
|
||||
gateway_hostname,
|
||||
expected_hostname,
|
||||
product_info,
|
||||
async def test_mqtt_client_credentials_no_sas(
|
||||
self, mocker, client_config, device_id, module_id
|
||||
):
|
||||
client_config.device_id = device_id
|
||||
client_config.module_id = module_id
|
||||
client_config.hostname = hostname
|
||||
client_config.gateway_hostname = gateway_hostname
|
||||
client_config.product_info = product_info
|
||||
client_config.sastoken = sastoken
|
||||
assert client_config.sastoken_provider is None
|
||||
|
||||
# Determine expected username based on config
|
||||
if product_info.startswith(constant.DIGITAL_TWIN_PREFIX):
|
||||
expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}&{digital_twin_prefix}={custom_product_info}".format(
|
||||
hostname=hostname,
|
||||
client_id=expected_client_id,
|
||||
api_version=constant.IOTHUB_API_VERSION,
|
||||
user_agent=urllib.parse.quote(user_agent.get_iothub_user_agent(), safe=""),
|
||||
digital_twin_prefix=constant.DIGITAL_TWIN_QUERY_HEADER,
|
||||
custom_product_info=urllib.parse.quote(product_info, safe=""),
|
||||
)
|
||||
else:
|
||||
expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}{custom_product_info}".format(
|
||||
hostname=hostname,
|
||||
client_id=expected_client_id,
|
||||
api_version=constant.IOTHUB_API_VERSION,
|
||||
user_agent=urllib.parse.quote(user_agent.get_iothub_user_agent(), safe=""),
|
||||
custom_product_info=urllib.parse.quote(product_info, safe=""),
|
||||
)
|
||||
|
||||
# Determine expected password based on sastoken
|
||||
if sastoken:
|
||||
expected_password = str(renewable_sastoken)
|
||||
else:
|
||||
expected_password = None
|
||||
|
||||
# Create the client under test
|
||||
mocker.patch.object(mqtt, "MQTTClient", spec=mqtt.MQTTClient)
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
|
||||
# Credentials were set
|
||||
expected_username = client._username
|
||||
expected_password = None
|
||||
assert client._mqtt_client.set_credentials.call_count == 1
|
||||
assert client._mqtt_client.set_credentials.call_args(expected_username, expected_password)
|
||||
|
||||
await client.shutdown()
|
||||
|
||||
@pytest.mark.it(
|
||||
"Uses the derived `username` as the username and the string-converted current SasToken from the SasTokenProvider as the password when setting credentials for the newly created MQTTClient instance when using SAS authentication"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"device_id, module_id",
|
||||
[
|
||||
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
|
||||
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
|
||||
],
|
||||
)
|
||||
async def test_mqtt_client_credentials_with_sas(
|
||||
self, mocker, client_config, device_id, module_id, mock_sastoken_provider
|
||||
):
|
||||
client_config.device_id = device_id
|
||||
client_config.module_id = module_id
|
||||
client_config.sastoken_provider = mock_sastoken_provider
|
||||
|
||||
fake_sastoken = mock_sastoken_provider.get_current_sastoken.return_value
|
||||
|
||||
mocker.patch.object(mqtt, "MQTTClient", spec=mqtt.MQTTClient)
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
expected_username = client._username
|
||||
expected_password = str(fake_sastoken)
|
||||
assert client._mqtt_client.set_credentials.call_count == 1
|
||||
assert client._mqtt_client.set_credentials.call_args(expected_username, expected_password)
|
||||
|
||||
# Graceful exit
|
||||
await client.shutdown()
|
||||
|
||||
@pytest.mark.it("Adds incoming message filter on the MQTTClient for C2D messages")
|
||||
|
@ -580,30 +609,129 @@ class TestIoTHubMQTTClientInstantiation:
|
|||
|
||||
await client.shutdown()
|
||||
|
||||
# NOTE: For testing the functionality of this task, see the corresponding test suite (TestIoTHubMQTTClientIncomingTwinResponse)
|
||||
# TODO: Consider removing this test. Does it really test anything? A Task was created? Who cares?
|
||||
@pytest.mark.it("Creates a ongoing task to listen for twin responses")
|
||||
async def test_twin_response_task(self, client_config):
|
||||
@pytest.mark.it("Creates an empty RequestLedger")
|
||||
@pytest.mark.parametrize(
|
||||
"device_id, module_id",
|
||||
[
|
||||
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
|
||||
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
|
||||
],
|
||||
)
|
||||
async def test_request_ledger(self, client_config, device_id, module_id):
|
||||
client_config.device_id = device_id
|
||||
client_config.module_id = module_id
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
|
||||
assert isinstance(client._twin_response_listener, asyncio.Task)
|
||||
assert not client._twin_response_listener.done()
|
||||
|
||||
assert isinstance(client._request_ledger, rr.RequestLedger)
|
||||
assert len(client._request_ledger) == 0
|
||||
await client.shutdown()
|
||||
|
||||
@pytest.mark.it("Sets the twin_responses_enabled flag to False")
|
||||
async def test_twin_responses_enabled(self, client_config):
|
||||
@pytest.mark.parametrize(
|
||||
"device_id, module_id",
|
||||
[
|
||||
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
|
||||
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
|
||||
],
|
||||
)
|
||||
async def test_twin_responses_enabled(self, client_config, device_id, module_id):
|
||||
client_config.device_id = device_id
|
||||
client_config.module_id = module_id
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
assert client._twin_responses_enabled is False
|
||||
|
||||
await client.shutdown()
|
||||
|
||||
# NOTE: For testing the functionality of this task, see the corresponding test suite (TestIoTHubMQTTClientIncomingTwinResponse)
|
||||
@pytest.mark.it(
|
||||
"Begins running the ._process_twin_responses() coroutine method as a background task, storing it as an attribute"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"device_id, module_id",
|
||||
[
|
||||
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
|
||||
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
|
||||
],
|
||||
)
|
||||
async def test_process_twin_responses_bg_task(self, client_config, device_id, module_id):
|
||||
client_config.device_id = device_id
|
||||
client_config.module_id = module_id
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
|
||||
assert isinstance(client._process_twin_responses_bg_task, asyncio.Task)
|
||||
assert not client._process_twin_responses_bg_task.done()
|
||||
if sys.version_info > (3, 8):
|
||||
# NOTE: There isn't a way to validate the contents of a task until 3.8
|
||||
# as far as I can tell.
|
||||
task_coro = client._process_twin_responses_bg_task.get_coro()
|
||||
assert task_coro.__qualname__ == "IoTHubMQTTClient._process_twin_responses"
|
||||
|
||||
await client.shutdown()
|
||||
|
||||
# NOTE: For testing the functionality of this task, see the corresponding test suite (???)
|
||||
# TODO: add this test suite
|
||||
@pytest.mark.it(
|
||||
"Begins running the ._keep_credentials_fresh() coroutine method as a background task, storing it as an attribute, if using SAS authentication"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"device_id, module_id",
|
||||
[
|
||||
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
|
||||
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
|
||||
],
|
||||
)
|
||||
async def test_keep_credentials_fresh_bg_task(
|
||||
self, client_config, device_id, module_id, mock_sastoken_provider
|
||||
):
|
||||
client_config.device_id = device_id
|
||||
client_config.module_id = module_id
|
||||
client_config.sastoken_provider = mock_sastoken_provider
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
|
||||
assert isinstance(client._keep_credentials_fresh_bg_task, asyncio.Task)
|
||||
assert not client._keep_credentials_fresh_bg_task.done()
|
||||
if sys.version_info > (3, 8):
|
||||
# NOTE: There isn't a way to validate the contents of a task until 3.8
|
||||
# as far as I can tell.
|
||||
task_coro = client._keep_credentials_fresh_bg_task.get_coro()
|
||||
assert task_coro.__qualname__ == "IoTHubMQTTClient._keep_credentials_fresh"
|
||||
|
||||
await client.shutdown()
|
||||
|
||||
@pytest.mark.it(
|
||||
"Does not begin running the ._keep_credentials_fresh() coroutine method as a background task if not using SAS authentication"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"device_id, module_id",
|
||||
[
|
||||
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
|
||||
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
|
||||
],
|
||||
)
|
||||
async def test_keep_credentials_fresh_bg_task_no_sas(self, client_config, device_id, module_id):
|
||||
client_config.device_id = device_id
|
||||
client_config.module_id = module_id
|
||||
assert client_config.sastoken_provider is None
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
|
||||
assert client._keep_credentials_fresh_bg_task is None
|
||||
|
||||
await client.shutdown()
|
||||
|
||||
|
||||
# TODO: exceptions
|
||||
@pytest.mark.describe("IoTHubMQTTClient - .shutdown()")
|
||||
class TestIotHubMQTTClientShutdown:
|
||||
class TestIoTHubMQTTClientShutdown:
|
||||
@pytest.fixture(autouse=True)
|
||||
def modify_client_config(self, client_config, mock_sastoken_provider):
|
||||
# NOTE: This has to be changed on the config, not the client,
|
||||
# because it affects client initialization
|
||||
client_config.sastoken_provider = mock_sastoken_provider
|
||||
|
||||
@pytest.mark.it("Disconnects the MQTTClient")
|
||||
async def test_disconnect(self, mocker, client):
|
||||
# NOTE: rather than mocking the MQTTClient, we just mock the .disconnect() method of the
|
||||
# IoTHubMQTTClient instead, since it's been fully tested elsewhere, and we assume
|
||||
# correctness, lest we have to repeat all .disconnect() tests here.
|
||||
client.disconnect = mocker.AsyncMock()
|
||||
assert client.disconnect.await_count == 0
|
||||
|
||||
|
@ -612,34 +740,135 @@ class TestIotHubMQTTClientShutdown:
|
|||
assert client.disconnect.await_count == 1
|
||||
assert client.disconnect.await_args == mocker.call()
|
||||
|
||||
@pytest.mark.it("Cancels the SasToken update Alarm, if it exists")
|
||||
async def test_sastoken_alarm_exists(self, mocker, client):
|
||||
assert client._sastoken_update_alarm is None
|
||||
mock_alarm = client._sastoken_update_alarm = mocker.MagicMock()
|
||||
@pytest.mark.it("Cancels the 'process_twin_responses' background task")
|
||||
async def test_process_twin_responses_bg_task(self, client):
|
||||
assert isinstance(client._process_twin_responses_bg_task, asyncio.Task)
|
||||
assert not client._process_twin_responses_bg_task.done()
|
||||
|
||||
await client.shutdown()
|
||||
|
||||
assert mock_alarm.cancel.call_count == 1
|
||||
assert mock_alarm.cancel.call_args == mocker.call()
|
||||
assert client._process_twin_responses_bg_task.done()
|
||||
assert client._process_twin_responses_bg_task.cancelled()
|
||||
|
||||
@pytest.mark.it("Handles the case where the SasToken update Alarm does not exist")
|
||||
async def test_sastoken_alarm_no_exist(self, client):
|
||||
assert client._sastoken_update_alarm is None
|
||||
@pytest.mark.it("Cancels the 'keep_credentials_fresh' background task, if it exists")
|
||||
async def test_keep_credentials_fresh_bg_task_exists(self, client):
|
||||
assert isinstance(client._keep_credentials_fresh_bg_task, asyncio.Task)
|
||||
assert not client._keep_credentials_fresh_bg_task.done()
|
||||
|
||||
await client.shutdown()
|
||||
|
||||
# No AttributeError raised means success
|
||||
|
||||
# TODO: Probably need to show that this is truly the twin response task by demonstrating that twin responses are no longer received after shutdown
|
||||
@pytest.mark.it("Cancels the twin response listener task")
|
||||
async def test_twin_response_listener(self, client):
|
||||
assert isinstance(client._twin_response_listener, asyncio.Task)
|
||||
assert not client._twin_response_listener.done()
|
||||
assert client._keep_credentials_fresh_bg_task.done()
|
||||
assert client._keep_credentials_fresh_bg_task.cancelled()
|
||||
|
||||
@pytest.mark.it("Handles the case where no 'keep_credentials_fresh' background task exists")
|
||||
async def test_keep_credentials_fresh_bg_task_no_exist(self, client, client_config):
|
||||
# NOTE: in this test we don't want to have the SAS bg task, so override the modified fixture
|
||||
client_config.sastoken_provider = None
|
||||
client = IoTHubMQTTClient(client_config)
|
||||
assert client._keep_credentials_fresh_bg_task is None
|
||||
await client.shutdown()
|
||||
# No AttributeError means success!
|
||||
|
||||
assert client._twin_response_listener.done()
|
||||
assert client._twin_response_listener.cancelled()
|
||||
@pytest.mark.it(
|
||||
"Allows any exception raised during MQTTClient disconnect to propagate, but only after cancelling background tasks"
|
||||
)
|
||||
@pytest.mark.parametrize("exception", mqtt_disconnect_exceptions)
|
||||
async def test_disconnect_raises(self, mocker, client, exception):
|
||||
# NOTE: rather than mocking the MQTTClient, we just mock the .disconnect() method of the
|
||||
# IoTHubMQTTClient instead, since it's been fully tested elsewhere, and we assume
|
||||
# correctness, lest we have to repeat all .disconnect() tests here.
|
||||
original_disconnect = client.disconnect
|
||||
client.disconnect = mocker.AsyncMock(side_effect=exception)
|
||||
client.disconnect.side_effect = exception
|
||||
assert not client._keep_credentials_fresh_bg_task.done()
|
||||
assert not client._process_twin_responses_bg_task.done()
|
||||
|
||||
with pytest.raises(type(exception)) as e_info:
|
||||
await client.shutdown()
|
||||
assert e_info.value is exception
|
||||
|
||||
# Background tasks were also cancelled despite the exception
|
||||
assert client._keep_credentials_fresh_bg_task.done()
|
||||
assert client._keep_credentials_fresh_bg_task.cancelled()
|
||||
assert client._process_twin_responses_bg_task.done()
|
||||
assert client._process_twin_responses_bg_task.cancelled()
|
||||
|
||||
# Unset the mock so that tests can clean up
|
||||
client.disconnect = original_disconnect
|
||||
|
||||
@pytest.mark.it(
|
||||
"Can be cancelled while waiting for the MQTTClient disconnect to finish, but it won't stop background task cancellation"
|
||||
)
|
||||
async def test_cancel_disconnect(self, client):
|
||||
# NOTE: rather than mocking the MQTTClient, we just mock the .disconnect() method of the
|
||||
# IoTHubMQTTClient instead, since it's been fully tested elsewhere, and we assume
|
||||
# correctness, lest we have to repeat all .disconnect() tests here.
|
||||
original_disconnect = client.disconnect
|
||||
client.disconnect = custom_mock.HangingAsyncMock()
|
||||
|
||||
t = asyncio.create_task(client.shutdown())
|
||||
|
||||
# Hanging, waiting for disconnect to finish
|
||||
await client.disconnect.wait_for_hang()
|
||||
assert not t.done()
|
||||
# Background tasks have not been cancelled
|
||||
assert not client._keep_credentials_fresh_bg_task.done()
|
||||
assert not client._process_twin_responses_bg_task.done()
|
||||
|
||||
# Cancel
|
||||
t.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await t
|
||||
|
||||
# Unset the mock so that tests can clean up.
|
||||
# And do it now so that test assertion failure doesn't hang
|
||||
client.disconnect = original_disconnect
|
||||
|
||||
# And yet the background tasks still were cancelled anyway
|
||||
assert client._keep_credentials_fresh_bg_task.done()
|
||||
assert client._keep_credentials_fresh_bg_task.cancelled()
|
||||
assert client._process_twin_responses_bg_task.done()
|
||||
assert client._process_twin_responses_bg_task.cancelled()
|
||||
|
||||
@pytest.mark.it(
|
||||
"Can be cancelled while waiting for the background tasks to finish cancellation, but it won't stop the background task cancellation"
|
||||
)
|
||||
async def test_cancel_gather(self, mocker, client):
|
||||
original_gather = asyncio.gather
|
||||
asyncio.gather = custom_mock.HangingAsyncMock()
|
||||
|
||||
spy_twin_response_bg_task_cancel = mocker.spy(
|
||||
client._process_twin_responses_bg_task, "cancel"
|
||||
)
|
||||
spy_credentials_bg_task_cancel = mocker.spy(
|
||||
client._keep_credentials_fresh_bg_task, "cancel"
|
||||
)
|
||||
|
||||
t = asyncio.create_task(client.shutdown())
|
||||
|
||||
# Hanging waiting for gather to return (indicating tasks are all done cancellation)
|
||||
await asyncio.gather.wait_for_hang()
|
||||
assert not t.done()
|
||||
# Background tests may or may not have completed cancellation yet, hard to test accurately.
|
||||
# But their cancellation HAS been requested.
|
||||
assert spy_twin_response_bg_task_cancel.call_count == 1
|
||||
assert spy_credentials_bg_task_cancel.call_count == 1
|
||||
|
||||
# Cancel
|
||||
t.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await t
|
||||
|
||||
# Unset the mock so that tests can clean up.
|
||||
# And do it now so that test assertion failure doesn't hang
|
||||
asyncio.gather = original_gather
|
||||
|
||||
# Tasks will be cancelled very soon (if they aren't already)
|
||||
await asyncio.sleep(0.1)
|
||||
assert client._keep_credentials_fresh_bg_task.done()
|
||||
assert client._keep_credentials_fresh_bg_task.cancelled()
|
||||
assert client._process_twin_responses_bg_task.done()
|
||||
assert client._process_twin_responses_bg_task.cancelled()
|
||||
|
||||
|
||||
@pytest.mark.describe("IoTHubMQTTClient - .connect()")
|
||||
|
@ -2387,6 +2616,9 @@ class TestIoTHubMQTTClientIncomingTwinPatches:
|
|||
assert patch == expected_json
|
||||
|
||||
|
||||
# TODO: To reflect the complexity of background tasks, these tests need to be adjusted
|
||||
# to be about the background task itself, not a reaction to an event. This is probably the
|
||||
# end of any "OCCURRENCE" tests in the client layer
|
||||
@pytest.mark.describe("IoTHubMQTTClient - OCCURRENCE: Twin Response Received")
|
||||
class TestIoTHubMQTTClientIncomingTwinResponse:
|
||||
# NOTE: This test suite exists for simplicity - twin responses are used in both
|
||||
|
@ -2455,3 +2687,58 @@ class TestIoTHubMQTTClientIncomingTwinResponse:
|
|||
resp1 = spy_response_factory.spy_return
|
||||
assert mock_ledger.match_response.call_count == 1
|
||||
assert mock_ledger.match_response.call_args == mocker.call(resp1)
|
||||
|
||||
|
||||
# TODO: To reflect the complexity of background tasks, these tests need to be adjusted
|
||||
# to be about the background task itself, not a reaction to an event. This is probably the
|
||||
# end of any "OCCURRENCE" tests in the client layer
|
||||
@pytest.mark.describe("IoTHubMQTTClient - OCCURRENCE: SasTokenProvider Updates SasToken")
|
||||
class TestIoTHubMQTTClientSasTokenUpdate:
|
||||
@pytest.fixture(autouse=True)
|
||||
def modify_client_config(self, client_config, mock_sastoken_provider):
|
||||
# Need to use a client with SAS token auth
|
||||
# NOTE: This has to be changed on the config, not the client,
|
||||
# because it affects client initialization
|
||||
client_config.sastoken_provider = mock_sastoken_provider
|
||||
|
||||
@pytest.mark.it(
|
||||
"Updates the MQTTClient's credentials, using the stored username as the username, and the string-converted new SasToken as the password"
|
||||
)
|
||||
async def test_updates_credentials(self, mocker, client):
|
||||
# Client is waiting on a new SasToken
|
||||
assert client._sastoken_provider.wait_for_new_sastoken.await_count == 1
|
||||
assert client._sastoken_provider.wait_for_new_sastoken.is_hanging()
|
||||
assert client._mqtt_client.set_credentials.call_count == 0
|
||||
|
||||
# Trigger new SasToken arrival
|
||||
client._sastoken_provider.wait_for_new_sastoken.stop_hanging()
|
||||
new_sastoken = client._sastoken_provider.wait_for_new_sastoken.return_value
|
||||
assert isinstance(new_sastoken, st.SasToken)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Credentials are updated
|
||||
assert client._mqtt_client.set_credentials.call_count == 1
|
||||
assert client._mqtt_client.set_credentials.call_args == mocker.call(
|
||||
client._username, str(new_sastoken)
|
||||
)
|
||||
|
||||
# Client is now waiting on a new SasToken again
|
||||
assert client._sastoken_provider.wait_for_new_sastoken.await_count == 2
|
||||
assert client._sastoken_provider.wait_for_new_sastoken.is_hanging()
|
||||
assert client._mqtt_client.set_credentials.call_count == 1
|
||||
|
||||
# Trigger new SasToken arrival again
|
||||
client._sastoken_provider.wait_for_new_sastoken.stop_hanging()
|
||||
new_sastoken = client._sastoken_provider.wait_for_new_sastoken.return_value
|
||||
assert isinstance(new_sastoken, st.SasToken)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Credentials are updated again
|
||||
assert client._mqtt_client.set_credentials.call_count == 2
|
||||
assert client._mqtt_client.set_credentials.call_args == mocker.call(
|
||||
client._username, str(new_sastoken)
|
||||
)
|
||||
|
||||
# And so it continues forever and ever
|
||||
|
||||
await client.shutdown()
|
||||
|
|
|
@ -13,7 +13,7 @@ from v3_async_wip.mqtt_client import (
|
|||
expected_on_connect_rc,
|
||||
expected_on_disconnect_rc,
|
||||
)
|
||||
from azure.iot.device.common import ProxyOptions
|
||||
from v3_async_wip.config import ProxyOptions
|
||||
import paho.mqtt.client as mqtt
|
||||
import asyncio
|
||||
import pytest
|
||||
|
@ -406,11 +406,13 @@ class TestInstantiation:
|
|||
proxy_type = "SOCKS5"
|
||||
|
||||
if "No Auth" in request.param:
|
||||
proxy = ProxyOptions(proxy_type=proxy_type, proxy_addr="fake.address", proxy_port=1080)
|
||||
proxy = ProxyOptions(
|
||||
proxy_type=proxy_type, proxy_address="fake.address", proxy_port=1080
|
||||
)
|
||||
else:
|
||||
proxy = ProxyOptions(
|
||||
proxy_type=proxy_type,
|
||||
proxy_addr="fake.address",
|
||||
proxy_address="fake.address",
|
||||
proxy_port=1080,
|
||||
proxy_username="fake_username",
|
||||
proxy_password="fake_password",
|
||||
|
|
|
@ -1,32 +1,36 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# -*- coding: utf-8 -*- # TODO: do we need this?
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import pytest
|
||||
import time
|
||||
import re
|
||||
import asyncio
|
||||
import logging
|
||||
import urllib
|
||||
from azure.iot.device.common.auth.sastoken import (
|
||||
RenewableSasToken,
|
||||
NonRenewableSasToken,
|
||||
import pytest
|
||||
import sys
|
||||
import time
|
||||
import urllib.parse
|
||||
from pytest_lazyfixture import lazy_fixture
|
||||
from v3_async_wip.sastoken import (
|
||||
SasToken,
|
||||
InternalSasTokenGenerator,
|
||||
ExternalSasTokenGenerator,
|
||||
SasTokenProvider,
|
||||
SasTokenError,
|
||||
TOKEN_FORMAT,
|
||||
DEFAULT_TOKEN_UPDATE_MARGIN,
|
||||
)
|
||||
from v3_async_wip import sastoken as st
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
fake_uri = "some/resource/location"
|
||||
fake_signed_data = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI="
|
||||
fake_key_name = "fakekeyname"
|
||||
fake_expiry = 12321312
|
||||
|
||||
simple_token_format = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}"
|
||||
auth_rule_token_format = (
|
||||
"SharedAccessSignature sr={resource}&sig={signature}&se={expiry}&skn={keyname}"
|
||||
)
|
||||
FAKE_URI = "some/resource/location"
|
||||
FAKE_SIGNED_DATA = "8NJRMT83CcplGrAGaUVIUM/md5914KpWVNngSVoF9/M="
|
||||
FAKE_SIGNED_DATA2 = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI="
|
||||
FAKE_CURRENT_TIME = 10000000000.0 # We living in 2286!
|
||||
# TODO: make this a float
|
||||
# TODO: should we mock out time.time to always return a fake time?
|
||||
|
||||
|
||||
def token_parser(token_str):
|
||||
|
@ -40,176 +44,70 @@ def token_parser(token_str):
|
|||
return token_map
|
||||
|
||||
|
||||
class RenewableSasTokenTestConfig(object):
|
||||
@pytest.fixture
|
||||
def signing_mechanism(self, mocker):
|
||||
mechanism = mocker.MagicMock()
|
||||
mechanism.sign.return_value = fake_signed_data
|
||||
return mechanism
|
||||
|
||||
# TODO: Rename this. These are not "device" and "service" tokens, the distinction is more generic
|
||||
@pytest.fixture(params=["Device Token", "Service Token"])
|
||||
def sastoken(self, request, signing_mechanism):
|
||||
token_type = request.param
|
||||
if token_type == "Device Token":
|
||||
return RenewableSasToken(uri=fake_uri, signing_mechanism=signing_mechanism)
|
||||
elif token_type == "Service Token":
|
||||
return RenewableSasToken(
|
||||
uri=fake_uri, signing_mechanism=signing_mechanism, key_name=fake_key_name
|
||||
)
|
||||
def get_expiry_time():
|
||||
return int(time.time()) + 3600 # One hour from right now,
|
||||
|
||||
|
||||
@pytest.mark.describe("RenewableSasToken")
|
||||
class TestRenewableSasToken(RenewableSasTokenTestConfig):
|
||||
@pytest.mark.it("Instantiates with a default TTL of 3600 seconds if no TTL is provided")
|
||||
def test_default_ttl(self, signing_mechanism):
|
||||
s = RenewableSasToken(fake_uri, signing_mechanism)
|
||||
assert s.ttl == 3600
|
||||
|
||||
@pytest.mark.it("Instantiates with a custom TTL if provided")
|
||||
def test_custom_ttl(self, signing_mechanism):
|
||||
custom_ttl = 4747
|
||||
s = RenewableSasToken(fake_uri, signing_mechanism, ttl=custom_ttl)
|
||||
assert s.ttl == custom_ttl
|
||||
|
||||
@pytest.mark.it("Instantiates with with no key name by default if no key name is provided")
|
||||
def test_default_key_name(self, signing_mechanism):
|
||||
s = RenewableSasToken(fake_uri, signing_mechanism)
|
||||
assert s._key_name is None
|
||||
|
||||
@pytest.mark.it("Instantiates with the given key name if provided")
|
||||
def test_custom_key_name(self, signing_mechanism):
|
||||
s = RenewableSasToken(fake_uri, signing_mechanism, key_name=fake_key_name)
|
||||
assert s._key_name == fake_key_name
|
||||
|
||||
@pytest.mark.it(
|
||||
"Instantiates with an expiry time TTL seconds in the future from the moment of instantiation"
|
||||
@pytest.fixture
|
||||
def sastoken_str():
|
||||
return TOKEN_FORMAT.format(
|
||||
resource=urllib.parse.quote(FAKE_URI, safe=""),
|
||||
signature=urllib.parse.quote(FAKE_SIGNED_DATA, safe=""),
|
||||
expiry=get_expiry_time(),
|
||||
)
|
||||
def test_expiry_time(self, mocker, signing_mechanism):
|
||||
fake_current_time = 1000
|
||||
mocker.patch.object(time, "time", return_value=fake_current_time)
|
||||
|
||||
s = RenewableSasToken(fake_uri, signing_mechanism)
|
||||
assert s.expiry_time == fake_current_time + s.ttl
|
||||
|
||||
@pytest.mark.it("Calls .refresh() to build the SAS token string on instantiation")
|
||||
def test_refresh_on_instantiation(self, mocker, signing_mechanism):
|
||||
refresh_mock = mocker.spy(RenewableSasToken, "refresh")
|
||||
assert refresh_mock.call_count == 0
|
||||
RenewableSasToken(fake_uri, signing_mechanism)
|
||||
assert refresh_mock.call_count == 1
|
||||
|
||||
@pytest.mark.it("Returns the SAS token string as the string representation of the object")
|
||||
def test_str_rep(self, sastoken):
|
||||
assert str(sastoken) == sastoken._token
|
||||
|
||||
@pytest.mark.it(
|
||||
"Maintains the .expiry_time attribute as a read-only property (raises AttributeError upon attempt)"
|
||||
)
|
||||
def test_expiry_time_read_only(self, sastoken):
|
||||
with pytest.raises(AttributeError):
|
||||
sastoken.expiry_time = 12321312
|
||||
|
||||
|
||||
@pytest.mark.describe("RenewableSasToken - .refresh()")
|
||||
class TestRenewableSasTokenRefresh(RenewableSasTokenTestConfig):
|
||||
@pytest.mark.it("Sets a new expiry time of TTL seconds in the future")
|
||||
def test_new_expiry(self, mocker, sastoken):
|
||||
fake_current_time = 1000
|
||||
mocker.patch.object(time, "time", return_value=fake_current_time)
|
||||
sastoken.refresh()
|
||||
assert sastoken.expiry_time == fake_current_time + sastoken.ttl
|
||||
|
||||
# TODO: reflect url encoding here?
|
||||
@pytest.mark.it(
|
||||
"Uses the token's signing mechanism to create a signature by signing a concatenation of the (URL encoded) URI and updated expiry time"
|
||||
)
|
||||
def test_generate_new_token(self, mocker, signing_mechanism, sastoken):
|
||||
old_token_str = str(sastoken)
|
||||
fake_future_time = 1000
|
||||
mocker.patch.object(time, "time", return_value=fake_future_time)
|
||||
signing_mechanism.reset_mock()
|
||||
fake_signature = "new_fake_signature"
|
||||
signing_mechanism.sign.return_value = fake_signature
|
||||
|
||||
sastoken.refresh()
|
||||
|
||||
# The token string has been updated
|
||||
assert str(sastoken) != old_token_str
|
||||
# The signing mechanism was used to sign a string
|
||||
assert signing_mechanism.sign.call_count == 1
|
||||
# The string being signed was a concatenation of the URI and expiry time
|
||||
assert signing_mechanism.sign.call_args == mocker.call(
|
||||
urllib.parse.quote(sastoken._uri, safe="") + "\n" + str(sastoken.expiry_time)
|
||||
)
|
||||
# The token string has the resulting signed string included as the signature
|
||||
token_info = token_parser(str(sastoken))
|
||||
assert token_info["sig"] == fake_signature
|
||||
|
||||
@pytest.mark.it(
|
||||
"Builds a new token string using the token's URI (URL encoded) and expiry time, along with the signature created by the signing mechanism (also URL encoded)"
|
||||
)
|
||||
def test_token_string(self, sastoken):
|
||||
token_str = sastoken._token
|
||||
|
||||
# Verify that token string representation matches token format
|
||||
if not sastoken._key_name:
|
||||
pattern = re.compile(r"SharedAccessSignature sr=(.+)&sig=(.+)&se=(.+)")
|
||||
else:
|
||||
pattern = re.compile(r"SharedAccessSignature sr=(.+)&sig=(.+)&se=(.+)&skn=(.+)")
|
||||
assert pattern.match(token_str)
|
||||
|
||||
# Verify that content in the string representation is correct
|
||||
token_info = token_parser(token_str)
|
||||
assert token_info["sr"] == urllib.parse.quote(sastoken._uri, safe="")
|
||||
assert token_info["sig"] == urllib.parse.quote(
|
||||
sastoken._signing_mechanism.sign.return_value, safe=""
|
||||
)
|
||||
assert token_info["se"] == str(sastoken.expiry_time)
|
||||
if sastoken._key_name:
|
||||
assert token_info["skn"] == sastoken._key_name
|
||||
|
||||
@pytest.mark.it("Raises a SasTokenError if an exception is raised by the signing mechanism")
|
||||
def test_signing_mechanism_raises_value_error(
|
||||
self, mocker, signing_mechanism, sastoken, arbitrary_exception
|
||||
):
|
||||
signing_mechanism.sign.side_effect = arbitrary_exception
|
||||
|
||||
with pytest.raises(SasTokenError) as e_info:
|
||||
sastoken.refresh()
|
||||
assert e_info.value.__cause__ is arbitrary_exception
|
||||
@pytest.fixture
|
||||
def sastoken(sastoken_str):
|
||||
return SasToken(sastoken_str)
|
||||
|
||||
|
||||
@pytest.mark.describe("NonRenewableSasToken")
|
||||
class TestNonRenewableSasToken(object):
|
||||
# TODO: Rename this. These are not "device" and "service" tokens, the distinction is more generic
|
||||
@pytest.fixture(params=["Device Token", "Service Token"])
|
||||
def sastoken_str(self, request):
|
||||
token_type = request.param
|
||||
if token_type == "Device Token":
|
||||
return simple_token_format.format(
|
||||
resource=urllib.parse.quote(fake_uri, safe=""),
|
||||
signature=urllib.parse.quote(fake_signed_data, safe=""),
|
||||
expiry=fake_expiry,
|
||||
)
|
||||
elif token_type == "Service Token":
|
||||
return auth_rule_token_format.format(
|
||||
resource=urllib.parse.quote(fake_uri, safe=""),
|
||||
signature=urllib.parse.quote(fake_signed_data, safe=""),
|
||||
expiry=fake_expiry,
|
||||
keyname=fake_key_name,
|
||||
)
|
||||
@pytest.fixture
|
||||
async def mock_signing_mechanism(mocker):
|
||||
mock_sm = mocker.AsyncMock()
|
||||
mock_sm.sign.return_value = FAKE_SIGNED_DATA
|
||||
return mock_sm
|
||||
|
||||
@pytest.fixture()
|
||||
def sastoken(self, sastoken_str):
|
||||
return NonRenewableSasToken(sastoken_str)
|
||||
|
||||
@pytest.fixture(params=["Generator Function", "Generator Coroutine Function"])
|
||||
def mock_token_generator_fn(mocker, request, sastoken_str):
|
||||
if request.param == "Function":
|
||||
return mocker.MagicMock(return_value=sastoken_str)
|
||||
else:
|
||||
return mocker.AsyncMock(return_value=sastoken_str)
|
||||
|
||||
|
||||
@pytest.fixture(params=["InternalSasTokenGenerator", "ExternalSasTokenGenerator"])
|
||||
def sastoken_generator(request, mocker, mock_signing_mechanism, sastoken_str):
|
||||
if request.param == "ExternalSasTokenGenerator":
|
||||
# We don't care about the difference between sync/async generator_fns when testing
|
||||
# at this level of abstraction, so just pick one
|
||||
generator = ExternalSasTokenGenerator(mocker.MagicMock(return_value=sastoken_str))
|
||||
else:
|
||||
generator = InternalSasTokenGenerator(mock_signing_mechanism, FAKE_URI)
|
||||
mocker.spy(generator, "generate_sastoken")
|
||||
return generator
|
||||
|
||||
|
||||
# TODO: adjust mocks so that initial token is not the same as generated token
|
||||
@pytest.fixture
|
||||
async def sastoken_provider(sastoken_generator):
|
||||
provider = await SasTokenProvider.create_from_generator(sastoken_generator)
|
||||
# Creating from the generator invokes a call on the generator, so reset the spy mock
|
||||
# so it doesn't throw off any testing logic
|
||||
provider._generator.generate_sastoken.reset_mock()
|
||||
yield provider
|
||||
await provider.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.describe("SasToken")
|
||||
class TestSasToken:
|
||||
@pytest.mark.it("Instantiates from a valid SAS Token string")
|
||||
def test_instantiates_from_token_string(self, sastoken_str):
|
||||
s = NonRenewableSasToken(sastoken_str)
|
||||
assert s._token == sastoken_str
|
||||
s = SasToken(sastoken_str)
|
||||
assert s._token_str == sastoken_str
|
||||
|
||||
@pytest.mark.it("Raises a SasToken error if instantiating from an invalid SAS Token string")
|
||||
@pytest.mark.it("Raises a ValueError error if instantiating from an invalid SAS Token string")
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_token_str",
|
||||
[
|
||||
|
@ -237,49 +135,575 @@ class TestNonRenewableSasToken(object):
|
|||
"SharedAccessSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=",
|
||||
id="Missing expiry value",
|
||||
),
|
||||
pytest.param(
|
||||
"SharedAccessSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312&foovalue=nonsense",
|
||||
id="Extraneous invalid value",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_raises_error_invalid_token_string(self, invalid_token_str):
|
||||
with pytest.raises(SasTokenError):
|
||||
NonRenewableSasToken(invalid_token_str)
|
||||
with pytest.raises(ValueError):
|
||||
SasToken(invalid_token_str)
|
||||
|
||||
@pytest.mark.it("Returns the SAS token string as the string representation of the object")
|
||||
def test_str_rep(self, sastoken_str):
|
||||
sastoken = NonRenewableSasToken(sastoken_str)
|
||||
sastoken = SasToken(sastoken_str)
|
||||
assert str(sastoken) == sastoken_str
|
||||
|
||||
@pytest.mark.it(
|
||||
"Instantiates with the .expiry_time attribute corresponding to the expiry time of the given SAS Token string (as an integer)"
|
||||
"Instantiates with the .expiry_time property corresponding to the expiry time of the given SAS Token string (as a float)"
|
||||
)
|
||||
def test_instantiates_expiry_time(self, sastoken_str):
|
||||
sastoken = NonRenewableSasToken(sastoken_str)
|
||||
sastoken = SasToken(sastoken_str)
|
||||
expected_expiry_time = token_parser(sastoken_str)["se"]
|
||||
assert sastoken.expiry_time == int(expected_expiry_time)
|
||||
assert sastoken.expiry_time == float(expected_expiry_time)
|
||||
|
||||
@pytest.mark.it(
|
||||
"Maintains the .expiry_time attribute as a read-only property (raises AttributeError upon attempt)"
|
||||
)
|
||||
@pytest.mark.it("Maintains .expiry_time as a read-only property")
|
||||
def test_expiry_time_read_only(self, sastoken):
|
||||
with pytest.raises(AttributeError):
|
||||
sastoken.expiry_time = 12312312312123
|
||||
|
||||
@pytest.mark.it(
|
||||
"Instantiates with the .resource_uri attribute corresponding to the URL decoded URI of the given SAS Token string"
|
||||
"Instantiates with the .resource_uri property corresponding to the URL decoded URI of the given SAS Token string"
|
||||
)
|
||||
def test_instantiates_resource_uri(self, sastoken_str):
|
||||
sastoken = NonRenewableSasToken(sastoken_str)
|
||||
sastoken = SasToken(sastoken_str)
|
||||
resource_uri = token_parser(sastoken_str)["sr"]
|
||||
assert resource_uri != sastoken.resource_uri
|
||||
assert resource_uri == urllib.parse.quote(sastoken.resource_uri, safe="")
|
||||
assert urllib.parse.unquote(resource_uri) == sastoken.resource_uri
|
||||
|
||||
@pytest.mark.it(
|
||||
"Maintains the .resource_uri attribute as a read-only property (raises AttributeError upon attempt)"
|
||||
)
|
||||
@pytest.mark.it("Maintains .resource_uri as a read-only property")
|
||||
def test_resource_uri_read_only(self, sastoken):
|
||||
with pytest.raises(AttributeError):
|
||||
sastoken.resource_uri = "new%2Ffake%2Furi"
|
||||
|
||||
@pytest.mark.it(
|
||||
"Instantiates with the .signature property corresponding to the URL decoded signature of the given SAS Token string"
|
||||
)
|
||||
def test_instantiates_signature(self, sastoken_str):
|
||||
sastoken = SasToken(sastoken_str)
|
||||
signature = token_parser(sastoken_str)["sig"]
|
||||
assert signature != sastoken.signature
|
||||
assert signature == urllib.parse.quote(sastoken.signature, safe="")
|
||||
assert urllib.parse.unquote(signature) == sastoken.signature
|
||||
|
||||
@pytest.mark.it("Maintains .signature as a read-only property")
|
||||
def test_signature_read_only(self, sastoken):
|
||||
with pytest.raises(AttributeError):
|
||||
sastoken.signature = "asdfas"
|
||||
|
||||
|
||||
@pytest.mark.describe("InternalSasTokenGenerator -- Instantiation")
|
||||
class TestSasTokenGeneratorInstantiation:
|
||||
@pytest.mark.it("Stores the provided signing mechanism as an attribute")
|
||||
def test_signing_mechanism(self, mock_signing_mechanism):
|
||||
generator = InternalSasTokenGenerator(
|
||||
signing_mechanism=mock_signing_mechanism, uri=FAKE_URI, ttl=4700
|
||||
)
|
||||
assert generator.signing_mechanism is mock_signing_mechanism
|
||||
|
||||
@pytest.mark.it("Stores the provided URI as an attribute")
|
||||
def test_uri(self, mock_signing_mechanism):
|
||||
generator = InternalSasTokenGenerator(
|
||||
signing_mechanism=mock_signing_mechanism, uri=FAKE_URI, ttl=4700
|
||||
)
|
||||
assert generator.uri == FAKE_URI
|
||||
|
||||
@pytest.mark.it("Stores the provided TTL as an attribute")
|
||||
def test_ttl(self, mock_signing_mechanism):
|
||||
generator = InternalSasTokenGenerator(
|
||||
signing_mechanism=mock_signing_mechanism, uri=FAKE_URI, ttl=4700
|
||||
)
|
||||
assert generator.ttl == 4700
|
||||
|
||||
@pytest.mark.it("Defaults to using 3600 as the TTL if not provided")
|
||||
def test_ttl_default(self, mock_signing_mechanism):
|
||||
generator = InternalSasTokenGenerator(
|
||||
signing_mechanism=mock_signing_mechanism, uri=FAKE_URI
|
||||
)
|
||||
assert generator.ttl == 3600
|
||||
|
||||
|
||||
@pytest.mark.describe("InternalSasTokenGenerator - .generate_sastoken()")
|
||||
class TestSasTokenGeneratorGenerateSastoken:
|
||||
@pytest.fixture
|
||||
def sastoken_generator(self, mock_signing_mechanism):
|
||||
return InternalSasTokenGenerator(
|
||||
signing_mechanism=mock_signing_mechanism, uri=FAKE_URI, ttl=4700
|
||||
)
|
||||
|
||||
@pytest.mark.it(
|
||||
"Returns a newly generated SasToken for the configured URI that is valid for TTL seconds"
|
||||
)
|
||||
async def test_token_expiry(self, mocker, sastoken_generator):
|
||||
# Patch time.time() to return a fake time so that it's easy to check the delta with expiry
|
||||
mocker.patch.object(time, "time", return_value=FAKE_CURRENT_TIME)
|
||||
expected_expiry = FAKE_CURRENT_TIME + sastoken_generator.ttl
|
||||
token = await sastoken_generator.generate_sastoken()
|
||||
assert isinstance(token, SasToken)
|
||||
assert token.expiry_time == expected_expiry
|
||||
assert token.resource_uri == sastoken_generator.uri
|
||||
assert token._token_info["sr"] == urllib.parse.quote(sastoken_generator.uri, safe="")
|
||||
assert token.resource_uri != token._token_info["sr"]
|
||||
|
||||
@pytest.mark.it(
|
||||
"Creates the resulting SasToken's signature by using the InternalSasTokenGenerator's signing mechanism to sign a concatenation of the (URL encoded) URI and (URL encoded, int converted) desired expiry time"
|
||||
)
|
||||
async def test_token_signature(self, mocker, sastoken_generator):
|
||||
assert sastoken_generator.signing_mechanism.await_count == 0
|
||||
mocker.patch.object(time, "time", return_value=FAKE_CURRENT_TIME)
|
||||
expected_expiry = int(FAKE_CURRENT_TIME + sastoken_generator.ttl)
|
||||
expected_data_to_sign = (
|
||||
urllib.parse.quote(sastoken_generator.uri, safe="") + "\n" + str(expected_expiry)
|
||||
)
|
||||
|
||||
token = await sastoken_generator.generate_sastoken()
|
||||
|
||||
assert sastoken_generator.signing_mechanism.sign.await_count == 1
|
||||
assert sastoken_generator.signing_mechanism.sign.await_args == mocker.call(
|
||||
expected_data_to_sign
|
||||
)
|
||||
assert token._token_info["sig"] == urllib.parse.quote(
|
||||
sastoken_generator.signing_mechanism.sign.return_value, safe=""
|
||||
)
|
||||
assert token.signature == sastoken_generator.signing_mechanism.sign.return_value
|
||||
assert token.signature != token._token_info["sig"]
|
||||
|
||||
@pytest.mark.it("Raises a SasTokenError if an exception is raised by the signing mechanism")
|
||||
async def test_signing_mechanism_raises(self, sastoken_generator, arbitrary_exception):
|
||||
sastoken_generator.signing_mechanism.sign.side_effect = arbitrary_exception
|
||||
|
||||
with pytest.raises(SasTokenError) as e_info:
|
||||
await sastoken_generator.generate_sastoken()
|
||||
assert e_info.value.__cause__ is arbitrary_exception
|
||||
|
||||
|
||||
@pytest.mark.describe("ExternalSasTokenGenerator -- Instantiation")
|
||||
class TestExternalSasTokenGeneratorInstantiation:
|
||||
@pytest.mark.it("Stores the provided generator_fn callable as an attribute")
|
||||
def test_generator_fn_attribute(self, mock_token_generator_fn):
|
||||
sastoken_generator = ExternalSasTokenGenerator(mock_token_generator_fn)
|
||||
assert sastoken_generator.generator_fn is mock_token_generator_fn
|
||||
|
||||
|
||||
@pytest.mark.describe("ExternalSasTokenGenerator -- .generate_sastoken()")
|
||||
class TestExternalSasTokenGeneratorGenerateSasToken:
|
||||
@pytest.fixture
|
||||
def sastoken_generator(self, mock_token_generator_fn):
|
||||
return ExternalSasTokenGenerator(mock_token_generator_fn)
|
||||
|
||||
@pytest.mark.it(
|
||||
"Generates a new SasToken from the SAS Token string returned by the configured generator_fn callable"
|
||||
)
|
||||
async def test_returns_token(self, mocker, sastoken_generator):
|
||||
if isinstance(sastoken_generator.generator_fn, mocker.AsyncMock):
|
||||
assert sastoken_generator.generator_fn.await_count == 0
|
||||
else:
|
||||
assert sastoken_generator.generator_fn.call_count == 0
|
||||
|
||||
token = await sastoken_generator.generate_sastoken()
|
||||
assert isinstance(token, SasToken)
|
||||
|
||||
if isinstance(sastoken_generator.generator_fn, mocker.AsyncMock):
|
||||
assert sastoken_generator.generator_fn.await_count == 1
|
||||
assert sastoken_generator.generator_fn.await_args == mocker.call()
|
||||
else:
|
||||
assert sastoken_generator.generator_fn.call_count == 1
|
||||
assert sastoken_generator.generator_fn.call_args == mocker.call()
|
||||
|
||||
assert str(token) == sastoken_generator.generator_fn.return_value
|
||||
|
||||
@pytest.mark.it(
|
||||
"Raises SasTokenError if an exception is raised while trying to generate a SAS Token string with the generator_fn"
|
||||
)
|
||||
async def test_generator_fn_raises(self, sastoken_generator, arbitrary_exception):
|
||||
sastoken_generator.generator_fn.side_effect = arbitrary_exception
|
||||
|
||||
with pytest.raises(SasTokenError) as e_info:
|
||||
await sastoken_generator.generate_sastoken()
|
||||
assert e_info.value.__cause__ is arbitrary_exception
|
||||
|
||||
@pytest.mark.it("Raises SasTokenError if the generated SAS Token string is invalid")
|
||||
async def test_invalid_token(self, sastoken_generator):
|
||||
sastoken_generator.generator_fn.return_value = "not a sastoken"
|
||||
|
||||
with pytest.raises(SasTokenError) as e_info:
|
||||
await sastoken_generator.generate_sastoken()
|
||||
assert isinstance(e_info.value.__cause__, ValueError)
|
||||
|
||||
|
||||
@pytest.mark.describe("SasTokenProvider -- Instantiation")
|
||||
class TestSasTokenProviderInstantiation:
|
||||
@pytest.mark.it("Stores the provided SasTokenGenerator")
|
||||
async def test_generator_fn(self, sastoken, sastoken_generator):
|
||||
provider = SasTokenProvider(initial_token=sastoken, generator=sastoken_generator)
|
||||
assert provider._generator is sastoken_generator
|
||||
await provider.shutdown()
|
||||
|
||||
@pytest.mark.it("Sets the provided initial_token as the current SasToken")
|
||||
async def test_initial_token(self, sastoken, sastoken_generator):
|
||||
provider = SasTokenProvider(initial_token=sastoken, generator=sastoken_generator)
|
||||
assert provider._sastoken is sastoken
|
||||
await provider.shutdown()
|
||||
|
||||
@pytest.mark.it("Sets the token update margin to the DEFAULT_TOKEN_UPDATE_MARGIN")
|
||||
async def test_token_update_margin(self, sastoken, sastoken_generator):
|
||||
provider = SasTokenProvider(initial_token=sastoken, generator=sastoken_generator)
|
||||
assert provider._token_update_margin == DEFAULT_TOKEN_UPDATE_MARGIN
|
||||
await provider.shutdown()
|
||||
|
||||
# NOTE: The contents of this coroutine are tested in a separate test suite below.
|
||||
# See TestSasTokenProviderKeepTokenFresh for more.
|
||||
@pytest.mark.it("Begins running the ._keep_token_fresh() coroutine method, storing the task")
|
||||
async def test_keep_token_fresh_running(self, sastoken, sastoken_generator):
|
||||
provider = SasTokenProvider(initial_token=sastoken, generator=sastoken_generator)
|
||||
assert isinstance(provider._keep_token_fresh_task, asyncio.Task)
|
||||
assert not provider._keep_token_fresh_task.done()
|
||||
if sys.version_info >= (3, 8):
|
||||
# NOTE: There isn't a way to validate the contents of a task until 3.8
|
||||
# as far as I can tell.
|
||||
task_coro = provider._keep_token_fresh_task.get_coro()
|
||||
assert task_coro.__qualname__ == "SasTokenProvider._keep_token_fresh"
|
||||
await provider.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.describe("SasTokenProvider - .create_from_generator()")
|
||||
class TestSasTokenProviderCreateFromGenerator:
|
||||
@pytest.mark.it("Returns a SasTokenProvider instance")
|
||||
async def test_instantiates(self, sastoken_generator):
|
||||
provider = await SasTokenProvider.create_from_generator(sastoken_generator)
|
||||
assert isinstance(provider, SasTokenProvider)
|
||||
await provider.shutdown()
|
||||
|
||||
@pytest.mark.it(
|
||||
"Generates a new SasToken using the provided SasTokenGenerator to use as the SasTokenProvider's initial sastoken"
|
||||
)
|
||||
async def test_generates_initial_token(self, mocker, sastoken_generator):
|
||||
assert sastoken_generator.generate_sastoken.await_count == 0
|
||||
|
||||
provider = await SasTokenProvider.create_from_generator(sastoken_generator)
|
||||
|
||||
assert sastoken_generator.generate_sastoken.await_count == 1
|
||||
assert sastoken_generator.generate_sastoken.await_args == mocker.call()
|
||||
assert isinstance(provider._sastoken, SasToken)
|
||||
assert provider._sastoken == sastoken_generator.generate_sastoken.spy_return
|
||||
|
||||
await provider.shutdown()
|
||||
|
||||
@pytest.mark.it("Allows any exception raised while trying to generate a SasToken to propagate")
|
||||
@pytest.mark.parametrize(
|
||||
"exception",
|
||||
[
|
||||
pytest.param(SasTokenError("token error"), id="SasTokenError"),
|
||||
pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"),
|
||||
],
|
||||
)
|
||||
async def test_generation_raises(self, sastoken_generator, exception):
|
||||
sastoken_generator.generate_sastoken.side_effect = exception
|
||||
|
||||
with pytest.raises(type(exception)) as e_info:
|
||||
await SasTokenProvider.create_from_generator(sastoken_generator)
|
||||
assert e_info.value is exception
|
||||
|
||||
@pytest.mark.it("Raises a SasTokenError if the generated SAS Token string has already expired")
|
||||
async def test_expired_token(self, mocker, sastoken_generator):
|
||||
expired_token_str = TOKEN_FORMAT.format(
|
||||
resource=urllib.parse.quote(FAKE_URI, safe=""),
|
||||
signature=urllib.parse.quote(FAKE_SIGNED_DATA, safe=""),
|
||||
expiry=int(time.time()) - 3600, # 1 hour ago
|
||||
)
|
||||
sastoken_generator.generate_sastoken = mocker.AsyncMock()
|
||||
sastoken_generator.generate_sastoken.return_value = SasToken(expired_token_str)
|
||||
|
||||
with pytest.raises(SasTokenError):
|
||||
await SasTokenProvider.create_from_generator(sastoken_generator)
|
||||
|
||||
|
||||
@pytest.mark.describe("SasTokenProvider - .shutdown()")
|
||||
class TestSasTokenProviderShutdown:
|
||||
@pytest.mark.it("Cancels the stored ._keep_token_fresh() task")
|
||||
async def test_cancels_keep_token_fresh(self, sastoken_provider):
|
||||
assert isinstance(sastoken_provider._keep_token_fresh_task, asyncio.Task)
|
||||
assert not sastoken_provider._keep_token_fresh_task.done()
|
||||
await sastoken_provider.shutdown()
|
||||
assert sastoken_provider._keep_token_fresh_task.done()
|
||||
assert sastoken_provider._keep_token_fresh_task.cancelled()
|
||||
|
||||
|
||||
@pytest.mark.describe("SasTokenProvider - .get_current_sastoken()")
|
||||
class TestSasTokenGetCurrentSasToken:
|
||||
@pytest.mark.it("Returns the current SasToken object")
|
||||
def test_returns_current_token(self, sastoken_provider):
|
||||
current_token = sastoken_provider.get_current_sastoken()
|
||||
assert current_token is sastoken_provider._sastoken
|
||||
new_token_str = TOKEN_FORMAT.format(
|
||||
resource=urllib.parse.quote(FAKE_URI, safe=""),
|
||||
signature=urllib.parse.quote(FAKE_SIGNED_DATA, safe=""),
|
||||
expiry=int(time.time()) + 3600,
|
||||
)
|
||||
new_token = SasToken(new_token_str)
|
||||
sastoken_provider._sastoken = new_token
|
||||
assert sastoken_provider.get_current_sastoken() is new_token
|
||||
|
||||
|
||||
@pytest.mark.describe("SasTokenProvider - .wait_for_new_sastoken()")
|
||||
class TestSasTokenWaitForNewSasToken:
|
||||
@pytest.mark.it(
|
||||
"Returns the current SasToken object once a notified of a new token being available"
|
||||
)
|
||||
async def test_returns_new_current_token(self, sastoken_provider):
|
||||
token_str_1 = TOKEN_FORMAT.format(
|
||||
resource=urllib.parse.quote(FAKE_URI, safe=""),
|
||||
signature=urllib.parse.quote(FAKE_SIGNED_DATA, safe=""),
|
||||
expiry=int(time.time()) + 3600,
|
||||
)
|
||||
token1 = SasToken(token_str_1)
|
||||
token_str_2 = TOKEN_FORMAT.format(
|
||||
resource=urllib.parse.quote(FAKE_URI, safe=""),
|
||||
signature=urllib.parse.quote(FAKE_SIGNED_DATA2, safe=""),
|
||||
expiry=int(time.time()) + 4500,
|
||||
)
|
||||
token2 = SasToken(token_str_2)
|
||||
|
||||
sastoken_provider._sastoken = token1
|
||||
assert sastoken_provider.get_current_sastoken() is token1
|
||||
|
||||
# Waiting for new token, but one is not yet available
|
||||
task = asyncio.create_task(sastoken_provider.wait_for_new_sastoken())
|
||||
await asyncio.sleep(0.1)
|
||||
assert not task.done()
|
||||
|
||||
# Update the token, but without notification, the waiting task still does not return
|
||||
sastoken_provider._sastoken = token2
|
||||
await asyncio.sleep(0.1)
|
||||
assert not task.done()
|
||||
|
||||
# Notify that a new token is available, and now the task will return
|
||||
async with sastoken_provider._new_sastoken_available:
|
||||
sastoken_provider._new_sastoken_available.notify_all()
|
||||
returned_token = await task
|
||||
|
||||
# The task returned the new token
|
||||
assert returned_token is token2
|
||||
assert returned_token is not token1
|
||||
assert returned_token is sastoken_provider.get_current_sastoken()
|
||||
|
||||
|
||||
# NOTE: This test suite assumes the correct implementation of ._wait_until() for critical
|
||||
# requirements. Find it tested in a separate suite below (TestWaitUntil)
|
||||
@pytest.mark.describe("SasTokenProvider -- Keep Token Fresh Task")
|
||||
class TestSasTokenProviderKeepTokenFresh:
|
||||
@pytest.fixture(autouse=True)
|
||||
def spy_time(self, mocker):
|
||||
"""Spy on the time module so that we can find out last time that was returned"""
|
||||
spy_time = mocker.spy(time, "time")
|
||||
return spy_time
|
||||
|
||||
# NOTE: This is an autouse fixture to ensure that it gets called first, since we want to make sure
|
||||
# this mock is running when the SasTokenProvider is created.
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_wait_until(self, mocker):
|
||||
"""Mock out the wait_until function so these tests aren't dependent on real time passing"""
|
||||
mock_wait_until = mocker.patch.object(st, "_wait_until")
|
||||
mock_wait_until._allow_proceed = asyncio.Event()
|
||||
|
||||
# Fake implementation that will wait for an explicit trigger to proceed, rather than the
|
||||
# passage of time
|
||||
async def fake_wait_until(when):
|
||||
await mock_wait_until._allow_proceed.wait()
|
||||
|
||||
mock_wait_until.side_effect = fake_wait_until
|
||||
|
||||
# Define a mechanism that will allow an explicit trigger to let the mocked coroutine return
|
||||
def proceed():
|
||||
mock_wait_until._allow_proceed.set()
|
||||
mock_wait_until._allow_proceed = asyncio.Event()
|
||||
|
||||
mock_wait_until.proceed = proceed
|
||||
|
||||
return mock_wait_until
|
||||
|
||||
@pytest.mark.it(
|
||||
"Waits until the configured update margin number of seconds before current SasToken expiry to generate a new SasToken"
|
||||
)
|
||||
async def test_wait_to_generate(self, mocker, mock_wait_until, sastoken_provider):
|
||||
original_token = sastoken_provider.get_current_sastoken()
|
||||
assert sastoken_provider._generator.generate_sastoken.await_count == 0
|
||||
await asyncio.sleep(0.1)
|
||||
# We are waiting the expected amount of time
|
||||
expected_update_time = original_token.expiry_time - sastoken_provider._token_update_margin
|
||||
assert mock_wait_until.await_count == 1
|
||||
assert mock_wait_until.await_args == mocker.call(expected_update_time)
|
||||
# Allow the waiting to end, and a new token to be generated
|
||||
mock_wait_until.proceed()
|
||||
await asyncio.sleep(0.1)
|
||||
assert sastoken_provider._generator.generate_sastoken.await_count == 1
|
||||
assert sastoken_provider._generator.generate_sastoken.await_args == mocker.call()
|
||||
|
||||
@pytest.mark.it(
|
||||
"Sets the newly generated SasToken as the new current SasToken and sends notification of its availability"
|
||||
)
|
||||
async def test_replace_token_and_notify(self, mocker, sastoken_provider, mock_wait_until):
|
||||
notification_spy = mocker.spy(sastoken_provider._new_sastoken_available, "notify_all")
|
||||
# We have the original token, as we have not yet generated a new one
|
||||
original_token = sastoken_provider.get_current_sastoken()
|
||||
assert sastoken_provider._generator.generate_sastoken.await_count == 0
|
||||
assert notification_spy.call_count == 0
|
||||
# Allow waiting to proceed, and a new token to be generated
|
||||
mock_wait_until.proceed()
|
||||
await asyncio.sleep(0.1)
|
||||
# A new token has now been generated
|
||||
assert sastoken_provider._generator.generate_sastoken.await_count == 1
|
||||
# The current token is now the token that was just generated
|
||||
current_token = sastoken_provider.get_current_sastoken()
|
||||
assert current_token is sastoken_provider._generator.generate_sastoken.spy_return
|
||||
# This token is not the same as the original token
|
||||
assert current_token is not original_token
|
||||
# A notification was sent about the new token
|
||||
assert notification_spy.call_count == 1
|
||||
|
||||
@pytest.mark.it(
|
||||
"Waits until the configured update margin number of seconds before the NEW current SasToken expiry, after each time a new SasToken is generated, before once again generating a new SasToken"
|
||||
)
|
||||
async def test_wait_to_generate_again_and_again(
|
||||
self, mocker, mock_wait_until, sastoken_provider
|
||||
):
|
||||
# Current token is the original, we have not yet generated a new one
|
||||
original_token = sastoken_provider.get_current_sastoken()
|
||||
assert sastoken_provider._generator.generate_sastoken.await_count == 0
|
||||
await asyncio.sleep(0.1)
|
||||
# We are waiting based on the original token's expiry time
|
||||
expected_update_time = original_token.expiry_time - sastoken_provider._token_update_margin
|
||||
assert mock_wait_until.await_count == 1
|
||||
assert mock_wait_until.await_args == mocker.call(expected_update_time)
|
||||
# Allow the waiting to end, and a new token to be generated
|
||||
mock_wait_until.proceed()
|
||||
await asyncio.sleep(0.1)
|
||||
assert sastoken_provider._generator.generate_sastoken.await_count == 1
|
||||
assert sastoken_provider._generator.generate_sastoken.await_args == mocker.call()
|
||||
# New token is the one that was just generated
|
||||
new_token = sastoken_provider.get_current_sastoken()
|
||||
assert new_token is sastoken_provider._generator.generate_sastoken.spy_return
|
||||
assert new_token is not original_token
|
||||
# We are once again waiting, this time based on the new token's expiry time
|
||||
expected_update_time = new_token.expiry_time - sastoken_provider._token_update_margin
|
||||
assert mock_wait_until.await_count == 2
|
||||
assert mock_wait_until.await_args == mocker.call(expected_update_time)
|
||||
# Allow the waiting to end and another new token to be generated
|
||||
mock_wait_until.proceed()
|
||||
await asyncio.sleep(0.1)
|
||||
assert sastoken_provider._generator.generate_sastoken.await_count == 2
|
||||
assert sastoken_provider._generator.generate_sastoken.await_args == mocker.call()
|
||||
# Newest token is the one that was just generated
|
||||
newest_token = sastoken_provider.get_current_sastoken()
|
||||
assert newest_token is sastoken_provider._generator.generate_sastoken.spy_return
|
||||
assert newest_token is not original_token
|
||||
assert newest_token is not new_token
|
||||
# We are once again waiting, this time based on the newest token's expiry time
|
||||
expected_update_time = newest_token.expiry_time - sastoken_provider._token_update_margin
|
||||
assert mock_wait_until.await_count == 3
|
||||
assert mock_wait_until.await_args == mocker.call(expected_update_time)
|
||||
# And so on and so forth to infinity...
|
||||
|
||||
@pytest.mark.it(
|
||||
"Sets the newly generated SasToken as the new current SasToken and sends notification of its availability each time a new token is generated"
|
||||
)
|
||||
async def test_replace_token_and_notify_each_time(
|
||||
self, mocker, sastoken_provider, mock_wait_until
|
||||
):
|
||||
notification_spy = mocker.spy(sastoken_provider._new_sastoken_available, "notify_all")
|
||||
# We have the original token, as we have not yet generated a new one
|
||||
original_token = sastoken_provider.get_current_sastoken()
|
||||
assert sastoken_provider._generator.generate_sastoken.await_count == 0
|
||||
assert notification_spy.call_count == 0
|
||||
# Allow waiting to proceed, and a new token to be generated
|
||||
mock_wait_until.proceed()
|
||||
await asyncio.sleep(0.1)
|
||||
# A new token has now been generated
|
||||
assert sastoken_provider._generator.generate_sastoken.await_count == 1
|
||||
# The current token is now the token that was just generated
|
||||
second_token = sastoken_provider.get_current_sastoken()
|
||||
assert second_token is sastoken_provider._generator.generate_sastoken.spy_return
|
||||
# This token is not the same as the original token
|
||||
assert second_token is not original_token
|
||||
# A notification was sent about the new token
|
||||
assert notification_spy.call_count == 1
|
||||
# Allow waiting to proceed and another new token to be generated
|
||||
mock_wait_until.proceed()
|
||||
await asyncio.sleep(0.1)
|
||||
# Another new token has now been generated
|
||||
assert sastoken_provider._generator.generate_sastoken.await_count == 2
|
||||
# The current token is now the token that was just generated
|
||||
third_token = sastoken_provider.get_current_sastoken()
|
||||
assert third_token is sastoken_provider._generator.generate_sastoken.spy_return
|
||||
# This token is not the same as any previous token
|
||||
assert third_token is not original_token
|
||||
assert third_token is not second_token
|
||||
# A notification was sent about the new token
|
||||
assert notification_spy.call_count == 2
|
||||
# And so on and so forth to infinity...
|
||||
|
||||
@pytest.mark.it("Tries to generate again in 10 seconds if SasToken generation fails")
|
||||
@pytest.mark.parametrize(
|
||||
"exception",
|
||||
[
|
||||
pytest.param(SasTokenError("Some error in SAS"), id="SasTokenError"),
|
||||
pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"),
|
||||
],
|
||||
)
|
||||
async def test_generation_failure(
|
||||
self, mocker, sastoken_provider, exception, mock_wait_until, spy_time
|
||||
):
|
||||
# Set token generation to raise exception
|
||||
sastoken_provider._generator.generate_sastoken.side_effect = exception
|
||||
# Token generation has not yet happened
|
||||
assert sastoken_provider._generator.generate_sastoken.await_count == 0
|
||||
# Allow waiting to proceed, and a new token to be generated
|
||||
assert mock_wait_until.await_count == 1
|
||||
mock_wait_until.proceed()
|
||||
await asyncio.sleep(0.1)
|
||||
# Waits 10 seconds past the current time
|
||||
expected_generate_time = spy_time.spy_return + 10
|
||||
assert mock_wait_until.await_count == 2
|
||||
assert mock_wait_until.await_args == mocker.call(expected_generate_time)
|
||||
|
||||
|
||||
# NOTE: We don't normally test convention-private helpers directly, but in this case, the
|
||||
# complexity is high enough, and the function is critical enough, that it makes more sense
|
||||
# to isolate rather than attempting to indirectly test.
|
||||
@pytest.mark.describe("._wait_until()")
|
||||
class TestWaitUntil:
|
||||
@pytest.mark.it(
|
||||
"Repeatedly does 1 second asyncio sleeps until the current time is greater than the provided 'when' parameter"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"time_from_now",
|
||||
[
|
||||
pytest.param(5, id="5 seconds from now"),
|
||||
pytest.param(60, id="1 minute from now"),
|
||||
pytest.param(3600, id="1 hour from now"),
|
||||
],
|
||||
)
|
||||
async def test_sleep(self, mocker, time_from_now):
|
||||
# Mock out the sleep coroutine so that we aren't waiting around forever on this test
|
||||
mock_sleep = mocker.patch.object(asyncio, "sleep")
|
||||
|
||||
# mock out time
|
||||
def fake_time():
|
||||
"""Fake time implementation that will return a time float that is 1 larger
|
||||
than the previous time it was called"""
|
||||
fake_time_return = FAKE_CURRENT_TIME
|
||||
while True:
|
||||
yield fake_time_return
|
||||
fake_time_return += 1
|
||||
|
||||
fake_time_gen = fake_time()
|
||||
mock_time = mocker.patch.object(time, "time", side_effect=fake_time_gen)
|
||||
|
||||
desired_time = FAKE_CURRENT_TIME + time_from_now
|
||||
|
||||
await st._wait_until(desired_time)
|
||||
|
||||
assert mock_sleep.await_count == time_from_now
|
||||
for call in mock_sleep.await_args_list:
|
||||
assert call == mocker.call(1)
|
||||
assert mock_time.call_count == time_from_now + 1
|
||||
for call in mock_time.call_args_list:
|
||||
assert call == mocker.call()
|
||||
|
|
|
@ -74,13 +74,13 @@ class TestSymmetricKeySigningMechanismSign(object):
|
|||
@pytest.mark.it(
|
||||
"Generates an HMAC message digest from the signing key and provided data string, using the HMAC-SHA256 algorithm"
|
||||
)
|
||||
def test_hmac(self, mocker, signing_mechanism):
|
||||
async def test_hmac(self, mocker, signing_mechanism):
|
||||
hmac_mock = mocker.patch.object(hmac, "HMAC")
|
||||
hmac_digest_mock = hmac_mock.return_value.digest
|
||||
hmac_digest_mock.return_value = b"\xd2\x06\xf7\x12\xf1\xe9\x95$\x90\xfd\x12\x9a\xb1\xbe\xb4\xf8\xf3\xc4\x1ap\x8a\xab'\x8a.D\xfb\x84\x96\xca\xf3z"
|
||||
|
||||
data_string = "sign this message"
|
||||
signing_mechanism.sign(data_string)
|
||||
await signing_mechanism.sign(data_string)
|
||||
|
||||
assert hmac_mock.call_count == 1
|
||||
assert hmac_mock.call_args == mocker.call(
|
||||
|
@ -93,13 +93,13 @@ class TestSymmetricKeySigningMechanismSign(object):
|
|||
@pytest.mark.it(
|
||||
"Returns the base64 encoded HMAC message digest (converted to string) as the signed data"
|
||||
)
|
||||
def test_b64encode(self, mocker, signing_mechanism):
|
||||
async def test_b64encode(self, mocker, signing_mechanism):
|
||||
hmac_mock = mocker.patch.object(hmac, "HMAC")
|
||||
hmac_digest_mock = hmac_mock.return_value.digest
|
||||
hmac_digest_mock.return_value = b"\xd2\x06\xf7\x12\xf1\xe9\x95$\x90\xfd\x12\x9a\xb1\xbe\xb4\xf8\xf3\xc4\x1ap\x8a\xab'\x8a.D\xfb\x84\x96\xca\xf3z"
|
||||
|
||||
data_string = "sign this message"
|
||||
signature = signing_mechanism.sign(data_string)
|
||||
signature = await signing_mechanism.sign(data_string)
|
||||
|
||||
assert signature == base64.b64encode(hmac_digest_mock.return_value).decode("utf-8")
|
||||
|
||||
|
@ -115,11 +115,11 @@ class TestSymmetricKeySigningMechanismSign(object):
|
|||
),
|
||||
],
|
||||
)
|
||||
def test_supported_types(self, signing_mechanism, data_string, expected_signature):
|
||||
assert signing_mechanism.sign(data_string) == expected_signature
|
||||
async def test_supported_types(self, signing_mechanism, data_string, expected_signature):
|
||||
assert await signing_mechanism.sign(data_string) == expected_signature
|
||||
|
||||
@pytest.mark.it("Raises a ValueError if unable to sign the provided data string")
|
||||
@pytest.mark.parametrize("data_string", [pytest.param(123, id="Integer input")])
|
||||
def test_bad_input(self, signing_mechanism, data_string):
|
||||
async def test_bad_input(self, signing_mechanism, data_string):
|
||||
with pytest.raises(ValueError):
|
||||
signing_mechanism.sign(data_string)
|
||||
await signing_mechanism.sign(data_string)
|
||||
|
|
|
@ -8,7 +8,7 @@ import logging
|
|||
import socks
|
||||
import ssl
|
||||
from typing import Optional, Any
|
||||
from azure.iot.device.common.auth import sastoken as st # type: ignore
|
||||
from .sastoken import SasTokenProvider
|
||||
|
||||
# TODO: add typings for imports
|
||||
# TODO: update docs to ensure types are correct
|
||||
|
@ -67,8 +67,8 @@ class ClientConfig:
|
|||
ssl_context: ssl.SSLContext,
|
||||
hostname: str,
|
||||
gateway_hostname: Optional[str] = None,
|
||||
sastoken_provider: Optional[SasTokenProvider] = None,
|
||||
proxy_options: Optional[ProxyOptions] = None,
|
||||
sastoken: Optional[st.SasToken] = None,
|
||||
keep_alive: int = 60,
|
||||
auto_reconnect: bool = True,
|
||||
websockets: bool = False,
|
||||
|
@ -77,12 +77,12 @@ class ClientConfig:
|
|||
|
||||
:param str hostname: The hostname being connected to
|
||||
:param str gateway_hostname: The gateway hostname optionally being used
|
||||
:param sastoken_provider: Object that can provide SasTokens
|
||||
:type sastoken_provider: :class:`SasTokenProvider`
|
||||
:param proxy_options: Details of proxy configuration
|
||||
:type proxy_options: :class:`azure.iot.device.common.models.ProxyOptions`
|
||||
:param ssl_context: SSLContext to use with the client
|
||||
:type ssl_context: :class:`ssl.SSLContext`
|
||||
:param sastoken: SasToken to be used for authentication. Mutually exclusive with x509.
|
||||
:type sastoken: :class:`azure.iot.device.common.auth.SasToken`
|
||||
:param int keepalive: Maximum period in seconds between communications with the
|
||||
broker.
|
||||
:param bool auto_reconnect: Indicates if dropped connection should result in attempts to
|
||||
|
@ -94,10 +94,10 @@ class ClientConfig:
|
|||
self.hostname = hostname
|
||||
self.gateway_hostname = gateway_hostname
|
||||
self.proxy_options = proxy_options
|
||||
self.ssl_context = ssl_context
|
||||
|
||||
# Auth
|
||||
self.sastoken = sastoken
|
||||
self.sastoken_provider = sastoken_provider
|
||||
self.ssl_context = ssl_context
|
||||
|
||||
# MQTT
|
||||
self.keep_alive = _sanitize_keep_alive(keep_alive)
|
||||
|
|
|
@ -47,7 +47,8 @@ class IoTEdgeHsm(SigningMechanism):
|
|||
self.generation_id = generation_id
|
||||
self.workload_uri = _format_socket_uri(workload_uri)
|
||||
|
||||
def get_certificate(self) -> str:
|
||||
# TODO: Use async http to make use of this being a coroutine
|
||||
async def get_certificate(self) -> str:
|
||||
"""
|
||||
Return the server verification certificate from the trust bundle that can be used to
|
||||
validate the server-side SSL TLS connection that we use to talk to Edge
|
||||
|
@ -79,7 +80,8 @@ class IoTEdgeHsm(SigningMechanism):
|
|||
raise IoTEdgeError("No certificate in trust bundle") from e
|
||||
return cert
|
||||
|
||||
def sign(self, data_str: Union[str, bytes]) -> str:
|
||||
# TODO: Use async http to make use of this being a coroutine
|
||||
async def sign(self, data_str: Union[str, bytes]) -> str:
|
||||
"""
|
||||
Use the IoTEdge HSM to sign a piece of string data. The caller should then insert the
|
||||
returned value (the signature) into the 'sig' field of a SharedAccessSignature string.
|
||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
import json
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import Optional, AsyncGenerator
|
||||
from typing import Optional, Union, AsyncGenerator
|
||||
from .custom_typing import TwinPatch, Twin
|
||||
from .iot_exceptions import IoTHubError, IoTHubClientError
|
||||
from .models import Message, MethodResponse, MethodRequest
|
||||
|
@ -16,15 +16,12 @@ from . import config, constant, user_agent
|
|||
from . import request_response as rr
|
||||
from . import mqtt_client as mqtt
|
||||
from . import mqtt_topic_iothub as mqtt_topic
|
||||
from azure.iot.device.common.auth import sastoken as st # type: ignore
|
||||
from azure.iot.device.common import alarm # type: ignore
|
||||
|
||||
# TODO: update docstrings with correct class paths once repo structured better
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_RECONNECT_INTERVAL = 10
|
||||
DEFAULT_TOKEN_UPDATE_MARGIN = 120
|
||||
DEFAULT_RECONNECT_INTERVAL: int = 10
|
||||
|
||||
# TODO: add exceptions to docstring
|
||||
# TODO: background exceptions how
|
||||
|
@ -43,23 +40,29 @@ class IoTHubMQTTClient:
|
|||
:param client_config: The config object for the client
|
||||
:type client_config: :class:`IoTHubClientConfig`
|
||||
"""
|
||||
# Basic Information
|
||||
# Identity
|
||||
self._device_id = client_config.device_id
|
||||
self._module_id = client_config.module_id
|
||||
self._client_id = _format_client_id(self._device_id, self._module_id)
|
||||
self._username = _format_username(
|
||||
# NOTE: Always use the original hostname, even if gateway hostname is set
|
||||
hostname=client_config.hostname,
|
||||
client_id=self._client_id,
|
||||
product_info=client_config.product_info,
|
||||
)
|
||||
|
||||
# Sastoken Auth
|
||||
# TODO: Should this be handled by a separate utility? Would make testing easier, and would abstract out the difference
|
||||
self._sastoken: Optional[st.SasToken]
|
||||
self._sastoken_update_alarm: Optional[alarm.Alarm]
|
||||
if client_config.sastoken is not None:
|
||||
self._sastoken = client_config.sastoken
|
||||
self._sastoken_update_alarm = self._create_token_update_alarm()
|
||||
else:
|
||||
self._sastoken = None
|
||||
self._sastoken_update_alarm = None
|
||||
# SAS (Optional)
|
||||
self._sastoken_provider = client_config.sastoken_provider
|
||||
|
||||
# MQTT Configuration
|
||||
self._mqtt_client = _create_mqtt_client(client_config)
|
||||
self._mqtt_client = _create_mqtt_client(self._client_id, client_config)
|
||||
if self._sastoken_provider:
|
||||
logger.debug("Using SASToken as password")
|
||||
password = str(self._sastoken_provider.get_current_sastoken())
|
||||
else:
|
||||
logger.debug("No password used")
|
||||
password = None
|
||||
self._mqtt_client.set_credentials(self._username, password)
|
||||
|
||||
# Create incoming IoTHub data generators
|
||||
# TODO: expose these via method to make the device/module split cleaner
|
||||
|
@ -79,40 +82,21 @@ class IoTHubMQTTClient:
|
|||
# Internal request/response infrastructure
|
||||
self._request_ledger = rr.RequestLedger()
|
||||
self._twin_responses_enabled = False
|
||||
self._twin_response_listener = asyncio.create_task(self._process_twin_responses())
|
||||
|
||||
def _create_token_update_alarm(self) -> alarm.Alarm:
|
||||
if not self._sastoken:
|
||||
# This should never happen, it's just for the type checker
|
||||
raise ValueError("Can't create alarm for no SASToken")
|
||||
|
||||
update_time = self._sastoken.expiry_time - DEFAULT_TOKEN_UPDATE_MARGIN
|
||||
|
||||
if isinstance(self._sastoken, st.RenewableSasToken):
|
||||
|
||||
def on_token_needs_update():
|
||||
# Renew the token
|
||||
logger.debug("Renewing SAS Token...")
|
||||
try:
|
||||
self._sastoken.refresh()
|
||||
logger.debug("SAS Token renewal succeeded")
|
||||
except st.SasTokenError:
|
||||
logger.error("SAS Token renewal failed")
|
||||
# TODO: background exception?
|
||||
# With the token renewed, now set a new Alarm
|
||||
self._sastoken_update_alarm = self._create_token_update_alarm()
|
||||
|
||||
# Background Tasks
|
||||
self._process_twin_responses_bg_task: asyncio.Task[None] = asyncio.create_task(
|
||||
self._process_twin_responses()
|
||||
)
|
||||
self._keep_credentials_fresh_bg_task: Optional[asyncio.Task[None]]
|
||||
if self._sastoken_provider:
|
||||
self._keep_credentials_fresh_bg_task = asyncio.create_task(
|
||||
self._keep_credentials_fresh()
|
||||
)
|
||||
else:
|
||||
|
||||
def on_token_needs_update():
|
||||
pass
|
||||
|
||||
update_alarm = alarm.Alarm(update_time, on_token_needs_update)
|
||||
update_alarm.daemon = True
|
||||
update_alarm.start()
|
||||
return update_alarm
|
||||
self._keep_credentials_fresh_bg_task = None
|
||||
|
||||
async def _enable_twin_responses(self) -> None:
|
||||
"""Enable receiving of twin responses (for twin requests, or twin patches) from IoTHub"""
|
||||
logger.debug("Enabling receive of twin responses...")
|
||||
topic = mqtt_topic.get_twin_response_topic_for_subscribe()
|
||||
await self._mqtt_client.subscribe(topic)
|
||||
|
@ -122,7 +106,7 @@ class IoTHubMQTTClient:
|
|||
# TODO: add background exception handling
|
||||
async def _process_twin_responses(self) -> None:
|
||||
"""Run indefinitely, matching twin responses with request ID"""
|
||||
logger.debug("Starting twin response listener")
|
||||
logger.debug("Starting the 'process_twin_responses' background task")
|
||||
twin_response_topic = mqtt_topic.get_twin_response_topic_for_subscribe()
|
||||
twin_responses = self._mqtt_client.get_incoming_message_generator(twin_response_topic)
|
||||
|
||||
|
@ -144,28 +128,62 @@ class IoTHubMQTTClient:
|
|||
# in-flight operations
|
||||
logger.warning("Twin response (rid: {}) does not match any request")
|
||||
|
||||
async def _keep_credentials_fresh(self) -> None:
|
||||
"""Run indefinitely, updating MQTT credentials when new SAS Token is available"""
|
||||
logger.debug("Starting the 'keep_credentials_fresh' background task")
|
||||
while True:
|
||||
if self._sastoken_provider:
|
||||
logger.debug("Waiting for new SAS Token to become available")
|
||||
new_sastoken = await self._sastoken_provider.wait_for_new_sastoken()
|
||||
logger.debug("New SAS Token available, updating MQTTClient credentials")
|
||||
self._mqtt_client.set_credentials(self._username, str(new_sastoken))
|
||||
# TODO: should we reconnect here? Or just wait for drop?
|
||||
else:
|
||||
# NOTE: This should never execute, it's mostly just here to keep the
|
||||
# type checker happy
|
||||
logger.error("No SasTokenProvider. Cannot update credentials")
|
||||
break
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Shut down the client.
|
||||
|
||||
Invoke only when completely finished with the client for graceful exit.
|
||||
Cannot be cancelled - if you try, the client will still fully shut down as much as
|
||||
possible.
|
||||
"""
|
||||
# TODO: this breaks when called twice. Build some protections.
|
||||
# TODO: is there an issue with cancellation here?
|
||||
await self.disconnect()
|
||||
# Cancel the SAS token update alarm. Note that this is not a task, it's a threaded Alarm.
|
||||
# No need to wait for the result.
|
||||
if self._sastoken_update_alarm:
|
||||
logger.debug("Cancelling SAS Token update alarm")
|
||||
self._sastoken_update_alarm.cancel()
|
||||
# Cancel and wait for the completion of the twin response task
|
||||
logger.debug("Cancelling twin response listener")
|
||||
self._twin_response_listener.cancel()
|
||||
# Wait for the cancellation to complete before returning
|
||||
# NOTE: .disconnect() really shouldn't fail, but if it does, we temporarily suppress
|
||||
# the exception so we can still do as much cleanup as possible.
|
||||
cached_exception: Optional[Union[Exception, asyncio.CancelledError]] = None
|
||||
logger.debug("Attempting disconnect in shutdown")
|
||||
try:
|
||||
await self._twin_response_listener
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await self.disconnect()
|
||||
except asyncio.CancelledError as e:
|
||||
logger.warning("Cancellation during shutdown. Still attempting to clean up.")
|
||||
cached_exception = e
|
||||
except Exception as e:
|
||||
logger.warning("Unexpected error disconnecting. Continuing shutdown procedure")
|
||||
cached_exception = e
|
||||
|
||||
cancelled_tasks = []
|
||||
|
||||
logger.debug("Cancelling 'process_twin_responses' background task")
|
||||
self._process_twin_responses_bg_task.cancel()
|
||||
cancelled_tasks.append(self._process_twin_responses_bg_task)
|
||||
|
||||
if self._keep_credentials_fresh_bg_task:
|
||||
logger.debug("Cancelling 'keep_credentials_fresh' background task")
|
||||
self._keep_credentials_fresh_bg_task.cancel()
|
||||
cancelled_tasks.append(self._keep_credentials_fresh_bg_task)
|
||||
|
||||
# Wait for the cancellation to complete before returning
|
||||
# NOTE: If cancelled while awaiting here, all tasks in gather will still be cancelled
|
||||
# because the cancellations have already been issued.
|
||||
# NOTE: Also, cancelling a gather implicitly cancels all the tasks that are gathered anyway
|
||||
await asyncio.gather(*cancelled_tasks, return_exceptions=True)
|
||||
|
||||
if cached_exception:
|
||||
raise cached_exception
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to IoTHub
|
||||
|
@ -468,17 +486,22 @@ class IoTHubMQTTClient:
|
|||
logger.debug("Twin patch receive disabled")
|
||||
|
||||
|
||||
# Auth Helpers
|
||||
def _format_client_id(device_id: str, module_id: Optional[str] = None) -> str:
|
||||
if module_id:
|
||||
client_id = "{}/{}".format(device_id, module_id)
|
||||
else:
|
||||
client_id = device_id
|
||||
return client_id
|
||||
|
||||
|
||||
def _create_mqtt_client(client_config: config.IoTHubClientConfig) -> mqtt.MQTTClient:
|
||||
def _create_mqtt_client(
|
||||
client_id: str, client_config: config.IoTHubClientConfig
|
||||
) -> mqtt.MQTTClient:
|
||||
logger.debug("Creating MQTTClient")
|
||||
|
||||
if client_config.module_id:
|
||||
client_id = "{}/{}".format(client_config.device_id, client_config.module_id)
|
||||
logger.debug("Using IoTHub Module. Client ID is {}".format(client_id))
|
||||
else:
|
||||
client_id = client_config.device_id
|
||||
logger.debug("Using IoTHub Device. Client ID is {}".format(client_id))
|
||||
|
||||
if client_config.gateway_hostname:
|
||||
|
@ -512,24 +535,6 @@ def _create_mqtt_client(client_config: config.IoTHubClientConfig) -> mqtt.MQTTCl
|
|||
proxy_options=client_config.proxy_options,
|
||||
)
|
||||
|
||||
# NOTE: we use the original hostname here, even if gateway hostname is set
|
||||
username = _create_username(
|
||||
hostname=client_config.hostname,
|
||||
client_id=client_id,
|
||||
product_info=client_config.product_info,
|
||||
)
|
||||
logger.debug("Using {} as username".format(username))
|
||||
|
||||
if client_config.sastoken:
|
||||
logger.debug("Using SASToken as password")
|
||||
password = str(client_config.sastoken)
|
||||
|
||||
else:
|
||||
logger.debug("No password used")
|
||||
password = None
|
||||
|
||||
client.set_credentials(username, password)
|
||||
|
||||
# Add topic filters for receive
|
||||
# IoTHub Receives
|
||||
c2d_msg_topic = mqtt_topic.get_c2d_topic_for_subscribe(client_config.device_id)
|
||||
|
@ -550,7 +555,7 @@ def _create_mqtt_client(client_config: config.IoTHubClientConfig) -> mqtt.MQTTCl
|
|||
return client
|
||||
|
||||
|
||||
def _create_username(hostname: str, client_id: str, product_info: str) -> str:
|
||||
def _format_username(hostname: str, client_id: str, product_info: str) -> str:
|
||||
query_param_seq = []
|
||||
|
||||
# Apply query parameters (i.e. key1=value1&key2=value2...&keyN=valueN format)
|
||||
|
|
|
@ -6,11 +6,19 @@
|
|||
"""This module contains tools for working with Shared Access Signature (SAS) Tokens"""
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import urllib
|
||||
from typing import Optional, Dict
|
||||
import urllib.parse
|
||||
from typing import Dict, List, Union, Awaitable, Callable, cast
|
||||
from .signing_mechanism import SigningMechanism
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TOKEN_UPDATE_MARGIN: int = 120
|
||||
REQUIRED_SASTOKEN_FIELDS: List[str] = ["sr", "sig", "se"]
|
||||
TOKEN_FORMAT: str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}"
|
||||
|
||||
|
||||
class SasTokenError(Exception):
|
||||
"""Error in SasToken"""
|
||||
|
@ -18,160 +26,213 @@ class SasTokenError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class SasToken(abc.ABC):
|
||||
"""Abstract parent class for SAS Tokens.
|
||||
class SasToken:
|
||||
def __init__(self, sastoken_str: str) -> None:
|
||||
"""Create a SasToken object from a SAS Token string
|
||||
:param str sastoken_str: The SAS Token string
|
||||
|
||||
Doesn't do much, but helps with type hints
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def expiry_time(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
class RenewableSasToken(SasToken):
|
||||
"""Renewable Shared Access Signature Token used to authenticate a request.
|
||||
|
||||
This token is 'renewable', which means that it can be updated when necessary to
|
||||
prevent expiry, by using the .refresh() method.
|
||||
|
||||
Data Attributes:
|
||||
expiry_time (int): Time that token will expire (in UTC, since epoch)
|
||||
ttl (int): Time to live for the token, in seconds
|
||||
"""
|
||||
|
||||
_auth_rule_token_format = (
|
||||
"SharedAccessSignature sr={resource}&sig={signature}&se={expiry}&skn={keyname}"
|
||||
)
|
||||
_simple_token_format = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
signing_mechanism: SigningMechanism,
|
||||
key_name: Optional[str] = None,
|
||||
ttl: int = 3600,
|
||||
) -> None:
|
||||
:raises: ValueError if SAS Token string is invalid
|
||||
"""
|
||||
:param str uri: URI of the resource to be accessed
|
||||
:param signing_mechanism: The signing mechanism to use in the SasToken
|
||||
:type signing_mechanism: Child classes of :class:`azure.iot.common.SigningMechanism`
|
||||
:param str key_name: Symmetric Key Name (optional)
|
||||
:param int ttl: Time to live for the token, in seconds (default 3600)
|
||||
|
||||
:raises: SasTokenError if an error occurs building a SasToken
|
||||
"""
|
||||
self._uri = uri
|
||||
self._signing_mechanism = signing_mechanism
|
||||
self._key_name = key_name
|
||||
# These two values will be set by the .refresh() call
|
||||
self._expiry_time: int
|
||||
self._token: str
|
||||
|
||||
self.ttl = ttl
|
||||
self.refresh()
|
||||
self._token_str: str = sastoken_str
|
||||
self._token_info: Dict[str, str] = _get_sastoken_info_from_string(sastoken_str)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._token
|
||||
|
||||
def refresh(self) -> None:
|
||||
"""
|
||||
Refresh the SasToken lifespan, giving it a new expiry time, and generating a new token.
|
||||
"""
|
||||
self._expiry_time = int(time.time() + self.ttl)
|
||||
self._token = self._build_token()
|
||||
|
||||
def _build_token(self) -> str:
|
||||
"""Build SasToken representation
|
||||
|
||||
:returns: String representation of the token
|
||||
"""
|
||||
url_encoded_uri = urllib.parse.quote(self._uri, safe="")
|
||||
message = url_encoded_uri + "\n" + str(self.expiry_time)
|
||||
try:
|
||||
signature = self._signing_mechanism.sign(message)
|
||||
except Exception as e:
|
||||
# Because of variant signing mechanisms, we don't know what error might be raised.
|
||||
# So we catch all of them.
|
||||
raise SasTokenError("Unable to build SasToken from given values") from e
|
||||
url_encoded_signature = urllib.parse.quote(signature, safe="")
|
||||
if self._key_name:
|
||||
token = self._auth_rule_token_format.format(
|
||||
resource=url_encoded_uri,
|
||||
signature=url_encoded_signature,
|
||||
expiry=str(self.expiry_time),
|
||||
keyname=self._key_name,
|
||||
)
|
||||
else:
|
||||
token = self._simple_token_format.format(
|
||||
resource=url_encoded_uri,
|
||||
signature=url_encoded_signature,
|
||||
expiry=str(self.expiry_time),
|
||||
)
|
||||
return token
|
||||
return self._token_str
|
||||
|
||||
@property
|
||||
def expiry_time(self) -> int:
|
||||
"""Expiry Time is READ ONLY"""
|
||||
return self._expiry_time
|
||||
|
||||
|
||||
class NonRenewableSasToken(SasToken):
|
||||
"""NonRenewable Shared Access Signature Token used to authenticate a request.
|
||||
|
||||
This token is 'non-renewable', which means that it is invalid once it expires, and there
|
||||
is no way to keep it alive. Instead, a new token must be created.
|
||||
|
||||
Data Attributes:
|
||||
expiry_time (int): Time that token will expire (in UTC, since epoch)
|
||||
resource_uri (str): URI for the resource the Token provides authentication to access
|
||||
"""
|
||||
|
||||
def __init__(self, sastoken_string) -> None:
|
||||
"""
|
||||
:param str sastoken_string: A string representation of a SAS token
|
||||
"""
|
||||
self._token = sastoken_string
|
||||
self._token_info = get_sastoken_info_from_string(self._token)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._token
|
||||
|
||||
@property
|
||||
def expiry_time(self) -> int:
|
||||
"""Expiry Time is READ ONLY"""
|
||||
return int(self._token_info["se"])
|
||||
def expiry_time(self) -> float:
|
||||
# NOTE: Time is typically expressed in float in Python, even though a
|
||||
# SAS Token expiry time should be a whole number.
|
||||
return float(self._token_info["se"])
|
||||
|
||||
@property
|
||||
def resource_uri(self) -> str:
|
||||
"""Resource URI is READ ONLY"""
|
||||
uri = self._token_info["sr"]
|
||||
return urllib.parse.unquote(uri)
|
||||
|
||||
|
||||
REQUIRED_SASTOKEN_FIELDS = ["sr", "sig", "se"]
|
||||
VALID_SASTOKEN_FIELDS = REQUIRED_SASTOKEN_FIELDS + ["skn"]
|
||||
@property
|
||||
def signature(self) -> str:
|
||||
signature = self._token_info["sig"]
|
||||
return urllib.parse.unquote(signature)
|
||||
|
||||
|
||||
def get_sastoken_info_from_string(sastoken_string: str) -> Dict[str, str]:
|
||||
class SasTokenGenerator(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def generate_sastoken(self):
|
||||
pass
|
||||
|
||||
|
||||
class InternalSasTokenGenerator(SasTokenGenerator):
|
||||
def __init__(self, signing_mechanism: SigningMechanism, uri: str, ttl: int = 3600) -> None:
|
||||
"""An object that can generate SasTokens using provided values
|
||||
|
||||
:param str uri: The URI of the resource you are generating a tokens to access
|
||||
:param signing_mechanism: The signing mechanism that will be used to sign data
|
||||
:type signing mechanism: :class:`SigningMechanism`
|
||||
:param int ttl: Time to live for generated tokens, in seconds (default 3600)
|
||||
"""
|
||||
self.signing_mechanism = signing_mechanism
|
||||
self.uri = uri
|
||||
self.ttl = ttl
|
||||
|
||||
async def generate_sastoken(self) -> SasToken:
|
||||
"""Generate a new SasToken
|
||||
|
||||
:raises: SasTokenError if the token cannot be generated
|
||||
"""
|
||||
expiry_time = int(time.time()) + self.ttl
|
||||
url_encoded_uri = urllib.parse.quote(self.uri, safe="")
|
||||
message = url_encoded_uri + "\n" + str(expiry_time)
|
||||
try:
|
||||
signature = await self.signing_mechanism.sign(message)
|
||||
except Exception as e:
|
||||
# Because of variant signing mechanisms, we don't know what error might be raised.
|
||||
# So we catch all of them.
|
||||
raise SasTokenError("Unable to generate SasToken") from e
|
||||
url_encoded_signature = urllib.parse.quote(signature, safe="")
|
||||
token_str = TOKEN_FORMAT.format(
|
||||
resource=url_encoded_uri,
|
||||
signature=url_encoded_signature,
|
||||
expiry=str(expiry_time),
|
||||
)
|
||||
return SasToken(token_str)
|
||||
|
||||
|
||||
class ExternalSasTokenGenerator(SasTokenGenerator):
|
||||
def __init__(self, generator_fn: Union[Callable[[], str], Callable[[], Awaitable[str]]]):
|
||||
"""An object that can generate SasTokens by invoking a provided callable.
|
||||
This callable can be a function or a coroutine function.
|
||||
|
||||
:param generator_fn: A callable that takes no arguments and returns a SAS Token string
|
||||
:type generator_fn: Function or Coroutine Function which returns a string
|
||||
"""
|
||||
self.generator_fn = generator_fn
|
||||
|
||||
async def generate_sastoken(self) -> SasToken:
|
||||
"""Generate a new SasToken
|
||||
|
||||
:raises: SasTokenError if the token cannot be generated
|
||||
"""
|
||||
try:
|
||||
# NOTE: the typechecker has some problems here, so we help it with a cast.
|
||||
if asyncio.iscoroutinefunction(self.generator_fn):
|
||||
generator_fn = cast(Callable[[], Awaitable[str]], self.generator_fn)
|
||||
token_str = await generator_fn()
|
||||
else:
|
||||
generator_coro_fn = cast(Callable[[], str], self.generator_fn)
|
||||
token_str = generator_coro_fn()
|
||||
return SasToken(token_str)
|
||||
except Exception as e:
|
||||
raise SasTokenError("Unable to generate SasToken") from e
|
||||
|
||||
|
||||
class SasTokenProvider:
|
||||
def __init__(self, initial_token: SasToken, generator: SasTokenGenerator) -> None:
|
||||
"""Object responsible for providing a valid SasToken.
|
||||
Instantiate using a factory method instead of directly.
|
||||
|
||||
:param generator: A SasTokenGenerator to generate SasTokens with
|
||||
:type generator: SasTokenGenerator
|
||||
"""
|
||||
# NOTE: There is no good way to invoke a coroutine from within the __init__, and since
|
||||
# the the generator's .sign() method is a coroutine, that means we can't generate an
|
||||
# initial token from it here. Thus, we have to take the initial token as a separate
|
||||
# argument.
|
||||
# However, this is inconvenient, and also prevents us from fast-failing if there's a
|
||||
# problem with the generator_fn, so a factory coroutine method has been implemented.
|
||||
self._event_loop = asyncio.get_running_loop()
|
||||
self._generator = generator
|
||||
self._sastoken = initial_token
|
||||
self._token_update_margin = DEFAULT_TOKEN_UPDATE_MARGIN
|
||||
self._new_sastoken_available = asyncio.Condition()
|
||||
self._keep_token_fresh_task = asyncio.create_task(self._keep_token_fresh())
|
||||
|
||||
async def _keep_token_fresh(self):
|
||||
"""Runs indefinitely and will generate a SasToken when the current one gets close to
|
||||
expiration (based on the update margin)
|
||||
"""
|
||||
generate_time = self._sastoken.expiry_time - self._token_update_margin
|
||||
while True:
|
||||
await _wait_until(generate_time)
|
||||
try:
|
||||
logger.debug("Updating SAS Token...")
|
||||
self._sastoken = await self._generator.generate_sastoken()
|
||||
logger.debug("SAS Token update succeeded")
|
||||
generate_time = self._sastoken.expiry_time - self._token_update_margin
|
||||
async with self._new_sastoken_available:
|
||||
self._new_sastoken_available.notify_all()
|
||||
except Exception:
|
||||
logger.error("SAS Token renewal failed. Trying again in 10 seconds")
|
||||
generate_time = time.time() + 10
|
||||
|
||||
@classmethod
|
||||
async def create_from_generator(cls, generator: SasTokenGenerator) -> "SasTokenProvider":
|
||||
"""Create an instance of the SasTokenProvider that will rely on an external source
|
||||
to generate new tokens via a callback function/coroutine.
|
||||
|
||||
:param generator: A SasTokenGenerator to generate SasTokens with
|
||||
:type generator: SasTokenGenerator
|
||||
:raises: SasTokenError if an initial SasToken cannot be generated
|
||||
:raises: SasTokenError if the initial SasToken generated is invalid
|
||||
"""
|
||||
initial_token = await generator.generate_sastoken()
|
||||
if initial_token.expiry_time < time.time():
|
||||
raise SasTokenError("Newly generated SAS Token has already expired")
|
||||
return cls(initial_token, generator)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shut down the SasToken provider, and free any resources.
|
||||
No further updates to the current SAS Token will be made
|
||||
"""
|
||||
self._keep_token_fresh_task.cancel()
|
||||
# Wait for cancellation to complete
|
||||
await asyncio.gather(self._keep_token_fresh_task, return_exceptions=True)
|
||||
|
||||
def get_current_sastoken(self) -> SasToken:
|
||||
"""Return the current SasToken"""
|
||||
return self._sastoken
|
||||
|
||||
async def wait_for_new_sastoken(self) -> SasToken:
|
||||
"""Waits for a new SasToken to become available, and return it"""
|
||||
async with self._new_sastoken_available:
|
||||
await self._new_sastoken_available.wait()
|
||||
return self.get_current_sastoken()
|
||||
|
||||
|
||||
def _get_sastoken_info_from_string(sastoken_string: str) -> Dict[str, str]:
|
||||
"""Given a SAS Token string, return a dictionary of it's keys and values"""
|
||||
pieces = sastoken_string.split("SharedAccessSignature ")
|
||||
if len(pieces) != 2:
|
||||
raise SasTokenError("Invalid SasToken string: Not a SasToken ")
|
||||
raise ValueError("Invalid SAS Token string: Not a SAS Token ")
|
||||
|
||||
# Get sastoken info as dictionary
|
||||
try:
|
||||
# TODO: fix this typehint later, it needs some kind of cast
|
||||
sastoken_info = dict(map(str.strip, sub.split("=", 1)) for sub in pieces[1].split("&")) # type: ignore
|
||||
except Exception as e:
|
||||
raise SasTokenError("Invalid SasToken string: Incorrectly formatted") from e
|
||||
raise ValueError("Invalid SAS Token string: Incorrectly formatted") from e
|
||||
|
||||
# Validate that all required fields are present
|
||||
if not all(key in sastoken_info for key in REQUIRED_SASTOKEN_FIELDS):
|
||||
raise SasTokenError("Invalid SasToken string: Not all required fields present")
|
||||
raise ValueError("Invalid SAS Token string: Not all required fields present")
|
||||
|
||||
# Validate that no unexpected fields are present
|
||||
if not all(key in VALID_SASTOKEN_FIELDS for key in sastoken_info):
|
||||
raise SasTokenError("Invalid SasToken string: Unexpected fields present")
|
||||
# Warn if extraneous fields are present
|
||||
if not all(key in REQUIRED_SASTOKEN_FIELDS for key in sastoken_info):
|
||||
logger.warning("Unexpected fields present in SAS Token")
|
||||
|
||||
return sastoken_info
|
||||
|
||||
|
||||
# NOTE: Arguably, this doesn't really belong in this module, give it's lack of a specific
|
||||
# relationship to SAS Tokens, and the fact that it needs to be unit-tested separately.
|
||||
# These things suggest it should be more than just a convention-private helper, however
|
||||
# its hard to justify making a separate module just for this function.
|
||||
# This would be a candidate for some kind of misc utility module if other similar functions
|
||||
# pop up over the course of development. Until then, it lives here.
|
||||
async def _wait_until(when: float) -> None:
|
||||
"""Wait until a specific time has passed (accurate within 1 second).
|
||||
|
||||
:param float when: The time to wait for, in seconds, since epoch
|
||||
"""
|
||||
while time.time() < when:
|
||||
await asyncio.sleep(1)
|
||||
|
|
|
@ -8,17 +8,22 @@ import base64
|
|||
import binascii
|
||||
import hmac
|
||||
import hashlib
|
||||
from typing import Union
|
||||
from typing import AnyStr
|
||||
|
||||
# TODO: remove commented signatures
|
||||
|
||||
|
||||
class SigningMechanism(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def sign(self, data_str: Union[str, bytes]) -> str:
|
||||
async def sign(self, data_str: AnyStr) -> str:
|
||||
# NOTE: This is defined as a coroutine to allow for flexibility of implementation.
|
||||
# Some implementations may not require a coroutine, but others may, so we err on the side
|
||||
# of a coroutine for consistent interface.
|
||||
pass
|
||||
|
||||
|
||||
class SymmetricKeySigningMechanism(SigningMechanism):
|
||||
def __init__(self, key: Union[str, bytes]) -> None:
|
||||
def __init__(self, key: AnyStr) -> None:
|
||||
"""
|
||||
A mechanism that signs data using a symmetric key
|
||||
|
||||
|
@ -34,13 +39,12 @@ class SymmetricKeySigningMechanism(SigningMechanism):
|
|||
key_bytes = key
|
||||
|
||||
# Derives the signing key
|
||||
# CT-TODO: is "signing key" the right term?
|
||||
try:
|
||||
self._signing_key = base64.b64decode(key_bytes)
|
||||
except (binascii.Error):
|
||||
raise ValueError("Invalid Symmetric Key")
|
||||
|
||||
def sign(self, data_str: Union[str, bytes]) -> str:
|
||||
async def sign(self, data_str: AnyStr) -> str:
|
||||
"""
|
||||
Sign a data string with symmetric key and the HMAC-SHA256 algorithm.
|
||||
|
||||
|
@ -52,6 +56,9 @@ class SymmetricKeySigningMechanism(SigningMechanism):
|
|||
|
||||
:raises: ValueError if an invalid data string is provided
|
||||
"""
|
||||
# NOTE: This implementation doesn't take advantage of being a coroutine, but this is by
|
||||
# design. See the definition of the abstract base class above.
|
||||
|
||||
# Convert data_str to bytes (if not already)
|
||||
if isinstance(data_str, str):
|
||||
data_bytes = data_str.encode("utf-8")
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
"""This module is for creating agent strings for all clients"""
|
||||
|
||||
import platform
|
||||
from azure.iot.device.constant import VERSION, IOTHUB_IDENTIFIER, PROVISIONING_IDENTIFIER
|
||||
from .constant import VERSION, IOTHUB_IDENTIFIER, PROVISIONING_IDENTIFIER
|
||||
|
||||
python_runtime = platform.python_version()
|
||||
os_type = platform.system()
|
||||
|
|
Загрузка…
Ссылка в новой задаче