зеркало из
1
0
Форкнуть 0

Add pre-commit hooks for black and flake8 (#23)

This commit is contained in:
Pierre Cauchois 2018-11-21 15:23:14 -08:00 коммит произвёл GitHub
Родитель 899ad25c83
Коммит ec6f217aa8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
33 изменённых файлов: 209 добавлений и 170 удалений

10
.flake8
Просмотреть файл

@ -1,2 +1,10 @@
[flake8]
max-line-length = 130
# E501: line length (handled by black)
# W503, E203: Not PEP8 compliant (incompatible with black)
# F401, F403: imports in Provisioning SDK
# Ignore generated code
ignore = E501,W503,E203,F401,F403
exclude =
.git,
__pycache__,
./azure-iot-provisioning-servicesdk/azure/iot/provisioning/servicesdk/protocol/models

11
.pre-commit-config.yaml Normal file
Просмотреть файл

@ -0,0 +1,11 @@
repos:
- repo: https://github.com/ambv/black
rev: 18.9b0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.0.0 # Use the ref you want to point at
hooks:
- id: flake8
args: ['--config=.flake8']

Просмотреть файл

@ -97,4 +97,4 @@ def test_get_item_does_not_exist_no_given_default():
cs = ConnectionString(
"HostName=my.host.name;SharedAccessKeyName=mykeyname;SharedAccessKey=Zm9vYmFy"
)
assert cs.get("invalidkey") == None
assert cs.get("invalidkey") is None

Просмотреть файл

@ -12,6 +12,7 @@ class AuthenticationProvider(object):
Super class for all providing known types of authentication mechanism like
x509 and SAS based authentication.
"""
def __init__(self, hostname, device_id, module_id=None):
self.hostname = hostname
self.device_id = device_id
@ -28,4 +29,3 @@ class AuthenticationProvider(object):
:param:source The source in string. This could be connections string or a shared access signature string.
"""
pass

Просмотреть файл

@ -25,12 +25,13 @@ def from_shared_access_signature(sas_token_str):
"""
return SharedAccessSignatureAuthenticationProvider.parse(sas_token_str)
def from_environment():
"""
Provides an `AuthenticationProvider` object that can be used inside of an Azure IoT Edge module.
This method does not need any parameters because all of the information necessary to connect
to Azure IoT Edge comes from the operating system of the module container and also from the
to Azure IoT Edge comes from the operating system of the module container and also from the
IoTEdge service.
:return: iotedge AuthenticationProvider

Просмотреть файл

@ -5,8 +5,6 @@ import time
import abc
import logging
import math
import weakref
from threading import Timer
import six.moves.urllib as urllib
from .authentication_provider import AuthenticationProvider
@ -80,9 +78,7 @@ class BaseRenewableTokenAuthenticationProvider(AuthenticationProvider):
quoted_resource_uri, signature, str(expiry), self.shared_access_key_name
)
else:
token = _device_token_format.format(
quoted_resource_uri, signature, str(expiry)
)
token = _device_token_format.format(quoted_resource_uri, signature, str(expiry))
self.sas_token_str = str(token)
@ -98,7 +94,6 @@ class BaseRenewableTokenAuthenticationProvider(AuthenticationProvider):
self.generate_new_sas_token()
return self.sas_token_str
@abc.abstractmethod
def _sign(self, quoted_resource_uri, expiry):
"""

Просмотреть файл

@ -4,7 +4,6 @@
# --------------------------------------------------------------------------------------------
import os
import logging
import six.moves.urllib as urllib
from .base_renewable_token_authentication_provider import BaseRenewableTokenAuthenticationProvider
from .iotedge_hsm import IotEdgeHsm
@ -24,13 +23,9 @@ class IotEdgeAuthenticationProvider(BaseRenewableTokenAuthenticationProvider):
device_id = os.environ["IOTEDGE_DEVICEID"]
module_id = os.environ["IOTEDGE_MODULEID"]
logger.info(
"Using IoTEdge authentication for {%s, %s, %s}", hostname, device_id, module_id
)
logger.info("Using IoTEdge authentication for {%s, %s, %s}", hostname, device_id, module_id)
BaseRenewableTokenAuthenticationProvider.__init__(
self, hostname, device_id, module_id
)
BaseRenewableTokenAuthenticationProvider.__init__(self, hostname, device_id, module_id)
self.hsm = IotEdgeHsm()
self.gateway_hostname = os.environ["IOTEDGE_GATEWAYHOSTNAME"]

Просмотреть файл

@ -10,30 +10,31 @@ requests_unixsocket.monkeypatch()
class IotEdgeHsm(object):
"""
Constructor for instantiating a iot hsm object. This is an object that
communicates with the Azure IoT Edge HSM in order to get connection credentials
for an Azure IoT Edge module. The credentials that this object return come in
Constructor for instantiating a iot hsm object. This is an object that
communicates with the Azure IoT Edge HSM in order to get connection credentials
for an Azure IoT Edge module. The credentials that this object return come in
two forms:
1. The trust bundle, which is a certificate that can be used as a trusted cert
to authenticate the SSL connection between the IoE Edge module and IoT Edge
2. A signing function, which can be used to create the sig field for a
2. A signing function, which can be used to create the sig field for a
SharedAccessSignature string which can be used to authenticate with Iot Edge
Instantiating this object does not require any parameters. All necessary parameters
come from environment variables that are set inside the IoT Edge module container
come from environment variables that are set inside the IoT Edge module container
by the edgeAgent that creates the module.
"""
@staticmethod
def _fix_socket_uri(old_uri):
"""
This function takes a socket URI in one form and converts it into another form.
The source form is based on what we receive inside the IOTEDGE_WORKLOADURI
environment variable, and it looks like this:
The source form is based on what we receive inside the IOTEDGE_WORKLOADURI
environment variable, and it looks like this:
"unix:///var/run/iotedge/workload.sock"
The destination form is based on what the requests_unixsocket library expects
The destination form is based on what the requests_unixsocket library expects
and it looks like this:
"http+unix://%2Fvar%2Frun%2Fiotedge%2Fworkload.sock/"
@ -90,12 +91,12 @@ class IotEdgeHsm(object):
def sign(self, data):
"""
Use the IoTEdge HSM to sign a piece of data. The caller should then insert the
Use the IoTEdge HSM to sign a piece of data. The caller should then insert the
returned value (the signature) into the 'sig' field of a SharedAccessSignature string.
:param data: The string to sign
:return: The signature, as a URI-encoded and base64-encoded value that is ready to
:return: The signature, as a URI-encoded and base64-encoded value that is ready to
directly insert into the SharedAccessSignature string.
"""
path = (
@ -109,13 +110,11 @@ class IotEdgeHsm(object):
sign_request = {
"keyId": "primary",
"algo": "HMACSHA256",
"data": base64.b64encode(data.encode('utf-8')).decode(),
"data": base64.b64encode(data.encode("utf-8")).decode(),
}
r = requests.post(
path,
params={"api-version": self.api_version},
data=json.dumps(sign_request),
path, params={"api-version": self.api_version}, data=json.dumps(sign_request)
)
r.raise_for_status()
return urllib.parse.quote(r.json()["digest"])

Просмотреть файл

@ -4,6 +4,7 @@
# --------------------------------------------------------------------------------------------
import logging
from .authentication_provider import AuthenticationProvider
"""
The urllib, urllib2, and urlparse modules from Python 2 have been combined in the urllib package in Python 3
The six.moves.urllib package is a python version-independent location of the above functionality.
@ -22,12 +23,7 @@ SHARED_ACCESS_KEY_NAME = "skn"
RESOURCE_URI = "sr"
EXPIRY = "se"
_valid_keys = [
SIGNATURE,
SHARED_ACCESS_KEY_NAME,
RESOURCE_URI,
EXPIRY
]
_valid_keys = [SIGNATURE, SHARED_ACCESS_KEY_NAME, RESOURCE_URI, EXPIRY]
class SharedAccessSignatureAuthenticationProvider(AuthenticationProvider):
@ -35,13 +31,12 @@ class SharedAccessSignatureAuthenticationProvider(AuthenticationProvider):
The Shared Access Signature Authentication Provider.
This provider already contains the sas token which will be needed to authenticate with The IoT hub.
"""
def __init__(self, hostname, device_id, module_id, sas_token_str):
"""
Constructor for Shared Access Signature Authentication Provider
"""
logger.info(
"Using SAS authentication for {%s, %s, %s}", hostname, device_id, module_id
)
logger.info("Using SAS authentication for {%s, %s, %s}", hostname, device_id, module_id)
AuthenticationProvider.__init__(self, hostname, device_id, module_id)
self.sas_token_str = sas_token_str
@ -65,7 +60,8 @@ class SharedAccessSignatureAuthenticationProvider(AuthenticationProvider):
parts = sas_token_str.split(PARTS_SEPARATOR)
if len(parts) != 2:
raise ValueError(
"The Shared Access Signature must be of the format 'SharedAccessSignature sr=<resource_uri>&sig=<signature>&se=<expiry>' or/and it can additionally contain an optional skn=<keyname> name=value pair.")
"The Shared Access Signature must be of the format 'SharedAccessSignature sr=<resource_uri>&sig=<signature>&se=<expiry>' or/and it can additionally contain an optional skn=<keyname> name=value pair."
)
sas_args = parts[1].split(DELIMITER)
d = dict(arg.split(VALUE_SEPARATOR, 1) for arg in sas_args)
@ -73,7 +69,8 @@ class SharedAccessSignatureAuthenticationProvider(AuthenticationProvider):
raise ValueError("Invalid Shared Access Signature - Unable to parse")
if not all(key in _valid_keys for key in d.keys()):
raise ValueError(
"Invalid keys in Shared Access Signature. The valid keys are sr, sig, se and an optional skn.")
"Invalid keys in Shared Access Signature. The valid keys are sr, sig, se and an optional skn."
)
_validate_required_keys(d)
@ -87,7 +84,9 @@ class SharedAccessSignatureAuthenticationProvider(AuthenticationProvider):
if len(url_segments) > 4:
module_id = url_segments[4]
return SharedAccessSignatureAuthenticationProvider(hostname, device_id, module_id, sas_token_str)
return SharedAccessSignatureAuthenticationProvider(
hostname, device_id, module_id, sas_token_str
)
def _validate_required_keys(d):
@ -102,6 +101,6 @@ def _validate_required_keys(d):
if resource_uri and signature and expiry:
pass
else:
raise ValueError("Invalid Shared Access Signature. It must be of the format 'SharedAccessSignature sr=<resource_uri>&sig=<signature>&se=<expiry>' or/and it can additionally contain an optional skn=<keyname> name=value pair.")
raise ValueError(
"Invalid Shared Access Signature. It must be of the format 'SharedAccessSignature sr=<resource_uri>&sig=<signature>&se=<expiry>' or/and it can additionally contain an optional skn=<keyname> name=value pair."
)

Просмотреть файл

@ -5,7 +5,6 @@
import base64
import hmac
import hashlib
import time
import logging
import six.moves.urllib as urllib
from .base_renewable_token_authentication_provider import BaseRenewableTokenAuthenticationProvider
@ -47,7 +46,7 @@ class SymmetricKeyAuthenticationProvider(BaseRenewableTokenAuthenticationProvide
module_id,
shared_access_key,
shared_access_key_name=None,
gateway_hostname=None
gateway_hostname=None,
):
"""
@ -57,9 +56,7 @@ class SymmetricKeyAuthenticationProvider(BaseRenewableTokenAuthenticationProvide
"Using Shared Key authentication for {%s, %s, %s}", hostname, device_id, module_id
)
BaseRenewableTokenAuthenticationProvider.__init__(
self, hostname, device_id, module_id
)
BaseRenewableTokenAuthenticationProvider.__init__(self, hostname, device_id, module_id)
self.shared_access_key = shared_access_key
self.shared_access_key_name = shared_access_key_name
self.gateway_hostname = gateway_hostname
@ -108,9 +105,7 @@ class SymmetricKeyAuthenticationProvider(BaseRenewableTokenAuthenticationProvide
signed_hmac = hmac.HMAC(signing_key, message, hashlib.sha256)
signature = urllib.parse.quote(base64.b64encode(signed_hmac.digest()))
except (TypeError, base64.binascii.Error) as e:
raise TypeError(
"Unable to build shared access signature from given values", e
)
raise TypeError("Unable to build shared access signature from given values", e)
return signature

Просмотреть файл

@ -4,7 +4,6 @@
# --------------------------------------------------------------------------------------------
import logging
import types
from .transport.mqtt.mqtt_transport import MQTTTransport
logger = logging.getLogger(__name__)

Просмотреть файл

@ -8,4 +8,3 @@ from .internal_client import InternalClient
class ModuleClient(InternalClient):
pass

Просмотреть файл

@ -5,7 +5,6 @@
import paho.mqtt.client as mqtt
import logging
import types
import ssl
logger = logging.getLogger(__name__)
@ -23,7 +22,7 @@ class MQTTProvider(object):
:param client_id: The id of the client connecting to the broker.
:param hostname: hostname or IP address of the remote broker.
:param password: The password to authenticate with.
:param ca_cert: Certificate which can be used to validate a server-side TLS connection.
:param ca_cert: Certificate which can be used to validate a server-side TLS connection.
"""
self._client_id = client_id
self._hostname = hostname

Просмотреть файл

@ -74,7 +74,7 @@ class MQTTTransport(AbstractTransport):
"source": "sending",
"dest": None,
"before": "_before_action_notify_publish_complete",
"after": "_trig_check_send_event_queue"
"after": "_trig_check_send_event_queue",
},
{
"trigger": "_trig_send_event",
@ -91,13 +91,13 @@ class MQTTTransport(AbstractTransport):
},
{
"trigger": "_trig_check_send_event_queue",
"source": [ "connected", "sending" ],
"source": ["connected", "sending"],
"dest": "connected",
"conditions": "_queue_is_empty",
},
{
"trigger": "_trig_check_send_event_queue",
"source": [ "connected", "sending" ],
"source": ["connected", "sending"],
"dest": "sending",
"unless": "_queue_is_empty",
"after": "_after_action_deliver_next_queued_event",
@ -147,7 +147,7 @@ class MQTTTransport(AbstractTransport):
def _before_action_notify_publish_complete(self, event):
logger.info("publish complete:" + str(event))
logger.info("publish error:" + str(event.error));
logger.info("publish error:" + str(event.error))
if not event.error:
self.on_event_sent()
@ -160,7 +160,7 @@ class MQTTTransport(AbstractTransport):
def _after_action_provider_disconnect(self, event):
"""
Call into the provider to disconnect the transport.
Call into the provider to disconnect the transport.
This is meant to be called by the state machine as an "after" action
"""
self._mqtt_provider.disconnect()

Просмотреть файл

@ -5,7 +5,9 @@
import os
import logging
from azure.iot.hub.devicesdk.device_client import DeviceClient
from azure.iot.hub.devicesdk.auth.authentication_provider_factory import from_shared_access_signature
from azure.iot.hub.devicesdk.auth.authentication_provider_factory import (
from_shared_access_signature,
)
logging.basicConfig(level=logging.INFO)

Просмотреть файл

@ -10,7 +10,7 @@ with open("doc/package-readme.md", "r") as fh:
setup(
name="azure_iot_hub_devicesdk",
version="0.0.0a1", # Alpha Release
version="0.0.0a1", # Alpha Release
description="Microsoft Azure IoT Hub Device SDK",
license="MIT License",
url="https://github.com/Azure/azure-iot-sdk-python",

Просмотреть файл

@ -13,4 +13,3 @@ def test_raises_exception_on_init_of_abstract_transport():
msg = str(error.value)
expected_msg = "Can't instantiate abstract class AbstractTransport with abstract methods connect, disconnect, send_event"
assert msg == expected_msg

Просмотреть файл

@ -3,19 +3,23 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from azure.iot.hub.devicesdk.auth.authentication_provider_factory import from_connection_string, from_shared_access_signature
from azure.iot.hub.devicesdk.auth.sk_authentication_provider import SymmetricKeyAuthenticationProvider
from azure.iot.hub.devicesdk.auth.sas_authentication_provider import SharedAccessSignatureAuthenticationProvider
import pytest
from azure.iot.hub.devicesdk.auth.authentication_provider_factory import (
from_connection_string,
from_shared_access_signature,
)
from azure.iot.hub.devicesdk.auth.sk_authentication_provider import (
SymmetricKeyAuthenticationProvider,
)
from azure.iot.hub.devicesdk.auth.sas_authentication_provider import (
SharedAccessSignatureAuthenticationProvider,
)
connection_string_device_sk_format = "HostName={};DeviceId={};SharedAccessKey={}"
connection_string_device_skn_format = (
"HostName={};DeviceId={};SharedAccessKeyName={};SharedAccessKey={}"
)
connection_string_module_sk_format = (
"HostName={};DeviceId={};ModuleId={};SharedAccessKey={}"
)
connection_string_module_sk_format = "HostName={};DeviceId={};ModuleId={};SharedAccessKey={}"
connection_string_module_gateway_sk_format = (
"HostName={};DeviceId={};ModuleId={};SharedAccessKey={};GatewayHostName={}"
)
@ -95,4 +99,3 @@ def create_sas_token_string(is_module=False, is_key_name=False):
return sas_device_skn_token_format.format(uri, signature, expiry, shared_access_key_name)
else:
return sas_device_token_format.format(uri, signature, expiry)

Просмотреть файл

@ -11,10 +11,9 @@ from azure.iot.hub.devicesdk.transport.abstract_transport import AbstractTranspo
import pytest
from six import add_move, MovedModule
from mock import MagicMock
add_move(MovedModule("mock", "mock", "unittest.mock"))
from six.moves import mock
from mock import MagicMock
connection_string_format = "HostName={};DeviceId={};SharedAccessKey={}"
@ -34,12 +33,15 @@ def authentication_provider(connection_string):
auth_provider = from_connection_string(connection_string)
return auth_provider
@pytest.fixture(scope="function")
def mock_transport():
return MagicMock(spec=AbstractTransport)
def test_internal_client_connect_in_turn_calls_transport_connect(authentication_provider, mock_transport):
def test_internal_client_connect_in_turn_calls_transport_connect(
authentication_provider, mock_transport
):
client = InternalClient(authentication_provider, mock_transport)
client.connect()
@ -47,7 +49,9 @@ def test_internal_client_connect_in_turn_calls_transport_connect(authentication_
mock_transport.connect.assert_called_once_with()
def test_connected_state_handler_called_wth_new_state_once_transport_gets_connected(mocker, authentication_provider, mock_transport):
def test_connected_state_handler_called_wth_new_state_once_transport_gets_connected(
mocker, authentication_provider, mock_transport
):
client = InternalClient(authentication_provider, mock_transport)
stub_on_connection_state = mocker.stub(name="on_connection_state")
client.on_connection_state = stub_on_connection_state
@ -58,7 +62,10 @@ def test_connected_state_handler_called_wth_new_state_once_transport_gets_connec
assert client.state == "connected"
stub_on_connection_state.assert_called_once_with("connected")
def test_connected_state_handler_called_wth_new_state_once_transport_gets_connected(mocker, authentication_provider, mock_transport):
def test_connected_state_handler_called_wth_new_state_once_transport_gets_disconnected(
mocker, authentication_provider, mock_transport
):
client = InternalClient(authentication_provider, mock_transport)
stub_on_connection_state = mocker.stub(name="on_connection_state")
client.on_connection_state = stub_on_connection_state
@ -74,7 +81,10 @@ def test_connected_state_handler_called_wth_new_state_once_transport_gets_connec
assert client.state == "disconnected"
stub_on_connection_state.assert_called_once_with("disconnected")
def test_internal_client_send_event_in_turn_calls_transport_send_event(authentication_provider, mock_transport):
def test_internal_client_send_event_in_turn_calls_transport_send_event(
authentication_provider, mock_transport
):
event = "Levicorpus"
client = InternalClient(authentication_provider, mock_transport)
@ -98,12 +108,14 @@ def test_transport_any_error_surfaces_to_internal_client(authentication_provider
mock_transport.send_event.assert_called_once_with(event)
@pytest.mark.parametrize("kind_of_client, auth", [
("Module", authentication_provider),
("Device", authentication_provider),
])
@pytest.mark.parametrize(
"kind_of_client, auth",
[("Module", authentication_provider), ("Device", authentication_provider)],
)
def test_client_gets_created_correctly(mocker, kind_of_client, auth, mock_transport):
mock_constructor_transport = mocker.patch("azure.iot.hub.devicesdk.internal_client.MQTTTransport")
mock_constructor_transport = mocker.patch(
"azure.iot.hub.devicesdk.internal_client.MQTTTransport"
)
mock_constructor_transport.return_value = mock_transport
if kind_of_client == "Module":
@ -121,12 +133,14 @@ def test_client_gets_created_correctly(mocker, kind_of_client, auth, mock_transp
assert isinstance(device_client, DeviceClient)
@pytest.mark.parametrize("kind_of_client, auth", [
("Module", authentication_provider),
("Device", authentication_provider),
])
@pytest.mark.parametrize(
"kind_of_client, auth",
[("Module", authentication_provider), ("Device", authentication_provider)],
)
def test_raises_on_creation_of_client_when_transport_is_incorrect(kind_of_client, auth):
with pytest.raises(NotImplementedError, match="No specific transport can be instantiated based on the choice."):
with pytest.raises(
NotImplementedError, match="No specific transport can be instantiated based on the choice."
):
if kind_of_client == "Module":
ModuleClient.from_authentication_provider(authentication_provider, "floo")
else:

Просмотреть файл

@ -3,15 +3,14 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import os
import pytest
from azure.iot.hub.devicesdk.auth.iotedge_authentication_provider import IotEdgeAuthenticationProvider
from azure.iot.hub.devicesdk.auth.iotedge_authentication_provider import (
IotEdgeAuthenticationProvider,
)
from six import add_move, MovedModule
from mock import patch
add_move(MovedModule("mock", "mock", "unittest.mock"))
from six.moves import mock
from mock import Mock
from mock import patch
fake_ca_cert = "__FAKE_CA_CERTIFICATE__"
fake_module_id = "__FAKE_MODULE_ID__"
@ -24,25 +23,27 @@ required_environment_variables = {
"IOTEDGE_MODULEID": fake_module_id,
"IOTEDGE_DEVICEID": fake_device_id,
"IOTEDGE_IOTHUBHOSTNAME": fake_hostname,
"IOTEDGE_GATEWAYHOSTNAME": fake_gateway_hostname
"IOTEDGE_GATEWAYHOSTNAME": fake_gateway_hostname,
}
@patch.dict(os.environ, required_environment_variables)
@patch("azure.iot.hub.devicesdk.auth.iotedge_authentication_provider.IotEdgeHsm")
def test_initializer_gets_details_from_environment(mock_hsm):
auth_provider = IotEdgeAuthenticationProvider()
assert(auth_provider.gateway_hostname == fake_gateway_hostname)
assert(auth_provider.device_id == fake_device_id)
assert(auth_provider.module_id == fake_module_id)
assert(auth_provider.hostname == fake_hostname)
assert auth_provider.gateway_hostname == fake_gateway_hostname
assert auth_provider.device_id == fake_device_id
assert auth_provider.module_id == fake_module_id
assert auth_provider.hostname == fake_hostname
@patch.dict(os.environ, required_environment_variables)
@patch("azure.iot.hub.devicesdk.auth.iotedge_authentication_provider.IotEdgeHsm")
def test_initializer_gets_ca_certificate_from_hsm(MockHsm):
MockHsm.return_value.get_trust_bundle.return_value = fake_ca_cert
auth_provider = IotEdgeAuthenticationProvider()
assert(auth_provider.ca_cert == fake_ca_cert)
assert auth_provider.ca_cert == fake_ca_cert
@patch.dict(os.environ, required_environment_variables)
@patch("azure.iot.hub.devicesdk.auth.iotedge_authentication_provider.IotEdgeHsm")
@ -50,8 +51,11 @@ def test_get_shared_access_key_uses_hsm_to_sign(MockHsm):
MockHsm.return_value.sign.return_value = fake_digest
auth_provider = IotEdgeAuthenticationProvider()
sas_token = auth_provider.get_current_sas_token()
assert(MockHsm.return_value.sign.call_args[0][0].startswith("{}%2Fdevices%2F{}%2Fmodules%2F{}\n".format(fake_hostname, fake_device_id, fake_module_id)))
assert(sas_token.startswith("SharedAccessSignature sr={}%2Fdevices%2F{}%2Fmodules%2F{}&sig={}&se=".format(fake_hostname, fake_device_id, fake_module_id, fake_digest)))
assert MockHsm.return_value.sign.call_args[0][0].startswith(
"{}%2Fdevices%2F{}%2Fmodules%2F{}\n".format(fake_hostname, fake_device_id, fake_module_id)
)
assert sas_token.startswith(
"SharedAccessSignature sr={}%2Fdevices%2F{}%2Fmodules%2F{}&sig={}&se=".format(
fake_hostname, fake_device_id, fake_module_id, fake_digest
)
)

Просмотреть файл

@ -11,11 +11,10 @@ import json
import base64
from six import add_move, MovedModule
from six.moves import mock
from mock import patch
add_move(MovedModule("mock", "mock", "unittest.mock"))
from six.moves import mock
from mock import MagicMock
from mock import patch
fake_module_id = "__FAKE_MODULE__ID__"
fake_api_version = "__FAKE_API_VERSION__"
@ -35,7 +34,7 @@ required_environment_variables = {
@patch.dict(os.environ, required_environment_variables)
def test_initializer_doesnt_throw_when_all_environment_variables_are_present():
hsm = IotEdgeHsm()
IotEdgeHsm()
def test_initializer_throws_with_missing_environment_variables():
@ -44,7 +43,7 @@ def test_initializer_throws_with_missing_environment_variables():
del env[key]
with patch.dict(os.environ, env):
with pytest.raises(KeyError, match=key):
hsm = IotEdgeHsm()
IotEdgeHsm()
@patch.object(requests, "get")
@ -60,16 +59,15 @@ def test_get_trust_bundle_returns_certificate(mock_get):
assert cert == fake_certificate
mock_response.raise_for_status.assert_called_once_with() # this verifies that a failed status code will throw
mock_get.assert_called_once_with(
fake_http_workload_uri + "trust-bundle",
params={"api-version": fake_api_version},
fake_http_workload_uri + "trust-bundle", params={"api-version": fake_api_version}
)
@patch.object(requests, "post")
@patch.dict(os.environ, required_environment_variables)
def test_get_trust_bundle_returns_certificate(mock_post):
def test_sign_sends_post_with_proper_url_and_data(mock_post):
mock_response = mock.Mock(spec=requests.Response)
mock_response.json.return_value = {"digest": fake_digest }
mock_response.json.return_value = {"digest": fake_digest}
mock_post.return_value = mock_response
hsm = IotEdgeHsm()
@ -113,7 +111,7 @@ def test_workload_uri_values_get_adjusted_correctly(mock_get):
env["IOTEDGE_WORKLOADURI"] = original_uri
with patch.dict(os.environ, env):
hsm = IotEdgeHsm()
cert = hsm.get_trust_bundle()
hsm.get_trust_bundle()
mock_get.assert_called_once_with(
adjusted_uri + "trust-bundle", params={"api-version": fake_api_version}

Просмотреть файл

@ -5,16 +5,13 @@
from azure.iot.hub.devicesdk.transport.mqtt.mqtt_provider import MQTTProvider
import paho.mqtt.client as mqtt
import os
import ssl
import pytest
from six import add_move, MovedModule
add_move(MovedModule("mock", "mock", "unittest.mock"))
from six.moves import mock
from mock import MagicMock
from mock import patch
add_move(MovedModule("mock", "mock", "unittest.mock"))
fake_hostname = "beauxbatons.academy-net"
fake_device_id = "MyFirebolt"
@ -32,12 +29,11 @@ def test_connect_triggers_client_connect(MockMqttClient, MockSsl):
mock_mqtt_client = MockMqttClient.return_value
MockSsl.assert_called_once_with(ssl.PROTOCOL_TLSv1_2)
mock_ssl = MockSsl.return_value
assert(mock_mqtt_client.tls_set_context.call_count == 1)
assert mock_mqtt_client.tls_set_context.call_count == 1
context = mock_mqtt_client.tls_set_context.call_args[0][0]
assert(context.check_hostname == True)
assert(context.verify_mode == ssl.CERT_REQUIRED)
assert context.check_hostname is True
assert context.verify_mode == ssl.CERT_REQUIRED
context.load_default_certs.assert_called_once_with()
mock_mqtt_client.tls_insecure_set.assert_called_once_with(False)
mock_mqtt_client.connect.assert_called_once_with(host=fake_hostname, port=8883)
@ -50,13 +46,22 @@ def test_connect_triggers_client_connect(MockMqttClient, MockSsl):
@patch.object(mqtt, "Client")
@pytest.mark.parametrize("client_callback_name, client_callback_args, provider_callback_name, provider_callback_args", [
("on_connect", [None, None, None, 0], "on_mqtt_connected", ["connected"]),
("on_disconnect", [None, None, 0], "on_mqtt_disconnected", ["disconnected"]),
("on_publish", [None, None, 0], "on_mqtt_published", []),
("on_subscribe", [None, None, 0], "on_mqtt_subscribed", [])
])
def test_mqtt_client_callback_triggers_provider_callback(MockMqttClient, client_callback_name, client_callback_args, provider_callback_name, provider_callback_args):
@pytest.mark.parametrize(
"client_callback_name, client_callback_args, provider_callback_name, provider_callback_args",
[
("on_connect", [None, None, None, 0], "on_mqtt_connected", ["connected"]),
("on_disconnect", [None, None, 0], "on_mqtt_disconnected", ["disconnected"]),
("on_publish", [None, None, 0], "on_mqtt_published", []),
("on_subscribe", [None, None, 0], "on_mqtt_subscribed", []),
],
)
def test_mqtt_client_callback_triggers_provider_callback(
MockMqttClient,
client_callback_name,
client_callback_args,
provider_callback_name,
provider_callback_args,
):
mock_mqtt_client = MockMqttClient.return_value
mqtt_provider = MQTTProvider(fake_device_id, fake_hostname, fake_username, fake_password)

Просмотреть файл

@ -6,14 +6,13 @@
import pytest
import logging
from azure.iot.hub.devicesdk.transport.mqtt.mqtt_transport import MQTTTransport
from azure.iot.hub.devicesdk.transport.mqtt.mqtt_provider import MQTTProvider
from azure.iot.hub.devicesdk.auth.authentication_provider_factory import from_connection_string
from six import add_move, MovedModule
add_move(MovedModule("mock", "mock", "unittest.mock"))
from six.moves import mock
from mock import MagicMock
add_move(MovedModule("mock", "mock", "unittest.mock"))
logging.basicConfig(level=logging.INFO)
connection_string_format = "HostName={};DeviceId={};SharedAccessKey={}"
@ -27,7 +26,9 @@ fake_topic = "devices/" + fake_device_id + "/messages/events/"
@pytest.fixture(scope="function")
def authentication_provider():
connection_string = connection_string_format.format(fake_hostname, fake_device_id, fake_shared_access_key)
connection_string = connection_string_format.format(
fake_hostname, fake_device_id, fake_shared_access_key
)
auth_provider = from_connection_string(connection_string)
return auth_provider
@ -48,13 +49,15 @@ def test_instantiation_creates_proper_transport(authentication_provider):
assert trans._mqtt_provider is not None
class TestConnect():
class TestConnect:
def test_connect_calls_connect_on_provider(self, transport):
mock_mqtt_provider = transport._mqtt_provider
transport.connect()
mock_mqtt_provider.connect.assert_called_once_with()
def test_connected_state_handler_called_wth_new_state_once_provider_gets_connected(self, transport):
def test_connected_state_handler_called_wth_new_state_once_provider_gets_connected(
self, transport
):
mock_mqtt_provider = transport._mqtt_provider
transport.connect()
@ -93,7 +96,7 @@ class TestConnect():
transport.on_transport_connected.assert_not_called()
class TestSendEvent():
class TestSendEvent:
def test_sendevent_calls_publish_on_provider(self, transport):
mock_mqtt_provider = transport._mqtt_provider
@ -109,11 +112,11 @@ class TestSendEvent():
# send an event
transport.send_event(fake_event)
# verify that we called connect
mock_mqtt_provider.connect.assert_called_once_with()
# verify that we're not connected yet and verify that we havent't published yet
# verify that we're not connected yet and verify that we havent't published yet
transport.on_transport_connected.assert_not_called()
mock_mqtt_provider.publish.assert_not_called()
@ -134,7 +137,7 @@ class TestSendEvent():
# send an event
transport.send_event(fake_event)
# verify that we're not connected yet and verify that we havent't published yet
# verify that we're not connected yet and verify that we havent't published yet
transport.on_transport_connected.assert_not_called()
mock_mqtt_provider.publish.assert_not_called()
@ -182,7 +185,7 @@ class TestSendEvent():
# assert
transport.on_event_sent.assert_called_once_with()
def test_connect_send_disconnect(self, transport):
mock_mqtt_provider = transport._mqtt_provider
@ -198,7 +201,8 @@ class TestSendEvent():
transport.disconnect()
mock_mqtt_provider.disconnect.assert_called_once_with()
class TestDisconnect():
class TestDisconnect:
def test_disconnect_calls_disconnect_on_provider(self, transport):
mock_mqtt_provider = transport._mqtt_provider
@ -216,8 +220,6 @@ class TestDisconnect():
mock_mqtt_provider.disconnect.assert_not_called()
def test_disconnect_calls_client_disconnect_callback(self, transport):
mock_mqtt_provider = transport._mqtt_provider
transport.connect()
transport._trig_provider_connect_complete()
@ -225,5 +227,3 @@ class TestDisconnect():
transport._trig_provider_disconnect_complete()
transport.on_transport_disconnected.assert_called_once_with("disconnected")

Просмотреть файл

@ -3,7 +3,9 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import pytest
from azure.iot.hub.devicesdk.auth.sas_authentication_provider import SharedAccessSignatureAuthenticationProvider
from azure.iot.hub.devicesdk.auth.sas_authentication_provider import (
SharedAccessSignatureAuthenticationProvider,
)
sas_device_token_format = "SharedAccessSignature sr={}&sig={}&se={}"
@ -71,7 +73,10 @@ def test_sas_auth_provider_is_created_from_device_sas_token_string_quoted():
def test_raises_auth_provider_created_from_missing_part_shared_access_signature_string():
with pytest.raises(ValueError, match="The Shared Access Signature must be of the format 'SharedAccessSignature sr=<resource_uri>&sig=<signature>&se=<expiry>' or/and it can additionally contain an optional skn=<keyname> name=value pair."):
with pytest.raises(
ValueError,
match="The Shared Access Signature must be of the format 'SharedAccessSignature sr=<resource_uri>&sig=<signature>&se=<expiry>' or/and it can additionally contain an optional skn=<keyname> name=value pair.",
):
one_part_sas_str = "sr=beauxbatons.academy-net%2Fdevices%2FMyPensieve&sig=IsolemnlySwearThatIamuUptoNogood&se=1539043658&skn=alohomora"
SharedAccessSignatureAuthenticationProvider.parse(one_part_sas_str)
@ -83,14 +88,18 @@ def test_raises_auth_provider_created_from_shared_access_signature_string_duplic
def test_raises_auth_provider_created_from_shared_access_signature_string_bad_keys():
with pytest.raises(ValueError, match="Invalid keys in Shared Access Signature. The valid keys are sr, sig, se and an optional skn."):
with pytest.raises(
ValueError,
match="Invalid keys in Shared Access Signature. The valid keys are sr, sig, se and an optional skn.",
):
bad_key_sas_str = "SharedAccessSignature sr=beauxbatons.academy-net%2Fdevices%2FMyPensieve&signature=IsolemnlySwearThatIamuUptoNogood&se=1539043658&skn=alohomora"
SharedAccessSignatureAuthenticationProvider.parse(bad_key_sas_str)
def test_raises_auth_provider_created_from_incomplete_shared_access_signature_string():
with pytest.raises(ValueError, match="Invalid Shared Access Signature. It must be of the format 'SharedAccessSignature sr=<resource_uri>&sig=<signature>&se=<expiry>' or/and it can additionally contain an optional skn=<keyname> name=value pair."):
with pytest.raises(
ValueError,
match="Invalid Shared Access Signature. It must be of the format 'SharedAccessSignature sr=<resource_uri>&sig=<signature>&se=<expiry>' or/and it can additionally contain an optional skn=<keyname> name=value pair.",
):
incomplete_sas_str = "SharedAccessSignature sr=beauxbatons.academy-net%2Fdevices%2FMyPensieve&se=1539043658&skn=alohomora"
SharedAccessSignatureAuthenticationProvider.parse(incomplete_sas_str)

Просмотреть файл

@ -3,15 +3,15 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import pytest
from azure.iot.hub.devicesdk.auth.sk_authentication_provider import SymmetricKeyAuthenticationProvider
from azure.iot.hub.devicesdk.auth.sk_authentication_provider import (
SymmetricKeyAuthenticationProvider,
)
connection_string_device_sk_format = "HostName={};DeviceId={};SharedAccessKey={}"
connection_string_device_skn_format = (
"HostName={};DeviceId={};SharedAccessKeyName={};SharedAccessKey={}"
)
connection_string_module_sk_format = (
"HostName={};DeviceId={};ModuleId={};SharedAccessKey={}"
)
connection_string_module_sk_format = "HostName={};DeviceId={};ModuleId={};SharedAccessKey={}"
shared_access_key = "Zm9vYmFy"
shared_access_key_name = "alohomora"

Просмотреть файл

@ -30,7 +30,7 @@ class ConnectionStringAuthentication(ConnectionString, Authentication):
If a session object is provided, configure it directly. Otherwise,
create a new session and return it.
:param session: The session to configure for authentication
:type session: requests.Session
:rtype: requests.Session

Просмотреть файл

@ -11,7 +11,6 @@ from azure.iot.sdk.provisioning.service.models import (
IndividualEnrollment,
AttestationMechanism,
TpmAttestation,
QuerySpecification,
)
@ -40,7 +39,7 @@ def run_sample(cs, ek):
new_enrollments = []
for i in range(0, 10):
new_tpm = TpmAttestation(endorsement_key=ek)
new_am = AttestationMechanism(type="tpm", tpm=tpm)
new_am = AttestationMechanism(type="tpm", tpm=new_tpm)
new_ie = IndividualEnrollment(registration_id=("id-" + str(i)), attestation=new_am)
new_enrollments.append(new_ie)
bulk_op = BulkEnrollmentOperation(enrollments=new_enrollments, mode="create")

Просмотреть файл

@ -3,11 +3,8 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import sys
import os
import copy
import six
import pytest
import e2e_convenience
@ -104,7 +101,7 @@ def purge_individual_enrollments(client):
enrollments = []
qs = QuerySpecification(query="*")
cont = ""
while cont != None:
while cont is not None:
qrr = client.query_individual_enrollments(
query_specification=qs, x_ms_continuation=cont, raw=True
)
@ -127,7 +124,7 @@ def purge_enrollment_groups(client):
enrollments = []
qs = QuerySpecification(query="*")
cont = ""
while cont != None:
while cont is not None:
qrr = client.query_enrollment_groups(
query_specification=qs, x_ms_continuation=cont, raw=True
)
@ -157,14 +154,14 @@ class TestIndividualEnrollment(object):
assert ret_ie.registration_id == REGISTRATION_ID
assert ret_ie.initial_twin.tags.additional_properties == TAGS
assert ret_ie.initial_twin.properties.desired.additional_properties == DESIRED_PROPERTIES
assert ret_ie.capabilities.iot_edge == True
assert ret_ie.capabilities.iot_edge is True
# get
ret_ie = client.get_individual_enrollment(REGISTRATION_ID)
assert ret_ie.registration_id == REGISTRATION_ID
assert ret_ie.initial_twin.tags.additional_properties == TAGS
assert ret_ie.initial_twin.properties.desired.additional_properties == DESIRED_PROPERTIES
assert ret_ie.capabilities.iot_edge == True
assert ret_ie.capabilities.iot_edge is True
# delete
client.delete_individual_enrollment(REGISTRATION_ID)

Просмотреть файл

@ -4,7 +4,6 @@
# --------------------------------------------------------------------------------------------
import pytest
from pytest_mock import mocker
from azure.iot.provisioning.servicesdk.auth import (
ConnectionStringAuthentication,
HOST_NAME,

Просмотреть файл

@ -4,7 +4,6 @@
# --------------------------------------------------------------------------------------------
import pytest
from pytest_mock import mocker
from azure.iot.provisioning.servicesdk import ProvisioningServiceClient
from azure.iot.provisioning.servicesdk.protocol import (
ProvisioningServiceClient as BaseProvisioningServiceClient,

7
pyproject.toml Normal file
Просмотреть файл

@ -0,0 +1,7 @@
[tool.black]
line-length = 100
exclude = '''
/(
\.git
)/
'''

Просмотреть файл

@ -11,3 +11,4 @@ msrest
six
mock
black; python_version >= '3.6'
pre-commit

Просмотреть файл

@ -32,6 +32,9 @@ jobs:
- script: 'python dev_setup.py'
displayName: 'Prepare environment (install packages + dependencies + tools)'
- script: 'flake8 .'
displayName: 'Flake8'
- script: 'python test_packages.py'
displayName: pytest