зеркало из
1
0
Форкнуть 0
* Restructured SAS-related objects and their responsibilities
* Removed references to "renewing" a SAS token due to bad semantics
* Keeping a valid SAS token is no longer the domain of the IoTHubMQTTClient
* Made signing data with a SigningMechanism asynchronous (in signature) - actual async implementation will come later.
This commit is contained in:
Carter Tinney 2023-02-19 16:34:07 -08:00 коммит произвёл GitHub
Родитель bb4cb90295
Коммит a9c20135dc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
15 изменённых файлов: 1450 добавлений и 651 удалений

Просмотреть файл

@ -10,7 +10,9 @@ repos:
- id: flake8
args: ['--config=.flake8']
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.991
rev: v1.0.0
# NOTE: Azure SDK guidelines say to use 0.931 as a pin, but it seems bugged.
# Getting some clarity on this, but for now, we're going to just use the most recent
hooks:
- id: mypy
files: v3_async_wip/ # for now only the new v3 files have typings

Просмотреть файл

@ -19,8 +19,11 @@ class HangingAsyncMock(mock.AsyncMock):
self._stop_hanging = asyncio.Event()
async def _do_hang(self, *args, **kwargs):
self._is_hanging.set()
if not self._is_hanging.is_set():
self._stop_hanging.clear()
self._is_hanging.set()
await self._stop_hanging.wait()
return self.return_value
async def wait_for_hang(self):
await self._is_hanging.wait()
@ -29,4 +32,8 @@ class HangingAsyncMock(mock.AsyncMock):
return self._is_hanging.is_set()
def stop_hanging(self):
self._stop_hanging.set()
if self._is_hanging.is_set():
self._stop_hanging.set()
self._is_hanging.clear()
else:
raise RuntimeError("Not hanging")

2
mypy.ini Normal file
Просмотреть файл

@ -0,0 +1,2 @@
[mypy]
show_error_codes = True

Просмотреть файл

@ -14,5 +14,5 @@ def arbitrary_exception():
class ArbitraryException(Exception):
pass
e = ArbitraryException()
e = ArbitraryException("arbitrary description")
return e

Просмотреть файл

@ -109,14 +109,14 @@ class TestIoTEdgeHsmGetCertificate(object):
return mocker.patch.object(requests, "get")
@pytest.mark.it("Sends an HTTP GET request to retrieve the trust bundle from Edge")
def test_requests_trust_bundle(self, mocker, edge_hsm, mock_requests_get):
async def test_requests_trust_bundle(self, mocker, edge_hsm, mock_requests_get):
expected_url = edge_hsm.workload_uri + "trust-bundle"
expected_params = {"api-version": edge_hsm.api_version}
expected_headers = {
"User-Agent": urllib.parse.quote_plus(user_agent.get_iothub_user_agent())
}
edge_hsm.get_certificate()
await edge_hsm.get_certificate()
assert mock_requests_get.call_count == 1
assert mock_requests_get.call_args == mocker.call(
@ -124,43 +124,43 @@ class TestIoTEdgeHsmGetCertificate(object):
)
@pytest.mark.it("Returns the certificate from the trust bundle received from Edge")
def test_returns_certificate(self, edge_hsm, mock_requests_get):
async def test_returns_certificate(self, edge_hsm, mock_requests_get):
mock_response = mock_requests_get.return_value
certificate = "my certificate"
mock_response.json.return_value = {"certificate": certificate}
returned_cert = edge_hsm.get_certificate()
returned_cert = await edge_hsm.get_certificate()
assert returned_cert is certificate
@pytest.mark.it("Raises IoTEdgeError if a bad request is made to Edge")
def test_bad_request(self, edge_hsm, mock_requests_get):
async def test_bad_request(self, edge_hsm, mock_requests_get):
mock_response = mock_requests_get.return_value
error = requests.exceptions.HTTPError()
mock_response.raise_for_status.side_effect = error
with pytest.raises(IoTEdgeError) as e_info:
edge_hsm.get_certificate()
await edge_hsm.get_certificate()
assert e_info.value.__cause__ is error
@pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the trust bundle")
def test_bad_json(self, edge_hsm, mock_requests_get):
async def test_bad_json(self, edge_hsm, mock_requests_get):
mock_response = mock_requests_get.return_value
error = ValueError()
mock_response.json.side_effect = error
with pytest.raises(IoTEdgeError) as e_info:
edge_hsm.get_certificate()
await edge_hsm.get_certificate()
assert e_info.value.__cause__ is error
@pytest.mark.it("Raises IoTEdgeError if the certificate is missing from the trust bundle")
def test_bad_trust_bundle(self, edge_hsm, mock_requests_get):
async def test_bad_trust_bundle(self, edge_hsm, mock_requests_get):
mock_response = mock_requests_get.return_value
# Return an empty json dict with no 'certificate' key
mock_response.json.return_value = {}
with pytest.raises(IoTEdgeError):
edge_hsm.get_certificate()
await edge_hsm.get_certificate()
@pytest.mark.describe("IoTEdgeHsm - .sign()")
@ -172,7 +172,7 @@ class TestIoTEdgeHsmSign(object):
@pytest.mark.it(
"Makes an HTTP request to Edge to sign a piece of string data using the HMAC-SHA256 algorithm"
)
def test_requests_data_signing(self, mocker, edge_hsm, mock_requests_post):
async def test_requests_data_signing(self, mocker, edge_hsm, mock_requests_post):
data_str = "somedata"
data_str_b64 = "c29tZWRhdGE="
mock_requests_post.return_value.json.return_value = {"digest": "somedigest"}
@ -187,7 +187,7 @@ class TestIoTEdgeHsmSign(object):
}
expected_json = json.dumps({"keyId": "primary", "algo": "HMACSHA256", "data": data_str_b64})
edge_hsm.sign(data_str)
await edge_hsm.sign(data_str)
assert mock_requests_post.call_count == 1
assert mock_requests_post.call_args == mocker.call(
@ -195,14 +195,14 @@ class TestIoTEdgeHsmSign(object):
)
@pytest.mark.it("Base64 encodes the string data in the request")
def test_b64_encodes_data(self, edge_hsm, mock_requests_post):
async def test_b64_encodes_data(self, edge_hsm, mock_requests_post):
# This test is actually implicitly tested in the first test, but it's
# important to have an explicit test for it since it's a requirement
data_str = "somedata"
data_str_b64 = base64.b64encode(data_str.encode("utf-8")).decode()
mock_requests_post.return_value.json.return_value = {"digest": "somedigest"}
edge_hsm.sign(data_str)
await edge_hsm.sign(data_str)
sent_data = json.loads(mock_requests_post.call_args[1]["data"])["data"]
@ -210,11 +210,11 @@ class TestIoTEdgeHsmSign(object):
assert sent_data == data_str_b64
@pytest.mark.it("Returns the signed data received from Edge")
def test_returns_signed_data(self, edge_hsm, mock_requests_post):
async def test_returns_signed_data(self, edge_hsm, mock_requests_post):
expected_digest = "somedigest"
mock_requests_post.return_value.json.return_value = {"digest": expected_digest}
signed_data = edge_hsm.sign("somedata")
signed_data = await edge_hsm.sign("somedata")
assert signed_data == expected_digest
@ -226,38 +226,38 @@ class TestIoTEdgeHsmSign(object):
pytest.param(b"sign this message", "c2lnbiB0aGlzIG1lc3NhZ2U=", id="Bytes"),
],
)
def test_supported_types(
async def test_supported_types(
self, edge_hsm, data_string, expected_request_data, mock_requests_post
):
mock_requests_post.return_value.json.return_value = {"digest": "somedigest"}
edge_hsm.sign(data_string)
await edge_hsm.sign(data_string)
sent_data = json.loads(mock_requests_post.call_args[1]["data"])["data"]
assert sent_data == expected_request_data
@pytest.mark.it("Raises IoTEdgeError if a bad request is made to EdgeHub")
def test_bad_request(self, edge_hsm, mock_requests_post):
async def test_bad_request(self, edge_hsm, mock_requests_post):
mock_response = mock_requests_post.return_value
error = requests.exceptions.HTTPError()
mock_response.raise_for_status.side_effect = error
with pytest.raises(IoTEdgeError) as e_info:
edge_hsm.sign("somedata")
await edge_hsm.sign("somedata")
assert e_info.value.__cause__ is error
@pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the signed response")
def test_bad_json(self, edge_hsm, mock_requests_post):
async def test_bad_json(self, edge_hsm, mock_requests_post):
mock_response = mock_requests_post.return_value
error = ValueError()
mock_response.json.side_effect = error
with pytest.raises(IoTEdgeError) as e_info:
edge_hsm.sign("somedata")
await edge_hsm.sign("somedata")
assert e_info.value.__cause__ is error
@pytest.mark.it("Raises IoTEdgeError if the signed data is missing from the response")
def test_bad_response(self, edge_hsm, mock_requests_post):
async def test_bad_response(self, edge_hsm, mock_requests_post):
mock_response = mock_requests_post.return_value
mock_response.json.return_value = {}
with pytest.raises(IoTEdgeError):
edge_hsm.sign("somedata")
await edge_hsm.sign("somedata")

Просмотреть файл

@ -8,6 +8,7 @@ import asyncio
import json
import pytest
import ssl
import sys
import time
import typing
import urllib
@ -22,12 +23,13 @@ from v3_async_wip import config, constant, models, user_agent
from v3_async_wip import mqtt_client as mqtt
from v3_async_wip import request_response as rr
from v3_async_wip import mqtt_topic_iothub as mqtt_topic
from azure.iot.device.common.auth import sastoken as st
from azure.iot.device.common import alarm
from v3_async_wip import sastoken as st
FAKE_DEVICE_ID = "fake_device_id"
FAKE_MODULE_ID = "fake_module_id"
FAKE_DEVICE_CLIENT_ID = "fake_device_id"
FAKE_MODULE_CLIENT_ID = "fake_device_id/fake_module_id"
FAKE_HOSTNAME = "fake.hostname"
FAKE_GATEWAY_HOSTNAME = "fake.gateway.hostname"
FAKE_SIGNATURE = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI="
@ -71,21 +73,24 @@ mqtt_unsubscribe_exceptions = [
@pytest.fixture
def renewable_sastoken(mocker):
mock_signing_mechanism = mocker.MagicMock()
mock_signing_mechanism.sign.return_value = FAKE_SIGNATURE
sastoken = st.RenewableSasToken(uri=FAKE_URI, signing_mechanism=mock_signing_mechanism)
# sastoken.refresh = mocker.MagicMock()
return sastoken
def sastoken():
sastoken_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format(
resource=FAKE_URI, signature=FAKE_SIGNATURE, expiry=FAKE_EXPIRY
)
return st.SasToken(sastoken_str)
@pytest.fixture
def nonrenewable_sastoken():
token_str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format(
resource=FAKE_URI, signature=FAKE_SIGNATURE, expiry=FAKE_EXPIRY
)
sastoken = st.NonRenewableSasToken(token_str)
return sastoken
def mock_sastoken_provider(mocker, sastoken):
provider = mocker.MagicMock(spec=st.SasTokenProvider)
provider.get_current_sastoken.return_value = sastoken
# Use a HangingAsyncMock so that it isn't constantly returning
provider.wait_for_new_sastoken = custom_mock.HangingAsyncMock()
provider.wait_for_new_sastoken.return_value = sastoken
# NOTE: Technically, this mock just always returns the same SasToken,
# even after an "update", but for the purposes of testing at this level,
# it doesn't matter
return provider
@pytest.fixture
@ -109,6 +114,8 @@ async def client(mocker, client_config):
client._mqtt_client.subscribe = mocker.AsyncMock()
client._mqtt_client.unsubscribe = mocker.AsyncMock()
client._mqtt_client.publish = mocker.AsyncMock()
# Also mock the set credentials method since we test that
client._mqtt_client.set_credentials = mocker.MagicMock()
yield client
await client.shutdown()
@ -130,7 +137,7 @@ class TestIoTHubMQTTClientInstantiation:
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
],
)
async def test_ids(self, client_config, device_id, module_id):
async def test_simple_ids(self, client_config, device_id, module_id):
client_config.device_id = device_id
client_config.module_id = module_id
@ -139,69 +146,141 @@ class TestIoTHubMQTTClientInstantiation:
assert client._module_id == client_config.module_id
await client.shutdown()
@pytest.mark.it("Stores the SasToken value from the IoTHubClientConfig as an attribute")
@pytest.mark.parametrize(
"sastoken",
[
pytest.param(lazy_fixture("renewable_sastoken"), id="Renewable SAS Authentication"),
pytest.param(
lazy_fixture("nonrenewable_sastoken"), id="Non-Renewable SAS Authentication"
),
pytest.param(None, id="Non-SAS Authentication"),
],
)
async def test_sastoken(self, client_config, sastoken):
client_config.sastoken = sastoken
client = IoTHubMQTTClient(client_config)
assert client._sastoken is client_config.sastoken
await client.shutdown()
# NOTE: For testing the functionality of this update Alarm, see the corresponding test suite
@pytest.mark.it("Starts the SasToken update Alarm if the IoTHubClientConfig has a SasToken")
@pytest.mark.parametrize(
"sastoken",
[
pytest.param(lazy_fixture("renewable_sastoken"), id="Renewable SAS Authentication"),
pytest.param(
lazy_fixture("nonrenewable_sastoken"), id="Non-Renewable SAS Authentication"
),
],
)
async def test_sastoken_update_task(self, client_config, sastoken):
client_config.sastoken = sastoken
client = IoTHubMQTTClient(client_config)
assert isinstance(client._sastoken_update_alarm, alarm.Alarm)
assert client._sastoken_update_alarm.is_alive()
assert client._sastoken_update_alarm.daemon is True
await client.shutdown()
@pytest.mark.it("Does not start an Alarm if the IoTHubClientConfig does not have a SasToken")
async def test_no_sastoken_update_task(self, client_config):
assert client_config.sastoken is None
client = IoTHubMQTTClient(client_config)
assert client._sastoken_update_alarm is None
await client.shutdown()
@pytest.mark.it("Creates an empty RequestLedger")
async def test_request_ledger(self, client_config):
client = IoTHubMQTTClient(client_config)
assert isinstance(client._request_ledger, rr.RequestLedger)
assert len(client._request_ledger) == 0
await client.shutdown()
@pytest.mark.it(
"Creates an MQTTClient instance based on the configuration of IoTHubClientConfig"
"Derives the `client_id` from the `device_id` and `module_id` and stores it as an attribute"
)
@pytest.mark.parametrize(
"device_id, module_id, expected_client_id",
[
pytest.param(FAKE_DEVICE_ID, None, FAKE_DEVICE_ID, id="Device Configuration"),
pytest.param(FAKE_DEVICE_ID, None, FAKE_DEVICE_CLIENT_ID, id="Device Configuration"),
pytest.param(
FAKE_DEVICE_ID,
FAKE_MODULE_ID,
"{}/{}".format(FAKE_DEVICE_ID, FAKE_MODULE_ID),
id="Module Configuration",
FAKE_DEVICE_ID, FAKE_MODULE_ID, FAKE_MODULE_CLIENT_ID, id="Module Configuration"
),
],
)
async def test_client_id(self, client_config, device_id, module_id, expected_client_id):
client_config.device_id = device_id
client_config.module_id = module_id
client = IoTHubMQTTClient(client_config)
assert client._client_id == expected_client_id
await client.shutdown()
@pytest.mark.it("Derives the `username` and stores the result as an attribute")
@pytest.mark.parametrize(
"device_id, module_id, client_id",
[
pytest.param(FAKE_DEVICE_ID, None, FAKE_DEVICE_CLIENT_ID, id="Device Configuration"),
pytest.param(
FAKE_DEVICE_ID, FAKE_MODULE_ID, FAKE_MODULE_CLIENT_ID, id="Module Configuration"
),
],
)
@pytest.mark.parametrize(
"hostname, gateway_hostname",
[
pytest.param(FAKE_HOSTNAME, None, id="No Gateway Hostname"),
pytest.param(FAKE_HOSTNAME, FAKE_GATEWAY_HOSTNAME, id="Gateway Hostname"),
],
)
@pytest.mark.parametrize(
"product_info",
[
pytest.param("", id="No Product Info"),
pytest.param("my-product-info", id="Custom Product Info"),
pytest.param("my$product$info", id="Custom Product Info (URL encoding required)"),
pytest.param(
constant.DIGITAL_TWIN_PREFIX + ":com:example:ClimateSensor;1",
id="Digital Twin Product Info",
),
pytest.param(
constant.DIGITAL_TWIN_PREFIX + ":com:example:$Climate$ensor;1",
id="Digital Twin Product Info (URL encoding required)",
),
],
)
async def test_username(
self,
client_config,
device_id,
module_id,
client_id,
hostname,
gateway_hostname,
product_info,
):
client_config.device_id = device_id
client_config.module_id = module_id
client_config.hostname = hostname
client_config.gateway_hostname = gateway_hostname
client_config.product_info = product_info
ua = user_agent.get_iothub_user_agent()
url_encoded_user_agent = urllib.parse.quote(ua, safe="")
# NOTE: This assertion shows the URL encoding was meaningful
assert user_agent != url_encoded_user_agent
url_encoded_product_info = urllib.parse.quote(product_info, safe="")
# NOTE: We can't really make the same assertion here, because this isn't always meaningful
# Determine expected username based on config
if product_info.startswith(constant.DIGITAL_TWIN_PREFIX):
expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}&{digital_twin_prefix}={custom_product_info}".format(
hostname=hostname,
client_id=client_id,
api_version=constant.IOTHUB_API_VERSION,
user_agent=url_encoded_user_agent,
digital_twin_prefix=constant.DIGITAL_TWIN_QUERY_HEADER,
custom_product_info=url_encoded_product_info,
)
else:
expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}{custom_product_info}".format(
hostname=hostname,
client_id=client_id,
api_version=constant.IOTHUB_API_VERSION,
user_agent=url_encoded_user_agent,
custom_product_info=url_encoded_product_info,
)
# NOTE: Regarding the above, no matter if we have a gateway hostname set or not, it is the hostname that is always used.
client = IoTHubMQTTClient(client_config)
# The expected username was derived
assert client._username == expected_username
await client.shutdown()
@pytest.mark.it("Stores the `sastoken_provider` from the IoTHubClientConfig as an attribute")
@pytest.mark.parametrize(
"sastoken_provider",
[
pytest.param(lazy_fixture("mock_sastoken_provider"), id="SasTokenProvider present"),
pytest.param(None, id="No SasTokenProvider present"),
],
)
@pytest.mark.parametrize(
"device_id, module_id",
[
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
],
)
async def test_sastoken_provider(self, client_config, sastoken_provider, device_id, module_id):
client_config.device_id = device_id
client_config.module_id = module_id
client_config.sastoken_provider = sastoken_provider
client = IoTHubMQTTClient(client_config)
assert client._sastoken_provider is sastoken_provider
await client.shutdown()
@pytest.mark.it(
"Creates an MQTTClient instance based on the configuration of IoTHubClientConfig and stores it as an attribute"
)
@pytest.mark.parametrize(
"device_id, module_id, expected_client_id",
[
pytest.param(FAKE_DEVICE_ID, None, FAKE_DEVICE_CLIENT_ID, id="Device Configuration"),
pytest.param(
FAKE_DEVICE_ID, FAKE_MODULE_ID, FAKE_MODULE_CLIENT_ID, id="Module Configuration"
),
],
)
@ -268,108 +347,58 @@ class TestIoTHubMQTTClientInstantiation:
# Graceful exit
await client.shutdown()
@pytest.mark.it("Sets credentials on the newly created MQTTClient instance")
@pytest.mark.parametrize(
"device_id, module_id, expected_client_id",
[
pytest.param(FAKE_DEVICE_ID, None, FAKE_DEVICE_ID, id="Device Configuration"),
pytest.param(
FAKE_DEVICE_ID,
FAKE_MODULE_ID,
FAKE_DEVICE_ID + "/" + FAKE_MODULE_ID,
id="Module Configuration",
),
],
@pytest.mark.it(
"Uses the derived `username` as the username, with no password as the credentials for the newly created MQTTClient instance when not using SAS authentication"
)
@pytest.mark.parametrize(
"hostname, gateway_hostname, expected_hostname",
"device_id, module_id",
[
pytest.param(FAKE_HOSTNAME, None, FAKE_HOSTNAME, id="No Gateway Hostname"),
pytest.param(FAKE_HOSTNAME, FAKE_GATEWAY_HOSTNAME, FAKE_HOSTNAME, id="Gateway Hostname")
# NOTE: Yes, that's right, we expect to always use the hostname, never the gateway hostname
# at least, when it comes to credentials
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
],
)
@pytest.mark.parametrize(
"sastoken",
[
pytest.param(lazy_fixture("renewable_sastoken"), id="Renewable SAS Authentication"),
pytest.param(
lazy_fixture("nonrenewable_sastoken"), id="Non-Renewable SAS Authentication"
),
pytest.param(None, id="Non-SAS Authentication"),
],
)
@pytest.mark.parametrize(
"product_info",
[
pytest.param("", id="No Product Info"),
pytest.param("my-product-info", id="Custom Product Info"),
pytest.param("my$product$info", id="Custom Product Info (URL encoding required)"),
pytest.param(
constant.DIGITAL_TWIN_PREFIX + ":com:example:ClimateSensor;1",
id="Digital Twin Product Info",
),
pytest.param(
constant.DIGITAL_TWIN_PREFIX + ":com:example:$Climate$ensor;1",
id="Digital Twin Product Info (URL encoding required)",
),
],
)
async def test_mqtt_client_credentials(
self,
mocker,
client_config,
sastoken,
device_id,
module_id,
expected_client_id,
hostname,
gateway_hostname,
expected_hostname,
product_info,
async def test_mqtt_client_credentials_no_sas(
self, mocker, client_config, device_id, module_id
):
client_config.device_id = device_id
client_config.module_id = module_id
client_config.hostname = hostname
client_config.gateway_hostname = gateway_hostname
client_config.product_info = product_info
client_config.sastoken = sastoken
assert client_config.sastoken_provider is None
# Determine expected username based on config
if product_info.startswith(constant.DIGITAL_TWIN_PREFIX):
expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}&{digital_twin_prefix}={custom_product_info}".format(
hostname=hostname,
client_id=expected_client_id,
api_version=constant.IOTHUB_API_VERSION,
user_agent=urllib.parse.quote(user_agent.get_iothub_user_agent(), safe=""),
digital_twin_prefix=constant.DIGITAL_TWIN_QUERY_HEADER,
custom_product_info=urllib.parse.quote(product_info, safe=""),
)
else:
expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}{custom_product_info}".format(
hostname=hostname,
client_id=expected_client_id,
api_version=constant.IOTHUB_API_VERSION,
user_agent=urllib.parse.quote(user_agent.get_iothub_user_agent(), safe=""),
custom_product_info=urllib.parse.quote(product_info, safe=""),
)
# Determine expected password based on sastoken
if sastoken:
expected_password = str(renewable_sastoken)
else:
expected_password = None
# Create the client under test
mocker.patch.object(mqtt, "MQTTClient", spec=mqtt.MQTTClient)
client = IoTHubMQTTClient(client_config)
# Credentials were set
expected_username = client._username
expected_password = None
assert client._mqtt_client.set_credentials.call_count == 1
assert client._mqtt_client.set_credentials.call_args(expected_username, expected_password)
await client.shutdown()
@pytest.mark.it(
"Uses the derived `username` as the username and the string-converted current SasToken from the SasTokenProvider as the password when setting credentials for the newly created MQTTClient instance when using SAS authentication"
)
@pytest.mark.parametrize(
"device_id, module_id",
[
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
],
)
async def test_mqtt_client_credentials_with_sas(
self, mocker, client_config, device_id, module_id, mock_sastoken_provider
):
client_config.device_id = device_id
client_config.module_id = module_id
client_config.sastoken_provider = mock_sastoken_provider
fake_sastoken = mock_sastoken_provider.get_current_sastoken.return_value
mocker.patch.object(mqtt, "MQTTClient", spec=mqtt.MQTTClient)
client = IoTHubMQTTClient(client_config)
expected_username = client._username
expected_password = str(fake_sastoken)
assert client._mqtt_client.set_credentials.call_count == 1
assert client._mqtt_client.set_credentials.call_args(expected_username, expected_password)
# Graceful exit
await client.shutdown()
@pytest.mark.it("Adds incoming message filter on the MQTTClient for C2D messages")
@ -580,30 +609,129 @@ class TestIoTHubMQTTClientInstantiation:
await client.shutdown()
# NOTE: For testing the functionality of this task, see the corresponding test suite (TestIoTHubMQTTClientIncomingTwinResponse)
# TODO: Consider removing this test. Does it really test anything? A Task was created? Who cares?
@pytest.mark.it("Creates a ongoing task to listen for twin responses")
async def test_twin_response_task(self, client_config):
@pytest.mark.it("Creates an empty RequestLedger")
@pytest.mark.parametrize(
"device_id, module_id",
[
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
],
)
async def test_request_ledger(self, client_config, device_id, module_id):
client_config.device_id = device_id
client_config.module_id = module_id
client = IoTHubMQTTClient(client_config)
assert isinstance(client._twin_response_listener, asyncio.Task)
assert not client._twin_response_listener.done()
assert isinstance(client._request_ledger, rr.RequestLedger)
assert len(client._request_ledger) == 0
await client.shutdown()
@pytest.mark.it("Sets the twin_responses_enabled flag to False")
async def test_twin_responses_enabled(self, client_config):
@pytest.mark.parametrize(
"device_id, module_id",
[
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
],
)
async def test_twin_responses_enabled(self, client_config, device_id, module_id):
client_config.device_id = device_id
client_config.module_id = module_id
client = IoTHubMQTTClient(client_config)
assert client._twin_responses_enabled is False
await client.shutdown()
# NOTE: For testing the functionality of this task, see the corresponding test suite (TestIoTHubMQTTClientIncomingTwinResponse)
@pytest.mark.it(
"Begins running the ._process_twin_responses() coroutine method as a background task, storing it as an attribute"
)
@pytest.mark.parametrize(
"device_id, module_id",
[
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
],
)
async def test_process_twin_responses_bg_task(self, client_config, device_id, module_id):
client_config.device_id = device_id
client_config.module_id = module_id
client = IoTHubMQTTClient(client_config)
assert isinstance(client._process_twin_responses_bg_task, asyncio.Task)
assert not client._process_twin_responses_bg_task.done()
if sys.version_info > (3, 8):
# NOTE: There isn't a way to validate the contents of a task until 3.8
# as far as I can tell.
task_coro = client._process_twin_responses_bg_task.get_coro()
assert task_coro.__qualname__ == "IoTHubMQTTClient._process_twin_responses"
await client.shutdown()
# NOTE: For testing the functionality of this task, see the corresponding test suite (???)
# TODO: add this test suite
@pytest.mark.it(
"Begins running the ._keep_credentials_fresh() coroutine method as a background task, storing it as an attribute, if using SAS authentication"
)
@pytest.mark.parametrize(
"device_id, module_id",
[
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
],
)
async def test_keep_credentials_fresh_bg_task(
self, client_config, device_id, module_id, mock_sastoken_provider
):
client_config.device_id = device_id
client_config.module_id = module_id
client_config.sastoken_provider = mock_sastoken_provider
client = IoTHubMQTTClient(client_config)
assert isinstance(client._keep_credentials_fresh_bg_task, asyncio.Task)
assert not client._keep_credentials_fresh_bg_task.done()
if sys.version_info > (3, 8):
# NOTE: There isn't a way to validate the contents of a task until 3.8
# as far as I can tell.
task_coro = client._keep_credentials_fresh_bg_task.get_coro()
assert task_coro.__qualname__ == "IoTHubMQTTClient._keep_credentials_fresh"
await client.shutdown()
@pytest.mark.it(
"Does not begin running the ._keep_credentials_fresh() coroutine method as a background task if not using SAS authentication"
)
@pytest.mark.parametrize(
"device_id, module_id",
[
pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"),
pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"),
],
)
async def test_keep_credentials_fresh_bg_task_no_sas(self, client_config, device_id, module_id):
client_config.device_id = device_id
client_config.module_id = module_id
assert client_config.sastoken_provider is None
client = IoTHubMQTTClient(client_config)
assert client._keep_credentials_fresh_bg_task is None
await client.shutdown()
# TODO: exceptions
@pytest.mark.describe("IoTHubMQTTClient - .shutdown()")
class TestIotHubMQTTClientShutdown:
class TestIoTHubMQTTClientShutdown:
@pytest.fixture(autouse=True)
def modify_client_config(self, client_config, mock_sastoken_provider):
# NOTE: This has to be changed on the config, not the client,
# because it affects client initialization
client_config.sastoken_provider = mock_sastoken_provider
@pytest.mark.it("Disconnects the MQTTClient")
async def test_disconnect(self, mocker, client):
# NOTE: rather than mocking the MQTTClient, we just mock the .disconnect() method of the
# IoTHubMQTTClient instead, since it's been fully tested elsewhere, and we assume
# correctness, lest we have to repeat all .disconnect() tests here.
client.disconnect = mocker.AsyncMock()
assert client.disconnect.await_count == 0
@ -612,34 +740,135 @@ class TestIotHubMQTTClientShutdown:
assert client.disconnect.await_count == 1
assert client.disconnect.await_args == mocker.call()
@pytest.mark.it("Cancels the SasToken update Alarm, if it exists")
async def test_sastoken_alarm_exists(self, mocker, client):
assert client._sastoken_update_alarm is None
mock_alarm = client._sastoken_update_alarm = mocker.MagicMock()
@pytest.mark.it("Cancels the 'process_twin_responses' background task")
async def test_process_twin_responses_bg_task(self, client):
assert isinstance(client._process_twin_responses_bg_task, asyncio.Task)
assert not client._process_twin_responses_bg_task.done()
await client.shutdown()
assert mock_alarm.cancel.call_count == 1
assert mock_alarm.cancel.call_args == mocker.call()
assert client._process_twin_responses_bg_task.done()
assert client._process_twin_responses_bg_task.cancelled()
@pytest.mark.it("Handles the case where the SasToken update Alarm does not exist")
async def test_sastoken_alarm_no_exist(self, client):
assert client._sastoken_update_alarm is None
@pytest.mark.it("Cancels the 'keep_credentials_fresh' background task, if it exists")
async def test_keep_credentials_fresh_bg_task_exists(self, client):
assert isinstance(client._keep_credentials_fresh_bg_task, asyncio.Task)
assert not client._keep_credentials_fresh_bg_task.done()
await client.shutdown()
# No AttributeError raised means success
# TODO: Probably need to show that this is truly the twin response task by demonstrating that twin responses are no longer received after shutdown
@pytest.mark.it("Cancels the twin response listener task")
async def test_twin_response_listener(self, client):
assert isinstance(client._twin_response_listener, asyncio.Task)
assert not client._twin_response_listener.done()
assert client._keep_credentials_fresh_bg_task.done()
assert client._keep_credentials_fresh_bg_task.cancelled()
@pytest.mark.it("Handles the case where no 'keep_credentials_fresh' background task exists")
async def test_keep_credentials_fresh_bg_task_no_exist(self, client, client_config):
# NOTE: in this test we don't want to have the SAS bg task, so override the modified fixture
client_config.sastoken_provider = None
client = IoTHubMQTTClient(client_config)
assert client._keep_credentials_fresh_bg_task is None
await client.shutdown()
# No AttributeError means success!
assert client._twin_response_listener.done()
assert client._twin_response_listener.cancelled()
@pytest.mark.it(
"Allows any exception raised during MQTTClient disconnect to propagate, but only after cancelling background tasks"
)
@pytest.mark.parametrize("exception", mqtt_disconnect_exceptions)
async def test_disconnect_raises(self, mocker, client, exception):
# NOTE: rather than mocking the MQTTClient, we just mock the .disconnect() method of the
# IoTHubMQTTClient instead, since it's been fully tested elsewhere, and we assume
# correctness, lest we have to repeat all .disconnect() tests here.
original_disconnect = client.disconnect
client.disconnect = mocker.AsyncMock(side_effect=exception)
client.disconnect.side_effect = exception
assert not client._keep_credentials_fresh_bg_task.done()
assert not client._process_twin_responses_bg_task.done()
with pytest.raises(type(exception)) as e_info:
await client.shutdown()
assert e_info.value is exception
# Background tasks were also cancelled despite the exception
assert client._keep_credentials_fresh_bg_task.done()
assert client._keep_credentials_fresh_bg_task.cancelled()
assert client._process_twin_responses_bg_task.done()
assert client._process_twin_responses_bg_task.cancelled()
# Unset the mock so that tests can clean up
client.disconnect = original_disconnect
@pytest.mark.it(
"Can be cancelled while waiting for the MQTTClient disconnect to finish, but it won't stop background task cancellation"
)
async def test_cancel_disconnect(self, client):
# NOTE: rather than mocking the MQTTClient, we just mock the .disconnect() method of the
# IoTHubMQTTClient instead, since it's been fully tested elsewhere, and we assume
# correctness, lest we have to repeat all .disconnect() tests here.
original_disconnect = client.disconnect
client.disconnect = custom_mock.HangingAsyncMock()
t = asyncio.create_task(client.shutdown())
# Hanging, waiting for disconnect to finish
await client.disconnect.wait_for_hang()
assert not t.done()
# Background tasks have not been cancelled
assert not client._keep_credentials_fresh_bg_task.done()
assert not client._process_twin_responses_bg_task.done()
# Cancel
t.cancel()
with pytest.raises(asyncio.CancelledError):
await t
# Unset the mock so that tests can clean up.
# And do it now so that test assertion failure doesn't hang
client.disconnect = original_disconnect
# And yet the background tasks still were cancelled anyway
assert client._keep_credentials_fresh_bg_task.done()
assert client._keep_credentials_fresh_bg_task.cancelled()
assert client._process_twin_responses_bg_task.done()
assert client._process_twin_responses_bg_task.cancelled()
@pytest.mark.it(
"Can be cancelled while waiting for the background tasks to finish cancellation, but it won't stop the background task cancellation"
)
async def test_cancel_gather(self, mocker, client):
original_gather = asyncio.gather
asyncio.gather = custom_mock.HangingAsyncMock()
spy_twin_response_bg_task_cancel = mocker.spy(
client._process_twin_responses_bg_task, "cancel"
)
spy_credentials_bg_task_cancel = mocker.spy(
client._keep_credentials_fresh_bg_task, "cancel"
)
t = asyncio.create_task(client.shutdown())
# Hanging waiting for gather to return (indicating tasks are all done cancellation)
await asyncio.gather.wait_for_hang()
assert not t.done()
# Background tests may or may not have completed cancellation yet, hard to test accurately.
# But their cancellation HAS been requested.
assert spy_twin_response_bg_task_cancel.call_count == 1
assert spy_credentials_bg_task_cancel.call_count == 1
# Cancel
t.cancel()
with pytest.raises(asyncio.CancelledError):
await t
# Unset the mock so that tests can clean up.
# And do it now so that test assertion failure doesn't hang
asyncio.gather = original_gather
# Tasks will be cancelled very soon (if they aren't already)
await asyncio.sleep(0.1)
assert client._keep_credentials_fresh_bg_task.done()
assert client._keep_credentials_fresh_bg_task.cancelled()
assert client._process_twin_responses_bg_task.done()
assert client._process_twin_responses_bg_task.cancelled()
@pytest.mark.describe("IoTHubMQTTClient - .connect()")
@ -2387,6 +2616,9 @@ class TestIoTHubMQTTClientIncomingTwinPatches:
assert patch == expected_json
# TODO: To reflect the complexity of background tasks, these tests need to be adjusted
# to be about the background task itself, not a reaction to an event. This is probably the
# end of any "OCCURRENCE" tests in the client layer
@pytest.mark.describe("IoTHubMQTTClient - OCCURRENCE: Twin Response Received")
class TestIoTHubMQTTClientIncomingTwinResponse:
# NOTE: This test suite exists for simplicity - twin responses are used in both
@ -2455,3 +2687,58 @@ class TestIoTHubMQTTClientIncomingTwinResponse:
resp1 = spy_response_factory.spy_return
assert mock_ledger.match_response.call_count == 1
assert mock_ledger.match_response.call_args == mocker.call(resp1)
# TODO: To reflect the complexity of background tasks, these tests need to be adjusted
# to be about the background task itself, not a reaction to an event. This is probably the
# end of any "OCCURRENCE" tests in the client layer
@pytest.mark.describe("IoTHubMQTTClient - OCCURRENCE: SasTokenProvider Updates SasToken")
class TestIoTHubMQTTClientSasTokenUpdate:
@pytest.fixture(autouse=True)
def modify_client_config(self, client_config, mock_sastoken_provider):
# Need to use a client with SAS token auth
# NOTE: This has to be changed on the config, not the client,
# because it affects client initialization
client_config.sastoken_provider = mock_sastoken_provider
@pytest.mark.it(
"Updates the MQTTClient's credentials, using the stored username as the username, and the string-converted new SasToken as the password"
)
async def test_updates_credentials(self, mocker, client):
# Client is waiting on a new SasToken
assert client._sastoken_provider.wait_for_new_sastoken.await_count == 1
assert client._sastoken_provider.wait_for_new_sastoken.is_hanging()
assert client._mqtt_client.set_credentials.call_count == 0
# Trigger new SasToken arrival
client._sastoken_provider.wait_for_new_sastoken.stop_hanging()
new_sastoken = client._sastoken_provider.wait_for_new_sastoken.return_value
assert isinstance(new_sastoken, st.SasToken)
await asyncio.sleep(0.1)
# Credentials are updated
assert client._mqtt_client.set_credentials.call_count == 1
assert client._mqtt_client.set_credentials.call_args == mocker.call(
client._username, str(new_sastoken)
)
# Client is now waiting on a new SasToken again
assert client._sastoken_provider.wait_for_new_sastoken.await_count == 2
assert client._sastoken_provider.wait_for_new_sastoken.is_hanging()
assert client._mqtt_client.set_credentials.call_count == 1
# Trigger new SasToken arrival again
client._sastoken_provider.wait_for_new_sastoken.stop_hanging()
new_sastoken = client._sastoken_provider.wait_for_new_sastoken.return_value
assert isinstance(new_sastoken, st.SasToken)
await asyncio.sleep(0.1)
# Credentials are updated again
assert client._mqtt_client.set_credentials.call_count == 2
assert client._mqtt_client.set_credentials.call_args == mocker.call(
client._username, str(new_sastoken)
)
# And so it continues forever and ever
await client.shutdown()

Просмотреть файл

@ -13,7 +13,7 @@ from v3_async_wip.mqtt_client import (
expected_on_connect_rc,
expected_on_disconnect_rc,
)
from azure.iot.device.common import ProxyOptions
from v3_async_wip.config import ProxyOptions
import paho.mqtt.client as mqtt
import asyncio
import pytest
@ -406,11 +406,13 @@ class TestInstantiation:
proxy_type = "SOCKS5"
if "No Auth" in request.param:
proxy = ProxyOptions(proxy_type=proxy_type, proxy_addr="fake.address", proxy_port=1080)
proxy = ProxyOptions(
proxy_type=proxy_type, proxy_address="fake.address", proxy_port=1080
)
else:
proxy = ProxyOptions(
proxy_type=proxy_type,
proxy_addr="fake.address",
proxy_address="fake.address",
proxy_port=1080,
proxy_username="fake_username",
proxy_password="fake_password",

Просмотреть файл

@ -1,32 +1,36 @@
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*- # TODO: do we need this?
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import pytest
import time
import re
import asyncio
import logging
import urllib
from azure.iot.device.common.auth.sastoken import (
RenewableSasToken,
NonRenewableSasToken,
import pytest
import sys
import time
import urllib.parse
from pytest_lazyfixture import lazy_fixture
from v3_async_wip.sastoken import (
SasToken,
InternalSasTokenGenerator,
ExternalSasTokenGenerator,
SasTokenProvider,
SasTokenError,
TOKEN_FORMAT,
DEFAULT_TOKEN_UPDATE_MARGIN,
)
from v3_async_wip import sastoken as st
logging.basicConfig(level=logging.DEBUG)
fake_uri = "some/resource/location"
fake_signed_data = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI="
fake_key_name = "fakekeyname"
fake_expiry = 12321312
simple_token_format = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}"
auth_rule_token_format = (
"SharedAccessSignature sr={resource}&sig={signature}&se={expiry}&skn={keyname}"
)
FAKE_URI = "some/resource/location"
FAKE_SIGNED_DATA = "8NJRMT83CcplGrAGaUVIUM/md5914KpWVNngSVoF9/M="
FAKE_SIGNED_DATA2 = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI="
FAKE_CURRENT_TIME = 10000000000.0 # We living in 2286!
# TODO: make this a float
# TODO: should we mock out time.time to always return a fake time?
def token_parser(token_str):
@ -40,176 +44,70 @@ def token_parser(token_str):
return token_map
class RenewableSasTokenTestConfig(object):
@pytest.fixture
def signing_mechanism(self, mocker):
mechanism = mocker.MagicMock()
mechanism.sign.return_value = fake_signed_data
return mechanism
# TODO: Rename this. These are not "device" and "service" tokens, the distinction is more generic
@pytest.fixture(params=["Device Token", "Service Token"])
def sastoken(self, request, signing_mechanism):
token_type = request.param
if token_type == "Device Token":
return RenewableSasToken(uri=fake_uri, signing_mechanism=signing_mechanism)
elif token_type == "Service Token":
return RenewableSasToken(
uri=fake_uri, signing_mechanism=signing_mechanism, key_name=fake_key_name
)
def get_expiry_time():
return int(time.time()) + 3600 # One hour from right now,
@pytest.mark.describe("RenewableSasToken")
class TestRenewableSasToken(RenewableSasTokenTestConfig):
@pytest.mark.it("Instantiates with a default TTL of 3600 seconds if no TTL is provided")
def test_default_ttl(self, signing_mechanism):
s = RenewableSasToken(fake_uri, signing_mechanism)
assert s.ttl == 3600
@pytest.mark.it("Instantiates with a custom TTL if provided")
def test_custom_ttl(self, signing_mechanism):
custom_ttl = 4747
s = RenewableSasToken(fake_uri, signing_mechanism, ttl=custom_ttl)
assert s.ttl == custom_ttl
@pytest.mark.it("Instantiates with with no key name by default if no key name is provided")
def test_default_key_name(self, signing_mechanism):
s = RenewableSasToken(fake_uri, signing_mechanism)
assert s._key_name is None
@pytest.mark.it("Instantiates with the given key name if provided")
def test_custom_key_name(self, signing_mechanism):
s = RenewableSasToken(fake_uri, signing_mechanism, key_name=fake_key_name)
assert s._key_name == fake_key_name
@pytest.mark.it(
"Instantiates with an expiry time TTL seconds in the future from the moment of instantiation"
@pytest.fixture
def sastoken_str():
return TOKEN_FORMAT.format(
resource=urllib.parse.quote(FAKE_URI, safe=""),
signature=urllib.parse.quote(FAKE_SIGNED_DATA, safe=""),
expiry=get_expiry_time(),
)
def test_expiry_time(self, mocker, signing_mechanism):
fake_current_time = 1000
mocker.patch.object(time, "time", return_value=fake_current_time)
s = RenewableSasToken(fake_uri, signing_mechanism)
assert s.expiry_time == fake_current_time + s.ttl
@pytest.mark.it("Calls .refresh() to build the SAS token string on instantiation")
def test_refresh_on_instantiation(self, mocker, signing_mechanism):
refresh_mock = mocker.spy(RenewableSasToken, "refresh")
assert refresh_mock.call_count == 0
RenewableSasToken(fake_uri, signing_mechanism)
assert refresh_mock.call_count == 1
@pytest.mark.it("Returns the SAS token string as the string representation of the object")
def test_str_rep(self, sastoken):
assert str(sastoken) == sastoken._token
@pytest.mark.it(
"Maintains the .expiry_time attribute as a read-only property (raises AttributeError upon attempt)"
)
def test_expiry_time_read_only(self, sastoken):
with pytest.raises(AttributeError):
sastoken.expiry_time = 12321312
@pytest.mark.describe("RenewableSasToken - .refresh()")
class TestRenewableSasTokenRefresh(RenewableSasTokenTestConfig):
@pytest.mark.it("Sets a new expiry time of TTL seconds in the future")
def test_new_expiry(self, mocker, sastoken):
fake_current_time = 1000
mocker.patch.object(time, "time", return_value=fake_current_time)
sastoken.refresh()
assert sastoken.expiry_time == fake_current_time + sastoken.ttl
# TODO: reflect url encoding here?
@pytest.mark.it(
"Uses the token's signing mechanism to create a signature by signing a concatenation of the (URL encoded) URI and updated expiry time"
)
def test_generate_new_token(self, mocker, signing_mechanism, sastoken):
old_token_str = str(sastoken)
fake_future_time = 1000
mocker.patch.object(time, "time", return_value=fake_future_time)
signing_mechanism.reset_mock()
fake_signature = "new_fake_signature"
signing_mechanism.sign.return_value = fake_signature
sastoken.refresh()
# The token string has been updated
assert str(sastoken) != old_token_str
# The signing mechanism was used to sign a string
assert signing_mechanism.sign.call_count == 1
# The string being signed was a concatenation of the URI and expiry time
assert signing_mechanism.sign.call_args == mocker.call(
urllib.parse.quote(sastoken._uri, safe="") + "\n" + str(sastoken.expiry_time)
)
# The token string has the resulting signed string included as the signature
token_info = token_parser(str(sastoken))
assert token_info["sig"] == fake_signature
@pytest.mark.it(
"Builds a new token string using the token's URI (URL encoded) and expiry time, along with the signature created by the signing mechanism (also URL encoded)"
)
def test_token_string(self, sastoken):
token_str = sastoken._token
# Verify that token string representation matches token format
if not sastoken._key_name:
pattern = re.compile(r"SharedAccessSignature sr=(.+)&sig=(.+)&se=(.+)")
else:
pattern = re.compile(r"SharedAccessSignature sr=(.+)&sig=(.+)&se=(.+)&skn=(.+)")
assert pattern.match(token_str)
# Verify that content in the string representation is correct
token_info = token_parser(token_str)
assert token_info["sr"] == urllib.parse.quote(sastoken._uri, safe="")
assert token_info["sig"] == urllib.parse.quote(
sastoken._signing_mechanism.sign.return_value, safe=""
)
assert token_info["se"] == str(sastoken.expiry_time)
if sastoken._key_name:
assert token_info["skn"] == sastoken._key_name
@pytest.mark.it("Raises a SasTokenError if an exception is raised by the signing mechanism")
def test_signing_mechanism_raises_value_error(
self, mocker, signing_mechanism, sastoken, arbitrary_exception
):
signing_mechanism.sign.side_effect = arbitrary_exception
with pytest.raises(SasTokenError) as e_info:
sastoken.refresh()
assert e_info.value.__cause__ is arbitrary_exception
@pytest.fixture
def sastoken(sastoken_str):
return SasToken(sastoken_str)
@pytest.mark.describe("NonRenewableSasToken")
class TestNonRenewableSasToken(object):
# TODO: Rename this. These are not "device" and "service" tokens, the distinction is more generic
@pytest.fixture(params=["Device Token", "Service Token"])
def sastoken_str(self, request):
token_type = request.param
if token_type == "Device Token":
return simple_token_format.format(
resource=urllib.parse.quote(fake_uri, safe=""),
signature=urllib.parse.quote(fake_signed_data, safe=""),
expiry=fake_expiry,
)
elif token_type == "Service Token":
return auth_rule_token_format.format(
resource=urllib.parse.quote(fake_uri, safe=""),
signature=urllib.parse.quote(fake_signed_data, safe=""),
expiry=fake_expiry,
keyname=fake_key_name,
)
@pytest.fixture
async def mock_signing_mechanism(mocker):
mock_sm = mocker.AsyncMock()
mock_sm.sign.return_value = FAKE_SIGNED_DATA
return mock_sm
@pytest.fixture()
def sastoken(self, sastoken_str):
return NonRenewableSasToken(sastoken_str)
@pytest.fixture(params=["Generator Function", "Generator Coroutine Function"])
def mock_token_generator_fn(mocker, request, sastoken_str):
if request.param == "Function":
return mocker.MagicMock(return_value=sastoken_str)
else:
return mocker.AsyncMock(return_value=sastoken_str)
@pytest.fixture(params=["InternalSasTokenGenerator", "ExternalSasTokenGenerator"])
def sastoken_generator(request, mocker, mock_signing_mechanism, sastoken_str):
if request.param == "ExternalSasTokenGenerator":
# We don't care about the difference between sync/async generator_fns when testing
# at this level of abstraction, so just pick one
generator = ExternalSasTokenGenerator(mocker.MagicMock(return_value=sastoken_str))
else:
generator = InternalSasTokenGenerator(mock_signing_mechanism, FAKE_URI)
mocker.spy(generator, "generate_sastoken")
return generator
# TODO: adjust mocks so that initial token is not the same as generated token
@pytest.fixture
async def sastoken_provider(sastoken_generator):
provider = await SasTokenProvider.create_from_generator(sastoken_generator)
# Creating from the generator invokes a call on the generator, so reset the spy mock
# so it doesn't throw off any testing logic
provider._generator.generate_sastoken.reset_mock()
yield provider
await provider.shutdown()
@pytest.mark.describe("SasToken")
class TestSasToken:
@pytest.mark.it("Instantiates from a valid SAS Token string")
def test_instantiates_from_token_string(self, sastoken_str):
s = NonRenewableSasToken(sastoken_str)
assert s._token == sastoken_str
s = SasToken(sastoken_str)
assert s._token_str == sastoken_str
@pytest.mark.it("Raises a SasToken error if instantiating from an invalid SAS Token string")
@pytest.mark.it("Raises a ValueError error if instantiating from an invalid SAS Token string")
@pytest.mark.parametrize(
"invalid_token_str",
[
@ -237,49 +135,575 @@ class TestNonRenewableSasToken(object):
"SharedAccessSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=",
id="Missing expiry value",
),
pytest.param(
"SharedAccessSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312&foovalue=nonsense",
id="Extraneous invalid value",
),
],
)
def test_raises_error_invalid_token_string(self, invalid_token_str):
with pytest.raises(SasTokenError):
NonRenewableSasToken(invalid_token_str)
with pytest.raises(ValueError):
SasToken(invalid_token_str)
@pytest.mark.it("Returns the SAS token string as the string representation of the object")
def test_str_rep(self, sastoken_str):
sastoken = NonRenewableSasToken(sastoken_str)
sastoken = SasToken(sastoken_str)
assert str(sastoken) == sastoken_str
@pytest.mark.it(
"Instantiates with the .expiry_time attribute corresponding to the expiry time of the given SAS Token string (as an integer)"
"Instantiates with the .expiry_time property corresponding to the expiry time of the given SAS Token string (as a float)"
)
def test_instantiates_expiry_time(self, sastoken_str):
sastoken = NonRenewableSasToken(sastoken_str)
sastoken = SasToken(sastoken_str)
expected_expiry_time = token_parser(sastoken_str)["se"]
assert sastoken.expiry_time == int(expected_expiry_time)
assert sastoken.expiry_time == float(expected_expiry_time)
@pytest.mark.it(
"Maintains the .expiry_time attribute as a read-only property (raises AttributeError upon attempt)"
)
@pytest.mark.it("Maintains .expiry_time as a read-only property")
def test_expiry_time_read_only(self, sastoken):
with pytest.raises(AttributeError):
sastoken.expiry_time = 12312312312123
@pytest.mark.it(
"Instantiates with the .resource_uri attribute corresponding to the URL decoded URI of the given SAS Token string"
"Instantiates with the .resource_uri property corresponding to the URL decoded URI of the given SAS Token string"
)
def test_instantiates_resource_uri(self, sastoken_str):
sastoken = NonRenewableSasToken(sastoken_str)
sastoken = SasToken(sastoken_str)
resource_uri = token_parser(sastoken_str)["sr"]
assert resource_uri != sastoken.resource_uri
assert resource_uri == urllib.parse.quote(sastoken.resource_uri, safe="")
assert urllib.parse.unquote(resource_uri) == sastoken.resource_uri
@pytest.mark.it(
"Maintains the .resource_uri attribute as a read-only property (raises AttributeError upon attempt)"
)
@pytest.mark.it("Maintains .resource_uri as a read-only property")
def test_resource_uri_read_only(self, sastoken):
with pytest.raises(AttributeError):
sastoken.resource_uri = "new%2Ffake%2Furi"
@pytest.mark.it(
"Instantiates with the .signature property corresponding to the URL decoded signature of the given SAS Token string"
)
def test_instantiates_signature(self, sastoken_str):
sastoken = SasToken(sastoken_str)
signature = token_parser(sastoken_str)["sig"]
assert signature != sastoken.signature
assert signature == urllib.parse.quote(sastoken.signature, safe="")
assert urllib.parse.unquote(signature) == sastoken.signature
@pytest.mark.it("Maintains .signature as a read-only property")
def test_signature_read_only(self, sastoken):
with pytest.raises(AttributeError):
sastoken.signature = "asdfas"
@pytest.mark.describe("InternalSasTokenGenerator -- Instantiation")
class TestSasTokenGeneratorInstantiation:
@pytest.mark.it("Stores the provided signing mechanism as an attribute")
def test_signing_mechanism(self, mock_signing_mechanism):
generator = InternalSasTokenGenerator(
signing_mechanism=mock_signing_mechanism, uri=FAKE_URI, ttl=4700
)
assert generator.signing_mechanism is mock_signing_mechanism
@pytest.mark.it("Stores the provided URI as an attribute")
def test_uri(self, mock_signing_mechanism):
generator = InternalSasTokenGenerator(
signing_mechanism=mock_signing_mechanism, uri=FAKE_URI, ttl=4700
)
assert generator.uri == FAKE_URI
@pytest.mark.it("Stores the provided TTL as an attribute")
def test_ttl(self, mock_signing_mechanism):
generator = InternalSasTokenGenerator(
signing_mechanism=mock_signing_mechanism, uri=FAKE_URI, ttl=4700
)
assert generator.ttl == 4700
@pytest.mark.it("Defaults to using 3600 as the TTL if not provided")
def test_ttl_default(self, mock_signing_mechanism):
generator = InternalSasTokenGenerator(
signing_mechanism=mock_signing_mechanism, uri=FAKE_URI
)
assert generator.ttl == 3600
@pytest.mark.describe("InternalSasTokenGenerator - .generate_sastoken()")
class TestSasTokenGeneratorGenerateSastoken:
@pytest.fixture
def sastoken_generator(self, mock_signing_mechanism):
return InternalSasTokenGenerator(
signing_mechanism=mock_signing_mechanism, uri=FAKE_URI, ttl=4700
)
@pytest.mark.it(
"Returns a newly generated SasToken for the configured URI that is valid for TTL seconds"
)
async def test_token_expiry(self, mocker, sastoken_generator):
# Patch time.time() to return a fake time so that it's easy to check the delta with expiry
mocker.patch.object(time, "time", return_value=FAKE_CURRENT_TIME)
expected_expiry = FAKE_CURRENT_TIME + sastoken_generator.ttl
token = await sastoken_generator.generate_sastoken()
assert isinstance(token, SasToken)
assert token.expiry_time == expected_expiry
assert token.resource_uri == sastoken_generator.uri
assert token._token_info["sr"] == urllib.parse.quote(sastoken_generator.uri, safe="")
assert token.resource_uri != token._token_info["sr"]
@pytest.mark.it(
"Creates the resulting SasToken's signature by using the InternalSasTokenGenerator's signing mechanism to sign a concatenation of the (URL encoded) URI and (URL encoded, int converted) desired expiry time"
)
async def test_token_signature(self, mocker, sastoken_generator):
assert sastoken_generator.signing_mechanism.await_count == 0
mocker.patch.object(time, "time", return_value=FAKE_CURRENT_TIME)
expected_expiry = int(FAKE_CURRENT_TIME + sastoken_generator.ttl)
expected_data_to_sign = (
urllib.parse.quote(sastoken_generator.uri, safe="") + "\n" + str(expected_expiry)
)
token = await sastoken_generator.generate_sastoken()
assert sastoken_generator.signing_mechanism.sign.await_count == 1
assert sastoken_generator.signing_mechanism.sign.await_args == mocker.call(
expected_data_to_sign
)
assert token._token_info["sig"] == urllib.parse.quote(
sastoken_generator.signing_mechanism.sign.return_value, safe=""
)
assert token.signature == sastoken_generator.signing_mechanism.sign.return_value
assert token.signature != token._token_info["sig"]
@pytest.mark.it("Raises a SasTokenError if an exception is raised by the signing mechanism")
async def test_signing_mechanism_raises(self, sastoken_generator, arbitrary_exception):
sastoken_generator.signing_mechanism.sign.side_effect = arbitrary_exception
with pytest.raises(SasTokenError) as e_info:
await sastoken_generator.generate_sastoken()
assert e_info.value.__cause__ is arbitrary_exception
@pytest.mark.describe("ExternalSasTokenGenerator -- Instantiation")
class TestExternalSasTokenGeneratorInstantiation:
@pytest.mark.it("Stores the provided generator_fn callable as an attribute")
def test_generator_fn_attribute(self, mock_token_generator_fn):
sastoken_generator = ExternalSasTokenGenerator(mock_token_generator_fn)
assert sastoken_generator.generator_fn is mock_token_generator_fn
@pytest.mark.describe("ExternalSasTokenGenerator -- .generate_sastoken()")
class TestExternalSasTokenGeneratorGenerateSasToken:
@pytest.fixture
def sastoken_generator(self, mock_token_generator_fn):
return ExternalSasTokenGenerator(mock_token_generator_fn)
@pytest.mark.it(
"Generates a new SasToken from the SAS Token string returned by the configured generator_fn callable"
)
async def test_returns_token(self, mocker, sastoken_generator):
if isinstance(sastoken_generator.generator_fn, mocker.AsyncMock):
assert sastoken_generator.generator_fn.await_count == 0
else:
assert sastoken_generator.generator_fn.call_count == 0
token = await sastoken_generator.generate_sastoken()
assert isinstance(token, SasToken)
if isinstance(sastoken_generator.generator_fn, mocker.AsyncMock):
assert sastoken_generator.generator_fn.await_count == 1
assert sastoken_generator.generator_fn.await_args == mocker.call()
else:
assert sastoken_generator.generator_fn.call_count == 1
assert sastoken_generator.generator_fn.call_args == mocker.call()
assert str(token) == sastoken_generator.generator_fn.return_value
@pytest.mark.it(
"Raises SasTokenError if an exception is raised while trying to generate a SAS Token string with the generator_fn"
)
async def test_generator_fn_raises(self, sastoken_generator, arbitrary_exception):
sastoken_generator.generator_fn.side_effect = arbitrary_exception
with pytest.raises(SasTokenError) as e_info:
await sastoken_generator.generate_sastoken()
assert e_info.value.__cause__ is arbitrary_exception
@pytest.mark.it("Raises SasTokenError if the generated SAS Token string is invalid")
async def test_invalid_token(self, sastoken_generator):
sastoken_generator.generator_fn.return_value = "not a sastoken"
with pytest.raises(SasTokenError) as e_info:
await sastoken_generator.generate_sastoken()
assert isinstance(e_info.value.__cause__, ValueError)
@pytest.mark.describe("SasTokenProvider -- Instantiation")
class TestSasTokenProviderInstantiation:
@pytest.mark.it("Stores the provided SasTokenGenerator")
async def test_generator_fn(self, sastoken, sastoken_generator):
provider = SasTokenProvider(initial_token=sastoken, generator=sastoken_generator)
assert provider._generator is sastoken_generator
await provider.shutdown()
@pytest.mark.it("Sets the provided initial_token as the current SasToken")
async def test_initial_token(self, sastoken, sastoken_generator):
provider = SasTokenProvider(initial_token=sastoken, generator=sastoken_generator)
assert provider._sastoken is sastoken
await provider.shutdown()
@pytest.mark.it("Sets the token update margin to the DEFAULT_TOKEN_UPDATE_MARGIN")
async def test_token_update_margin(self, sastoken, sastoken_generator):
provider = SasTokenProvider(initial_token=sastoken, generator=sastoken_generator)
assert provider._token_update_margin == DEFAULT_TOKEN_UPDATE_MARGIN
await provider.shutdown()
# NOTE: The contents of this coroutine are tested in a separate test suite below.
# See TestSasTokenProviderKeepTokenFresh for more.
@pytest.mark.it("Begins running the ._keep_token_fresh() coroutine method, storing the task")
async def test_keep_token_fresh_running(self, sastoken, sastoken_generator):
provider = SasTokenProvider(initial_token=sastoken, generator=sastoken_generator)
assert isinstance(provider._keep_token_fresh_task, asyncio.Task)
assert not provider._keep_token_fresh_task.done()
if sys.version_info >= (3, 8):
# NOTE: There isn't a way to validate the contents of a task until 3.8
# as far as I can tell.
task_coro = provider._keep_token_fresh_task.get_coro()
assert task_coro.__qualname__ == "SasTokenProvider._keep_token_fresh"
await provider.shutdown()
@pytest.mark.describe("SasTokenProvider - .create_from_generator()")
class TestSasTokenProviderCreateFromGenerator:
@pytest.mark.it("Returns a SasTokenProvider instance")
async def test_instantiates(self, sastoken_generator):
provider = await SasTokenProvider.create_from_generator(sastoken_generator)
assert isinstance(provider, SasTokenProvider)
await provider.shutdown()
@pytest.mark.it(
"Generates a new SasToken using the provided SasTokenGenerator to use as the SasTokenProvider's initial sastoken"
)
async def test_generates_initial_token(self, mocker, sastoken_generator):
assert sastoken_generator.generate_sastoken.await_count == 0
provider = await SasTokenProvider.create_from_generator(sastoken_generator)
assert sastoken_generator.generate_sastoken.await_count == 1
assert sastoken_generator.generate_sastoken.await_args == mocker.call()
assert isinstance(provider._sastoken, SasToken)
assert provider._sastoken == sastoken_generator.generate_sastoken.spy_return
await provider.shutdown()
@pytest.mark.it("Allows any exception raised while trying to generate a SasToken to propagate")
@pytest.mark.parametrize(
"exception",
[
pytest.param(SasTokenError("token error"), id="SasTokenError"),
pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"),
],
)
async def test_generation_raises(self, sastoken_generator, exception):
sastoken_generator.generate_sastoken.side_effect = exception
with pytest.raises(type(exception)) as e_info:
await SasTokenProvider.create_from_generator(sastoken_generator)
assert e_info.value is exception
@pytest.mark.it("Raises a SasTokenError if the generated SAS Token string has already expired")
async def test_expired_token(self, mocker, sastoken_generator):
expired_token_str = TOKEN_FORMAT.format(
resource=urllib.parse.quote(FAKE_URI, safe=""),
signature=urllib.parse.quote(FAKE_SIGNED_DATA, safe=""),
expiry=int(time.time()) - 3600, # 1 hour ago
)
sastoken_generator.generate_sastoken = mocker.AsyncMock()
sastoken_generator.generate_sastoken.return_value = SasToken(expired_token_str)
with pytest.raises(SasTokenError):
await SasTokenProvider.create_from_generator(sastoken_generator)
@pytest.mark.describe("SasTokenProvider - .shutdown()")
class TestSasTokenProviderShutdown:
@pytest.mark.it("Cancels the stored ._keep_token_fresh() task")
async def test_cancels_keep_token_fresh(self, sastoken_provider):
assert isinstance(sastoken_provider._keep_token_fresh_task, asyncio.Task)
assert not sastoken_provider._keep_token_fresh_task.done()
await sastoken_provider.shutdown()
assert sastoken_provider._keep_token_fresh_task.done()
assert sastoken_provider._keep_token_fresh_task.cancelled()
@pytest.mark.describe("SasTokenProvider - .get_current_sastoken()")
class TestSasTokenGetCurrentSasToken:
@pytest.mark.it("Returns the current SasToken object")
def test_returns_current_token(self, sastoken_provider):
current_token = sastoken_provider.get_current_sastoken()
assert current_token is sastoken_provider._sastoken
new_token_str = TOKEN_FORMAT.format(
resource=urllib.parse.quote(FAKE_URI, safe=""),
signature=urllib.parse.quote(FAKE_SIGNED_DATA, safe=""),
expiry=int(time.time()) + 3600,
)
new_token = SasToken(new_token_str)
sastoken_provider._sastoken = new_token
assert sastoken_provider.get_current_sastoken() is new_token
@pytest.mark.describe("SasTokenProvider - .wait_for_new_sastoken()")
class TestSasTokenWaitForNewSasToken:
@pytest.mark.it(
"Returns the current SasToken object once a notified of a new token being available"
)
async def test_returns_new_current_token(self, sastoken_provider):
token_str_1 = TOKEN_FORMAT.format(
resource=urllib.parse.quote(FAKE_URI, safe=""),
signature=urllib.parse.quote(FAKE_SIGNED_DATA, safe=""),
expiry=int(time.time()) + 3600,
)
token1 = SasToken(token_str_1)
token_str_2 = TOKEN_FORMAT.format(
resource=urllib.parse.quote(FAKE_URI, safe=""),
signature=urllib.parse.quote(FAKE_SIGNED_DATA2, safe=""),
expiry=int(time.time()) + 4500,
)
token2 = SasToken(token_str_2)
sastoken_provider._sastoken = token1
assert sastoken_provider.get_current_sastoken() is token1
# Waiting for new token, but one is not yet available
task = asyncio.create_task(sastoken_provider.wait_for_new_sastoken())
await asyncio.sleep(0.1)
assert not task.done()
# Update the token, but without notification, the waiting task still does not return
sastoken_provider._sastoken = token2
await asyncio.sleep(0.1)
assert not task.done()
# Notify that a new token is available, and now the task will return
async with sastoken_provider._new_sastoken_available:
sastoken_provider._new_sastoken_available.notify_all()
returned_token = await task
# The task returned the new token
assert returned_token is token2
assert returned_token is not token1
assert returned_token is sastoken_provider.get_current_sastoken()
# NOTE: This test suite assumes the correct implementation of ._wait_until() for critical
# requirements. Find it tested in a separate suite below (TestWaitUntil)
@pytest.mark.describe("SasTokenProvider -- Keep Token Fresh Task")
class TestSasTokenProviderKeepTokenFresh:
@pytest.fixture(autouse=True)
def spy_time(self, mocker):
"""Spy on the time module so that we can find out last time that was returned"""
spy_time = mocker.spy(time, "time")
return spy_time
# NOTE: This is an autouse fixture to ensure that it gets called first, since we want to make sure
# this mock is running when the SasTokenProvider is created.
@pytest.fixture(autouse=True)
def mock_wait_until(self, mocker):
"""Mock out the wait_until function so these tests aren't dependent on real time passing"""
mock_wait_until = mocker.patch.object(st, "_wait_until")
mock_wait_until._allow_proceed = asyncio.Event()
# Fake implementation that will wait for an explicit trigger to proceed, rather than the
# passage of time
async def fake_wait_until(when):
await mock_wait_until._allow_proceed.wait()
mock_wait_until.side_effect = fake_wait_until
# Define a mechanism that will allow an explicit trigger to let the mocked coroutine return
def proceed():
mock_wait_until._allow_proceed.set()
mock_wait_until._allow_proceed = asyncio.Event()
mock_wait_until.proceed = proceed
return mock_wait_until
@pytest.mark.it(
"Waits until the configured update margin number of seconds before current SasToken expiry to generate a new SasToken"
)
async def test_wait_to_generate(self, mocker, mock_wait_until, sastoken_provider):
original_token = sastoken_provider.get_current_sastoken()
assert sastoken_provider._generator.generate_sastoken.await_count == 0
await asyncio.sleep(0.1)
# We are waiting the expected amount of time
expected_update_time = original_token.expiry_time - sastoken_provider._token_update_margin
assert mock_wait_until.await_count == 1
assert mock_wait_until.await_args == mocker.call(expected_update_time)
# Allow the waiting to end, and a new token to be generated
mock_wait_until.proceed()
await asyncio.sleep(0.1)
assert sastoken_provider._generator.generate_sastoken.await_count == 1
assert sastoken_provider._generator.generate_sastoken.await_args == mocker.call()
@pytest.mark.it(
"Sets the newly generated SasToken as the new current SasToken and sends notification of its availability"
)
async def test_replace_token_and_notify(self, mocker, sastoken_provider, mock_wait_until):
notification_spy = mocker.spy(sastoken_provider._new_sastoken_available, "notify_all")
# We have the original token, as we have not yet generated a new one
original_token = sastoken_provider.get_current_sastoken()
assert sastoken_provider._generator.generate_sastoken.await_count == 0
assert notification_spy.call_count == 0
# Allow waiting to proceed, and a new token to be generated
mock_wait_until.proceed()
await asyncio.sleep(0.1)
# A new token has now been generated
assert sastoken_provider._generator.generate_sastoken.await_count == 1
# The current token is now the token that was just generated
current_token = sastoken_provider.get_current_sastoken()
assert current_token is sastoken_provider._generator.generate_sastoken.spy_return
# This token is not the same as the original token
assert current_token is not original_token
# A notification was sent about the new token
assert notification_spy.call_count == 1
@pytest.mark.it(
"Waits until the configured update margin number of seconds before the NEW current SasToken expiry, after each time a new SasToken is generated, before once again generating a new SasToken"
)
async def test_wait_to_generate_again_and_again(
self, mocker, mock_wait_until, sastoken_provider
):
# Current token is the original, we have not yet generated a new one
original_token = sastoken_provider.get_current_sastoken()
assert sastoken_provider._generator.generate_sastoken.await_count == 0
await asyncio.sleep(0.1)
# We are waiting based on the original token's expiry time
expected_update_time = original_token.expiry_time - sastoken_provider._token_update_margin
assert mock_wait_until.await_count == 1
assert mock_wait_until.await_args == mocker.call(expected_update_time)
# Allow the waiting to end, and a new token to be generated
mock_wait_until.proceed()
await asyncio.sleep(0.1)
assert sastoken_provider._generator.generate_sastoken.await_count == 1
assert sastoken_provider._generator.generate_sastoken.await_args == mocker.call()
# New token is the one that was just generated
new_token = sastoken_provider.get_current_sastoken()
assert new_token is sastoken_provider._generator.generate_sastoken.spy_return
assert new_token is not original_token
# We are once again waiting, this time based on the new token's expiry time
expected_update_time = new_token.expiry_time - sastoken_provider._token_update_margin
assert mock_wait_until.await_count == 2
assert mock_wait_until.await_args == mocker.call(expected_update_time)
# Allow the waiting to end and another new token to be generated
mock_wait_until.proceed()
await asyncio.sleep(0.1)
assert sastoken_provider._generator.generate_sastoken.await_count == 2
assert sastoken_provider._generator.generate_sastoken.await_args == mocker.call()
# Newest token is the one that was just generated
newest_token = sastoken_provider.get_current_sastoken()
assert newest_token is sastoken_provider._generator.generate_sastoken.spy_return
assert newest_token is not original_token
assert newest_token is not new_token
# We are once again waiting, this time based on the newest token's expiry time
expected_update_time = newest_token.expiry_time - sastoken_provider._token_update_margin
assert mock_wait_until.await_count == 3
assert mock_wait_until.await_args == mocker.call(expected_update_time)
# And so on and so forth to infinity...
@pytest.mark.it(
"Sets the newly generated SasToken as the new current SasToken and sends notification of its availability each time a new token is generated"
)
async def test_replace_token_and_notify_each_time(
self, mocker, sastoken_provider, mock_wait_until
):
notification_spy = mocker.spy(sastoken_provider._new_sastoken_available, "notify_all")
# We have the original token, as we have not yet generated a new one
original_token = sastoken_provider.get_current_sastoken()
assert sastoken_provider._generator.generate_sastoken.await_count == 0
assert notification_spy.call_count == 0
# Allow waiting to proceed, and a new token to be generated
mock_wait_until.proceed()
await asyncio.sleep(0.1)
# A new token has now been generated
assert sastoken_provider._generator.generate_sastoken.await_count == 1
# The current token is now the token that was just generated
second_token = sastoken_provider.get_current_sastoken()
assert second_token is sastoken_provider._generator.generate_sastoken.spy_return
# This token is not the same as the original token
assert second_token is not original_token
# A notification was sent about the new token
assert notification_spy.call_count == 1
# Allow waiting to proceed and another new token to be generated
mock_wait_until.proceed()
await asyncio.sleep(0.1)
# Another new token has now been generated
assert sastoken_provider._generator.generate_sastoken.await_count == 2
# The current token is now the token that was just generated
third_token = sastoken_provider.get_current_sastoken()
assert third_token is sastoken_provider._generator.generate_sastoken.spy_return
# This token is not the same as any previous token
assert third_token is not original_token
assert third_token is not second_token
# A notification was sent about the new token
assert notification_spy.call_count == 2
# And so on and so forth to infinity...
@pytest.mark.it("Tries to generate again in 10 seconds if SasToken generation fails")
@pytest.mark.parametrize(
"exception",
[
pytest.param(SasTokenError("Some error in SAS"), id="SasTokenError"),
pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"),
],
)
async def test_generation_failure(
self, mocker, sastoken_provider, exception, mock_wait_until, spy_time
):
# Set token generation to raise exception
sastoken_provider._generator.generate_sastoken.side_effect = exception
# Token generation has not yet happened
assert sastoken_provider._generator.generate_sastoken.await_count == 0
# Allow waiting to proceed, and a new token to be generated
assert mock_wait_until.await_count == 1
mock_wait_until.proceed()
await asyncio.sleep(0.1)
# Waits 10 seconds past the current time
expected_generate_time = spy_time.spy_return + 10
assert mock_wait_until.await_count == 2
assert mock_wait_until.await_args == mocker.call(expected_generate_time)
# NOTE: We don't normally test convention-private helpers directly, but in this case, the
# complexity is high enough, and the function is critical enough, that it makes more sense
# to isolate rather than attempting to indirectly test.
@pytest.mark.describe("._wait_until()")
class TestWaitUntil:
@pytest.mark.it(
"Repeatedly does 1 second asyncio sleeps until the current time is greater than the provided 'when' parameter"
)
@pytest.mark.parametrize(
"time_from_now",
[
pytest.param(5, id="5 seconds from now"),
pytest.param(60, id="1 minute from now"),
pytest.param(3600, id="1 hour from now"),
],
)
async def test_sleep(self, mocker, time_from_now):
# Mock out the sleep coroutine so that we aren't waiting around forever on this test
mock_sleep = mocker.patch.object(asyncio, "sleep")
# mock out time
def fake_time():
"""Fake time implementation that will return a time float that is 1 larger
than the previous time it was called"""
fake_time_return = FAKE_CURRENT_TIME
while True:
yield fake_time_return
fake_time_return += 1
fake_time_gen = fake_time()
mock_time = mocker.patch.object(time, "time", side_effect=fake_time_gen)
desired_time = FAKE_CURRENT_TIME + time_from_now
await st._wait_until(desired_time)
assert mock_sleep.await_count == time_from_now
for call in mock_sleep.await_args_list:
assert call == mocker.call(1)
assert mock_time.call_count == time_from_now + 1
for call in mock_time.call_args_list:
assert call == mocker.call()

Просмотреть файл

@ -74,13 +74,13 @@ class TestSymmetricKeySigningMechanismSign(object):
@pytest.mark.it(
"Generates an HMAC message digest from the signing key and provided data string, using the HMAC-SHA256 algorithm"
)
def test_hmac(self, mocker, signing_mechanism):
async def test_hmac(self, mocker, signing_mechanism):
hmac_mock = mocker.patch.object(hmac, "HMAC")
hmac_digest_mock = hmac_mock.return_value.digest
hmac_digest_mock.return_value = b"\xd2\x06\xf7\x12\xf1\xe9\x95$\x90\xfd\x12\x9a\xb1\xbe\xb4\xf8\xf3\xc4\x1ap\x8a\xab'\x8a.D\xfb\x84\x96\xca\xf3z"
data_string = "sign this message"
signing_mechanism.sign(data_string)
await signing_mechanism.sign(data_string)
assert hmac_mock.call_count == 1
assert hmac_mock.call_args == mocker.call(
@ -93,13 +93,13 @@ class TestSymmetricKeySigningMechanismSign(object):
@pytest.mark.it(
"Returns the base64 encoded HMAC message digest (converted to string) as the signed data"
)
def test_b64encode(self, mocker, signing_mechanism):
async def test_b64encode(self, mocker, signing_mechanism):
hmac_mock = mocker.patch.object(hmac, "HMAC")
hmac_digest_mock = hmac_mock.return_value.digest
hmac_digest_mock.return_value = b"\xd2\x06\xf7\x12\xf1\xe9\x95$\x90\xfd\x12\x9a\xb1\xbe\xb4\xf8\xf3\xc4\x1ap\x8a\xab'\x8a.D\xfb\x84\x96\xca\xf3z"
data_string = "sign this message"
signature = signing_mechanism.sign(data_string)
signature = await signing_mechanism.sign(data_string)
assert signature == base64.b64encode(hmac_digest_mock.return_value).decode("utf-8")
@ -115,11 +115,11 @@ class TestSymmetricKeySigningMechanismSign(object):
),
],
)
def test_supported_types(self, signing_mechanism, data_string, expected_signature):
assert signing_mechanism.sign(data_string) == expected_signature
async def test_supported_types(self, signing_mechanism, data_string, expected_signature):
assert await signing_mechanism.sign(data_string) == expected_signature
@pytest.mark.it("Raises a ValueError if unable to sign the provided data string")
@pytest.mark.parametrize("data_string", [pytest.param(123, id="Integer input")])
def test_bad_input(self, signing_mechanism, data_string):
async def test_bad_input(self, signing_mechanism, data_string):
with pytest.raises(ValueError):
signing_mechanism.sign(data_string)
await signing_mechanism.sign(data_string)

Просмотреть файл

@ -8,7 +8,7 @@ import logging
import socks
import ssl
from typing import Optional, Any
from azure.iot.device.common.auth import sastoken as st # type: ignore
from .sastoken import SasTokenProvider
# TODO: add typings for imports
# TODO: update docs to ensure types are correct
@ -67,8 +67,8 @@ class ClientConfig:
ssl_context: ssl.SSLContext,
hostname: str,
gateway_hostname: Optional[str] = None,
sastoken_provider: Optional[SasTokenProvider] = None,
proxy_options: Optional[ProxyOptions] = None,
sastoken: Optional[st.SasToken] = None,
keep_alive: int = 60,
auto_reconnect: bool = True,
websockets: bool = False,
@ -77,12 +77,12 @@ class ClientConfig:
:param str hostname: The hostname being connected to
:param str gateway_hostname: The gateway hostname optionally being used
:param sastoken_provider: Object that can provide SasTokens
:type sastoken_provider: :class:`SasTokenProvider`
:param proxy_options: Details of proxy configuration
:type proxy_options: :class:`azure.iot.device.common.models.ProxyOptions`
:param ssl_context: SSLContext to use with the client
:type ssl_context: :class:`ssl.SSLContext`
:param sastoken: SasToken to be used for authentication. Mutually exclusive with x509.
:type sastoken: :class:`azure.iot.device.common.auth.SasToken`
:param int keepalive: Maximum period in seconds between communications with the
broker.
:param bool auto_reconnect: Indicates if dropped connection should result in attempts to
@ -94,10 +94,10 @@ class ClientConfig:
self.hostname = hostname
self.gateway_hostname = gateway_hostname
self.proxy_options = proxy_options
self.ssl_context = ssl_context
# Auth
self.sastoken = sastoken
self.sastoken_provider = sastoken_provider
self.ssl_context = ssl_context
# MQTT
self.keep_alive = _sanitize_keep_alive(keep_alive)

Просмотреть файл

@ -47,7 +47,8 @@ class IoTEdgeHsm(SigningMechanism):
self.generation_id = generation_id
self.workload_uri = _format_socket_uri(workload_uri)
def get_certificate(self) -> str:
# TODO: Use async http to make use of this being a coroutine
async def get_certificate(self) -> str:
"""
Return the server verification certificate from the trust bundle that can be used to
validate the server-side SSL TLS connection that we use to talk to Edge
@ -79,7 +80,8 @@ class IoTEdgeHsm(SigningMechanism):
raise IoTEdgeError("No certificate in trust bundle") from e
return cert
def sign(self, data_str: Union[str, bytes]) -> str:
# TODO: Use async http to make use of this being a coroutine
async def sign(self, data_str: Union[str, bytes]) -> str:
"""
Use the IoTEdge HSM to sign a piece of string data. The caller should then insert the
returned value (the signature) into the 'sig' field of a SharedAccessSignature string.

Просмотреть файл

@ -8,7 +8,7 @@ import asyncio
import json
import logging
import urllib.parse
from typing import Optional, AsyncGenerator
from typing import Optional, Union, AsyncGenerator
from .custom_typing import TwinPatch, Twin
from .iot_exceptions import IoTHubError, IoTHubClientError
from .models import Message, MethodResponse, MethodRequest
@ -16,15 +16,12 @@ from . import config, constant, user_agent
from . import request_response as rr
from . import mqtt_client as mqtt
from . import mqtt_topic_iothub as mqtt_topic
from azure.iot.device.common.auth import sastoken as st # type: ignore
from azure.iot.device.common import alarm # type: ignore
# TODO: update docstrings with correct class paths once repo structured better
logger = logging.getLogger(__name__)
DEFAULT_RECONNECT_INTERVAL = 10
DEFAULT_TOKEN_UPDATE_MARGIN = 120
DEFAULT_RECONNECT_INTERVAL: int = 10
# TODO: add exceptions to docstring
# TODO: background exceptions how
@ -43,23 +40,29 @@ class IoTHubMQTTClient:
:param client_config: The config object for the client
:type client_config: :class:`IoTHubClientConfig`
"""
# Basic Information
# Identity
self._device_id = client_config.device_id
self._module_id = client_config.module_id
self._client_id = _format_client_id(self._device_id, self._module_id)
self._username = _format_username(
# NOTE: Always use the original hostname, even if gateway hostname is set
hostname=client_config.hostname,
client_id=self._client_id,
product_info=client_config.product_info,
)
# Sastoken Auth
# TODO: Should this be handled by a separate utility? Would make testing easier, and would abstract out the difference
self._sastoken: Optional[st.SasToken]
self._sastoken_update_alarm: Optional[alarm.Alarm]
if client_config.sastoken is not None:
self._sastoken = client_config.sastoken
self._sastoken_update_alarm = self._create_token_update_alarm()
else:
self._sastoken = None
self._sastoken_update_alarm = None
# SAS (Optional)
self._sastoken_provider = client_config.sastoken_provider
# MQTT Configuration
self._mqtt_client = _create_mqtt_client(client_config)
self._mqtt_client = _create_mqtt_client(self._client_id, client_config)
if self._sastoken_provider:
logger.debug("Using SASToken as password")
password = str(self._sastoken_provider.get_current_sastoken())
else:
logger.debug("No password used")
password = None
self._mqtt_client.set_credentials(self._username, password)
# Create incoming IoTHub data generators
# TODO: expose these via method to make the device/module split cleaner
@ -79,40 +82,21 @@ class IoTHubMQTTClient:
# Internal request/response infrastructure
self._request_ledger = rr.RequestLedger()
self._twin_responses_enabled = False
self._twin_response_listener = asyncio.create_task(self._process_twin_responses())
def _create_token_update_alarm(self) -> alarm.Alarm:
if not self._sastoken:
# This should never happen, it's just for the type checker
raise ValueError("Can't create alarm for no SASToken")
update_time = self._sastoken.expiry_time - DEFAULT_TOKEN_UPDATE_MARGIN
if isinstance(self._sastoken, st.RenewableSasToken):
def on_token_needs_update():
# Renew the token
logger.debug("Renewing SAS Token...")
try:
self._sastoken.refresh()
logger.debug("SAS Token renewal succeeded")
except st.SasTokenError:
logger.error("SAS Token renewal failed")
# TODO: background exception?
# With the token renewed, now set a new Alarm
self._sastoken_update_alarm = self._create_token_update_alarm()
# Background Tasks
self._process_twin_responses_bg_task: asyncio.Task[None] = asyncio.create_task(
self._process_twin_responses()
)
self._keep_credentials_fresh_bg_task: Optional[asyncio.Task[None]]
if self._sastoken_provider:
self._keep_credentials_fresh_bg_task = asyncio.create_task(
self._keep_credentials_fresh()
)
else:
def on_token_needs_update():
pass
update_alarm = alarm.Alarm(update_time, on_token_needs_update)
update_alarm.daemon = True
update_alarm.start()
return update_alarm
self._keep_credentials_fresh_bg_task = None
async def _enable_twin_responses(self) -> None:
"""Enable receiving of twin responses (for twin requests, or twin patches) from IoTHub"""
logger.debug("Enabling receive of twin responses...")
topic = mqtt_topic.get_twin_response_topic_for_subscribe()
await self._mqtt_client.subscribe(topic)
@ -122,7 +106,7 @@ class IoTHubMQTTClient:
# TODO: add background exception handling
async def _process_twin_responses(self) -> None:
"""Run indefinitely, matching twin responses with request ID"""
logger.debug("Starting twin response listener")
logger.debug("Starting the 'process_twin_responses' background task")
twin_response_topic = mqtt_topic.get_twin_response_topic_for_subscribe()
twin_responses = self._mqtt_client.get_incoming_message_generator(twin_response_topic)
@ -144,28 +128,62 @@ class IoTHubMQTTClient:
# in-flight operations
logger.warning("Twin response (rid: {}) does not match any request")
async def _keep_credentials_fresh(self) -> None:
"""Run indefinitely, updating MQTT credentials when new SAS Token is available"""
logger.debug("Starting the 'keep_credentials_fresh' background task")
while True:
if self._sastoken_provider:
logger.debug("Waiting for new SAS Token to become available")
new_sastoken = await self._sastoken_provider.wait_for_new_sastoken()
logger.debug("New SAS Token available, updating MQTTClient credentials")
self._mqtt_client.set_credentials(self._username, str(new_sastoken))
# TODO: should we reconnect here? Or just wait for drop?
else:
# NOTE: This should never execute, it's mostly just here to keep the
# type checker happy
logger.error("No SasTokenProvider. Cannot update credentials")
break
async def shutdown(self) -> None:
"""
Shut down the client.
Invoke only when completely finished with the client for graceful exit.
Cannot be cancelled - if you try, the client will still fully shut down as much as
possible.
"""
# TODO: this breaks when called twice. Build some protections.
# TODO: is there an issue with cancellation here?
await self.disconnect()
# Cancel the SAS token update alarm. Note that this is not a task, it's a threaded Alarm.
# No need to wait for the result.
if self._sastoken_update_alarm:
logger.debug("Cancelling SAS Token update alarm")
self._sastoken_update_alarm.cancel()
# Cancel and wait for the completion of the twin response task
logger.debug("Cancelling twin response listener")
self._twin_response_listener.cancel()
# Wait for the cancellation to complete before returning
# NOTE: .disconnect() really shouldn't fail, but if it does, we temporarily suppress
# the exception so we can still do as much cleanup as possible.
cached_exception: Optional[Union[Exception, asyncio.CancelledError]] = None
logger.debug("Attempting disconnect in shutdown")
try:
await self._twin_response_listener
except asyncio.CancelledError:
pass
await self.disconnect()
except asyncio.CancelledError as e:
logger.warning("Cancellation during shutdown. Still attempting to clean up.")
cached_exception = e
except Exception as e:
logger.warning("Unexpected error disconnecting. Continuing shutdown procedure")
cached_exception = e
cancelled_tasks = []
logger.debug("Cancelling 'process_twin_responses' background task")
self._process_twin_responses_bg_task.cancel()
cancelled_tasks.append(self._process_twin_responses_bg_task)
if self._keep_credentials_fresh_bg_task:
logger.debug("Cancelling 'keep_credentials_fresh' background task")
self._keep_credentials_fresh_bg_task.cancel()
cancelled_tasks.append(self._keep_credentials_fresh_bg_task)
# Wait for the cancellation to complete before returning
# NOTE: If cancelled while awaiting here, all tasks in gather will still be cancelled
# because the cancellations have already been issued.
# NOTE: Also, cancelling a gather implicitly cancels all the tasks that are gathered anyway
await asyncio.gather(*cancelled_tasks, return_exceptions=True)
if cached_exception:
raise cached_exception
async def connect(self) -> None:
"""Connect to IoTHub
@ -468,17 +486,22 @@ class IoTHubMQTTClient:
logger.debug("Twin patch receive disabled")
# Auth Helpers
def _format_client_id(device_id: str, module_id: Optional[str] = None) -> str:
if module_id:
client_id = "{}/{}".format(device_id, module_id)
else:
client_id = device_id
return client_id
def _create_mqtt_client(client_config: config.IoTHubClientConfig) -> mqtt.MQTTClient:
def _create_mqtt_client(
client_id: str, client_config: config.IoTHubClientConfig
) -> mqtt.MQTTClient:
logger.debug("Creating MQTTClient")
if client_config.module_id:
client_id = "{}/{}".format(client_config.device_id, client_config.module_id)
logger.debug("Using IoTHub Module. Client ID is {}".format(client_id))
else:
client_id = client_config.device_id
logger.debug("Using IoTHub Device. Client ID is {}".format(client_id))
if client_config.gateway_hostname:
@ -512,24 +535,6 @@ def _create_mqtt_client(client_config: config.IoTHubClientConfig) -> mqtt.MQTTCl
proxy_options=client_config.proxy_options,
)
# NOTE: we use the original hostname here, even if gateway hostname is set
username = _create_username(
hostname=client_config.hostname,
client_id=client_id,
product_info=client_config.product_info,
)
logger.debug("Using {} as username".format(username))
if client_config.sastoken:
logger.debug("Using SASToken as password")
password = str(client_config.sastoken)
else:
logger.debug("No password used")
password = None
client.set_credentials(username, password)
# Add topic filters for receive
# IoTHub Receives
c2d_msg_topic = mqtt_topic.get_c2d_topic_for_subscribe(client_config.device_id)
@ -550,7 +555,7 @@ def _create_mqtt_client(client_config: config.IoTHubClientConfig) -> mqtt.MQTTCl
return client
def _create_username(hostname: str, client_id: str, product_info: str) -> str:
def _format_username(hostname: str, client_id: str, product_info: str) -> str:
query_param_seq = []
# Apply query parameters (i.e. key1=value1&key2=value2...&keyN=valueN format)

Просмотреть файл

@ -6,11 +6,19 @@
"""This module contains tools for working with Shared Access Signature (SAS) Tokens"""
import abc
import asyncio
import logging
import time
import urllib
from typing import Optional, Dict
import urllib.parse
from typing import Dict, List, Union, Awaitable, Callable, cast
from .signing_mechanism import SigningMechanism
logger = logging.getLogger(__name__)
DEFAULT_TOKEN_UPDATE_MARGIN: int = 120
REQUIRED_SASTOKEN_FIELDS: List[str] = ["sr", "sig", "se"]
TOKEN_FORMAT: str = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}"
class SasTokenError(Exception):
"""Error in SasToken"""
@ -18,160 +26,213 @@ class SasTokenError(Exception):
pass
class SasToken(abc.ABC):
"""Abstract parent class for SAS Tokens.
class SasToken:
def __init__(self, sastoken_str: str) -> None:
"""Create a SasToken object from a SAS Token string
:param str sastoken_str: The SAS Token string
Doesn't do much, but helps with type hints
"""
@property
@abc.abstractmethod
def expiry_time(self) -> int:
pass
class RenewableSasToken(SasToken):
"""Renewable Shared Access Signature Token used to authenticate a request.
This token is 'renewable', which means that it can be updated when necessary to
prevent expiry, by using the .refresh() method.
Data Attributes:
expiry_time (int): Time that token will expire (in UTC, since epoch)
ttl (int): Time to live for the token, in seconds
"""
_auth_rule_token_format = (
"SharedAccessSignature sr={resource}&sig={signature}&se={expiry}&skn={keyname}"
)
_simple_token_format = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}"
def __init__(
self,
uri: str,
signing_mechanism: SigningMechanism,
key_name: Optional[str] = None,
ttl: int = 3600,
) -> None:
:raises: ValueError if SAS Token string is invalid
"""
:param str uri: URI of the resource to be accessed
:param signing_mechanism: The signing mechanism to use in the SasToken
:type signing_mechanism: Child classes of :class:`azure.iot.common.SigningMechanism`
:param str key_name: Symmetric Key Name (optional)
:param int ttl: Time to live for the token, in seconds (default 3600)
:raises: SasTokenError if an error occurs building a SasToken
"""
self._uri = uri
self._signing_mechanism = signing_mechanism
self._key_name = key_name
# These two values will be set by the .refresh() call
self._expiry_time: int
self._token: str
self.ttl = ttl
self.refresh()
self._token_str: str = sastoken_str
self._token_info: Dict[str, str] = _get_sastoken_info_from_string(sastoken_str)
def __str__(self) -> str:
return self._token
def refresh(self) -> None:
"""
Refresh the SasToken lifespan, giving it a new expiry time, and generating a new token.
"""
self._expiry_time = int(time.time() + self.ttl)
self._token = self._build_token()
def _build_token(self) -> str:
"""Build SasToken representation
:returns: String representation of the token
"""
url_encoded_uri = urllib.parse.quote(self._uri, safe="")
message = url_encoded_uri + "\n" + str(self.expiry_time)
try:
signature = self._signing_mechanism.sign(message)
except Exception as e:
# Because of variant signing mechanisms, we don't know what error might be raised.
# So we catch all of them.
raise SasTokenError("Unable to build SasToken from given values") from e
url_encoded_signature = urllib.parse.quote(signature, safe="")
if self._key_name:
token = self._auth_rule_token_format.format(
resource=url_encoded_uri,
signature=url_encoded_signature,
expiry=str(self.expiry_time),
keyname=self._key_name,
)
else:
token = self._simple_token_format.format(
resource=url_encoded_uri,
signature=url_encoded_signature,
expiry=str(self.expiry_time),
)
return token
return self._token_str
@property
def expiry_time(self) -> int:
"""Expiry Time is READ ONLY"""
return self._expiry_time
class NonRenewableSasToken(SasToken):
"""NonRenewable Shared Access Signature Token used to authenticate a request.
This token is 'non-renewable', which means that it is invalid once it expires, and there
is no way to keep it alive. Instead, a new token must be created.
Data Attributes:
expiry_time (int): Time that token will expire (in UTC, since epoch)
resource_uri (str): URI for the resource the Token provides authentication to access
"""
def __init__(self, sastoken_string) -> None:
"""
:param str sastoken_string: A string representation of a SAS token
"""
self._token = sastoken_string
self._token_info = get_sastoken_info_from_string(self._token)
def __str__(self) -> str:
return self._token
@property
def expiry_time(self) -> int:
"""Expiry Time is READ ONLY"""
return int(self._token_info["se"])
def expiry_time(self) -> float:
# NOTE: Time is typically expressed in float in Python, even though a
# SAS Token expiry time should be a whole number.
return float(self._token_info["se"])
@property
def resource_uri(self) -> str:
"""Resource URI is READ ONLY"""
uri = self._token_info["sr"]
return urllib.parse.unquote(uri)
REQUIRED_SASTOKEN_FIELDS = ["sr", "sig", "se"]
VALID_SASTOKEN_FIELDS = REQUIRED_SASTOKEN_FIELDS + ["skn"]
@property
def signature(self) -> str:
signature = self._token_info["sig"]
return urllib.parse.unquote(signature)
def get_sastoken_info_from_string(sastoken_string: str) -> Dict[str, str]:
class SasTokenGenerator(abc.ABC):
@abc.abstractmethod
async def generate_sastoken(self):
pass
class InternalSasTokenGenerator(SasTokenGenerator):
def __init__(self, signing_mechanism: SigningMechanism, uri: str, ttl: int = 3600) -> None:
"""An object that can generate SasTokens using provided values
:param str uri: The URI of the resource you are generating a tokens to access
:param signing_mechanism: The signing mechanism that will be used to sign data
:type signing mechanism: :class:`SigningMechanism`
:param int ttl: Time to live for generated tokens, in seconds (default 3600)
"""
self.signing_mechanism = signing_mechanism
self.uri = uri
self.ttl = ttl
async def generate_sastoken(self) -> SasToken:
"""Generate a new SasToken
:raises: SasTokenError if the token cannot be generated
"""
expiry_time = int(time.time()) + self.ttl
url_encoded_uri = urllib.parse.quote(self.uri, safe="")
message = url_encoded_uri + "\n" + str(expiry_time)
try:
signature = await self.signing_mechanism.sign(message)
except Exception as e:
# Because of variant signing mechanisms, we don't know what error might be raised.
# So we catch all of them.
raise SasTokenError("Unable to generate SasToken") from e
url_encoded_signature = urllib.parse.quote(signature, safe="")
token_str = TOKEN_FORMAT.format(
resource=url_encoded_uri,
signature=url_encoded_signature,
expiry=str(expiry_time),
)
return SasToken(token_str)
class ExternalSasTokenGenerator(SasTokenGenerator):
def __init__(self, generator_fn: Union[Callable[[], str], Callable[[], Awaitable[str]]]):
"""An object that can generate SasTokens by invoking a provided callable.
This callable can be a function or a coroutine function.
:param generator_fn: A callable that takes no arguments and returns a SAS Token string
:type generator_fn: Function or Coroutine Function which returns a string
"""
self.generator_fn = generator_fn
async def generate_sastoken(self) -> SasToken:
"""Generate a new SasToken
:raises: SasTokenError if the token cannot be generated
"""
try:
# NOTE: the typechecker has some problems here, so we help it with a cast.
if asyncio.iscoroutinefunction(self.generator_fn):
generator_fn = cast(Callable[[], Awaitable[str]], self.generator_fn)
token_str = await generator_fn()
else:
generator_coro_fn = cast(Callable[[], str], self.generator_fn)
token_str = generator_coro_fn()
return SasToken(token_str)
except Exception as e:
raise SasTokenError("Unable to generate SasToken") from e
class SasTokenProvider:
def __init__(self, initial_token: SasToken, generator: SasTokenGenerator) -> None:
"""Object responsible for providing a valid SasToken.
Instantiate using a factory method instead of directly.
:param generator: A SasTokenGenerator to generate SasTokens with
:type generator: SasTokenGenerator
"""
# NOTE: There is no good way to invoke a coroutine from within the __init__, and since
# the the generator's .sign() method is a coroutine, that means we can't generate an
# initial token from it here. Thus, we have to take the initial token as a separate
# argument.
# However, this is inconvenient, and also prevents us from fast-failing if there's a
# problem with the generator_fn, so a factory coroutine method has been implemented.
self._event_loop = asyncio.get_running_loop()
self._generator = generator
self._sastoken = initial_token
self._token_update_margin = DEFAULT_TOKEN_UPDATE_MARGIN
self._new_sastoken_available = asyncio.Condition()
self._keep_token_fresh_task = asyncio.create_task(self._keep_token_fresh())
async def _keep_token_fresh(self):
"""Runs indefinitely and will generate a SasToken when the current one gets close to
expiration (based on the update margin)
"""
generate_time = self._sastoken.expiry_time - self._token_update_margin
while True:
await _wait_until(generate_time)
try:
logger.debug("Updating SAS Token...")
self._sastoken = await self._generator.generate_sastoken()
logger.debug("SAS Token update succeeded")
generate_time = self._sastoken.expiry_time - self._token_update_margin
async with self._new_sastoken_available:
self._new_sastoken_available.notify_all()
except Exception:
logger.error("SAS Token renewal failed. Trying again in 10 seconds")
generate_time = time.time() + 10
@classmethod
async def create_from_generator(cls, generator: SasTokenGenerator) -> "SasTokenProvider":
"""Create an instance of the SasTokenProvider that will rely on an external source
to generate new tokens via a callback function/coroutine.
:param generator: A SasTokenGenerator to generate SasTokens with
:type generator: SasTokenGenerator
:raises: SasTokenError if an initial SasToken cannot be generated
:raises: SasTokenError if the initial SasToken generated is invalid
"""
initial_token = await generator.generate_sastoken()
if initial_token.expiry_time < time.time():
raise SasTokenError("Newly generated SAS Token has already expired")
return cls(initial_token, generator)
async def shutdown(self) -> None:
"""Shut down the SasToken provider, and free any resources.
No further updates to the current SAS Token will be made
"""
self._keep_token_fresh_task.cancel()
# Wait for cancellation to complete
await asyncio.gather(self._keep_token_fresh_task, return_exceptions=True)
def get_current_sastoken(self) -> SasToken:
"""Return the current SasToken"""
return self._sastoken
async def wait_for_new_sastoken(self) -> SasToken:
"""Waits for a new SasToken to become available, and return it"""
async with self._new_sastoken_available:
await self._new_sastoken_available.wait()
return self.get_current_sastoken()
def _get_sastoken_info_from_string(sastoken_string: str) -> Dict[str, str]:
"""Given a SAS Token string, return a dictionary of it's keys and values"""
pieces = sastoken_string.split("SharedAccessSignature ")
if len(pieces) != 2:
raise SasTokenError("Invalid SasToken string: Not a SasToken ")
raise ValueError("Invalid SAS Token string: Not a SAS Token ")
# Get sastoken info as dictionary
try:
# TODO: fix this typehint later, it needs some kind of cast
sastoken_info = dict(map(str.strip, sub.split("=", 1)) for sub in pieces[1].split("&")) # type: ignore
except Exception as e:
raise SasTokenError("Invalid SasToken string: Incorrectly formatted") from e
raise ValueError("Invalid SAS Token string: Incorrectly formatted") from e
# Validate that all required fields are present
if not all(key in sastoken_info for key in REQUIRED_SASTOKEN_FIELDS):
raise SasTokenError("Invalid SasToken string: Not all required fields present")
raise ValueError("Invalid SAS Token string: Not all required fields present")
# Validate that no unexpected fields are present
if not all(key in VALID_SASTOKEN_FIELDS for key in sastoken_info):
raise SasTokenError("Invalid SasToken string: Unexpected fields present")
# Warn if extraneous fields are present
if not all(key in REQUIRED_SASTOKEN_FIELDS for key in sastoken_info):
logger.warning("Unexpected fields present in SAS Token")
return sastoken_info
# NOTE: Arguably, this doesn't really belong in this module, give it's lack of a specific
# relationship to SAS Tokens, and the fact that it needs to be unit-tested separately.
# These things suggest it should be more than just a convention-private helper, however
# its hard to justify making a separate module just for this function.
# This would be a candidate for some kind of misc utility module if other similar functions
# pop up over the course of development. Until then, it lives here.
async def _wait_until(when: float) -> None:
"""Wait until a specific time has passed (accurate within 1 second).
:param float when: The time to wait for, in seconds, since epoch
"""
while time.time() < when:
await asyncio.sleep(1)

Просмотреть файл

@ -8,17 +8,22 @@ import base64
import binascii
import hmac
import hashlib
from typing import Union
from typing import AnyStr
# TODO: remove commented signatures
class SigningMechanism(abc.ABC):
@abc.abstractmethod
def sign(self, data_str: Union[str, bytes]) -> str:
async def sign(self, data_str: AnyStr) -> str:
# NOTE: This is defined as a coroutine to allow for flexibility of implementation.
# Some implementations may not require a coroutine, but others may, so we err on the side
# of a coroutine for consistent interface.
pass
class SymmetricKeySigningMechanism(SigningMechanism):
def __init__(self, key: Union[str, bytes]) -> None:
def __init__(self, key: AnyStr) -> None:
"""
A mechanism that signs data using a symmetric key
@ -34,13 +39,12 @@ class SymmetricKeySigningMechanism(SigningMechanism):
key_bytes = key
# Derives the signing key
# CT-TODO: is "signing key" the right term?
try:
self._signing_key = base64.b64decode(key_bytes)
except (binascii.Error):
raise ValueError("Invalid Symmetric Key")
def sign(self, data_str: Union[str, bytes]) -> str:
async def sign(self, data_str: AnyStr) -> str:
"""
Sign a data string with symmetric key and the HMAC-SHA256 algorithm.
@ -52,6 +56,9 @@ class SymmetricKeySigningMechanism(SigningMechanism):
:raises: ValueError if an invalid data string is provided
"""
# NOTE: This implementation doesn't take advantage of being a coroutine, but this is by
# design. See the definition of the abstract base class above.
# Convert data_str to bytes (if not already)
if isinstance(data_str, str):
data_bytes = data_str.encode("utf-8")

Просмотреть файл

@ -6,7 +6,7 @@
"""This module is for creating agent strings for all clients"""
import platform
from azure.iot.device.constant import VERSION, IOTHUB_IDENTIFIER, PROVISIONING_IDENTIFIER
from .constant import VERSION, IOTHUB_IDENTIFIER, PROVISIONING_IDENTIFIER
python_runtime = platform.python_version()
os_type = platform.system()