[V3] Port and add type hints to auth modules (#1101)
* port * updated credscan suppression
This commit is contained in:
Родитель
904e3213e1
Коммит
bb4cb90295
|
@ -1,5 +1,5 @@
|
|||
# Python V3 Workspace
|
||||
v3_async_wip/sandbox/
|
||||
v3_async_wip/_sandbox/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
|
|
@ -13,10 +13,18 @@
|
|||
"file": "\\tests\\unit\\common\\auth\\test_signing_mechanism.py",
|
||||
"_justification": "Test containing fake keys"
|
||||
},
|
||||
{
|
||||
"file": "\\v3_async_wip\\tests\\test_signing_mechanism.py",
|
||||
"_justification": "Test containing fake keys"
|
||||
},
|
||||
{
|
||||
"file": "\\tests\\unit\\common\\auth\\test_sastoken.py",
|
||||
"_justification": "Test containing fake signed data"
|
||||
},
|
||||
{
|
||||
"file": "\\v3_async_wip\\tests\\test_sastoken.py",
|
||||
"_justification": "Test containing fake signed data"
|
||||
},
|
||||
{
|
||||
"file": "\\tests\\unit\\common\\test_mqtt_transport.py",
|
||||
"_justification": "Test containing fake passwords"
|
||||
|
|
|
@ -0,0 +1,263 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import base64
|
||||
import json
|
||||
import pytest
|
||||
import requests
|
||||
import urllib.parse
|
||||
from v3_async_wip.edge_hsm import IoTEdgeHsm
|
||||
from v3_async_wip.iot_exceptions import IoTEdgeError
|
||||
from v3_async_wip import user_agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def edge_hsm():
|
||||
return IoTEdgeHsm(
|
||||
module_id="my_module_id",
|
||||
generation_id="module_generation_id",
|
||||
workload_uri="unix:///var/run/iotedge/workload.sock",
|
||||
api_version="my_api_version",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.describe("IoTEdgeHsm - Instantiation")
|
||||
class TestIoTEdgeHsmInstantiation(object):
|
||||
@pytest.mark.it("URL encodes the provided module_id parameter and sets it as an attribute")
|
||||
def test_encode_and_set_module_id(self):
|
||||
module_id = "my_module_id"
|
||||
generation_id = "my_generation_id"
|
||||
api_version = "my_api_version"
|
||||
workload_uri = "unix:///var/run/iotedge/workload.sock"
|
||||
|
||||
edge_hsm = IoTEdgeHsm(
|
||||
module_id=module_id,
|
||||
generation_id=generation_id,
|
||||
workload_uri=workload_uri,
|
||||
api_version=api_version,
|
||||
)
|
||||
|
||||
assert edge_hsm.module_id == urllib.parse.quote(module_id, safe="")
|
||||
|
||||
@pytest.mark.it(
|
||||
"Formats the provided workload_uri parameter for use with the requests library and sets it as an attribute"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"workload_uri, expected_formatted_uri",
|
||||
[
|
||||
pytest.param(
|
||||
"unix:///var/run/iotedge/workload.sock",
|
||||
"http+unix://%2Fvar%2Frun%2Fiotedge%2Fworkload.sock/",
|
||||
id="Domain Socket URI",
|
||||
),
|
||||
pytest.param("http://127.0.0.1:15580", "http://127.0.0.1:15580/", id="IP Address URI"),
|
||||
],
|
||||
)
|
||||
def test_workload_uri_formatting(self, workload_uri, expected_formatted_uri):
|
||||
module_id = "my_module_id"
|
||||
generation_id = "my_generation_id"
|
||||
api_version = "my_api_version"
|
||||
|
||||
edge_hsm = IoTEdgeHsm(
|
||||
module_id=module_id,
|
||||
generation_id=generation_id,
|
||||
workload_uri=workload_uri,
|
||||
api_version=api_version,
|
||||
)
|
||||
|
||||
assert edge_hsm.workload_uri == expected_formatted_uri
|
||||
|
||||
@pytest.mark.it("Sets the provided generation_id parameter as an attribute")
|
||||
def test_set_generation_id(self):
|
||||
module_id = "my_module_id"
|
||||
generation_id = "my_generation_id"
|
||||
api_version = "my_api_version"
|
||||
workload_uri = "unix:///var/run/iotedge/workload.sock"
|
||||
|
||||
edge_hsm = IoTEdgeHsm(
|
||||
module_id=module_id,
|
||||
generation_id=generation_id,
|
||||
workload_uri=workload_uri,
|
||||
api_version=api_version,
|
||||
)
|
||||
|
||||
assert edge_hsm.generation_id == generation_id
|
||||
|
||||
@pytest.mark.it("Sets the provided api_version parameter as an attribute")
|
||||
def test_set_api_version(self):
|
||||
module_id = "my_module_id"
|
||||
generation_id = "my_generation_id"
|
||||
api_version = "my_api_version"
|
||||
workload_uri = "unix:///var/run/iotedge/workload.sock"
|
||||
|
||||
edge_hsm = IoTEdgeHsm(
|
||||
module_id=module_id,
|
||||
generation_id=generation_id,
|
||||
workload_uri=workload_uri,
|
||||
api_version=api_version,
|
||||
)
|
||||
|
||||
assert edge_hsm.api_version == api_version
|
||||
|
||||
|
||||
@pytest.mark.describe("IoTEdgeHsm - .get_certificate()")
|
||||
class TestIoTEdgeHsmGetCertificate(object):
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_requests_get(self, mocker):
|
||||
return mocker.patch.object(requests, "get")
|
||||
|
||||
@pytest.mark.it("Sends an HTTP GET request to retrieve the trust bundle from Edge")
|
||||
def test_requests_trust_bundle(self, mocker, edge_hsm, mock_requests_get):
|
||||
expected_url = edge_hsm.workload_uri + "trust-bundle"
|
||||
expected_params = {"api-version": edge_hsm.api_version}
|
||||
expected_headers = {
|
||||
"User-Agent": urllib.parse.quote_plus(user_agent.get_iothub_user_agent())
|
||||
}
|
||||
|
||||
edge_hsm.get_certificate()
|
||||
|
||||
assert mock_requests_get.call_count == 1
|
||||
assert mock_requests_get.call_args == mocker.call(
|
||||
expected_url, params=expected_params, headers=expected_headers
|
||||
)
|
||||
|
||||
@pytest.mark.it("Returns the certificate from the trust bundle received from Edge")
|
||||
def test_returns_certificate(self, edge_hsm, mock_requests_get):
|
||||
mock_response = mock_requests_get.return_value
|
||||
certificate = "my certificate"
|
||||
mock_response.json.return_value = {"certificate": certificate}
|
||||
|
||||
returned_cert = edge_hsm.get_certificate()
|
||||
|
||||
assert returned_cert is certificate
|
||||
|
||||
@pytest.mark.it("Raises IoTEdgeError if a bad request is made to Edge")
|
||||
def test_bad_request(self, edge_hsm, mock_requests_get):
|
||||
mock_response = mock_requests_get.return_value
|
||||
error = requests.exceptions.HTTPError()
|
||||
mock_response.raise_for_status.side_effect = error
|
||||
|
||||
with pytest.raises(IoTEdgeError) as e_info:
|
||||
edge_hsm.get_certificate()
|
||||
assert e_info.value.__cause__ is error
|
||||
|
||||
@pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the trust bundle")
|
||||
def test_bad_json(self, edge_hsm, mock_requests_get):
|
||||
mock_response = mock_requests_get.return_value
|
||||
error = ValueError()
|
||||
mock_response.json.side_effect = error
|
||||
|
||||
with pytest.raises(IoTEdgeError) as e_info:
|
||||
edge_hsm.get_certificate()
|
||||
assert e_info.value.__cause__ is error
|
||||
|
||||
@pytest.mark.it("Raises IoTEdgeError if the certificate is missing from the trust bundle")
|
||||
def test_bad_trust_bundle(self, edge_hsm, mock_requests_get):
|
||||
mock_response = mock_requests_get.return_value
|
||||
# Return an empty json dict with no 'certificate' key
|
||||
mock_response.json.return_value = {}
|
||||
|
||||
with pytest.raises(IoTEdgeError):
|
||||
edge_hsm.get_certificate()
|
||||
|
||||
|
||||
@pytest.mark.describe("IoTEdgeHsm - .sign()")
|
||||
class TestIoTEdgeHsmSign(object):
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_requests_post(self, mocker):
|
||||
return mocker.patch.object(requests, "post")
|
||||
|
||||
@pytest.mark.it(
|
||||
"Makes an HTTP request to Edge to sign a piece of string data using the HMAC-SHA256 algorithm"
|
||||
)
|
||||
def test_requests_data_signing(self, mocker, edge_hsm, mock_requests_post):
|
||||
data_str = "somedata"
|
||||
data_str_b64 = "c29tZWRhdGE="
|
||||
mock_requests_post.return_value.json.return_value = {"digest": "somedigest"}
|
||||
expected_url = "{workload_uri}modules/{module_id}/genid/{generation_id}/sign".format(
|
||||
workload_uri=edge_hsm.workload_uri,
|
||||
module_id=edge_hsm.module_id,
|
||||
generation_id=edge_hsm.generation_id,
|
||||
)
|
||||
expected_params = {"api-version": edge_hsm.api_version}
|
||||
expected_headers = {
|
||||
"User-Agent": urllib.parse.quote(user_agent.get_iothub_user_agent(), safe="")
|
||||
}
|
||||
expected_json = json.dumps({"keyId": "primary", "algo": "HMACSHA256", "data": data_str_b64})
|
||||
|
||||
edge_hsm.sign(data_str)
|
||||
|
||||
assert mock_requests_post.call_count == 1
|
||||
assert mock_requests_post.call_args == mocker.call(
|
||||
url=expected_url, params=expected_params, headers=expected_headers, data=expected_json
|
||||
)
|
||||
|
||||
@pytest.mark.it("Base64 encodes the string data in the request")
|
||||
def test_b64_encodes_data(self, edge_hsm, mock_requests_post):
|
||||
# This test is actually implicitly tested in the first test, but it's
|
||||
# important to have an explicit test for it since it's a requirement
|
||||
data_str = "somedata"
|
||||
data_str_b64 = base64.b64encode(data_str.encode("utf-8")).decode()
|
||||
mock_requests_post.return_value.json.return_value = {"digest": "somedigest"}
|
||||
|
||||
edge_hsm.sign(data_str)
|
||||
|
||||
sent_data = json.loads(mock_requests_post.call_args[1]["data"])["data"]
|
||||
|
||||
assert data_str != data_str_b64
|
||||
assert sent_data == data_str_b64
|
||||
|
||||
@pytest.mark.it("Returns the signed data received from Edge")
|
||||
def test_returns_signed_data(self, edge_hsm, mock_requests_post):
|
||||
expected_digest = "somedigest"
|
||||
mock_requests_post.return_value.json.return_value = {"digest": expected_digest}
|
||||
|
||||
signed_data = edge_hsm.sign("somedata")
|
||||
|
||||
assert signed_data == expected_digest
|
||||
|
||||
@pytest.mark.it("Supports data strings in both string and byte formats")
|
||||
@pytest.mark.parametrize(
|
||||
"data_string, expected_request_data",
|
||||
[
|
||||
pytest.param("sign this message", "c2lnbiB0aGlzIG1lc3NhZ2U=", id="String"),
|
||||
pytest.param(b"sign this message", "c2lnbiB0aGlzIG1lc3NhZ2U=", id="Bytes"),
|
||||
],
|
||||
)
|
||||
def test_supported_types(
|
||||
self, edge_hsm, data_string, expected_request_data, mock_requests_post
|
||||
):
|
||||
mock_requests_post.return_value.json.return_value = {"digest": "somedigest"}
|
||||
edge_hsm.sign(data_string)
|
||||
sent_data = json.loads(mock_requests_post.call_args[1]["data"])["data"]
|
||||
|
||||
assert sent_data == expected_request_data
|
||||
|
||||
@pytest.mark.it("Raises IoTEdgeError if a bad request is made to EdgeHub")
|
||||
def test_bad_request(self, edge_hsm, mock_requests_post):
|
||||
mock_response = mock_requests_post.return_value
|
||||
error = requests.exceptions.HTTPError()
|
||||
mock_response.raise_for_status.side_effect = error
|
||||
|
||||
with pytest.raises(IoTEdgeError) as e_info:
|
||||
edge_hsm.sign("somedata")
|
||||
assert e_info.value.__cause__ is error
|
||||
|
||||
@pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the signed response")
|
||||
def test_bad_json(self, edge_hsm, mock_requests_post):
|
||||
mock_response = mock_requests_post.return_value
|
||||
error = ValueError()
|
||||
mock_response.json.side_effect = error
|
||||
with pytest.raises(IoTEdgeError) as e_info:
|
||||
edge_hsm.sign("somedata")
|
||||
assert e_info.value.__cause__ is error
|
||||
|
||||
@pytest.mark.it("Raises IoTEdgeError if the signed data is missing from the response")
|
||||
def test_bad_response(self, edge_hsm, mock_requests_post):
|
||||
mock_response = mock_requests_post.return_value
|
||||
mock_response.json.return_value = {}
|
||||
|
||||
with pytest.raises(IoTEdgeError):
|
||||
edge_hsm.sign("somedata")
|
|
@ -15,10 +15,9 @@ from pytest_lazyfixture import lazy_fixture
|
|||
from dev_utils import custom_mock
|
||||
from v3_async_wip.iothub_mqtt_client import (
|
||||
IoTHubMQTTClient,
|
||||
IoTHubError,
|
||||
IoTHubClientError,
|
||||
DEFAULT_RECONNECT_INTERVAL,
|
||||
)
|
||||
from v3_async_wip.iot_exceptions import IoTHubClientError, IoTHubError
|
||||
from v3_async_wip import config, constant, models, user_agent
|
||||
from v3_async_wip import mqtt_client as mqtt
|
||||
from v3_async_wip import request_response as rr
|
||||
|
@ -31,7 +30,6 @@ FAKE_DEVICE_ID = "fake_device_id"
|
|||
FAKE_MODULE_ID = "fake_module_id"
|
||||
FAKE_HOSTNAME = "fake.hostname"
|
||||
FAKE_GATEWAY_HOSTNAME = "fake.gateway.hostname"
|
||||
# FAKE_SHARED_ACCESS_KEY = "Zm9vYmFy"
|
||||
FAKE_SIGNATURE = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI="
|
||||
FAKE_EXPIRY = str(int(time.time()) + 3600)
|
||||
FAKE_URI = "fake/resource/location"
|
||||
|
|
|
@ -0,0 +1,285 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import pytest
|
||||
import time
|
||||
import re
|
||||
import logging
|
||||
import urllib
|
||||
from azure.iot.device.common.auth.sastoken import (
|
||||
RenewableSasToken,
|
||||
NonRenewableSasToken,
|
||||
SasTokenError,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
fake_uri = "some/resource/location"
|
||||
fake_signed_data = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI="
|
||||
fake_key_name = "fakekeyname"
|
||||
fake_expiry = 12321312
|
||||
|
||||
simple_token_format = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}"
|
||||
auth_rule_token_format = (
|
||||
"SharedAccessSignature sr={resource}&sig={signature}&se={expiry}&skn={keyname}"
|
||||
)
|
||||
|
||||
|
||||
def token_parser(token_str):
|
||||
"""helper function that parses a token string for individual values"""
|
||||
token_map = {}
|
||||
kv_string = token_str.split(" ")[1]
|
||||
kv_pairs = kv_string.split("&")
|
||||
for kv in kv_pairs:
|
||||
t = kv.split("=")
|
||||
token_map[t[0]] = t[1]
|
||||
return token_map
|
||||
|
||||
|
||||
class RenewableSasTokenTestConfig(object):
|
||||
@pytest.fixture
|
||||
def signing_mechanism(self, mocker):
|
||||
mechanism = mocker.MagicMock()
|
||||
mechanism.sign.return_value = fake_signed_data
|
||||
return mechanism
|
||||
|
||||
# TODO: Rename this. These are not "device" and "service" tokens, the distinction is more generic
|
||||
@pytest.fixture(params=["Device Token", "Service Token"])
|
||||
def sastoken(self, request, signing_mechanism):
|
||||
token_type = request.param
|
||||
if token_type == "Device Token":
|
||||
return RenewableSasToken(uri=fake_uri, signing_mechanism=signing_mechanism)
|
||||
elif token_type == "Service Token":
|
||||
return RenewableSasToken(
|
||||
uri=fake_uri, signing_mechanism=signing_mechanism, key_name=fake_key_name
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.describe("RenewableSasToken")
|
||||
class TestRenewableSasToken(RenewableSasTokenTestConfig):
|
||||
@pytest.mark.it("Instantiates with a default TTL of 3600 seconds if no TTL is provided")
|
||||
def test_default_ttl(self, signing_mechanism):
|
||||
s = RenewableSasToken(fake_uri, signing_mechanism)
|
||||
assert s.ttl == 3600
|
||||
|
||||
@pytest.mark.it("Instantiates with a custom TTL if provided")
|
||||
def test_custom_ttl(self, signing_mechanism):
|
||||
custom_ttl = 4747
|
||||
s = RenewableSasToken(fake_uri, signing_mechanism, ttl=custom_ttl)
|
||||
assert s.ttl == custom_ttl
|
||||
|
||||
@pytest.mark.it("Instantiates with with no key name by default if no key name is provided")
|
||||
def test_default_key_name(self, signing_mechanism):
|
||||
s = RenewableSasToken(fake_uri, signing_mechanism)
|
||||
assert s._key_name is None
|
||||
|
||||
@pytest.mark.it("Instantiates with the given key name if provided")
|
||||
def test_custom_key_name(self, signing_mechanism):
|
||||
s = RenewableSasToken(fake_uri, signing_mechanism, key_name=fake_key_name)
|
||||
assert s._key_name == fake_key_name
|
||||
|
||||
@pytest.mark.it(
|
||||
"Instantiates with an expiry time TTL seconds in the future from the moment of instantiation"
|
||||
)
|
||||
def test_expiry_time(self, mocker, signing_mechanism):
|
||||
fake_current_time = 1000
|
||||
mocker.patch.object(time, "time", return_value=fake_current_time)
|
||||
|
||||
s = RenewableSasToken(fake_uri, signing_mechanism)
|
||||
assert s.expiry_time == fake_current_time + s.ttl
|
||||
|
||||
@pytest.mark.it("Calls .refresh() to build the SAS token string on instantiation")
|
||||
def test_refresh_on_instantiation(self, mocker, signing_mechanism):
|
||||
refresh_mock = mocker.spy(RenewableSasToken, "refresh")
|
||||
assert refresh_mock.call_count == 0
|
||||
RenewableSasToken(fake_uri, signing_mechanism)
|
||||
assert refresh_mock.call_count == 1
|
||||
|
||||
@pytest.mark.it("Returns the SAS token string as the string representation of the object")
|
||||
def test_str_rep(self, sastoken):
|
||||
assert str(sastoken) == sastoken._token
|
||||
|
||||
@pytest.mark.it(
|
||||
"Maintains the .expiry_time attribute as a read-only property (raises AttributeError upon attempt)"
|
||||
)
|
||||
def test_expiry_time_read_only(self, sastoken):
|
||||
with pytest.raises(AttributeError):
|
||||
sastoken.expiry_time = 12321312
|
||||
|
||||
|
||||
@pytest.mark.describe("RenewableSasToken - .refresh()")
|
||||
class TestRenewableSasTokenRefresh(RenewableSasTokenTestConfig):
|
||||
@pytest.mark.it("Sets a new expiry time of TTL seconds in the future")
|
||||
def test_new_expiry(self, mocker, sastoken):
|
||||
fake_current_time = 1000
|
||||
mocker.patch.object(time, "time", return_value=fake_current_time)
|
||||
sastoken.refresh()
|
||||
assert sastoken.expiry_time == fake_current_time + sastoken.ttl
|
||||
|
||||
# TODO: reflect url encoding here?
|
||||
@pytest.mark.it(
|
||||
"Uses the token's signing mechanism to create a signature by signing a concatenation of the (URL encoded) URI and updated expiry time"
|
||||
)
|
||||
def test_generate_new_token(self, mocker, signing_mechanism, sastoken):
|
||||
old_token_str = str(sastoken)
|
||||
fake_future_time = 1000
|
||||
mocker.patch.object(time, "time", return_value=fake_future_time)
|
||||
signing_mechanism.reset_mock()
|
||||
fake_signature = "new_fake_signature"
|
||||
signing_mechanism.sign.return_value = fake_signature
|
||||
|
||||
sastoken.refresh()
|
||||
|
||||
# The token string has been updated
|
||||
assert str(sastoken) != old_token_str
|
||||
# The signing mechanism was used to sign a string
|
||||
assert signing_mechanism.sign.call_count == 1
|
||||
# The string being signed was a concatenation of the URI and expiry time
|
||||
assert signing_mechanism.sign.call_args == mocker.call(
|
||||
urllib.parse.quote(sastoken._uri, safe="") + "\n" + str(sastoken.expiry_time)
|
||||
)
|
||||
# The token string has the resulting signed string included as the signature
|
||||
token_info = token_parser(str(sastoken))
|
||||
assert token_info["sig"] == fake_signature
|
||||
|
||||
@pytest.mark.it(
|
||||
"Builds a new token string using the token's URI (URL encoded) and expiry time, along with the signature created by the signing mechanism (also URL encoded)"
|
||||
)
|
||||
def test_token_string(self, sastoken):
|
||||
token_str = sastoken._token
|
||||
|
||||
# Verify that token string representation matches token format
|
||||
if not sastoken._key_name:
|
||||
pattern = re.compile(r"SharedAccessSignature sr=(.+)&sig=(.+)&se=(.+)")
|
||||
else:
|
||||
pattern = re.compile(r"SharedAccessSignature sr=(.+)&sig=(.+)&se=(.+)&skn=(.+)")
|
||||
assert pattern.match(token_str)
|
||||
|
||||
# Verify that content in the string representation is correct
|
||||
token_info = token_parser(token_str)
|
||||
assert token_info["sr"] == urllib.parse.quote(sastoken._uri, safe="")
|
||||
assert token_info["sig"] == urllib.parse.quote(
|
||||
sastoken._signing_mechanism.sign.return_value, safe=""
|
||||
)
|
||||
assert token_info["se"] == str(sastoken.expiry_time)
|
||||
if sastoken._key_name:
|
||||
assert token_info["skn"] == sastoken._key_name
|
||||
|
||||
@pytest.mark.it("Raises a SasTokenError if an exception is raised by the signing mechanism")
|
||||
def test_signing_mechanism_raises_value_error(
|
||||
self, mocker, signing_mechanism, sastoken, arbitrary_exception
|
||||
):
|
||||
signing_mechanism.sign.side_effect = arbitrary_exception
|
||||
|
||||
with pytest.raises(SasTokenError) as e_info:
|
||||
sastoken.refresh()
|
||||
assert e_info.value.__cause__ is arbitrary_exception
|
||||
|
||||
|
||||
@pytest.mark.describe("NonRenewableSasToken")
|
||||
class TestNonRenewableSasToken(object):
|
||||
# TODO: Rename this. These are not "device" and "service" tokens, the distinction is more generic
|
||||
@pytest.fixture(params=["Device Token", "Service Token"])
|
||||
def sastoken_str(self, request):
|
||||
token_type = request.param
|
||||
if token_type == "Device Token":
|
||||
return simple_token_format.format(
|
||||
resource=urllib.parse.quote(fake_uri, safe=""),
|
||||
signature=urllib.parse.quote(fake_signed_data, safe=""),
|
||||
expiry=fake_expiry,
|
||||
)
|
||||
elif token_type == "Service Token":
|
||||
return auth_rule_token_format.format(
|
||||
resource=urllib.parse.quote(fake_uri, safe=""),
|
||||
signature=urllib.parse.quote(fake_signed_data, safe=""),
|
||||
expiry=fake_expiry,
|
||||
keyname=fake_key_name,
|
||||
)
|
||||
|
||||
@pytest.fixture()
|
||||
def sastoken(self, sastoken_str):
|
||||
return NonRenewableSasToken(sastoken_str)
|
||||
|
||||
@pytest.mark.it("Instantiates from a valid SAS Token string")
|
||||
def test_instantiates_from_token_string(self, sastoken_str):
|
||||
s = NonRenewableSasToken(sastoken_str)
|
||||
assert s._token == sastoken_str
|
||||
|
||||
@pytest.mark.it("Raises a SasToken error if instantiating from an invalid SAS Token string")
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_token_str",
|
||||
[
|
||||
pytest.param(
|
||||
"sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312",
|
||||
id="Incomplete token format",
|
||||
),
|
||||
pytest.param(
|
||||
"SharedERRORSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312",
|
||||
id="Invalid token format",
|
||||
),
|
||||
pytest.param(
|
||||
"SharedAccessignature sr=some%2Fresource%2Flocationsig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se12321312",
|
||||
id="Token values incorectly formatted",
|
||||
),
|
||||
pytest.param(
|
||||
"SharedAccessSignature sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312",
|
||||
id="Missing resource value",
|
||||
),
|
||||
pytest.param(
|
||||
"SharedAccessSignature sr=some%2Fresource%2Flocation&se=12321312",
|
||||
id="Missing signature value",
|
||||
),
|
||||
pytest.param(
|
||||
"SharedAccessSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=",
|
||||
id="Missing expiry value",
|
||||
),
|
||||
pytest.param(
|
||||
"SharedAccessSignature sr=some%2Fresource%2Flocation&sig=ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=&se=12321312&foovalue=nonsense",
|
||||
id="Extraneous invalid value",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_raises_error_invalid_token_string(self, invalid_token_str):
|
||||
with pytest.raises(SasTokenError):
|
||||
NonRenewableSasToken(invalid_token_str)
|
||||
|
||||
@pytest.mark.it("Returns the SAS token string as the string representation of the object")
|
||||
def test_str_rep(self, sastoken_str):
|
||||
sastoken = NonRenewableSasToken(sastoken_str)
|
||||
assert str(sastoken) == sastoken_str
|
||||
|
||||
@pytest.mark.it(
|
||||
"Instantiates with the .expiry_time attribute corresponding to the expiry time of the given SAS Token string (as an integer)"
|
||||
)
|
||||
def test_instantiates_expiry_time(self, sastoken_str):
|
||||
sastoken = NonRenewableSasToken(sastoken_str)
|
||||
expected_expiry_time = token_parser(sastoken_str)["se"]
|
||||
assert sastoken.expiry_time == int(expected_expiry_time)
|
||||
|
||||
@pytest.mark.it(
|
||||
"Maintains the .expiry_time attribute as a read-only property (raises AttributeError upon attempt)"
|
||||
)
|
||||
def test_expiry_time_read_only(self, sastoken):
|
||||
with pytest.raises(AttributeError):
|
||||
sastoken.expiry_time = 12312312312123
|
||||
|
||||
@pytest.mark.it(
|
||||
"Instantiates with the .resource_uri attribute corresponding to the URL decoded URI of the given SAS Token string"
|
||||
)
|
||||
def test_instantiates_resource_uri(self, sastoken_str):
|
||||
sastoken = NonRenewableSasToken(sastoken_str)
|
||||
resource_uri = token_parser(sastoken_str)["sr"]
|
||||
assert resource_uri != sastoken.resource_uri
|
||||
assert resource_uri == urllib.parse.quote(sastoken.resource_uri, safe="")
|
||||
assert urllib.parse.unquote(resource_uri) == sastoken.resource_uri
|
||||
|
||||
@pytest.mark.it(
|
||||
"Maintains the .resource_uri attribute as a read-only property (raises AttributeError upon attempt)"
|
||||
)
|
||||
def test_resource_uri_read_only(self, sastoken):
|
||||
with pytest.raises(AttributeError):
|
||||
sastoken.resource_uri = "new%2Ffake%2Furi"
|
|
@ -0,0 +1,125 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import pytest
|
||||
import hmac
|
||||
import hashlib
|
||||
import base64
|
||||
from v3_async_wip.signing_mechanism import SymmetricKeySigningMechanism
|
||||
|
||||
|
||||
@pytest.mark.describe("SymmetricKeySigningMechanism - Instantiation")
|
||||
class TestSymmetricKeySigningMechanismInstantiation(object):
|
||||
@pytest.mark.it(
|
||||
"Derives and stores the signing key from the provided symmetric key by base64 decoding it"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"key, expected_signing_key",
|
||||
[
|
||||
pytest.param(
|
||||
"NMgJDvdKTxjLi+xBxxkDDEwDJxEvOE5u8BiT0mVgPeg=",
|
||||
b"4\xc8\t\x0e\xf7JO\x18\xcb\x8b\xecA\xc7\x19\x03\x0cL\x03'\x11/8Nn\xf0\x18\x93\xd2e`=\xe8",
|
||||
id="Example 1",
|
||||
),
|
||||
pytest.param(
|
||||
"zqtyZCGuKg/UHvSzgYnNod/uHChWrzGGtHSgPi4cC2U=",
|
||||
b"\xce\xabrd!\xae*\x0f\xd4\x1e\xf4\xb3\x81\x89\xcd\xa1\xdf\xee\x1c(V\xaf1\x86\xb4t\xa0>.\x1c\x0be",
|
||||
id="Example 2",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_dervies_signing_key(self, key, expected_signing_key):
|
||||
sm = SymmetricKeySigningMechanism(key)
|
||||
assert sm._signing_key == expected_signing_key
|
||||
|
||||
@pytest.mark.it("Supports symmetric keys in both string and byte formats")
|
||||
@pytest.mark.parametrize(
|
||||
"key, expected_signing_key",
|
||||
[
|
||||
pytest.param(
|
||||
"NMgJDvdKTxjLi+xBxxkDDEwDJxEvOE5u8BiT0mVgPeg=",
|
||||
b"4\xc8\t\x0e\xf7JO\x18\xcb\x8b\xecA\xc7\x19\x03\x0cL\x03'\x11/8Nn\xf0\x18\x93\xd2e`=\xe8",
|
||||
id="String",
|
||||
),
|
||||
pytest.param(
|
||||
b"NMgJDvdKTxjLi+xBxxkDDEwDJxEvOE5u8BiT0mVgPeg=",
|
||||
b"4\xc8\t\x0e\xf7JO\x18\xcb\x8b\xecA\xc7\x19\x03\x0cL\x03'\x11/8Nn\xf0\x18\x93\xd2e`=\xe8",
|
||||
id="Bytes",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_supported_types(self, key, expected_signing_key):
|
||||
sm = SymmetricKeySigningMechanism(key)
|
||||
assert sm._signing_key == expected_signing_key
|
||||
|
||||
@pytest.mark.it("Raises a ValueError if the provided symmetric key is invalid")
|
||||
@pytest.mark.parametrize(
|
||||
"key",
|
||||
[pytest.param("not a key", id="Not a key"), pytest.param("YWJjx", id="Incomplete key")],
|
||||
)
|
||||
def test_invalid_key(self, key):
|
||||
with pytest.raises(ValueError):
|
||||
SymmetricKeySigningMechanism(key)
|
||||
|
||||
|
||||
@pytest.mark.describe("SymmetricKeySigningMechanism - .sign()")
|
||||
class TestSymmetricKeySigningMechanismSign(object):
|
||||
@pytest.fixture
|
||||
def signing_mechanism(self):
|
||||
return SymmetricKeySigningMechanism("NMgJDvdKTxjLi+xBxxkDDEwDJxEvOE5u8BiT0mVgPeg=")
|
||||
|
||||
@pytest.mark.it(
|
||||
"Generates an HMAC message digest from the signing key and provided data string, using the HMAC-SHA256 algorithm"
|
||||
)
|
||||
def test_hmac(self, mocker, signing_mechanism):
|
||||
hmac_mock = mocker.patch.object(hmac, "HMAC")
|
||||
hmac_digest_mock = hmac_mock.return_value.digest
|
||||
hmac_digest_mock.return_value = b"\xd2\x06\xf7\x12\xf1\xe9\x95$\x90\xfd\x12\x9a\xb1\xbe\xb4\xf8\xf3\xc4\x1ap\x8a\xab'\x8a.D\xfb\x84\x96\xca\xf3z"
|
||||
|
||||
data_string = "sign this message"
|
||||
signing_mechanism.sign(data_string)
|
||||
|
||||
assert hmac_mock.call_count == 1
|
||||
assert hmac_mock.call_args == mocker.call(
|
||||
key=signing_mechanism._signing_key,
|
||||
msg=data_string.encode("utf-8"),
|
||||
digestmod=hashlib.sha256,
|
||||
)
|
||||
assert hmac_digest_mock.call_count == 1
|
||||
|
||||
@pytest.mark.it(
|
||||
"Returns the base64 encoded HMAC message digest (converted to string) as the signed data"
|
||||
)
|
||||
def test_b64encode(self, mocker, signing_mechanism):
|
||||
hmac_mock = mocker.patch.object(hmac, "HMAC")
|
||||
hmac_digest_mock = hmac_mock.return_value.digest
|
||||
hmac_digest_mock.return_value = b"\xd2\x06\xf7\x12\xf1\xe9\x95$\x90\xfd\x12\x9a\xb1\xbe\xb4\xf8\xf3\xc4\x1ap\x8a\xab'\x8a.D\xfb\x84\x96\xca\xf3z"
|
||||
|
||||
data_string = "sign this message"
|
||||
signature = signing_mechanism.sign(data_string)
|
||||
|
||||
assert signature == base64.b64encode(hmac_digest_mock.return_value).decode("utf-8")
|
||||
|
||||
@pytest.mark.it("Supports data strings in both string and byte formats")
|
||||
@pytest.mark.parametrize(
|
||||
"data_string, expected_signature",
|
||||
[
|
||||
pytest.param(
|
||||
"sign this message", "8NJRMT83CcplGrAGaUVIUM/md5914KpWVNngSVoF9/M=", id="String"
|
||||
),
|
||||
pytest.param(
|
||||
b"sign this message", "8NJRMT83CcplGrAGaUVIUM/md5914KpWVNngSVoF9/M=", id="Bytes"
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_supported_types(self, signing_mechanism, data_string, expected_signature):
|
||||
assert signing_mechanism.sign(data_string) == expected_signature
|
||||
|
||||
@pytest.mark.it("Raises a ValueError if unable to sign the provided data string")
|
||||
@pytest.mark.parametrize("data_string", [pytest.param(123, id="Integer input")])
|
||||
def test_bad_input(self, signing_mechanism, data_string):
|
||||
with pytest.raises(ValueError):
|
||||
signing_mechanism.sign(data_string)
|
|
@ -0,0 +1,165 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import base64
|
||||
import logging
|
||||
import json
|
||||
import requests # type: ignore
|
||||
import requests_unixsocket # type: ignore
|
||||
import urllib.parse
|
||||
from typing import Union
|
||||
from . import user_agent
|
||||
from .iot_exceptions import IoTEdgeError
|
||||
from .signing_mechanism import SigningMechanism
|
||||
|
||||
requests_unixsocket.monkeypatch()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IoTEdgeHsm(SigningMechanism):
|
||||
"""
|
||||
Constructor for instantiating a iot hsm object. This is an object that
|
||||
communicates with the Azure IoT Edge HSM in order to get connection credentials
|
||||
for an Azure IoT Edge module. The credentials that this object return come in
|
||||
two forms:
|
||||
|
||||
1. The trust bundle, which is a certificate that can be used as a trusted cert
|
||||
to authenticate the SSL connection between the IoE Edge module and IoT Edge
|
||||
2. A signing function, which can be used to create the sig field for a
|
||||
SharedAccessSignature string which can be used to authenticate with Iot Edge
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, module_id: str, generation_id: str, workload_uri: str, api_version: str
|
||||
) -> None:
|
||||
"""
|
||||
Constructor for instantiating a Azure IoT Edge HSM object
|
||||
|
||||
:param str module_id: The module id
|
||||
:param str api_version: The API version
|
||||
:param str generation_id: The module generation id
|
||||
:param str workload_uri: The workload uri
|
||||
"""
|
||||
self.module_id = urllib.parse.quote(module_id, safe="")
|
||||
self.api_version = api_version
|
||||
self.generation_id = generation_id
|
||||
self.workload_uri = _format_socket_uri(workload_uri)
|
||||
|
||||
def get_certificate(self) -> str:
|
||||
"""
|
||||
Return the server verification certificate from the trust bundle that can be used to
|
||||
validate the server-side SSL TLS connection that we use to talk to Edge
|
||||
|
||||
:return: The server verification certificate to use for connections to the Azure IoT Edge
|
||||
instance, as a PEM certificate in string form.
|
||||
|
||||
:raises: IoTEdgeError if unable to retrieve the certificate.
|
||||
"""
|
||||
r = requests.get(
|
||||
self.workload_uri + "trust-bundle",
|
||||
params={"api-version": self.api_version},
|
||||
headers={"User-Agent": urllib.parse.quote_plus(user_agent.get_iothub_user_agent())},
|
||||
)
|
||||
# Validate that the request was successful
|
||||
try:
|
||||
r.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise IoTEdgeError("Unable to get trust bundle from Edge") from e
|
||||
# Decode the trust bundle
|
||||
try:
|
||||
bundle = r.json()
|
||||
except ValueError as e:
|
||||
raise IoTEdgeError("Unable to decode trust bundle") from e
|
||||
# Retrieve the certificate
|
||||
try:
|
||||
cert = bundle["certificate"]
|
||||
except KeyError as e:
|
||||
raise IoTEdgeError("No certificate in trust bundle") from e
|
||||
return cert
|
||||
|
||||
def sign(self, data_str: Union[str, bytes]) -> str:
|
||||
"""
|
||||
Use the IoTEdge HSM to sign a piece of string data. The caller should then insert the
|
||||
returned value (the signature) into the 'sig' field of a SharedAccessSignature string.
|
||||
|
||||
:param str data_str: The data string to sign
|
||||
|
||||
:return: The signature, as a URI-encoded and base64-encoded value that is ready to
|
||||
directly insert into the SharedAccessSignature string.
|
||||
|
||||
:raises: IoTEdgeError if unable to sign the data.
|
||||
"""
|
||||
# Convert data_str to bytes (if not already)
|
||||
if isinstance(data_str, str):
|
||||
data_bytes = data_str.encode("utf-8")
|
||||
else:
|
||||
data_bytes = data_str
|
||||
|
||||
encoded_data_str = base64.b64encode(data_bytes).decode()
|
||||
|
||||
path = "{workload_uri}modules/{module_id}/genid/{gen_id}/sign".format(
|
||||
workload_uri=self.workload_uri, module_id=self.module_id, gen_id=self.generation_id
|
||||
)
|
||||
sign_request = {"keyId": "primary", "algo": "HMACSHA256", "data": encoded_data_str}
|
||||
|
||||
r = requests.post( # can we use json field instead of data?
|
||||
url=path,
|
||||
params={"api-version": self.api_version},
|
||||
headers={"User-Agent": urllib.parse.quote(user_agent.get_iothub_user_agent(), safe="")},
|
||||
data=json.dumps(sign_request),
|
||||
)
|
||||
try:
|
||||
r.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise IoTEdgeError("Unable to sign data") from e
|
||||
try:
|
||||
sign_response = r.json()
|
||||
except ValueError as e:
|
||||
raise IoTEdgeError("Unable to decode signed data") from e
|
||||
try:
|
||||
signed_data_str = sign_response["digest"]
|
||||
except KeyError as e:
|
||||
raise IoTEdgeError("No signed data received") from e
|
||||
|
||||
return signed_data_str # what format is this? string? bytes?
|
||||
|
||||
|
||||
def _format_socket_uri(old_uri: str) -> str:
|
||||
"""
|
||||
This function takes a socket URI in one form and converts it into another form.
|
||||
|
||||
The source form is based on what we receive inside the IOTEDGE_WORKLOADURI
|
||||
environment variable, and it looks like this:
|
||||
"unix:///var/run/iotedge/workload.sock"
|
||||
|
||||
The destination form is based on what the requests_unixsocket library expects
|
||||
and it looks like this:
|
||||
"http+unix://%2Fvar%2Frun%2Fiotedge%2Fworkload.sock/"
|
||||
|
||||
The function changes the prefix, uri-encodes the path, and adds a slash
|
||||
at the end.
|
||||
|
||||
If the socket URI does not start with unix:// this function only adds
|
||||
a slash at the end.
|
||||
|
||||
:param old_uri: The URI in IOTEDGE_WORKLOADURI form
|
||||
|
||||
:return: The URI in requests_unixsocket form
|
||||
"""
|
||||
old_prefix = "unix://"
|
||||
new_prefix = "http+unix://"
|
||||
|
||||
if old_uri.startswith(old_prefix):
|
||||
stripped_uri = old_uri[len(old_prefix) :]
|
||||
if stripped_uri.endswith("/"):
|
||||
stripped_uri = stripped_uri[:-1]
|
||||
new_uri = new_prefix + urllib.parse.quote(stripped_uri, safe="")
|
||||
else:
|
||||
new_uri = old_uri
|
||||
|
||||
if not new_uri.endswith("/"):
|
||||
new_uri += "/"
|
||||
|
||||
return new_uri
|
|
@ -0,0 +1,24 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
"""Define Azure IoT domain exceptions to be shared across modules"""
|
||||
|
||||
|
||||
class IoTHubError(Exception):
|
||||
"""Represents a failure reported by IoTHub"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IoTEdgeError(Exception):
|
||||
"""Represents a failure reported by IoTEdge"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IoTHubClientError(Exception):
|
||||
"""Represents a failure from the IoTHub Client"""
|
||||
|
||||
pass
|
|
@ -10,6 +10,7 @@ import logging
|
|||
import urllib.parse
|
||||
from typing import Optional, AsyncGenerator
|
||||
from .custom_typing import TwinPatch, Twin
|
||||
from .iot_exceptions import IoTHubError, IoTHubClientError
|
||||
from .models import Message, MethodResponse, MethodRequest
|
||||
from . import config, constant, user_agent
|
||||
from . import request_response as rr
|
||||
|
@ -31,18 +32,6 @@ DEFAULT_TOKEN_UPDATE_MARGIN = 120
|
|||
# TODO: error handling in generators
|
||||
|
||||
|
||||
class IoTHubError(Exception):
|
||||
"""Represents a failure reported by IoTHub"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IoTHubClientError(Exception):
|
||||
"""Represents a failure from the IoTHub Client"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IoTHubMQTTClient:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
"""This module contains tools for working with Shared Access Signature (SAS) Tokens"""
|
||||
|
||||
import abc
|
||||
import time
|
||||
import urllib
|
||||
from typing import Optional, Dict
|
||||
from .signing_mechanism import SigningMechanism
|
||||
|
||||
|
||||
class SasTokenError(Exception):
|
||||
"""Error in SasToken"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SasToken(abc.ABC):
|
||||
"""Abstract parent class for SAS Tokens.
|
||||
|
||||
Doesn't do much, but helps with type hints
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def expiry_time(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
class RenewableSasToken(SasToken):
|
||||
"""Renewable Shared Access Signature Token used to authenticate a request.
|
||||
|
||||
This token is 'renewable', which means that it can be updated when necessary to
|
||||
prevent expiry, by using the .refresh() method.
|
||||
|
||||
Data Attributes:
|
||||
expiry_time (int): Time that token will expire (in UTC, since epoch)
|
||||
ttl (int): Time to live for the token, in seconds
|
||||
"""
|
||||
|
||||
_auth_rule_token_format = (
|
||||
"SharedAccessSignature sr={resource}&sig={signature}&se={expiry}&skn={keyname}"
|
||||
)
|
||||
_simple_token_format = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
signing_mechanism: SigningMechanism,
|
||||
key_name: Optional[str] = None,
|
||||
ttl: int = 3600,
|
||||
) -> None:
|
||||
"""
|
||||
:param str uri: URI of the resource to be accessed
|
||||
:param signing_mechanism: The signing mechanism to use in the SasToken
|
||||
:type signing_mechanism: Child classes of :class:`azure.iot.common.SigningMechanism`
|
||||
:param str key_name: Symmetric Key Name (optional)
|
||||
:param int ttl: Time to live for the token, in seconds (default 3600)
|
||||
|
||||
:raises: SasTokenError if an error occurs building a SasToken
|
||||
"""
|
||||
self._uri = uri
|
||||
self._signing_mechanism = signing_mechanism
|
||||
self._key_name = key_name
|
||||
# These two values will be set by the .refresh() call
|
||||
self._expiry_time: int
|
||||
self._token: str
|
||||
|
||||
self.ttl = ttl
|
||||
self.refresh()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._token
|
||||
|
||||
def refresh(self) -> None:
|
||||
"""
|
||||
Refresh the SasToken lifespan, giving it a new expiry time, and generating a new token.
|
||||
"""
|
||||
self._expiry_time = int(time.time() + self.ttl)
|
||||
self._token = self._build_token()
|
||||
|
||||
def _build_token(self) -> str:
|
||||
"""Build SasToken representation
|
||||
|
||||
:returns: String representation of the token
|
||||
"""
|
||||
url_encoded_uri = urllib.parse.quote(self._uri, safe="")
|
||||
message = url_encoded_uri + "\n" + str(self.expiry_time)
|
||||
try:
|
||||
signature = self._signing_mechanism.sign(message)
|
||||
except Exception as e:
|
||||
# Because of variant signing mechanisms, we don't know what error might be raised.
|
||||
# So we catch all of them.
|
||||
raise SasTokenError("Unable to build SasToken from given values") from e
|
||||
url_encoded_signature = urllib.parse.quote(signature, safe="")
|
||||
if self._key_name:
|
||||
token = self._auth_rule_token_format.format(
|
||||
resource=url_encoded_uri,
|
||||
signature=url_encoded_signature,
|
||||
expiry=str(self.expiry_time),
|
||||
keyname=self._key_name,
|
||||
)
|
||||
else:
|
||||
token = self._simple_token_format.format(
|
||||
resource=url_encoded_uri,
|
||||
signature=url_encoded_signature,
|
||||
expiry=str(self.expiry_time),
|
||||
)
|
||||
return token
|
||||
|
||||
@property
|
||||
def expiry_time(self) -> int:
|
||||
"""Expiry Time is READ ONLY"""
|
||||
return self._expiry_time
|
||||
|
||||
|
||||
class NonRenewableSasToken(SasToken):
|
||||
"""NonRenewable Shared Access Signature Token used to authenticate a request.
|
||||
|
||||
This token is 'non-renewable', which means that it is invalid once it expires, and there
|
||||
is no way to keep it alive. Instead, a new token must be created.
|
||||
|
||||
Data Attributes:
|
||||
expiry_time (int): Time that token will expire (in UTC, since epoch)
|
||||
resource_uri (str): URI for the resource the Token provides authentication to access
|
||||
"""
|
||||
|
||||
def __init__(self, sastoken_string) -> None:
|
||||
"""
|
||||
:param str sastoken_string: A string representation of a SAS token
|
||||
"""
|
||||
self._token = sastoken_string
|
||||
self._token_info = get_sastoken_info_from_string(self._token)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._token
|
||||
|
||||
@property
|
||||
def expiry_time(self) -> int:
|
||||
"""Expiry Time is READ ONLY"""
|
||||
return int(self._token_info["se"])
|
||||
|
||||
@property
|
||||
def resource_uri(self) -> str:
|
||||
"""Resource URI is READ ONLY"""
|
||||
uri = self._token_info["sr"]
|
||||
return urllib.parse.unquote(uri)
|
||||
|
||||
|
||||
REQUIRED_SASTOKEN_FIELDS = ["sr", "sig", "se"]
|
||||
VALID_SASTOKEN_FIELDS = REQUIRED_SASTOKEN_FIELDS + ["skn"]
|
||||
|
||||
|
||||
def get_sastoken_info_from_string(sastoken_string: str) -> Dict[str, str]:
|
||||
pieces = sastoken_string.split("SharedAccessSignature ")
|
||||
if len(pieces) != 2:
|
||||
raise SasTokenError("Invalid SasToken string: Not a SasToken ")
|
||||
|
||||
# Get sastoken info as dictionary
|
||||
try:
|
||||
# TODO: fix this typehint later, it needs some kind of cast
|
||||
sastoken_info = dict(map(str.strip, sub.split("=", 1)) for sub in pieces[1].split("&")) # type: ignore
|
||||
except Exception as e:
|
||||
raise SasTokenError("Invalid SasToken string: Incorrectly formatted") from e
|
||||
|
||||
# Validate that all required fields are present
|
||||
if not all(key in sastoken_info for key in REQUIRED_SASTOKEN_FIELDS):
|
||||
raise SasTokenError("Invalid SasToken string: Not all required fields present")
|
||||
|
||||
# Validate that no unexpected fields are present
|
||||
if not all(key in VALID_SASTOKEN_FIELDS for key in sastoken_info):
|
||||
raise SasTokenError("Invalid SasToken string: Unexpected fields present")
|
||||
|
||||
return sastoken_info
|
|
@ -0,0 +1,70 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import abc
|
||||
import base64
|
||||
import binascii
|
||||
import hmac
|
||||
import hashlib
|
||||
from typing import Union
|
||||
|
||||
|
||||
class SigningMechanism(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def sign(self, data_str: Union[str, bytes]) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class SymmetricKeySigningMechanism(SigningMechanism):
|
||||
def __init__(self, key: Union[str, bytes]) -> None:
|
||||
"""
|
||||
A mechanism that signs data using a symmetric key
|
||||
|
||||
:param key: Symmetric Key (base64 encoded)
|
||||
:type key: str or bytes
|
||||
|
||||
:raises: ValueError if provided key is invalid
|
||||
"""
|
||||
# Convert key to bytes (if not already)
|
||||
if isinstance(key, str):
|
||||
key_bytes = key.encode("utf-8")
|
||||
else:
|
||||
key_bytes = key
|
||||
|
||||
# Derives the signing key
|
||||
# CT-TODO: is "signing key" the right term?
|
||||
try:
|
||||
self._signing_key = base64.b64decode(key_bytes)
|
||||
except (binascii.Error):
|
||||
raise ValueError("Invalid Symmetric Key")
|
||||
|
||||
def sign(self, data_str: Union[str, bytes]) -> str:
|
||||
"""
|
||||
Sign a data string with symmetric key and the HMAC-SHA256 algorithm.
|
||||
|
||||
:param data_str: Data string to be signed
|
||||
:type data_str: str or bytes
|
||||
|
||||
:returns: The signed data
|
||||
:rtype: str
|
||||
|
||||
:raises: ValueError if an invalid data string is provided
|
||||
"""
|
||||
# Convert data_str to bytes (if not already)
|
||||
if isinstance(data_str, str):
|
||||
data_bytes = data_str.encode("utf-8")
|
||||
else:
|
||||
data_bytes = data_str
|
||||
|
||||
# Derive signature via HMAC-SHA256 algorithm
|
||||
try:
|
||||
hmac_digest = hmac.HMAC(
|
||||
key=self._signing_key, msg=data_bytes, digestmod=hashlib.sha256
|
||||
).digest()
|
||||
signed_data = base64.b64encode(hmac_digest)
|
||||
except (TypeError):
|
||||
raise ValueError("Unable to sign string using the provided symmetric key")
|
||||
# Convert from bytes to string
|
||||
return signed_data.decode("utf-8")
|
Загрузка…
Ссылка в новой задаче