зеркало из
1
0
Форкнуть 0

[V3] Port and add type hints to auth modules (#1101)

* port

* updated credscan suppression
This commit is contained in:
Carter Tinney 2023-02-10 16:44:34 -08:00 коммит произвёл GitHub
Родитель 904e3213e1
Коммит bb4cb90295
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 1120 добавлений и 16 удалений

2
.gitignore поставляемый
Просмотреть файл

@ -1,5 +1,5 @@
# Python V3 Workspace
v3_async_wip/sandbox/
v3_async_wip/_sandbox/
# Byte-compiled / optimized / DLL files
__pycache__/

Просмотреть файл

@ -13,10 +13,18 @@
"file": "\\tests\\unit\\common\\auth\\test_signing_mechanism.py",
"_justification": "Test containing fake keys"
},
{
"file": "\\v3_async_wip\\tests\\test_signing_mechanism.py",
"_justification": "Test containing fake keys"
},
{
"file": "\\tests\\unit\\common\\auth\\test_sastoken.py",
"_justification": "Test containing fake signed data"
},
{
"file": "\\v3_async_wip\\tests\\test_sastoken.py",
"_justification": "Test containing fake signed data"
},
{
"file": "\\tests\\unit\\common\\test_mqtt_transport.py",
"_justification": "Test containing fake passwords"

Просмотреть файл

@ -0,0 +1,263 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import base64
import json
import pytest
import requests
import urllib.parse
from v3_async_wip.edge_hsm import IoTEdgeHsm
from v3_async_wip.iot_exceptions import IoTEdgeError
from v3_async_wip import user_agent
@pytest.fixture
def edge_hsm():
return IoTEdgeHsm(
module_id="my_module_id",
generation_id="module_generation_id",
workload_uri="unix:///var/run/iotedge/workload.sock",
api_version="my_api_version",
)
@pytest.mark.describe("IoTEdgeHsm - Instantiation")
class TestIoTEdgeHsmInstantiation(object):
@pytest.mark.it("URL encodes the provided module_id parameter and sets it as an attribute")
def test_encode_and_set_module_id(self):
module_id = "my_module_id"
generation_id = "my_generation_id"
api_version = "my_api_version"
workload_uri = "unix:///var/run/iotedge/workload.sock"
edge_hsm = IoTEdgeHsm(
module_id=module_id,
generation_id=generation_id,
workload_uri=workload_uri,
api_version=api_version,
)
assert edge_hsm.module_id == urllib.parse.quote(module_id, safe="")
@pytest.mark.it(
"Formats the provided workload_uri parameter for use with the requests library and sets it as an attribute"
)
@pytest.mark.parametrize(
"workload_uri, expected_formatted_uri",
[
pytest.param(
"unix:///var/run/iotedge/workload.sock",
"http+unix://%2Fvar%2Frun%2Fiotedge%2Fworkload.sock/",
id="Domain Socket URI",
),
pytest.param("http://127.0.0.1:15580", "http://127.0.0.1:15580/", id="IP Address URI"),
],
)
def test_workload_uri_formatting(self, workload_uri, expected_formatted_uri):
module_id = "my_module_id"
generation_id = "my_generation_id"
api_version = "my_api_version"
edge_hsm = IoTEdgeHsm(
module_id=module_id,
generation_id=generation_id,
workload_uri=workload_uri,
api_version=api_version,
)
assert edge_hsm.workload_uri == expected_formatted_uri
@pytest.mark.it("Sets the provided generation_id parameter as an attribute")
def test_set_generation_id(self):
module_id = "my_module_id"
generation_id = "my_generation_id"
api_version = "my_api_version"
workload_uri = "unix:///var/run/iotedge/workload.sock"
edge_hsm = IoTEdgeHsm(
module_id=module_id,
generation_id=generation_id,
workload_uri=workload_uri,
api_version=api_version,
)
assert edge_hsm.generation_id == generation_id
@pytest.mark.it("Sets the provided api_version parameter as an attribute")
def test_set_api_version(self):
module_id = "my_module_id"
generation_id = "my_generation_id"
api_version = "my_api_version"
workload_uri = "unix:///var/run/iotedge/workload.sock"
edge_hsm = IoTEdgeHsm(
module_id=module_id,
generation_id=generation_id,
workload_uri=workload_uri,
api_version=api_version,
)
assert edge_hsm.api_version == api_version
@pytest.mark.describe("IoTEdgeHsm - .get_certificate()")
class TestIoTEdgeHsmGetCertificate(object):
@pytest.fixture(autouse=True)
def mock_requests_get(self, mocker):
return mocker.patch.object(requests, "get")
@pytest.mark.it("Sends an HTTP GET request to retrieve the trust bundle from Edge")
def test_requests_trust_bundle(self, mocker, edge_hsm, mock_requests_get):
expected_url = edge_hsm.workload_uri + "trust-bundle"
expected_params = {"api-version": edge_hsm.api_version}
expected_headers = {
"User-Agent": urllib.parse.quote_plus(user_agent.get_iothub_user_agent())
}
edge_hsm.get_certificate()
assert mock_requests_get.call_count == 1
assert mock_requests_get.call_args == mocker.call(
expected_url, params=expected_params, headers=expected_headers
)
@pytest.mark.it("Returns the certificate from the trust bundle received from Edge")
def test_returns_certificate(self, edge_hsm, mock_requests_get):
mock_response = mock_requests_get.return_value
certificate = "my certificate"
mock_response.json.return_value = {"certificate": certificate}
returned_cert = edge_hsm.get_certificate()
assert returned_cert is certificate
@pytest.mark.it("Raises IoTEdgeError if a bad request is made to Edge")
def test_bad_request(self, edge_hsm, mock_requests_get):
mock_response = mock_requests_get.return_value
error = requests.exceptions.HTTPError()
mock_response.raise_for_status.side_effect = error
with pytest.raises(IoTEdgeError) as e_info:
edge_hsm.get_certificate()
assert e_info.value.__cause__ is error
@pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the trust bundle")
def test_bad_json(self, edge_hsm, mock_requests_get):
mock_response = mock_requests_get.return_value
error = ValueError()
mock_response.json.side_effect = error
with pytest.raises(IoTEdgeError) as e_info:
edge_hsm.get_certificate()
assert e_info.value.__cause__ is error
@pytest.mark.it("Raises IoTEdgeError if the certificate is missing from the trust bundle")
def test_bad_trust_bundle(self, edge_hsm, mock_requests_get):
mock_response = mock_requests_get.return_value
# Return an empty json dict with no 'certificate' key
mock_response.json.return_value = {}
with pytest.raises(IoTEdgeError):
edge_hsm.get_certificate()
@pytest.mark.describe("IoTEdgeHsm - .sign()")
class TestIoTEdgeHsmSign(object):
@pytest.fixture(autouse=True)
def mock_requests_post(self, mocker):
return mocker.patch.object(requests, "post")
@pytest.mark.it(
"Makes an HTTP request to Edge to sign a piece of string data using the HMAC-SHA256 algorithm"
)
def test_requests_data_signing(self, mocker, edge_hsm, mock_requests_post):
data_str = "somedata"
data_str_b64 = "c29tZWRhdGE="
mock_requests_post.return_value.json.return_value = {"digest": "somedigest"}
expected_url = "{workload_uri}modules/{module_id}/genid/{generation_id}/sign".format(
workload_uri=edge_hsm.workload_uri,
module_id=edge_hsm.module_id,
generation_id=edge_hsm.generation_id,
)
expected_params = {"api-version": edge_hsm.api_version}
expected_headers = {
"User-Agent": urllib.parse.quote(user_agent.get_iothub_user_agent(), safe="")
}
expected_json = json.dumps({"keyId": "primary", "algo": "HMACSHA256", "data": data_str_b64})
edge_hsm.sign(data_str)
assert mock_requests_post.call_count == 1
assert mock_requests_post.call_args == mocker.call(
url=expected_url, params=expected_params, headers=expected_headers, data=expected_json
)
@pytest.mark.it("Base64 encodes the string data in the request")
def test_b64_encodes_data(self, edge_hsm, mock_requests_post):
# This test is actually implicitly tested in the first test, but it's
# important to have an explicit test for it since it's a requirement
data_str = "somedata"
data_str_b64 = base64.b64encode(data_str.encode("utf-8")).decode()
mock_requests_post.return_value.json.return_value = {"digest": "somedigest"}
edge_hsm.sign(data_str)
sent_data = json.loads(mock_requests_post.call_args[1]["data"])["data"]
assert data_str != data_str_b64
assert sent_data == data_str_b64
@pytest.mark.it("Returns the signed data received from Edge")
def test_returns_signed_data(self, edge_hsm, mock_requests_post):
expected_digest = "somedigest"
mock_requests_post.return_value.json.return_value = {"digest": expected_digest}
signed_data = edge_hsm.sign("somedata")
assert signed_data == expected_digest
@pytest.mark.it("Supports data strings in both string and byte formats")
@pytest.mark.parametrize(
"data_string, expected_request_data",
[
pytest.param("sign this message", "c2lnbiB0aGlzIG1lc3NhZ2U=", id="String"),
pytest.param(b"sign this message", "c2lnbiB0aGlzIG1lc3NhZ2U=", id="Bytes"),
],
)
def test_supported_types(
self, edge_hsm, data_string, expected_request_data, mock_requests_post
):
mock_requests_post.return_value.json.return_value = {"digest": "somedigest"}
edge_hsm.sign(data_string)
sent_data = json.loads(mock_requests_post.call_args[1]["data"])["data"]
assert sent_data == expected_request_data
@pytest.mark.it("Raises IoTEdgeError if a bad request is made to EdgeHub")
def test_bad_request(self, edge_hsm, mock_requests_post):
mock_response = mock_requests_post.return_value
error = requests.exceptions.HTTPError()
mock_response.raise_for_status.side_effect = error
with pytest.raises(IoTEdgeError) as e_info:
edge_hsm.sign("somedata")
assert e_info.value.__cause__ is error
@pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the signed response")
def test_bad_json(self, edge_hsm, mock_requests_post):
mock_response = mock_requests_post.return_value
error = ValueError()
mock_response.json.side_effect = error
with pytest.raises(IoTEdgeError) as e_info:
edge_hsm.sign("somedata")
assert e_info.value.__cause__ is error
@pytest.mark.it("Raises IoTEdgeError if the signed data is missing from the response")
def test_bad_response(self, edge_hsm, mock_requests_post):
mock_response = mock_requests_post.return_value
mock_response.json.return_value = {}
with pytest.raises(IoTEdgeError):
edge_hsm.sign("somedata")

Просмотреть файл

@ -15,10 +15,9 @@ from pytest_lazyfixture import lazy_fixture
from dev_utils import custom_mock
from v3_async_wip.iothub_mqtt_client import (
IoTHubMQTTClient,
IoTHubError,
IoTHubClientError,
DEFAULT_RECONNECT_INTERVAL,
)
from v3_async_wip.iot_exceptions import IoTHubClientError, IoTHubError
from v3_async_wip import config, constant, models, user_agent
from v3_async_wip import mqtt_client as mqtt
from v3_async_wip import request_response as rr
@ -31,7 +30,6 @@ FAKE_DEVICE_ID = "fake_device_id"
FAKE_MODULE_ID = "fake_module_id"
FAKE_HOSTNAME = "fake.hostname"
FAKE_GATEWAY_HOSTNAME = "fake.gateway.hostname"
# FAKE_SHARED_ACCESS_KEY = "Zm9vYmFy"
FAKE_SIGNATURE = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI="
FAKE_EXPIRY = str(int(time.time()) + 3600)
FAKE_URI = "fake/resource/location"

Просмотреть файл

@ -0,0 +1,285 @@
# -*- coding: utf-8 -*-
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import pytest
import time
import re
import logging
import urllib
from azure.iot.device.common.auth.sastoken import (
RenewableSasToken,
NonRenewableSasToken,
SasTokenError,
)
logging.basicConfig(level=logging.DEBUG)
fake_uri = "some/resource/location"
fake_signed_data = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI="
fake_key_name = "fakekeyname"
fake_expiry = 12321312
simple_token_format = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}"
auth_rule_token_format = (
"SharedAccessSignature sr={resource}&sig={signature}&se={expiry}&skn={keyname}"
)
def token_parser(token_str):
"""helper function that parses a token string for individual values"""
token_map = {}
kv_string = token_str.split(" ")[1]
kv_pairs = kv_string.split("&")
for kv in kv_pairs:
t = kv.split("=")
token_map[t[0]] = t[1]
return token_map
class RenewableSasTokenTestConfig(object):
@pytest.fixture
def signing_mechanism(self, mocker):
mechanism = mocker.MagicMock()
mechanism.sign.return_value = fake_signed_data
return mechanism
# TODO: Rename this. These are not "device" and "service" tokens, the distinction is more generic
@pytest.fixture(params=["Device Token", "Service Token"])
def sastoken(self, request, signing_mechanism):
token_type = request.param
if token_type == "Device Token":
return RenewableSasToken(uri=fake_uri, signing_mechanism=signing_mechanism)
elif token_type == "Service Token":
return RenewableSasToken(
uri=fake_uri, signing_mechanism=signing_mechanism, key_name=fake_key_name
)
@pytest.mark.describe("RenewableSasToken")
class TestRenewableSasToken(RenewableSasTokenTestConfig):
@pytest.mark.it("Instantiates with a default TTL of 3600 seconds if no TTL is provided")
def test_default_ttl(self, signing_mechanism):
s = RenewableSasToken(fake_uri, signing_mechanism)
assert s.ttl == 3600
@pytest.mark.it("Instantiates with a custom TTL if provided")
def test_custom_ttl(self, signing_mechanism):
custom_ttl = 4747
s = RenewableSasToken(fake_uri, signing_mechanism, ttl=custom_ttl)
assert s.ttl == custom_ttl
@pytest.mark.it("Instantiates with with no key name by default if no key name is provided")
def test_default_key_name(self, signing_mechanism):
s = RenewableSasToken(fake_uri, signing_mechanism)
assert s._key_name is None
@pytest.mark.it("Instantiates with the given key name if provided")
def test_custom_key_name(self, signing_mechanism):
s = RenewableSasToken(fake_uri, signing_mechanism, key_name=fake_key_name)
assert s._key_name == fake_key_name
@pytest.mark.it(
"Instantiates with an expiry time TTL seconds in the future from the moment of instantiation"
)
def test_expiry_time(self, mocker, signing_mechanism):
fake_current_time = 1000
mocker.patch.object(time, "time", return_value=fake_current_time)
s = RenewableSasToken(fake_uri, signing_mechanism)
assert s.expiry_time == fake_current_time + s.ttl
@pytest.mark.it("Calls .refresh() to build the SAS token string on instantiation")
def test_refresh_on_instantiation(self, mocker, signing_mechanism):
refresh_mock = mocker.spy(RenewableSasToken, "refresh")
assert refresh_mock.call_count == 0
RenewableSasToken(fake_uri, signing_mechanism)
assert refresh_mock.call_count == 1
@pytest.mark.it("Returns the SAS token string as the string representation of the object")
def test_str_rep(self, sastoken):
assert str(sastoken) == sastoken._token
@pytest.mark.it(
"Maintains the .expiry_time attribute as a read-only property (raises AttributeError upon attempt)"
)
def test_expiry_time_read_only(self, sastoken):
with pytest.raises(AttributeError):
sastoken.expiry_time = 12321312
@pytest.mark.describe("RenewableSasToken - .refresh()")
class TestRenewableSasTokenRefresh(RenewableSasTokenTestConfig):
@pytest.mark.it("Sets a new expiry time of TTL seconds in the future")
def test_new_expiry(self, mocker, sastoken):
fake_current_time = 1000
mocker.patch.object(time, "time", return_value=fake_current_time)
sastoken.refresh()
assert sastoken.expiry_time == fake_current_time + sastoken.ttl
# TODO: reflect url encoding here?
@pytest.mark.it(
"Uses the token's signing mechanism to create a signature by signing a concatenation of the (URL encoded) URI and updated expiry time"
)
def test_generate_new_token(self, mocker, signing_mechanism, sastoken):
old_token_str = str(sastoken)
fake_future_time = 1000
mocker.patch.object(time, "time", return_value=fake_future_time)
signing_mechanism.reset_mock()
fake_signature = "new_fake_signature"
signing_mechanism.sign.return_value = fake_signature
sastoken.refresh()
# The token string has been updated
assert str(sastoken) != old_token_str
# The signing mechanism was used to sign a string
assert signing_mechanism.sign.call_count == 1
# The string being signed was a concatenation of the URI and expiry time
assert signing_mechanism.sign.call_args == mocker.call(
urllib.parse.quote(sastoken._uri, safe="") + "\n" + str(sastoken.expiry_time)
)
# The token string has the resulting signed string included as the signature
token_info = token_parser(str(sastoken))
assert token_info["sig"] == fake_signature
@pytest.mark.it(
"Builds a new token string using the token's URI (URL encoded) and expiry time, along with the signature created by the signing mechanism (also URL encoded)"
)
def test_token_string(self, sastoken):
token_str = sastoken._token
# Verify that token string representation matches token format
if not sastoken._key_name:
pattern = re.compile(r"SharedAccessSignature sr=(.+)&sig=(.+)&se=(.+)")
else:
pattern = re.compile(r"SharedAccessSignature sr=(.+)&sig=(.+)&se=(.+)&skn=(.+)")
assert pattern.match(token_str)
# Verify that content in the string representation is correct
token_info = token_parser(token_str)
assert token_info["sr"] == urllib.parse.quote(sastoken._uri, safe="")
assert token_info["sig"] == urllib.parse.quote(
sastoken._signing_mechanism.sign.return_value, safe=""
)
assert token_info["se"] == str(sastoken.expiry_time)
if sastoken._key_name:
assert token_info["skn"] == sastoken._key_name
@pytest.mark.it("Raises a SasTokenError if an exception is raised by the signing mechanism")
def test_signing_mechanism_raises_value_error(
self, mocker, signing_mechanism, sastoken, arbitrary_exception
):
signing_mechanism.sign.side_effect = arbitrary_exception
with pytest.raises(SasTokenError) as e_info:
sastoken.refresh()
assert e_info.value.__cause__ is arbitrary_exception
@pytest.mark.describe("NonRenewableSasToken")
class TestNonRenewableSasToken(object):
# TODO: Rename this. These are not "device" and "service" tokens, the distinction is more generic
@pytest.fixture(params=["Device Token", "Service Token"])
def sastoken_str(self, request):
token_type = request.param
if token_type == "Device Token":
return simple_token_format.format(
resource=urllib.parse.quote(fake_uri, safe=""),
signature=urllib.parse.quote(fake_signed_data, safe=""),
expiry=fake_expiry,
)
elif token_type == "Service Token":
return auth_rule_token_format.format(
resource=urllib.parse.quote(fake_uri, safe=""),
signature=urllib.parse.quote(fake_signed_data, safe=""),
expiry=fake_expiry,
keyname=fake_key_name,
)
@pytest.fixture()
def sastoken(self, sastoken_str):
return NonRenewableSasToken(sastoken_str)
@pytest.mark.it("Instantiates from a valid SAS Token string")
def test_instantiates_from_token_string(self, sastoken_str):
s = NonRenewableSasToken(sastoken_str)
assert s._token == sastoken_str
@pytest.mark.it("Raises a SasToken error if instantiating from an invalid SAS Token string")
@pytest.mark.parametrize(
"invalid_token_str",
[
pytest.param(
"sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312",
id="Incomplete token format",
),
pytest.param(
"SharedERRORSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312",
id="Invalid token format",
),
pytest.param(
"SharedAccessignature sr=some%2Fresource%2Flocationsig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se12321312",
id="Token values incorectly formatted",
),
pytest.param(
"SharedAccessSignature sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312",
id="Missing resource value",
),
pytest.param(
"SharedAccessSignature sr=some%2Fresource%2Flocation&se=12321312",
id="Missing signature value",
),
pytest.param(
"SharedAccessSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=",
id="Missing expiry value",
),
pytest.param(
"SharedAccessSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312&foovalue=nonsense",
id="Extraneous invalid value",
),
],
)
def test_raises_error_invalid_token_string(self, invalid_token_str):
with pytest.raises(SasTokenError):
NonRenewableSasToken(invalid_token_str)
@pytest.mark.it("Returns the SAS token string as the string representation of the object")
def test_str_rep(self, sastoken_str):
sastoken = NonRenewableSasToken(sastoken_str)
assert str(sastoken) == sastoken_str
@pytest.mark.it(
"Instantiates with the .expiry_time attribute corresponding to the expiry time of the given SAS Token string (as an integer)"
)
def test_instantiates_expiry_time(self, sastoken_str):
sastoken = NonRenewableSasToken(sastoken_str)
expected_expiry_time = token_parser(sastoken_str)["se"]
assert sastoken.expiry_time == int(expected_expiry_time)
@pytest.mark.it(
"Maintains the .expiry_time attribute as a read-only property (raises AttributeError upon attempt)"
)
def test_expiry_time_read_only(self, sastoken):
with pytest.raises(AttributeError):
sastoken.expiry_time = 12312312312123
@pytest.mark.it(
"Instantiates with the .resource_uri attribute corresponding to the URL decoded URI of the given SAS Token string"
)
def test_instantiates_resource_uri(self, sastoken_str):
sastoken = NonRenewableSasToken(sastoken_str)
resource_uri = token_parser(sastoken_str)["sr"]
assert resource_uri != sastoken.resource_uri
assert resource_uri == urllib.parse.quote(sastoken.resource_uri, safe="")
assert urllib.parse.unquote(resource_uri) == sastoken.resource_uri
@pytest.mark.it(
"Maintains the .resource_uri attribute as a read-only property (raises AttributeError upon attempt)"
)
def test_resource_uri_read_only(self, sastoken):
with pytest.raises(AttributeError):
sastoken.resource_uri = "new%2Ffake%2Furi"

Просмотреть файл

@ -0,0 +1,125 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import pytest
import hmac
import hashlib
import base64
from v3_async_wip.signing_mechanism import SymmetricKeySigningMechanism
@pytest.mark.describe("SymmetricKeySigningMechanism - Instantiation")
class TestSymmetricKeySigningMechanismInstantiation(object):
@pytest.mark.it(
"Derives and stores the signing key from the provided symmetric key by base64 decoding it"
)
@pytest.mark.parametrize(
"key, expected_signing_key",
[
pytest.param(
"NMgJDvdKTxjLi+xBxxkDDEwDJxEvOE5u8BiT0mVgPeg=",
b"4\xc8\t\x0e\xf7JO\x18\xcb\x8b\xecA\xc7\x19\x03\x0cL\x03'\x11/8Nn\xf0\x18\x93\xd2e`=\xe8",
id="Example 1",
),
pytest.param(
"zqtyZCGuKg/UHvSzgYnNod/uHChWrzGGtHSgPi4cC2U=",
b"\xce\xabrd!\xae*\x0f\xd4\x1e\xf4\xb3\x81\x89\xcd\xa1\xdf\xee\x1c(V\xaf1\x86\xb4t\xa0>.\x1c\x0be",
id="Example 2",
),
],
)
def test_dervies_signing_key(self, key, expected_signing_key):
sm = SymmetricKeySigningMechanism(key)
assert sm._signing_key == expected_signing_key
@pytest.mark.it("Supports symmetric keys in both string and byte formats")
@pytest.mark.parametrize(
"key, expected_signing_key",
[
pytest.param(
"NMgJDvdKTxjLi+xBxxkDDEwDJxEvOE5u8BiT0mVgPeg=",
b"4\xc8\t\x0e\xf7JO\x18\xcb\x8b\xecA\xc7\x19\x03\x0cL\x03'\x11/8Nn\xf0\x18\x93\xd2e`=\xe8",
id="String",
),
pytest.param(
b"NMgJDvdKTxjLi+xBxxkDDEwDJxEvOE5u8BiT0mVgPeg=",
b"4\xc8\t\x0e\xf7JO\x18\xcb\x8b\xecA\xc7\x19\x03\x0cL\x03'\x11/8Nn\xf0\x18\x93\xd2e`=\xe8",
id="Bytes",
),
],
)
def test_supported_types(self, key, expected_signing_key):
sm = SymmetricKeySigningMechanism(key)
assert sm._signing_key == expected_signing_key
@pytest.mark.it("Raises a ValueError if the provided symmetric key is invalid")
@pytest.mark.parametrize(
"key",
[pytest.param("not a key", id="Not a key"), pytest.param("YWJjx", id="Incomplete key")],
)
def test_invalid_key(self, key):
with pytest.raises(ValueError):
SymmetricKeySigningMechanism(key)
@pytest.mark.describe("SymmetricKeySigningMechanism - .sign()")
class TestSymmetricKeySigningMechanismSign(object):
@pytest.fixture
def signing_mechanism(self):
return SymmetricKeySigningMechanism("NMgJDvdKTxjLi+xBxxkDDEwDJxEvOE5u8BiT0mVgPeg=")
@pytest.mark.it(
"Generates an HMAC message digest from the signing key and provided data string, using the HMAC-SHA256 algorithm"
)
def test_hmac(self, mocker, signing_mechanism):
hmac_mock = mocker.patch.object(hmac, "HMAC")
hmac_digest_mock = hmac_mock.return_value.digest
hmac_digest_mock.return_value = b"\xd2\x06\xf7\x12\xf1\xe9\x95$\x90\xfd\x12\x9a\xb1\xbe\xb4\xf8\xf3\xc4\x1ap\x8a\xab'\x8a.D\xfb\x84\x96\xca\xf3z"
data_string = "sign this message"
signing_mechanism.sign(data_string)
assert hmac_mock.call_count == 1
assert hmac_mock.call_args == mocker.call(
key=signing_mechanism._signing_key,
msg=data_string.encode("utf-8"),
digestmod=hashlib.sha256,
)
assert hmac_digest_mock.call_count == 1
@pytest.mark.it(
"Returns the base64 encoded HMAC message digest (converted to string) as the signed data"
)
def test_b64encode(self, mocker, signing_mechanism):
hmac_mock = mocker.patch.object(hmac, "HMAC")
hmac_digest_mock = hmac_mock.return_value.digest
hmac_digest_mock.return_value = b"\xd2\x06\xf7\x12\xf1\xe9\x95$\x90\xfd\x12\x9a\xb1\xbe\xb4\xf8\xf3\xc4\x1ap\x8a\xab'\x8a.D\xfb\x84\x96\xca\xf3z"
data_string = "sign this message"
signature = signing_mechanism.sign(data_string)
assert signature == base64.b64encode(hmac_digest_mock.return_value).decode("utf-8")
@pytest.mark.it("Supports data strings in both string and byte formats")
@pytest.mark.parametrize(
"data_string, expected_signature",
[
pytest.param(
"sign this message", "8NJRMT83CcplGrAGaUVIUM/md5914KpWVNngSVoF9/M=", id="String"
),
pytest.param(
b"sign this message", "8NJRMT83CcplGrAGaUVIUM/md5914KpWVNngSVoF9/M=", id="Bytes"
),
],
)
def test_supported_types(self, signing_mechanism, data_string, expected_signature):
assert signing_mechanism.sign(data_string) == expected_signature
@pytest.mark.it("Raises a ValueError if unable to sign the provided data string")
@pytest.mark.parametrize("data_string", [pytest.param(123, id="Integer input")])
def test_bad_input(self, signing_mechanism, data_string):
with pytest.raises(ValueError):
signing_mechanism.sign(data_string)

Просмотреть файл

@ -0,0 +1,165 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import base64
import logging
import json
import requests # type: ignore
import requests_unixsocket # type: ignore
import urllib.parse
from typing import Union
from . import user_agent
from .iot_exceptions import IoTEdgeError
from .signing_mechanism import SigningMechanism
requests_unixsocket.monkeypatch()
logger = logging.getLogger(__name__)
class IoTEdgeHsm(SigningMechanism):
"""
Constructor for instantiating a iot hsm object. This is an object that
communicates with the Azure IoT Edge HSM in order to get connection credentials
for an Azure IoT Edge module. The credentials that this object return come in
two forms:
1. The trust bundle, which is a certificate that can be used as a trusted cert
to authenticate the SSL connection between the IoE Edge module and IoT Edge
2. A signing function, which can be used to create the sig field for a
SharedAccessSignature string which can be used to authenticate with Iot Edge
"""
def __init__(
self, module_id: str, generation_id: str, workload_uri: str, api_version: str
) -> None:
"""
Constructor for instantiating a Azure IoT Edge HSM object
:param str module_id: The module id
:param str api_version: The API version
:param str generation_id: The module generation id
:param str workload_uri: The workload uri
"""
self.module_id = urllib.parse.quote(module_id, safe="")
self.api_version = api_version
self.generation_id = generation_id
self.workload_uri = _format_socket_uri(workload_uri)
def get_certificate(self) -> str:
"""
Return the server verification certificate from the trust bundle that can be used to
validate the server-side SSL TLS connection that we use to talk to Edge
:return: The server verification certificate to use for connections to the Azure IoT Edge
instance, as a PEM certificate in string form.
:raises: IoTEdgeError if unable to retrieve the certificate.
"""
r = requests.get(
self.workload_uri + "trust-bundle",
params={"api-version": self.api_version},
headers={"User-Agent": urllib.parse.quote_plus(user_agent.get_iothub_user_agent())},
)
# Validate that the request was successful
try:
r.raise_for_status()
except requests.exceptions.HTTPError as e:
raise IoTEdgeError("Unable to get trust bundle from Edge") from e
# Decode the trust bundle
try:
bundle = r.json()
except ValueError as e:
raise IoTEdgeError("Unable to decode trust bundle") from e
# Retrieve the certificate
try:
cert = bundle["certificate"]
except KeyError as e:
raise IoTEdgeError("No certificate in trust bundle") from e
return cert
def sign(self, data_str: Union[str, bytes]) -> str:
"""
Use the IoTEdge HSM to sign a piece of string data. The caller should then insert the
returned value (the signature) into the 'sig' field of a SharedAccessSignature string.
:param str data_str: The data string to sign
:return: The signature, as a URI-encoded and base64-encoded value that is ready to
directly insert into the SharedAccessSignature string.
:raises: IoTEdgeError if unable to sign the data.
"""
# Convert data_str to bytes (if not already)
if isinstance(data_str, str):
data_bytes = data_str.encode("utf-8")
else:
data_bytes = data_str
encoded_data_str = base64.b64encode(data_bytes).decode()
path = "{workload_uri}modules/{module_id}/genid/{gen_id}/sign".format(
workload_uri=self.workload_uri, module_id=self.module_id, gen_id=self.generation_id
)
sign_request = {"keyId": "primary", "algo": "HMACSHA256", "data": encoded_data_str}
r = requests.post( # can we use json field instead of data?
url=path,
params={"api-version": self.api_version},
headers={"User-Agent": urllib.parse.quote(user_agent.get_iothub_user_agent(), safe="")},
data=json.dumps(sign_request),
)
try:
r.raise_for_status()
except requests.exceptions.HTTPError as e:
raise IoTEdgeError("Unable to sign data") from e
try:
sign_response = r.json()
except ValueError as e:
raise IoTEdgeError("Unable to decode signed data") from e
try:
signed_data_str = sign_response["digest"]
except KeyError as e:
raise IoTEdgeError("No signed data received") from e
return signed_data_str # what format is this? string? bytes?
def _format_socket_uri(old_uri: str) -> str:
"""
This function takes a socket URI in one form and converts it into another form.
The source form is based on what we receive inside the IOTEDGE_WORKLOADURI
environment variable, and it looks like this:
"unix:///var/run/iotedge/workload.sock"
The destination form is based on what the requests_unixsocket library expects
and it looks like this:
"http+unix://%2Fvar%2Frun%2Fiotedge%2Fworkload.sock/"
The function changes the prefix, uri-encodes the path, and adds a slash
at the end.
If the socket URI does not start with unix:// this function only adds
a slash at the end.
:param old_uri: The URI in IOTEDGE_WORKLOADURI form
:return: The URI in requests_unixsocket form
"""
old_prefix = "unix://"
new_prefix = "http+unix://"
if old_uri.startswith(old_prefix):
stripped_uri = old_uri[len(old_prefix) :]
if stripped_uri.endswith("/"):
stripped_uri = stripped_uri[:-1]
new_uri = new_prefix + urllib.parse.quote(stripped_uri, safe="")
else:
new_uri = old_uri
if not new_uri.endswith("/"):
new_uri += "/"
return new_uri

Просмотреть файл

@ -0,0 +1,24 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""Define Azure IoT domain exceptions to be shared across modules"""
class IoTHubError(Exception):
"""Represents a failure reported by IoTHub"""
pass
class IoTEdgeError(Exception):
"""Represents a failure reported by IoTEdge"""
pass
class IoTHubClientError(Exception):
"""Represents a failure from the IoTHub Client"""
pass

Просмотреть файл

@ -10,6 +10,7 @@ import logging
import urllib.parse
from typing import Optional, AsyncGenerator
from .custom_typing import TwinPatch, Twin
from .iot_exceptions import IoTHubError, IoTHubClientError
from .models import Message, MethodResponse, MethodRequest
from . import config, constant, user_agent
from . import request_response as rr
@ -31,18 +32,6 @@ DEFAULT_TOKEN_UPDATE_MARGIN = 120
# TODO: error handling in generators
class IoTHubError(Exception):
"""Represents a failure reported by IoTHub"""
pass
class IoTHubClientError(Exception):
"""Represents a failure from the IoTHub Client"""
pass
class IoTHubMQTTClient:
def __init__(
self,

Просмотреть файл

@ -0,0 +1,177 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""This module contains tools for working with Shared Access Signature (SAS) Tokens"""
import abc
import time
import urllib
from typing import Optional, Dict
from .signing_mechanism import SigningMechanism
class SasTokenError(Exception):
"""Error in SasToken"""
pass
class SasToken(abc.ABC):
"""Abstract parent class for SAS Tokens.
Doesn't do much, but helps with type hints
"""
@property
@abc.abstractmethod
def expiry_time(self) -> int:
pass
class RenewableSasToken(SasToken):
"""Renewable Shared Access Signature Token used to authenticate a request.
This token is 'renewable', which means that it can be updated when necessary to
prevent expiry, by using the .refresh() method.
Data Attributes:
expiry_time (int): Time that token will expire (in UTC, since epoch)
ttl (int): Time to live for the token, in seconds
"""
_auth_rule_token_format = (
"SharedAccessSignature sr={resource}&sig={signature}&se={expiry}&skn={keyname}"
)
_simple_token_format = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}"
def __init__(
self,
uri: str,
signing_mechanism: SigningMechanism,
key_name: Optional[str] = None,
ttl: int = 3600,
) -> None:
"""
:param str uri: URI of the resource to be accessed
:param signing_mechanism: The signing mechanism to use in the SasToken
:type signing_mechanism: Child classes of :class:`azure.iot.common.SigningMechanism`
:param str key_name: Symmetric Key Name (optional)
:param int ttl: Time to live for the token, in seconds (default 3600)
:raises: SasTokenError if an error occurs building a SasToken
"""
self._uri = uri
self._signing_mechanism = signing_mechanism
self._key_name = key_name
# These two values will be set by the .refresh() call
self._expiry_time: int
self._token: str
self.ttl = ttl
self.refresh()
def __str__(self) -> str:
return self._token
def refresh(self) -> None:
"""
Refresh the SasToken lifespan, giving it a new expiry time, and generating a new token.
"""
self._expiry_time = int(time.time() + self.ttl)
self._token = self._build_token()
def _build_token(self) -> str:
"""Build SasToken representation
:returns: String representation of the token
"""
url_encoded_uri = urllib.parse.quote(self._uri, safe="")
message = url_encoded_uri + "\n" + str(self.expiry_time)
try:
signature = self._signing_mechanism.sign(message)
except Exception as e:
# Because of variant signing mechanisms, we don't know what error might be raised.
# So we catch all of them.
raise SasTokenError("Unable to build SasToken from given values") from e
url_encoded_signature = urllib.parse.quote(signature, safe="")
if self._key_name:
token = self._auth_rule_token_format.format(
resource=url_encoded_uri,
signature=url_encoded_signature,
expiry=str(self.expiry_time),
keyname=self._key_name,
)
else:
token = self._simple_token_format.format(
resource=url_encoded_uri,
signature=url_encoded_signature,
expiry=str(self.expiry_time),
)
return token
@property
def expiry_time(self) -> int:
"""Expiry Time is READ ONLY"""
return self._expiry_time
class NonRenewableSasToken(SasToken):
"""NonRenewable Shared Access Signature Token used to authenticate a request.
This token is 'non-renewable', which means that it is invalid once it expires, and there
is no way to keep it alive. Instead, a new token must be created.
Data Attributes:
expiry_time (int): Time that token will expire (in UTC, since epoch)
resource_uri (str): URI for the resource the Token provides authentication to access
"""
def __init__(self, sastoken_string) -> None:
"""
:param str sastoken_string: A string representation of a SAS token
"""
self._token = sastoken_string
self._token_info = get_sastoken_info_from_string(self._token)
def __str__(self) -> str:
return self._token
@property
def expiry_time(self) -> int:
"""Expiry Time is READ ONLY"""
return int(self._token_info["se"])
@property
def resource_uri(self) -> str:
"""Resource URI is READ ONLY"""
uri = self._token_info["sr"]
return urllib.parse.unquote(uri)
REQUIRED_SASTOKEN_FIELDS = ["sr", "sig", "se"]
VALID_SASTOKEN_FIELDS = REQUIRED_SASTOKEN_FIELDS + ["skn"]
def get_sastoken_info_from_string(sastoken_string: str) -> Dict[str, str]:
pieces = sastoken_string.split("SharedAccessSignature ")
if len(pieces) != 2:
raise SasTokenError("Invalid SasToken string: Not a SasToken ")
# Get sastoken info as dictionary
try:
# TODO: fix this typehint later, it needs some kind of cast
sastoken_info = dict(map(str.strip, sub.split("=", 1)) for sub in pieces[1].split("&")) # type: ignore
except Exception as e:
raise SasTokenError("Invalid SasToken string: Incorrectly formatted") from e
# Validate that all required fields are present
if not all(key in sastoken_info for key in REQUIRED_SASTOKEN_FIELDS):
raise SasTokenError("Invalid SasToken string: Not all required fields present")
# Validate that no unexpected fields are present
if not all(key in VALID_SASTOKEN_FIELDS for key in sastoken_info):
raise SasTokenError("Invalid SasToken string: Unexpected fields present")
return sastoken_info

Просмотреть файл

@ -0,0 +1,70 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import abc
import base64
import binascii
import hmac
import hashlib
from typing import Union
class SigningMechanism(abc.ABC):
@abc.abstractmethod
def sign(self, data_str: Union[str, bytes]) -> str:
pass
class SymmetricKeySigningMechanism(SigningMechanism):
def __init__(self, key: Union[str, bytes]) -> None:
"""
A mechanism that signs data using a symmetric key
:param key: Symmetric Key (base64 encoded)
:type key: str or bytes
:raises: ValueError if provided key is invalid
"""
# Convert key to bytes (if not already)
if isinstance(key, str):
key_bytes = key.encode("utf-8")
else:
key_bytes = key
# Derives the signing key
# CT-TODO: is "signing key" the right term?
try:
self._signing_key = base64.b64decode(key_bytes)
except (binascii.Error):
raise ValueError("Invalid Symmetric Key")
def sign(self, data_str: Union[str, bytes]) -> str:
"""
Sign a data string with symmetric key and the HMAC-SHA256 algorithm.
:param data_str: Data string to be signed
:type data_str: str or bytes
:returns: The signed data
:rtype: str
:raises: ValueError if an invalid data string is provided
"""
# Convert data_str to bytes (if not already)
if isinstance(data_str, str):
data_bytes = data_str.encode("utf-8")
else:
data_bytes = data_str
# Derive signature via HMAC-SHA256 algorithm
try:
hmac_digest = hmac.HMAC(
key=self._signing_key, msg=data_bytes, digestmod=hashlib.sha256
).digest()
signed_data = base64.b64encode(hmac_digest)
except (TypeError):
raise ValueError("Unable to sign string using the provided symmetric key")
# Convert from bytes to string
return signed_data.decode("utf-8")