diff --git a/migration_guide.md b/migration_guide.md deleted file mode 100644 index 27fa23e15..000000000 --- a/migration_guide.md +++ /dev/null @@ -1,198 +0,0 @@ -# IoTHub Python SDK Migration Guide - -This guide details how to update existing code that uses an `azure-iot-device` V2 release to use a V3 release instead. While the APIs remain mostly the same, there are a few differences you may need to account for in your application, as we have removed some of the implicit behaviors present in V2 in order to provide a more reliable and consistent user experience. - -## Connecting to IoTHub -One of the primary changes in V3 is the removal of automatic connections when invoking other APIs on the `IoTHubDeviceClient` and `IoTHubModuleClient`. You must now make an explicit manual connection before sending or receiving any data. - -### V2 -```python -from azure.iot.device import IoTHubDeviceClient - -client = IoTHubDeviceClient.create_from_connection_string("") -client.send_message("some message") -``` - -### V3 -```python -from azure.iot.device import IoTHubDeviceClient - -client = IoTHubDeviceClient.create_from_connection_string("") -client.connect() -client.send_message("some message") -``` - -Note that many people using V2 may already have been doing manual connects, as for some time, this has been our recommended practice. - -Note also that this change does *not* affect automatic reconnection attempts in the case of network failure. Once the manual connect has been successful, the client will (under default settings) still attempt to retain that connected state as it did in V2. - - -## Receiving data from IoTHub -Similarly to the above, there is an additional explicit step you must now make when trying to receive data. In addition to setting your handler, you must explicitly start/stop receiving. Note also that the above step of manually connecting must also be done before starting to receive data. - -Furthermore, note that the content of the message is now referred to by the 'payload' attribute on the message, rather than the 'data' attribute (see "Message" section below) - -### V2 -```python -from azure.iot.device import IoTHubDeviceClient - -client = IoTHubDeviceClient.create_from_connection_string("") - -# define behavior for receiving a message -def message_handler(message): - print("the data in the message received was ") - print(message.data) - print("custom properties are") - print(message.custom_properties) - -# set the message handler on the client -client.on_message_received = message_handler -``` - -### V3 -```python -from azure.iot.device import IoTHubDeviceClient - -client = IoTHubDeviceClient.create_from_connection_string("") - -# define behavior for receiving a message -def message_handler(message): - print("the payload of the message received was ") - print(message.payload) - print("custom properties are") - print(message.custom_properties) - -# set the message handler on the client -client.on_message_received = message_handler - -# connect and start receiving messages -client.connect() -client.start_message_receive() -``` - -Note that this must be done not just for receiving messages, but receiving any data. Consult the chart below to see which APIs you will need for the type of data you are receiving. - - -| Data Type | Handler name | Start Receive API | Stop Receive API | -|---------------------------------|----------------------------------------------|--------------------------------------------------|-------------------------------------------------| -| Messages | `.on_message_received` | `.start_message_receive()` | `.stop_message_receive()` | -| Method Requests | `.on_method_request_received` | `.start_method_request_receive()` | `.stop_method_request_receive()` | -| Twin Desired Properties Patches | `.on_twin_desired_properties_patch_received` | `.start_twin_desired_properties_patch_receive()` | `.stop_twin_desired_properties_patch_receive()` | - - -Finally, it should be clarified that the following receive APIs that were deprecated in V2 have been fully removed in V3: -* `.receive_message()` -* `.receive_message_on_input()` -* `.receive_method_request()` -* `.receive_twin_desired_properties_patch()` - -All receives should now be done using the handlers in the table above. - - -## Message object - IoTHubDeviceClient/IoTHubModuleClient - -Some changes have been made to the `Message` object used for sending and receiving data. -* The `.data` attribute is now called `.payload` for consistency with other objects in the API -* The `message_id` parameter is no longer part of the constructor arguments. It should be manually added as an attribute, just like all other attributes -* The payload of a received Message is now a unicode string value instead of a bytestring value. -It will be decoded according to the content encoding property sent along with the message. - -### V2 -```python -from azure.iot.device import Message - -payload = "this is a payload" -message_id = "1234" -m = Message(data=payload, message_id=message_id) - -assert m.data == payload -assert m.message_id = message_id -``` - -### V3 -```python -from azure.iot.device import Message - -payload = "this is a payload" -message_id = "1234" -m = Message(payload=payload) -m.message_id = message_id - -assert m.payload == payload -``` - -## Modified Client Options - IoTHubDeviceClient/IoTHubModuleClient - -Some keyword arguments provided at client creation have changed or been removed - -| V2 | V3 | Explanation | -|-----------------------------|-------------|----------------------------------------| -| `auto_connect` | **REMOVED** | Initial manual connection now required | -| `ensure_desired_properties` | **REMOVED** | No more implicit twin updates | - - -## Shutting down - IoTHubDeviceClient/IoTHubModuleClient - -While using the `.shutdown()` method when you are completely finished with an instance of the client has been a highly recommended practice for some time, some early versions of V2 did not require it. As of V3, in order to ensure a graceful exit, you must make an explicit shutdown. - -### V2 -```python -from azure.iot.device import IoTHubDeviceClient - -client = IoTHubDeviceClient.create_from_connection_string("") - -# ... -# -# ... -``` - -### V3 -```python -from azure.iot.device import IoTHubDeviceClient - -client = IoTHubDeviceClient.create_from_connection_string("") - -# ... -# -# ... - -client.shutdown() -``` - - -## Shutting down - ProvisioningDeviceClient - -As with the IoTHub clients mentioned above, the Provisioning clients now also require shutdown. This was implicit in V2, but now it must be explicit and manual to ensure graceful exit. - -### V2 -```python -from azure.iot.device import ProvisioningDeviceClient - -client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host="", - registration_id="", - id_scope="", - symmetric_key=", - ) - -registration_result = client.register() - -# Shutdown is implicit upon successful registration -``` - -### V3 -```python -from azure.iot.device import ProvisioningDeviceClient - -client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host="", - registration_id="", - id_scope="", - symmetric_key=", - ) - -registration_result = client.register() - -# Manual shutdown for graceful exit -client.shutdown() -``` \ No newline at end of file diff --git a/migration_guide_iothub.md b/migration_guide_iothub.md new file mode 100644 index 000000000..7cf8879cc --- /dev/null +++ b/migration_guide_iothub.md @@ -0,0 +1,359 @@ +# Azure IoT Device SDK for Python Migration Guide - IoTHubDeviceClient and IoTHubModuleClient + +This guide details how to update existing code that uses an `azure-iot-device` V2 release to use a V3 release instead. While the APIs remain mostly the same, there are several differences you will need to account for in your application, as some APIs have changed, and we have removed some of the implicit behaviors present in V2 in order to provide a more reliable and consistent user experience. + +Note that this guide mostly refers to the `IoTHubDeviceClient`, although it's contents apply equally to the `IoTHubModuleClient`. + +For changes to the `ProvisioningDeviceClient` please refer to `migration_guide_provisioning.md` in this same directory. + +## Connecting to IoTHub +One of the primary changes in V3 is the removal of automatic connections when invoking other APIs on the `IoTHubDeviceClient` and `IoTHubModuleClient`. You must now make an explicit manual connection before sending or receiving any data. + +### V2 +```python +from azure.iot.device import IoTHubDeviceClient + +client = IoTHubDeviceClient.create_from_connection_string("") +client.send_message("some message") +``` + +### V3 +```python +from azure.iot.device import IoTHubDeviceClient + +client = IoTHubDeviceClient.create_from_connection_string("") +client.connect() +client.send_message("some message") +``` + +Note that many people using V2 may already have been doing manual connects, as for some time, this has been our recommended practice. + +Note also that this change does *not* affect automatic reconnection attempts in the case of network failure. Once the manual connect has been successful, the client will (under default settings) still attempt to retain that connected state as it did in V2. + + +## Receiving data from IoTHub +Similarly to the above, there is an additional explicit step you must now make when trying to receive data. In addition to setting your handler, you must explicitly start/stop receiving. Note also that the above step of manually connecting must also be done before starting to receive data. + +Furthermore, note that the content of the message is now referred to by the 'payload' attribute on the message, rather than the 'data' attribute (see "Message" section below) + +### V2 +```python +from azure.iot.device import IoTHubDeviceClient + +client = IoTHubDeviceClient.create_from_connection_string("") + +# define behavior for receiving a message +def message_handler(message): + print("the data in the message received was ") + print(message.data) + print("custom properties are") + print(message.custom_properties) + +# set the message handler on the client +client.on_message_received = message_handler +``` + +### V3 +```python +from azure.iot.device import IoTHubDeviceClient + +client = IoTHubDeviceClient.create_from_connection_string("") + +# define behavior for receiving a message +def message_handler(message): + print("the payload of the message received was ") + print(message.payload) + print("custom properties are") + print(message.custom_properties) + +# set the message handler on the client +client.on_message_received = message_handler + +# connect and start receiving messages +client.connect() +client.start_message_receive() +``` + +Note that this must be done not just for receiving messages, but receiving any data. Consult the chart below to see which APIs you will need for the type of data you are receiving. + + +| Data Type | Handler name | Start Receive API | Stop Receive API | +|---------------------------------|----------------------------------------------|--------------------------------------------------|-------------------------------------------------| +| Messages | `.on_message_received` | `.start_message_receive()` | `.stop_message_receive()` | +| Method Requests | `.on_method_request_received` | `.start_direct_method_request_receive()` | `.stop_direct_method_request_receive()` | +| Twin Desired Properties Patches | `.on_twin_desired_properties_patch_received` | `.start_twin_desired_properties_patch_receive()` | `.stop_twin_desired_properties_patch_receive()` | + + +Finally, it should be clarified that the following receive APIs that were deprecated in V2 have been fully removed in V3: +* `.receive_message()` +* `.receive_message_on_input()` +* `.receive_method_request()` +* `.receive_twin_desired_properties_patch()` + +All receives should now be done using the handlers in the table above. + + +## Direct Methods +For clarity, all references to direct methods are now explicit about being "direct methods", rather than the more generic (and overloaded) "method". As such, the following methods and objects have all had a name change: +* `.invoke_method()` -> `.invoke_direct_method()` +* `MethodRequest` -> `DirectMethodRequest` +* `MethodResponse` -> `DirectMethodResponse` + + +## Message object + +Some changes have been made to the `Message` object used for sending and receiving data. +* The `.data` attribute is now called `.payload` for consistency with other objects in the API +* The `message_id` parameter is no longer part of the constructor arguments. It should be manually added as an attribute, just like all other attributes +* The payload of a received Message is now a unicode string value instead of a bytestring value. +It will be decoded according to the content encoding property sent along with the message. + +### V2 +```python +from azure.iot.device import Message + +payload = "this is a payload" +message_id = "1234" +m = Message(data=payload, message_id=message_id) + +assert m.data == payload +assert m.message_id = message_id +``` + +### V3 +```python +from azure.iot.device import Message + +payload = "this is a payload" +message_id = "1234" +m = Message(payload=payload) +m.message_id = message_id + +assert m.payload == payload +``` + + +## Shutting down +While using the `.shutdown()` method when you are completely finished with an instance of the client has been a highly recommended practice for some time, some early versions of V2 did not require it. As of V3, in order to ensure a graceful exit, you must make an explicit shutdown. + +### V2 +```python +from azure.iot.device import IoTHubDeviceClient + +client = IoTHubDeviceClient.create_from_connection_string("") + +# ... +# +# ... +``` + +### V3 +```python +from azure.iot.device import IoTHubDeviceClient + +client = IoTHubDeviceClient.create_from_connection_string("") + +# ... +# +# ... + +client.shutdown() +``` + + +## Symmetric Key Authentication +Creating a client that uses a symmetric key to authenticate is now done via the new `.create()` factory method instead of `.create_from_symmetric_key()` + +### V2 +```python +from azure.iot.device import IoTHubDeviceClient + +client = IoTHubDeviceClient.create_from_symmetric_key( + symmetric_key="", + hostname="", + device_id="" +) +``` + +### V3 +```python +from azure.iot.device import IoTHubDeviceClient + +client = IoTHubDeviceClient.create( + symmetric_key="", + hostname="", + device_id="" +) +``` + +## Custom SAS Token Authentication +There have been significant changes surrounding this style of authentication - it was rather complex in V2, and we have tried to simplify it for V3. It now also uses the new `.create()` method rather than `.create_from_sastoken()`. With this new style of providing a custom token via callback, you no longer +will have to manually update the SAS token via the `.on_new_sastoken_required` handler, and as such, +the handler no longer exists. + +### V2 +```python +from azure.iot.device import IoTHubDeviceClient + +def get_new_sastoken(): + sastoken = # Do something here to create/retrieve a token + return sastoken + +sastoken = get_new_sastoken() +client = IoTHubDeviceClient.create_from_sastoken(sastoken) + +def sastoken_update_handler(): + print("Updating SAS Token...") + sastoken = get_new_sastoken() + client.update_sastoken(sastoken) + print("SAS Token updated") + +client.on_new_sastoken_required = sastoken_update_handler +``` + +### V3 +```python +from azure.iot.device import IoTHubDeviceClient + +def get_new_sastoken(): + sastoken = # Do something here to create/retrieve a token + return sastoken + +client = IoTHubDeviceClient.create( + hostname="", + device_id="", + sastoken_fn=get_new_sastoken, +) +``` + +## X509 Authentication +Using X509 authentication is now provided via the new `ssl_context` keyword for the `.create()` method, rather than having it's own `.create_from_x509_certificate()` method. This is to allow additional flexibility for customers who wish for more control over their TLS/SSL authorization. See "TLS/SSL customization" below for more information. + +### V2 +```python +from azure.iot.device import IoTHubDeviceClient, X509 + +x509 = X509( + cert_file="", + key_file="", + pass_phrase="", +) + +client = IoTHubDeviceClient.create_from_x509_certificate( + hostname="", + device_id="", + x509=x509, +) +``` + +### V3 +```python +from azure.iot.device import IoTHubDeviceClient +import ssl + +ssl_context = ssl.SSLContext.create_default_context() +ssl_context.load_cert_chain( + certfile="", + keyfile="", + password="", +) + +client = IoTHubDeviceClient.create( + hostname="", + device_id="", + ssl_context=ssl_context, +) +``` + +Note that SSLContexts can be used with the `.create_from_connection_string()` factory method as well, so V3 now fully supports X509 connection strings. +### V3 +```python +from azure.iot.device import IoTHubDeviceClient +import ssl + +ssl_context = ssl.SSLContext.create_default_context() +ssl_context.load_cert_chain( + certfile="", + keyfile="", + password="", +) + +client = IoTHubDeviceClient.create_from_connection_string( + "", + ssl_context=ssl_context, +) +``` + +## TLS/SSL Customization +To allow users more flexibility, we have added the ability to inject an `SSLContext` object into the client via the optional `ssl_context` keyword argument to factory methods in order to customize the TLS/SSL encryption and authentication. As a result, some features previously handled via client APIs are now expected to have been directly set on the injected `SSLContext`. + +By moving to a model that allows `SSLContext` injection we not only bring our client in line with standard practices, but we also allow for users to modify any aspect of their `SSLContext`, not just the ones we previously supported via API. + +### **Server Verification Certificates (CA certs)** +### V2 +```python +from azure.iot.device import IoTHubDeviceClient + +certfile = open("") +root_ca_cert = certfile.read() + +client = IoTHubDeviceClient.create_from_connection_string( + "", + server_verification_cert=root_ca_cert +) +``` + +### V3 +```python +from azure.iot.device import IoTHubDeviceClient +import ssl + +ssl_context = ssl.SSLContext.create_default_context( + cafile="", +) + +client = IoTHubDeviceClient.create_from_connection_string( + "", + ssl_context=ssl_context, +) +``` + +### **Cipher Suites** +### V2 +```python +from azure.iot.device import IoTHubDeviceClient + +client = IoTHubDeviceClient.create_from_connection_string( + "", + cipher="" +) +``` + +### V3 +```python +from azure.iot.device import IoTHubDeviceClient +import ssl + +ssl_context = ssl.SSLContext.create_default_context() +ssl_context.set_ciphers("") + +client = IoTHubDeviceClient.create_from_connection_string( + "", + ssl_context=ssl_context, +) +``` + +## Modified Client Options + +Some keyword arguments provided at client creation have changed or been removed + +| V2 | V3 | Explanation | +|-----------------------------|------------------|----------------------------------------------------------| +| `connection_retry` | `auto_reconnect` | Improved clarity | +| `connection_retry_interval` | **REMOVED** | Automatic reconnect no longer uses a static interval | +| `auto_connect` | **REMOVED** | Initial manual connection now required | +| `ensure_desired_properties` | **REMOVED** | No more implicit twin updates | +| `sastoken_ttl` | **REMOVED** | Unnecessary, but open to re-adding if a use case emerges | +| `gateway_hostname` | **REMOVED** | Supported via `hostname` parameter | +| `server_verification_cert` | **REMOVED** | Supported via SSL injection | +| `cipher` | **REMOVED** | Supported via SSL injection | diff --git a/migration_guide_provisioning.md b/migration_guide_provisioning.md new file mode 100644 index 000000000..cb782755a --- /dev/null +++ b/migration_guide_provisioning.md @@ -0,0 +1,42 @@ +# Azure IoT Device SDK for Python Migration Guide - ProvisioningDeviceClient + +This guide details how to update existing code that uses an `azure-iot-device` V2 release to use a V3 release instead. While the APIs remain mostly the same, there are several differences you will need to account for in your application, as changes have been made in order to provide a more reliable and consistent user experience. + +Note that this guide is a work in progress. + +## Shutting down - ProvisioningDeviceClient + +As with the IoTHub clients mentioned above, the Provisioning clients now also require shutdown. This was implicit in V2, but now it must be explicit and manual to ensure graceful exit. + +### V2 +```python +from azure.iot.device import ProvisioningDeviceClient + +client = ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host="", + registration_id="", + id_scope="", + symmetric_key=", +) + +registration_result = client.register() + +# Shutdown is implicit upon successful registration +``` + +### V3 +```python +from azure.iot.device import ProvisioningDeviceClient + +client = ProvisioningDeviceClient.create_from_symmetric_key( + provisioning_host="", + registration_id="", + id_scope="", + symmetric_key=", +) + +registration_result = client.register() + +# Manual shutdown for graceful exit +client.shutdown() +``` \ No newline at end of file diff --git a/v3_async_wip/tests/test_connection_string.py b/v3_async_wip/tests/test_connection_string.py new file mode 100644 index 000000000..f490f2529 --- /dev/null +++ b/v3_async_wip/tests/test_connection_string.py @@ -0,0 +1,154 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest +import logging +from v3_async_wip.connection_string import ConnectionString + +logging.basicConfig(level=logging.DEBUG) + +# TODO: eliminate refernces to service connection string + + +@pytest.mark.describe("ConnectionString") +class TestConnectionString(object): + @pytest.mark.it("Instantiates from a given connection string") + @pytest.mark.parametrize( + "input_string", + [ + pytest.param( + "HostName=my.host.name;SharedAccessKeyName=mykeyname;SharedAccessKey=Zm9vYmFy", + id="Service connection string", + ), + pytest.param( + "HostName=my.host.name;DeviceId=my-device;SharedAccessKey=Zm9vYmFy", + id="Device connection string", + ), + pytest.param( + "HostName=my.host.name;DeviceId=my-device;SharedAccessKey=Zm9vYmFy;GatewayHostName=mygateway", + id="Device connection string w/ gatewayhostname", + ), + pytest.param( + "HostName=my.host.name;DeviceId=my-device;x509=True", + id="Device connection string w/ X509", + ), + pytest.param( + "HostName=my.host.name;DeviceId=my-device;ModuleId=my-module;SharedAccessKey=Zm9vYmFy", + id="Module connection string", + ), + pytest.param( + "HostName=my.host.name;DeviceId=my-device;ModuleId=my-module;SharedAccessKey=Zm9vYmFy;GatewayHostName=mygateway", + id="Module connection string w/ gatewayhostname", + ), + pytest.param( + "HostName=my.host.name;DeviceId=my-device;ModuleId=my-module;x509=True", + id="Module connection string w/ X509", + ), + ], + ) + def test_instantiates_correctly_from_string(self, input_string): + cs = ConnectionString(input_string) + assert isinstance(cs, ConnectionString) + + @pytest.mark.it("Raises ValueError on invalid string input during instantiation") + @pytest.mark.parametrize( + "input_string", + [ + pytest.param("", id="Empty string"), + pytest.param("garbage", id="Not a connection string"), + pytest.param("HostName=my.host.name", id="Incomplete connection string"), + pytest.param( + "InvalidKey=my.host.name;SharedAccessKeyName=mykeyname;SharedAccessKey=Zm9vYmFy", + id="Invalid key", + ), + pytest.param( + "HostName=my.host.name;HostName=my.host.name;SharedAccessKey=mykeyname;SharedAccessKey=Zm9vYmFy", + id="Duplicate key", + ), + pytest.param( + "HostName=my.host.name;DeviceId=my-device;ModuleId=my-module;SharedAccessKey=mykeyname;x509=true", + id="Mixed authentication scheme", + ), + ], + ) + def test_raises_value_error_on_invalid_input(self, input_string): + with pytest.raises(ValueError): + ConnectionString(input_string) + + @pytest.mark.it("Raises TypeError on non-string input during instantiation") + @pytest.mark.parametrize( + "input_val", + [ + pytest.param(2123, id="Integer"), + pytest.param(23.098, id="Float"), + pytest.param(b"bytes", id="Bytes"), + pytest.param(object(), id="Complex object"), + pytest.param(["a", "b"], id="List"), + pytest.param({"a": "b"}, id="Dictionary"), + ], + ) + def test_raises_type_error_on_non_string_input(self, input_val): + with pytest.raises(TypeError): + ConnectionString(input_val) + + @pytest.mark.it("Uses the input connection string as a string representation") + def test_string_representation_of_object_is_the_input_string(self): + string = "HostName=my.host.name;SharedAccessKeyName=mykeyname;SharedAccessKey=Zm9vYmFy" + cs = ConnectionString(string) + assert str(cs) == string + + @pytest.mark.it("Supports indexing syntax to return the stored value for a given key") + def test_indexing_key_returns_corresponding_value(self): + cs = ConnectionString( + "HostName=my.host.name;SharedAccessKeyName=mykeyname;SharedAccessKey=Zm9vYmFy" + ) + assert cs["HostName"] == "my.host.name" + assert cs["SharedAccessKeyName"] == "mykeyname" + assert cs["SharedAccessKey"] == "Zm9vYmFy" + + @pytest.mark.it("Raises KeyError if indexing on a key not contained in the ConnectionString") + def test_indexing_key_raises_key_error_if_key_not_in_string(self): + with pytest.raises(KeyError): + cs = ConnectionString( + "HostName=my.host.name;SharedAccessKeyName=mykeyname;SharedAccessKey=Zm9vYmFy" + ) + cs["SharedAccessSignature"] + + @pytest.mark.it( + "Supports the 'in' operator for validating if a key is contained in the ConnectionString" + ) + def test_item_in_string(self): + cs = ConnectionString( + "HostName=my.host.name;SharedAccessKeyName=mykeyname;SharedAccessKey=Zm9vYmFy" + ) + assert "SharedAccessKey" in cs + assert "SharedAccessKeyName" in cs + assert "HostName" in cs + assert "FakeKeyNotInTheString" not in cs + + +@pytest.mark.describe("ConnectionString - .get()") +class TestConnectionStringGet(object): + @pytest.mark.it("Returns the stored value for a given key") + def test_calling_get_with_key_returns_corresponding_value(self): + cs = ConnectionString( + "HostName=my.host.name;SharedAccessKeyName=mykeyname;SharedAccessKey=Zm9vYmFy" + ) + assert cs.get("HostName") == "my.host.name" + + @pytest.mark.it("Returns None if the given key is invalid") + def test_calling_get_with_invalid_key_and_no_default_value_returns_none(self): + cs = ConnectionString( + "HostName=my.host.name;SharedAccessKeyName=mykeyname;SharedAccessKey=Zm9vYmFy" + ) + assert cs.get("invalidkey") is None + + @pytest.mark.it("Returns an optionally provided default value if the given key is invalid") + def test_calling_get_with_invalid_key_and_a_default_value_returns_default_value(self): + cs = ConnectionString( + "HostName=my.host.name;SharedAccessKeyName=mykeyname;SharedAccessKey=Zm9vYmFy" + ) + assert cs.get("invalidkey", "defaultval") == "defaultval" diff --git a/v3_async_wip/tests/test_iothub_client.py b/v3_async_wip/tests/test_iothub_client.py new file mode 100644 index 000000000..a4dc268c4 --- /dev/null +++ b/v3_async_wip/tests/test_iothub_client.py @@ -0,0 +1,2779 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import asyncio +import os +import pytest +import ssl +import time +from dev_utils import custom_mock +from pytest_lazyfixture import lazy_fixture +from v3_async_wip.iothub_client import IoTHubDeviceClient, IoTHubModuleClient +from v3_async_wip import config, edge_hsm, iothub_client, iot_exceptions +from v3_async_wip import connection_string as cs +from v3_async_wip import iothub_mqtt_client as mqtt +from v3_async_wip import iothub_http_client as http +from v3_async_wip import sastoken as st +from v3_async_wip import signing_mechanism as sm + +FAKE_DEVICE_ID = "fake_device_id" +FAKE_MODULE_ID = "fake_module_id" +FAKE_HOSTNAME = "fake.hostname" +FAKE_GATEWAY_HOSTNAME = "fake.gateway.hostname" +FAKE_URI = "fake/resource/location" +FAKE_SYMMETRIC_KEY = "Zm9vYmFy" +FAKE_SIGNATURE = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=" + + +# NOTE: HELPFUL INFORMATION ABOUT NAVIGATING THIS FILE +# This is a very long test file. Lots going on in here. To help navigate, there are headings for +# various sections. You can use the search feature of your IDE to jump to these headings: +# +# - Shared Client Tests +# - IoTHubDeviceClient Tests +# - IoTHubModuleClient Tests + + +# ~~~~~ Helpers ~~~~~~ + + +def sastoken_generator_fn(): + return "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}".format( + resource=FAKE_URI, signature=FAKE_SIGNATURE, expiry=str(int(time.time()) + 3600) + ) + + +# ~~~~~ Fixtures ~~~~~~ + +# Mock out the underlying clients to avoid starting up tasks that will reduce performance +@pytest.fixture(autouse=True) +def mock_mqtt_iothub_client(mocker): + return mocker.patch.object(mqtt, "IoTHubMQTTClient", spec=mqtt.IoTHubMQTTClient).return_value + + +@pytest.fixture(autouse=True) +def mock_http_iothub_client(mocker): + return mocker.patch.object(http, "IoTHubHTTPClient", spec=http.IoTHubHTTPClient).return_value + + +@pytest.fixture(autouse=True) +def mock_sastoken_provider(mocker): + return mocker.patch.object(st, "SasTokenProvider", spec=st.SasTokenProvider).return_value + + +@pytest.fixture +def custom_ssl_context(): + # NOTE: It doesn't matter how the SSLContext is configured for the tests that use this fixture, + # so it isn't configured at all. + return ssl.SSLContext() + + +@pytest.fixture(params=["Default SSLContext", "Custom SSLContext"]) +def optional_ssl_context(request, custom_ssl_context): + """Sometimes tests need to show something works with or without an SSLContext""" + if request.param == "Custom SSLContext": + return custom_ssl_context + else: + return None + + +# ~~~~~ Parametrizations ~~~~~ +# Define parametrizations that will be used across multiple test suites, and that may eventually +# need to be changed everywhere, e.g. new auth scheme added. +# Note that some parametrizations are also defined within the scope of a single test suite if that +# is the only unit they are relevant to. + + +# Parameters for arguments to the .create() method of clients. Represent different types of +# authentication. Use this parametrization whenever possible on .create() tests. +# NOTE: Do NOT combine this with the SSL fixtures above. This parametrization contains +# ssl contexts where necessary +create_auth_params = [ + # Provide args in form 'symmetric_key, sastoken_fn, ssl_context' + pytest.param(FAKE_SYMMETRIC_KEY, None, None, id="Symmetric Key SAS Auth + Default SSLContext"), + pytest.param( + FAKE_SYMMETRIC_KEY, + None, + lazy_fixture("custom_ssl_context"), + id="Symmetric Key SAS Auth + Custom SSLContext", + ), + pytest.param( + None, + sastoken_generator_fn, + None, + id="User-Provided SAS Token Auth + Default SSLContext", + ), + pytest.param( + None, + sastoken_generator_fn, + lazy_fixture("custom_ssl_context"), + id="User-Provided SAS Token Auth + Custom SSLContext", + ), + pytest.param(None, None, lazy_fixture("custom_ssl_context"), id="Custom SSLContext Auth"), +] +# Just the parameters where SAS auth is used +create_auth_params_sas = [param for param in create_auth_params if "SAS" in param.id] +# Just the parameters where a Symmetric Key auth is used +create_auth_params_sk = [param for param in create_auth_params if param.values[0] is not None] +# Just the parameters where SAS callback auth is used +create_auth_params_token_cb = [param for param in create_auth_params if param.values[1] is not None] +# Just the parameters where a custom SSLContext is provided +create_auth_params_custom_ssl = [ + param for param in create_auth_params if param.values[2] is not None +] +# Just the parameters where a custom SSLContext is NOT provided +create_auth_params_default_ssl = [param for param in create_auth_params if param.values[2] is None] + + +# Covers all option kwargs shared across client factory methods +factory_kwargs = [ + pytest.param("auto_reconnect", False, id="auto_reconnect"), + pytest.param("keep_alive", 34, id="keep_alive"), + pytest.param("product_info", "fake-product-info", id="product_info"), + pytest.param( + "proxy_options", config.ProxyOptions("HTTP", "fake.address", 1080), id="proxy_options" + ), + pytest.param("websockets", True, id="websockets"), +] + +sastoken_provider_create_exceptions = [ + pytest.param(st.SasTokenError(), id="SasTokenError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), +] + +sk_sm_create_exceptions = [ + pytest.param(ValueError(), id="ValueError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), +] + + +# ~~~~~ Shared Client Tests ~~~~~~ +# Many methods are the same between an IoTHubDeviceClient and an IoTHubModuleClient. +# For those tests, write a single suite, that can use generic fixtures, and then +# inherit the generic suite class into a child suite class that provides specific +# versions of those fixtures. +# Only do this if the tests are identical aside from client class - if there are +# distinctions, even if minor, just write two separate suites - it is not worth the +# trouble. + + +class SharedClientInstantiationTests: + """Defines shared tests for instantiation of Device/Module clients""" + + @pytest.fixture + def client_config(self, custom_ssl_context): + # NOTE: It really doesn't matter whether or not this has a module_id for the purposes + # of these tests, so don't make this more complicated than it needs to be. + return config.IoTHubClientConfig( + device_id=FAKE_DEVICE_ID, hostname=FAKE_HOSTNAME, ssl_context=custom_ssl_context + ) + + @pytest.mark.it( + "Instantiates and stores an IoTHubMQTTClient using the provided IoTHubClientConfig" + ) + async def test_mqtt_client(self, mocker, client_class, client_config): + assert mqtt.IoTHubMQTTClient.call_count == 0 + + client = client_class(client_config) + + assert client._mqtt_client is mqtt.IoTHubMQTTClient.return_value + assert mqtt.IoTHubMQTTClient.call_count == 1 + assert mqtt.IoTHubMQTTClient.call_args == mocker.call(client_config) + + await client.shutdown() + + @pytest.mark.it( + "Instantiates and stores an IoTHubHTTPClient using the provided IoTHubClientConfig" + ) + async def test_http_client(self, mocker, client_class, client_config): + assert http.IoTHubHTTPClient.call_count == 0 + + client = client_class(client_config) + + assert client._http_client is http.IoTHubHTTPClient.return_value + assert http.IoTHubHTTPClient.call_count == 1 + assert http.IoTHubHTTPClient.call_args == mocker.call(client_config) + + await client.shutdown() + + @pytest.mark.it("Stores the IoTHubClientConfig's `sastoken_provider`, if it exists") + async def test_sastoken_provider(self, client_class, client_config, mock_sastoken_provider): + client_config.sastoken_provider = mock_sastoken_provider + + client = client_class(client_config) + + assert client._sastoken_provider is mock_sastoken_provider + + await client.shutdown() + + +class SharedClientShutdownTests: + """Defines shared tests for Device/Module client .shutdown() method""" + + @pytest.mark.it("Shuts down the IoTHubMQTTClient") + async def test_mqtt_shutdown(self, client): + assert client._mqtt_client.shutdown.await_count == 0 + + await client.shutdown() + + assert client._mqtt_client.shutdown.await_count == 1 + + @pytest.mark.it("Shuts down the IoTHubHTTPClient") + async def test_http_shutdown(self, client): + assert client._http_client.shutdown.await_count == 0 + + await client.shutdown() + + assert client._http_client.shutdown.await_count == 1 + + @pytest.mark.it("Shuts down the SasTokenProvider, if present") + async def test_sastoken_provider_shutdown(self, mocker, client, mock_sastoken_provider): + # Add the mock sastoken provider since it isn't there by default + assert client._sastoken_provider is None + client._sastoken_provider = mock_sastoken_provider + assert mock_sastoken_provider.shutdown.await_count == 0 + + await client.shutdown() + + assert mock_sastoken_provider.shutdown.await_count == 1 + assert mock_sastoken_provider.shutdown.await_args == mocker.call() + + @pytest.mark.it("Handles the case where no SasTokenProvider is present") + async def test_no_sastoken_provider(self, client): + assert client._sastoken_provider is None + + await client.shutdown() + + # If no error was raised, this test passes + + @pytest.mark.it( + "Allows any exception raised during IoTHubMQTTClient shutdown to propagate, but only after completing the rest of the shutdown procedure" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="W/ SasTokenProvider"), + pytest.param(None, id="No SasTokenProvider"), + ], + ) + async def test_mqtt_client_raises(self, client, sastoken_provider, arbitrary_exception): + client._sastoken_provider = sastoken_provider + assert client._mqtt_client.shutdown.await_count == 0 + assert client._http_client.shutdown.await_count == 0 + if sastoken_provider: + assert client._sastoken_provider.shutdown.await_count == 0 + + # MQTT shutdown will raise + client._mqtt_client.shutdown.side_effect = arbitrary_exception + + # MQTT shutdown error propagates + with pytest.raises(type(arbitrary_exception)) as e_info: + await client.shutdown() + assert e_info.value is arbitrary_exception + + # But the whole shutdown protocol was executed + assert client._mqtt_client.shutdown.await_count == 1 + assert client._http_client.shutdown.await_count == 1 + if sastoken_provider: + assert client._sastoken_provider.shutdown.await_count == 1 + + # Unset the the MQTT shutdown failure so teardown doesn't crash + client._mqtt_client.shutdown.side_effect = None + + @pytest.mark.it( + "Allows any exception raised during IoTHubHTTPClient shutdown to propagate, but only after completing the rest of the shutdown procedure" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="W/ SasTokenProvider"), + pytest.param(None, id="No SasTokenProvider"), + ], + ) + async def test_http_client_raises(self, client, sastoken_provider, arbitrary_exception): + client._sastoken_provider = sastoken_provider + assert client._mqtt_client.shutdown.await_count == 0 + assert client._http_client.shutdown.await_count == 0 + if sastoken_provider: + assert client._sastoken_provider.shutdown.await_count == 0 + + # HTTP shutdown will raise + client._http_client.shutdown.side_effect = arbitrary_exception + + # HTTP shutdown error propagates + with pytest.raises(type(arbitrary_exception)) as e_info: + await client.shutdown() + assert e_info.value is arbitrary_exception + + # But the whole shutdown protocol was executed + assert client._mqtt_client.shutdown.await_count == 1 + assert client._http_client.shutdown.await_count == 1 + if sastoken_provider: + assert client._sastoken_provider.shutdown.await_count == 1 + + # Unset the the HTTP shutdown failure so teardown doesn't crash + client._http_client.shutdown.side_effect = None + + @pytest.mark.it( + "Allows any exception raised during SasTokenProvider shutdown to propagate, but only after completing the rest of the shutdown procedure" + ) + async def test_sastoken_provider_raises( + self, client, mock_sastoken_provider, arbitrary_exception + ): + client._sastoken_provider = mock_sastoken_provider + assert client._mqtt_client.shutdown.await_count == 0 + assert client._http_client.shutdown.await_count == 0 + assert client._sastoken_provider.shutdown.await_count == 0 + + # SasTokenProvider shutdown will raise + client._sastoken_provider.shutdown.side_effect = arbitrary_exception + + # SasTokenProvider shutdown error propagates + with pytest.raises(type(arbitrary_exception)) as e_info: + await client.shutdown() + assert e_info.value is arbitrary_exception + + # But the whole shutdown protocol was executed + assert client._mqtt_client.shutdown.await_count == 1 + assert client._http_client.shutdown.await_count == 1 + assert client._sastoken_provider.shutdown.await_count == 1 + + # Unset the the SasTokenProvider shutdown failure so teardown doesn't crash + client._sastoken_provider.shutdown.side_effect = None + + @pytest.mark.it( + "Can be cancelled during IoTHubMQTTClient shutdown, but shutdown procedure will still complete" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="W/ SasTokenProvider"), + pytest.param(None, id="No SasTokenProvider"), + ], + ) + async def test_cancel_mqtt_client(self, client, sastoken_provider): + client._sastoken_provider = sastoken_provider + assert client._mqtt_client.shutdown.await_count == 0 + assert client._http_client.shutdown.await_count == 0 + if sastoken_provider: + assert client._sastoken_provider.shutdown.await_count == 0 + + # MQTT shutdown will hang + original_shutdown = client._mqtt_client.shutdown + client._mqtt_client.shutdown = custom_mock.HangingAsyncMock() + + # Attempt to shutdown will hang + t = asyncio.create_task(client.shutdown()) + await client._mqtt_client.shutdown.wait_for_hang() + assert not t.done() + + # Shutdown protocol is incomplete + assert client._mqtt_client.shutdown.await_count == 1 + assert client._http_client.shutdown.await_count == 0 + if sastoken_provider: + assert client._sastoken_provider.shutdown.await_count == 0 + + # Shutdown can be cancelled + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # But the whole shutdown protocol was still executed + assert client._mqtt_client.shutdown.await_count == 1 + assert client._http_client.shutdown.await_count == 1 + if sastoken_provider: + assert client._sastoken_provider.shutdown.await_count == 1 + + # Unset the the MQTT shutdown hang so teardown doesn't hang + client._mqtt_client.shutdown = original_shutdown + + @pytest.mark.it( + "Can be cancelled during IoTHubHTTPClient shutdown, but shutdown procedure will still complete" + ) + @pytest.mark.parametrize( + "sastoken_provider", + [ + pytest.param(lazy_fixture("mock_sastoken_provider"), id="W/ SasTokenProvider"), + pytest.param(None, id="No SasTokenProvider"), + ], + ) + async def test_cancel_http_client(self, client, sastoken_provider): + client._sastoken_provider = sastoken_provider + assert client._mqtt_client.shutdown.await_count == 0 + assert client._http_client.shutdown.await_count == 0 + if sastoken_provider: + assert client._sastoken_provider.shutdown.await_count == 0 + + # HTTP shutdown will hang + original_shutdown = client._http_client.shutdown + client._http_client.shutdown = custom_mock.HangingAsyncMock() + + # Attempt to shutdown will hang + t = asyncio.create_task(client.shutdown()) + await client._http_client.shutdown.wait_for_hang() + assert not t.done() + + # Shutdown protocol is incomplete + # (unless sastoken provider isn't there, in which case, I guess it is complete) + assert client._mqtt_client.shutdown.await_count == 1 + assert client._http_client.shutdown.await_count == 1 + if sastoken_provider: + assert client._sastoken_provider.shutdown.await_count == 0 + + # Shutdown can be cancelled + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # But the whole shutdown protocol was still executed + assert client._mqtt_client.shutdown.await_count == 1 + assert client._http_client.shutdown.await_count == 1 + if sastoken_provider: + assert client._sastoken_provider.shutdown.await_count == 1 + + # Unset the the HTTP shutdown hang so teardown doesn't hang + client._http_client.shutdown = original_shutdown + + @pytest.mark.it( + "Can be cancelled during SasTokenProvider shutdown, but shutdown procedure will still complete" + ) + async def test_cancel_sastoken_provider(self, client, mock_sastoken_provider): + client._sastoken_provider = mock_sastoken_provider + assert client._mqtt_client.shutdown.await_count == 0 + assert client._http_client.shutdown.await_count == 0 + assert client._sastoken_provider.shutdown.await_count == 0 + + # SasTokenProvider shutdown will hang + original_shutdown = client._sastoken_provider.shutdown + client._sastoken_provider.shutdown = custom_mock.HangingAsyncMock() + + # Attempt to shutdown will hang + t = asyncio.create_task(client.shutdown()) + await client._sastoken_provider.shutdown.wait_for_hang() + assert not t.done() + + # Shutdown protocol is incomplete + # (okay, no it's not, it's definitely done, but I'm keeping this test structure the same + # so it can easily be expanded in the future) + assert client._mqtt_client.shutdown.await_count == 1 + assert client._http_client.shutdown.await_count == 1 + assert client._sastoken_provider.shutdown.await_count == 1 + + # Shutdown can be cancelled + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # But the whole shutdown protocol was still executed + assert client._mqtt_client.shutdown.await_count == 1 + assert client._http_client.shutdown.await_count == 1 + assert client._sastoken_provider.shutdown.await_count == 1 + + # Unset the the HTTP shutdown hang so teardown doesn't hang + client._sastoken_provider.shutdown = original_shutdown + + +# ~~~~~ IoTHubDeviceClient Tests ~~~~~ + + +class IoTHubDeviceClientTestConfig: + """Mixin parent class defining a set of fixtures used in IoTHubDeviceClient tests""" + + @pytest.fixture + async def client(self, custom_ssl_context): + # Use a custom_ssl_context for auth for simplicity. Almost any test using this fixture + # will not be affected by auth type, so just use the simplest one. + client_config = config.IoTHubClientConfig( + device_id=FAKE_DEVICE_ID, hostname=FAKE_HOSTNAME, ssl_context=custom_ssl_context + ) + client = IoTHubDeviceClient(client_config) + yield client + await client.shutdown() + + @pytest.fixture + def client_class(self): + return IoTHubDeviceClient + + +@pytest.mark.describe("IoTHubDeviceClient -- Instantiation") +class TestIoTHubDeviceClientInstantiation( + SharedClientInstantiationTests, IoTHubDeviceClientTestConfig +): + pass + + +@pytest.mark.describe("IoTHubDeviceClient - .create()") +class TestIoTHubDeviceClientCreate(IoTHubDeviceClientTestConfig): + @pytest.mark.it( + "Returns a new IoTHubDeviceClient instance, created with the use of a new IoTHubClientConfig object" + ) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params) + async def test_instantiation(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_config_cls = mocker.spy(config, "IoTHubClientConfig") + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + assert spy_config_cls.call_count == 0 + assert spy_client_init.call_count == 0 + + client = await IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + assert spy_config_cls.call_count == 1 + assert spy_client_init.call_count == 1 + # NOTE: Normally passing through self or cls isn't necessary in a mock call, but + # it seems that when mocking the __init__ it is. This is actually good though, as it + # allows us to match the specific object reference which otherwise is very dicey when + # mocking constructors/initializers + assert spy_client_init.call_args == mocker.call(client, spy_config_cls.spy_return) + assert isinstance(client, IoTHubDeviceClient) + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the provided `device_id` on the IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params) + async def test_device_id(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + + client = await IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.device_id == FAKE_DEVICE_ID + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Does not set any `module_id` on the IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params) + async def test_module_id(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + + client = await IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.module_id is None + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the provided `hostname` on the IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params) + async def test_hostname(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + + client = await IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.hostname == FAKE_HOSTNAME + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the provided `ssl_context` on the IoTHubClientConfig used to create the client, if provided" + ) + @pytest.mark.parametrize( + "symmetric_key, sastoken_fn, ssl_context", create_auth_params_custom_ssl + ) + async def test_custom_ssl_context(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + assert ssl_context is not None + + client = await IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.ssl_context is ssl_context + + # Graceful exit + await client.shutdown() + + # NOTE: The details of this default SSLContext are covered in the TestDefaultSSLContext suite + @pytest.mark.it( + "Sets a default SSLContext on the IoTHubClientConfig used to create the client, if `ssl_context` is not provided" + ) + @pytest.mark.parametrize( + "symmetric_key, sastoken_fn, ssl_context", create_auth_params_default_ssl + ) + async def test_default_ssl_context(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + spy_default_ssl = mocker.spy(iothub_client, "_default_ssl_context") + assert ssl_context is None + + client = await IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.ssl_context is not None + assert isinstance(config.ssl_context, ssl.SSLContext) + # This SSLContext was returned from the default ssl context helper + assert spy_default_ssl.call_count == 1 + assert spy_default_ssl.call_args == mocker.call() + assert config.ssl_context is spy_default_ssl.spy_return + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Creates a SasTokenProvider that uses symmetric key-based token generation and sets it on the IoTHubClientConfig used to create the client, if `symmetric_key` is provided as a parameter" + ) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params_sk) + async def test_sk_auth(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + spy_sk_sm_cls = mocker.spy(sm, "SymmetricKeySigningMechanism") + spy_st_generator_cls = mocker.spy(st, "InternalSasTokenGenerator") + spy_st_provider_create = mocker.spy(st.SasTokenProvider, "create_from_generator") + expected_token_uri = "{hostname}/devices/{device_id}".format( + hostname=FAKE_HOSTNAME, device_id=FAKE_DEVICE_ID + ) + assert sastoken_fn is None + + client = await IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + ssl_context=ssl_context, + ) + + # SymmetricKeySigningMechanism was created from the symmetric key + assert spy_sk_sm_cls.call_count == 1 + assert spy_sk_sm_cls.call_args == mocker.call(FAKE_SYMMETRIC_KEY) + # InternalSasTokenGenerator was created from the SymmetricKeySigningMechanism and expected URI + assert spy_st_generator_cls.call_count == 1 + assert spy_st_generator_cls.call_args == mocker.call( + signing_mechanism=spy_sk_sm_cls.spy_return, uri=expected_token_uri + ) + # SasTokenProvider was created from the InternalSasTokenGenerator + assert spy_st_provider_create.call_count == 1 + assert spy_st_provider_create.call_args == mocker.call(spy_st_generator_cls.spy_return) + # The SasTokenProvider was set on the IoTHubClientConfig that was used to instantiate the client + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.sastoken_provider is spy_st_provider_create.spy_return + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Creates a SasTokenProvider that uses user callback-based token generation and sets it on the IoTHubClientConfig used to create the client, if `sastoken_fn` is provided as a parameter" + ) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params_token_cb) + async def test_token_callback_auth(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + spy_st_generator_cls = mocker.spy(st, "ExternalSasTokenGenerator") + spy_st_provider_create = mocker.spy(st.SasTokenProvider, "create_from_generator") + assert symmetric_key is None + + client = await IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + # ExternalSasTokenGenerator was created from the `sastoken_fn`` + assert spy_st_generator_cls.call_count == 1 + assert spy_st_generator_cls.call_args == mocker.call(sastoken_generator_fn) + # SasTokenProvider was created from the ExternalSasTokenGenerator + assert spy_st_provider_create.call_count == 1 + assert spy_st_provider_create.call_args == mocker.call(spy_st_generator_cls.spy_return) + # The SasTokenProvider was set on the IoTHubClientConfig that was used to instantiate the client + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.sastoken_provider is spy_st_provider_create.spy_return + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Does not set any SasTokenProvider on the IoTHubClientConfig used to create the client if neither `symmetric_key` nor `sastoken_fn` are provided as parameters" + ) + async def test_non_sas_auth(self, mocker, custom_ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + + client = await IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + ssl_context=custom_ssl_context, + ) + + # No SasTokenProvider was set on the IoTHubClientConfig that was used to instantiate the client + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.sastoken_provider is None + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets any provided optional keyword arguments on IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params) + @pytest.mark.parametrize("kwarg_name, kwarg_value", factory_kwargs) + async def test_kwargs( + self, mocker, symmetric_key, sastoken_fn, ssl_context, kwarg_name, kwarg_value + ): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + + kwargs = {kwarg_name: kwarg_value} + + client = await IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + **kwargs + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert getattr(config, kwarg_name) == kwarg_value + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Raises ValueError if neither `symmetric_key` nor `sastoken_fn` nor `ssl_context` are provided as parameters" + ) + async def test_no_auth(self): + with pytest.raises(ValueError): + await IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + ) + + @pytest.mark.it( + "Raises ValueError if both `symmetric_key` and `sastoken_fn` are provided as parameters" + ) + async def test_conflicting_auth(self, optional_ssl_context): + with pytest.raises(ValueError): + await IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=FAKE_SYMMETRIC_KEY, + sastoken_fn=sastoken_generator_fn, + ssl_context=optional_ssl_context, + ) + + @pytest.mark.it( + "Allows any exceptions raised when creating a SymmetricKeySigningMechanism to propagate" + ) + @pytest.mark.parametrize("exception", sk_sm_create_exceptions) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params_sk) + async def test_sksm_raises(self, mocker, symmetric_key, sastoken_fn, ssl_context, exception): + mocker.patch.object(sm, "SymmetricKeySigningMechanism", side_effect=exception) + assert sastoken_fn is None + + with pytest.raises(type(exception)) as e_info: + await IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + ssl_context=ssl_context, + ) + assert e_info.value is exception + + @pytest.mark.it("Allows any exceptions raised when creating a SasTokenProvider to propagate") + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params_sas) + @pytest.mark.parametrize("exception", sastoken_provider_create_exceptions) + async def test_sastoken_provider_raises( + self, mocker, symmetric_key, sastoken_fn, ssl_context, exception + ): + mocker.patch.object(st.SasTokenProvider, "create_from_generator", side_effect=exception) + + with pytest.raises(type(exception)) as e_info: + await IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + assert e_info.value is exception + + @pytest.mark.it("Can be cancelled while waiting for SasTokenProvider creation") + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params_sas) + async def test_cancel_during_sastoken_provider_creation( + self, mocker, symmetric_key, sastoken_fn, ssl_context + ): + mocker.patch.object( + st.SasTokenProvider, "create_from_generator", custom_mock.HangingAsyncMock() + ) + + coro = IoTHubDeviceClient.create( + device_id=FAKE_DEVICE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + t = asyncio.create_task(coro) + + # Hanging, waiting for SasTokenProvider creation to finish + await st.SasTokenProvider.create_from_generator.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("IoTHubDeviceClient - .create_from_connection_string()") +class TestIoTHubDeviceClientCreateFromConnectionString(IoTHubDeviceClientTestConfig): + + factory_params = [ + pytest.param( + "HostName={hostname};DeviceId={device_id};SharedAccessKey={shared_access_key}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + shared_access_key=FAKE_SYMMETRIC_KEY, + ), + None, + id="Standard Connection String w/ SharedAccessKey + Default SSLContext", + ), + pytest.param( + "HostName={hostname};DeviceId={device_id};SharedAccessKey={shared_access_key}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + shared_access_key=FAKE_SYMMETRIC_KEY, + ), + lazy_fixture("custom_ssl_context"), + id="Standard Connection String w/ SharedAccessKey + Custom SSLContext", + ), + pytest.param( + "HostName={hostname};DeviceId={device_id};SharedAccessKey={shared_access_key};GatewayHostName={gateway_hostname}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + shared_access_key=FAKE_SYMMETRIC_KEY, + gateway_hostname=FAKE_GATEWAY_HOSTNAME, + ), + None, + id="Edge Connection String w/ SharedAccessKey + Default SSLContext", + ), + pytest.param( + "HostName={hostname};DeviceId={device_id};SharedAccessKey={shared_access_key};GatewayHostName={gateway_hostname}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + shared_access_key=FAKE_SYMMETRIC_KEY, + gateway_hostname=FAKE_GATEWAY_HOSTNAME, + ), + lazy_fixture("custom_ssl_context"), + id="Edge Connection String w/ SharedAccessKey + Custom SSLContext", + ), + # NOTE: X509 certs imply use of custom SSLContext + pytest.param( + "HostName={hostname};DeviceId={device_id};x509=true".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + ), + lazy_fixture("custom_ssl_context"), + id="Standard Connection String w/ X509", + ), + pytest.param( + "HostName={hostname};DeviceId={device_id};GatewayHostName={gateway_hostname};x509=true".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + gateway_hostname=FAKE_GATEWAY_HOSTNAME, + ), + lazy_fixture("custom_ssl_context"), + id="Edge Connection String w/ X509", + ), + ] + # Just the parameters for using standard connection strings + factory_params_no_gateway = [ + param for param in factory_params if cs.GATEWAY_HOST_NAME not in param.values[0] + ] + # Just the parameters for using connection strings with a GatewayHostName + factory_params_gateway = [ + param for param in factory_params if cs.GATEWAY_HOST_NAME in param.values[0] + ] + # Just the parameters where a custom SSLContext is provided + factory_params_custom_ssl = [param for param in factory_params if param.values[1] is not None] + # Just the parameters where a custom SSLContext is NOT provided + factory_params_default_ssl = [param for param in factory_params if param.values[1] is None] + # Just the parameters for using SharedAccessKeys + factory_params_sak = [ + param for param in factory_params if cs.SHARED_ACCESS_KEY in param.values[0] + ] + # Just the parameters for NOT using SharedAccessKeys + factory_params_no_sak = [ + param for param in factory_params if cs.SHARED_ACCESS_KEY not in param.values[0] + ] + + @pytest.mark.it( + "Returns a new IoTHubDeviceClient instance, created with the use of a new IoTHubClientConfig object" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params) + async def test_instantiation(self, mocker, connection_string, ssl_context): + spy_config_cls = mocker.spy(config, "IoTHubClientConfig") + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + assert spy_config_cls.call_count == 0 + assert spy_client_init.call_count == 0 + + client = await IoTHubDeviceClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + assert spy_config_cls.call_count == 1 + assert spy_client_init.call_count == 1 + # NOTE: Normally passing through self or cls isn't necessary in a mock call, but + # it seems that when mocking the __init__ it is. This is actually good though, as it + # allows us to match the specific object reference which otherwise is very dicey when + # mocking constructors/initializers + assert spy_client_init.call_args == mocker.call(client, spy_config_cls.spy_return) + assert isinstance(client, IoTHubDeviceClient) + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the `DeviceId` from the connection string as the `device_id` on the IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params) + async def test_device_id(self, mocker, connection_string, ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + # Create a ConnectionString object from the connection string to simply value access + cs_obj = cs.ConnectionString(connection_string) + + client = await IoTHubDeviceClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.device_id == cs_obj[cs.DEVICE_ID] + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Does not set any `module_id` on the IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params) + async def test_module_id(self, mocker, connection_string, ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + + client = await IoTHubDeviceClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.module_id is None + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the `HostName` from the connection string as the `hostname` on the IoTHubClientConfig, if no `GatewayHostName` is present in the connection string" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_no_gateway) + async def test_hostname_cs_has_no_gateway(self, mocker, connection_string, ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + # Create a ConnectionString object from the connection string to simply value access + cs_obj = cs.ConnectionString(connection_string) + assert cs.GATEWAY_HOST_NAME not in cs_obj + + client = await IoTHubDeviceClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.hostname == cs_obj[cs.HOST_NAME] + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the `HostName` from the connection string as the `hostname` on the IoTHubClientConfig used to create the client, if no `GatewayHostName` is present in the connection string" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_gateway) + async def test_hostname_cs_has_gateway(self, mocker, connection_string, ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + # Create a ConnectionString object from the connection string to simply value access + cs_obj = cs.ConnectionString(connection_string) + assert cs.GATEWAY_HOST_NAME in cs_obj + assert cs_obj[cs.GATEWAY_HOST_NAME] != cs_obj[cs.HOST_NAME] + + client = await IoTHubDeviceClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.hostname == cs_obj[cs.GATEWAY_HOST_NAME] + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the provided `ssl_context` on the IoTHubClientConfig used to create the client, if provided" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_custom_ssl) + async def test_custom_ssl_context(self, mocker, connection_string, ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + assert ssl_context is not None + + client = await IoTHubDeviceClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.ssl_context is ssl_context + + # Graceful exit + await client.shutdown() + + # NOTE: The details of this default SSLContext are covered in the TestDefaultSSLContext suite + @pytest.mark.it( + "Sets a default SSLContext as the `ssl_context` on the IoTHubClientConfig used to create the client, if `ssl_context` is not provided" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_default_ssl) + async def test_default_ssl_context(self, mocker, connection_string, ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + spy_default_ssl = mocker.spy(iothub_client, "_default_ssl_context") + assert ssl_context is None + + client = await IoTHubDeviceClient.create_from_connection_string(connection_string) + + assert spy_default_ssl.call_count == 1 + assert spy_default_ssl.call_args == mocker.call() + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.ssl_context is spy_default_ssl.spy_return + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Creates a SasTokenProvider that uses symmetric key-based token generation and sets it on the IoTHubClientConfig used to create the client, if `SharedAccessKey` is present in the connection string" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_sak) + async def test_sk_auth(self, mocker, connection_string, ssl_context): + # Create a ConnectionString object from the connection string to simply value access + cs_obj = cs.ConnectionString(connection_string) + assert cs.SHARED_ACCESS_KEY in cs_obj + # Mock + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + spy_sk_sm_cls = mocker.spy(sm, "SymmetricKeySigningMechanism") + spy_st_generator_cls = mocker.spy(st, "InternalSasTokenGenerator") + spy_st_provider_create = mocker.spy(st.SasTokenProvider, "create_from_generator") + expected_token_uri = "{hostname}/devices/{device_id}".format( + hostname=cs_obj.get(cs.GATEWAY_HOST_NAME, default=cs_obj[cs.HOST_NAME]), + device_id=cs_obj[cs.DEVICE_ID], + ) + + client = await IoTHubDeviceClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + # SymmetricKeySigningMechanism was created from the SharedAccessKey + assert spy_sk_sm_cls.call_count == 1 + assert spy_sk_sm_cls.call_args == mocker.call(cs_obj[cs.SHARED_ACCESS_KEY]) + # InternalSasTokenGenerator was created from the SymmetricKeySigningMechanism and expected URI + assert spy_st_generator_cls.call_count == 1 + assert spy_st_generator_cls.call_args == mocker.call( + signing_mechanism=spy_sk_sm_cls.spy_return, uri=expected_token_uri + ) + # SasTokenProvider was created from the InternalSasTokenGenerator + assert spy_st_provider_create.call_count == 1 + assert spy_st_provider_create.call_args == mocker.call(spy_st_generator_cls.spy_return) + # The SasTokenProvider was set on the IoTHubClientConfig that was used to instantiate the client + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.sastoken_provider is spy_st_provider_create.spy_return + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Does not set any SasTokenProvider on the IoTHubClientConfig used to create the client if no `SharedAccessKey` is present in the connection string" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_no_sak) + async def test_non_sas_auth(self, mocker, connection_string, ssl_context): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + # Create a ConnectionString object from the connection string to simply value access + cs_obj = cs.ConnectionString(connection_string) + assert cs.SHARED_ACCESS_KEY not in cs_obj + + client = await IoTHubDeviceClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + # No SasTokenProvider was set on the IoTHubClientConfig that was used to instantiate the client + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.sastoken_provider is None + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets any provided optional keyword arguments on IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params) + @pytest.mark.parametrize("kwarg_name, kwarg_value", factory_kwargs) + async def test_kwargs(self, mocker, connection_string, ssl_context, kwarg_name, kwarg_value): + spy_client_init = mocker.spy(IoTHubDeviceClient, "__init__") + + kwargs = {kwarg_name: kwarg_value} + + client = await IoTHubDeviceClient.create_from_connection_string( + connection_string, ssl_context=ssl_context, **kwargs + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert getattr(config, kwarg_name) == kwarg_value + + # Graceful exit + await client.shutdown() + + @pytest.mark.it("Raises ValueError if a `ModuleId` is present in the connection string") + async def test_module_id_in_string(self, optional_ssl_context): + # NOTE: There could be many strings containing a ModuleId, but I'm not going to try them + # all to avoid confounds with other errors, I'll just use a standard module string that + # uses a SharedAccessKey + connection_string = "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + shared_access_key=FAKE_SYMMETRIC_KEY, + ) + with pytest.raises(ValueError): + await IoTHubDeviceClient.create_from_connection_string( + connection_string, ssl_context=optional_ssl_context + ) + + @pytest.mark.it( + "Raises ValueError if `x509=true` is present in the connection string, but no `ssl_context` is provided" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_no_sak) + async def test_x509_with_no_ssl(self, connection_string, ssl_context): + # Ignore the ssl_context provided by the parametrization + with pytest.raises(ValueError): + await IoTHubDeviceClient.create_from_connection_string(connection_string) + + @pytest.mark.it( + "Does not raise a ValueError if `x509=false` is present in the connection string and no `ssl_context` is provided" + ) + async def test_x509_equals_false(self): + # NOTE: This is a weird test in that if you aren't using X509 certs, there shouldn't be + # an `x509` field in your connection string in the first place. But, semantically, it feels + # as though this test ought to exist to validate that we are checking the value of the + # field, not just the key name. + # NOTE: Because we're in the land of undefined behavior here, on account of this scenario + # not being supposed to happen, I'm arbitrarily deciding we're testing this with a string + # containing a SharedAccessKey and no GatewayHostName for simplicity. + connection_string = "HostName={hostname};DeviceId={device_id};SharedAccessKey={shared_access_key};x509=false".format( + hostname=FAKE_HOSTNAME, device_id=FAKE_DEVICE_ID, shared_access_key=FAKE_SYMMETRIC_KEY + ) + client = await IoTHubDeviceClient.create_from_connection_string(connection_string) + # If the above invocation didn't raise, the test passed, no assertions required + + # Graceful exit + await client.shutdown() + + @pytest.mark.it("Allows any exceptions raised when parsing the connection string to propagate") + @pytest.mark.parametrize( + "exception", + [ + pytest.param(ValueError(), id="ValueError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_cs_parsing_raises(self, mocker, optional_ssl_context, exception): + # NOTE: This test covers all invalid connection string scenarios. For more detail, see the + # dedicated connection string parsing tests for the `connection_string.py` module - there's + # no reason to replicate them all here. + # NOTE: For the purposes of this test, it does not matter what this connection string is. + # The one provided here is valid, but the mock will cause the parsing to raise anyway. + connection_string = ( + "HostName={hostname};DeviceId={device_id};SharedAccessKey={shared_access_key}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + shared_access_key=FAKE_SYMMETRIC_KEY, + ) + ) + # Mock cs parsing + mocker.patch.object(cs, "ConnectionString", side_effect=exception) + + with pytest.raises(type(exception)) as e_info: + await IoTHubDeviceClient.create_from_connection_string( + connection_string, ssl_context=optional_ssl_context + ) + assert e_info.value is exception + + @pytest.mark.it( + "Allows any exceptions raised when creating a SymmetricKeySigningMechanism to propagate" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_sak) + @pytest.mark.parametrize("exception", sk_sm_create_exceptions) + async def test_sksm_raises(self, mocker, connection_string, ssl_context, exception): + mocker.patch.object(sm, "SymmetricKeySigningMechanism", side_effect=exception) + + with pytest.raises(type(exception)) as e_info: + await IoTHubDeviceClient.create_from_connection_string( + connection_string, + ssl_context=ssl_context, + ) + assert e_info.value is exception + + @pytest.mark.it("Allows any exceptions raised when creating a SasTokenProvider to propagate") + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_sak) + @pytest.mark.parametrize("exception", sastoken_provider_create_exceptions) + async def test_sastoken_provider_raises( + self, mocker, connection_string, ssl_context, exception + ): + mocker.patch.object(st.SasTokenProvider, "create_from_generator", side_effect=exception) + + with pytest.raises(type(exception)) as e_info: + await IoTHubDeviceClient.create_from_connection_string( + connection_string, + ssl_context=ssl_context, + ) + assert e_info.value is exception + + @pytest.mark.it("Can be cancelled while waiting for SasTokenProvider creation") + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_sak) + async def test_cancel_during_sastoken_provider_creation( + self, mocker, connection_string, ssl_context + ): + mocker.patch.object( + st.SasTokenProvider, "create_from_generator", custom_mock.HangingAsyncMock() + ) + + coro = IoTHubDeviceClient.create_from_connection_string( + connection_string, + ssl_context=ssl_context, + ) + t = asyncio.create_task(coro) + + # Hanging, waiting for SasTokenProvider creation to finish + await st.SasTokenProvider.create_from_generator.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("IoTHubDeviceClient - .shutdown()") +class TestIoTHubDeviceClientShutdown(SharedClientShutdownTests, IoTHubDeviceClientTestConfig): + pass + + +# ~~~~~ IoTHubModuleClient Tests ~~~~~ + + +class IoTHubModuleClientTestConfig: + """Mixin parent class defining a set of fixtures used in IoTHubModuleClient tests""" + + @pytest.fixture + async def client(self, custom_ssl_context): + # Use a custom_ssl_context for auth for simplicity. Almost any test using this fixture + # will not be affected by auth type, so just use the simplest one. + client_config = config.IoTHubClientConfig( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + ssl_context=custom_ssl_context, + ) + client = IoTHubDeviceClient(client_config) + yield client + await client.shutdown() + + @pytest.fixture + def client_class(self): + return IoTHubModuleClient + + +@pytest.mark.describe("IoTHubModuleClient -- Instantiation") +class TestIoTHubModuleClientInstantiation( + SharedClientInstantiationTests, IoTHubModuleClientTestConfig +): + pass + + +@pytest.mark.describe("IoTHubModuleClient - .create()") +class TestIoTHubModuleClientCreate(IoTHubModuleClientTestConfig): + @pytest.mark.it( + "Returns a new IoTHubModuleClient instance, created with the use of a new IoTHubClientConfig object" + ) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params) + async def test_instantiation(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_config_cls = mocker.spy(config, "IoTHubClientConfig") + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + assert spy_config_cls.call_count == 0 + assert spy_client_init.call_count == 0 + + client = await IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + assert spy_config_cls.call_count == 1 + assert spy_client_init.call_count == 1 + # NOTE: Normally passing through self or cls isn't necessary in a mock call, but + # it seems that when mocking the __init__ it is. This is actually good though, as it + # allows us to match the specific object reference which otherwise is very dicey when + # mocking constructors/initializers + assert spy_client_init.call_args == mocker.call(client, spy_config_cls.spy_return) + assert isinstance(client, IoTHubModuleClient) + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the provided `device_id` on the IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params) + async def test_device_id(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + + client = await IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.device_id == FAKE_DEVICE_ID + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the provided `module_id` on the IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params) + async def test_module_id(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + + client = await IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.module_id == FAKE_MODULE_ID + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the provided `hostname` on the IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params) + async def test_hostname(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + + client = await IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.hostname == FAKE_HOSTNAME + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the provided `ssl_context` on the IoTHubClientConfig used to create the client, if provided" + ) + @pytest.mark.parametrize( + "symmetric_key, sastoken_fn, ssl_context", create_auth_params_custom_ssl + ) + async def test_custom_ssl_context(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + assert ssl_context is not None + + client = await IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=custom_ssl_context, + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.ssl_context is custom_ssl_context + + # Graceful exit + await client.shutdown() + + # NOTE: The details of this default SSLContext are covered in the TestDefaultSSLContext suite + @pytest.mark.it( + "Sets a default SSLContext on the IoTHubClientConfig used to create the client, if `ssl_context` is not provided" + ) + @pytest.mark.parametrize( + "symmetric_key, sastoken_fn, ssl_context", create_auth_params_default_ssl + ) + async def test_default_ssl_context(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + spy_default_ssl = mocker.spy(iothub_client, "_default_ssl_context") + assert ssl_context is None + + client = await IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.ssl_context is not None + assert isinstance(config.ssl_context, ssl.SSLContext) + # This SSLContext was returned from the default ssl context helper + assert spy_default_ssl.call_count == 1 + assert spy_default_ssl.call_args == mocker.call() + assert config.ssl_context is spy_default_ssl.spy_return + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Creates a SasTokenProvider that uses symmetric key-based token generation and sets it on the IoTHubClientConfig used to create the client, if `symmetric_key` is provided as a parameter" + ) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params_sk) + async def test_sk_auth(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + spy_sk_sm_cls = mocker.spy(sm, "SymmetricKeySigningMechanism") + spy_st_generator_cls = mocker.spy(st, "InternalSasTokenGenerator") + spy_st_provider_create = mocker.spy(st.SasTokenProvider, "create_from_generator") + expected_token_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=FAKE_HOSTNAME, device_id=FAKE_DEVICE_ID, module_id=FAKE_MODULE_ID + ) + assert sastoken_fn is None + + client = await IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + ssl_context=ssl_context, + ) + + # SymmetricKeySigningMechanism was created from the symmetric key + assert spy_sk_sm_cls.call_count == 1 + assert spy_sk_sm_cls.call_args == mocker.call(FAKE_SYMMETRIC_KEY) + # InternalSasTokenGenerator was created from the SymmetricKeySigningMechanism and expected URI + assert spy_st_generator_cls.call_count == 1 + assert spy_st_generator_cls.call_args == mocker.call( + signing_mechanism=spy_sk_sm_cls.spy_return, uri=expected_token_uri + ) + # SasTokenProvider was created from the InternalSasTokenGenerator + assert spy_st_provider_create.call_count == 1 + assert spy_st_provider_create.call_args == mocker.call(spy_st_generator_cls.spy_return) + # The SasTokenProvider was set on the IoTHubClientConfig that was used to instantiate the client + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.sastoken_provider is spy_st_provider_create.spy_return + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Creates a SasTokenProvider that uses user callback-based token generation and sets it on the IoTHubClientConfig used to create the client, if `sastoken_fn` is provided as a parameter" + ) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params_token_cb) + async def test_token_callback_auth(self, mocker, symmetric_key, sastoken_fn, ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + spy_st_generator_cls = mocker.spy(st, "ExternalSasTokenGenerator") + spy_st_provider_create = mocker.spy(st.SasTokenProvider, "create_from_generator") + assert symmetric_key is None + + client = await IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + + # ExternalSasTokenGenerator was created from the `sastoken_fn`` + assert spy_st_generator_cls.call_count == 1 + assert spy_st_generator_cls.call_args == mocker.call(sastoken_generator_fn) + # SasTokenProvider was created from the ExternalSasTokenGenerator + assert spy_st_provider_create.call_count == 1 + assert spy_st_provider_create.call_args == mocker.call(spy_st_generator_cls.spy_return) + # The SasTokenProvider was set on the IoTHubClientConfig that was used to instantiate the client + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.sastoken_provider is spy_st_provider_create.spy_return + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Does not set any SasTokenProvider on the IoTHubClientConfig used to create the client if neither `symmetric_key` nor `sastoken_fn` are provided as parameters" + ) + async def test_non_sas_auth(self, mocker, custom_ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + + client = await IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + ssl_context=custom_ssl_context, + ) + + # No SasTokenProvider was set on the IoTHubClientConfig that was used to instantiate the client + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.sastoken_provider is None + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets any provided optional keyword arguments on IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params) + @pytest.mark.parametrize("kwarg_name, kwarg_value", factory_kwargs) + async def test_kwargs( + self, mocker, symmetric_key, sastoken_fn, ssl_context, kwarg_name, kwarg_value + ): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + + kwargs = {kwarg_name: kwarg_value} + + client = await IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + **kwargs + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert getattr(config, kwarg_name) == kwarg_value + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Raises ValueError if neither `symmetric_key` nor `sastoken_fn` nor `ssl_context` are provided as parameters" + ) + async def test_no_auth(self): + with pytest.raises(ValueError): + await IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + ) + + @pytest.mark.it( + "Raises ValueError if both `symmetric_key` and `sastoken_fn` are provided as parameters" + ) + async def test_conflicting_auth(self, optional_ssl_context): + with pytest.raises(ValueError): + await IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=FAKE_SYMMETRIC_KEY, + sastoken_fn=sastoken_generator_fn, + ssl_context=optional_ssl_context, + ) + + @pytest.mark.it( + "Allows any exceptions raised when creating a SymmetricKeySigningMechanism to propagate" + ) + @pytest.mark.parametrize("exception", sk_sm_create_exceptions) + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params_sk) + async def test_sksm_raises(self, mocker, symmetric_key, sastoken_fn, ssl_context, exception): + mocker.patch.object(sm, "SymmetricKeySigningMechanism", side_effect=exception) + assert sastoken_fn is None + + with pytest.raises(type(exception)) as e_info: + await IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + ssl_context=ssl_context, + ) + assert e_info.value is exception + + @pytest.mark.it("Allows any exceptions raised when creating a SasTokenProvider to propagate") + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params_sas) + @pytest.mark.parametrize("exception", sastoken_provider_create_exceptions) + async def test_sastoken_provider_raises( + self, mocker, symmetric_key, sastoken_fn, ssl_context, exception + ): + mocker.patch.object(st.SasTokenProvider, "create_from_generator", side_effect=exception) + + with pytest.raises(type(exception)) as e_info: + await IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + assert e_info.value is exception + + @pytest.mark.it("Can be cancelled while waiting for SasTokenProvider creation") + @pytest.mark.parametrize("symmetric_key, sastoken_fn, ssl_context", create_auth_params_sas) + async def test_cancel_during_sastoken_provider_creation( + self, mocker, symmetric_key, sastoken_fn, ssl_context + ): + mocker.patch.object( + st.SasTokenProvider, "create_from_generator", custom_mock.HangingAsyncMock() + ) + + coro = IoTHubModuleClient.create( + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + hostname=FAKE_HOSTNAME, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + ssl_context=ssl_context, + ) + t = asyncio.create_task(coro) + + # Hanging, waiting for SasTokenProvider creation to finish + await st.SasTokenProvider.create_from_generator.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("IoTHubModuleClient - .create_from_connection_string()") +class TestIoTHubModuleClientCreateFromConnectionString(IoTHubModuleClientTestConfig): + + factory_params = [ + pytest.param( + "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + shared_access_key=FAKE_SYMMETRIC_KEY, + ), + None, + id="Standard Connection String w/ SharedAccessKey + Default SSLContext", + ), + pytest.param( + "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + shared_access_key=FAKE_SYMMETRIC_KEY, + ), + lazy_fixture("custom_ssl_context"), + id="Standard Connection String w/ SharedAccessKey + Custom SSLContext", + ), + pytest.param( + "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key};GatewayHostName={gateway_hostname}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + shared_access_key=FAKE_SYMMETRIC_KEY, + gateway_hostname=FAKE_GATEWAY_HOSTNAME, + ), + None, + id="Edge Connection String w/ SharedAccessKey + Default SSLContext", + ), + pytest.param( + "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key};GatewayHostName={gateway_hostname}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + shared_access_key=FAKE_SYMMETRIC_KEY, + gateway_hostname=FAKE_GATEWAY_HOSTNAME, + ), + lazy_fixture("custom_ssl_context"), + id="Edge Connection String w/ SharedAccessKey + Custom SSLContext", + ), + # NOTE: X509 certs imply use of custom SSLContext + pytest.param( + "HostName={hostname};DeviceId={device_id};ModuleId={module_id};x509=true".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + ), + lazy_fixture("custom_ssl_context"), + id="Standard Connection String w/ X509", + ), + pytest.param( + "HostName={hostname};DeviceId={device_id};ModuleId={module_id};GatewayHostName={gateway_hostname};x509=true".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + gateway_hostname=FAKE_GATEWAY_HOSTNAME, + ), + lazy_fixture("custom_ssl_context"), + id="Edge Connection String w/ X509", + ), + ] + # Just the parameters for using standard connection strings + factory_params_no_gateway = [ + param for param in factory_params if cs.GATEWAY_HOST_NAME not in param.values[0] + ] + # Just the parameters for using connection strings with a GatewayHostName + factory_params_gateway = [ + param for param in factory_params if cs.GATEWAY_HOST_NAME in param.values[0] + ] + # Just the parameters where a custom SSLContext is provided + factory_params_custom_ssl = [param for param in factory_params if param.values[1] is not None] + # Just the parameters where a custom SSLContext is NOT provided + factory_params_default_ssl = [param for param in factory_params if param.values[1] is None] + # Just the parameters for using SharedAccessKeys + factory_params_sak = [ + param for param in factory_params if cs.SHARED_ACCESS_KEY in param.values[0] + ] + # Just the parameters for NOT using SharedAccessKeys + factory_params_no_sak = [ + param for param in factory_params if cs.SHARED_ACCESS_KEY not in param.values[0] + ] + + @pytest.mark.it( + "Returns a new IoTHubModuleClient instance, created with the use of a new IoTHubClientConfig object" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params) + async def test_instantiation(self, mocker, connection_string, ssl_context): + spy_config_cls = mocker.spy(config, "IoTHubClientConfig") + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + assert spy_config_cls.call_count == 0 + assert spy_client_init.call_count == 0 + + client = await IoTHubModuleClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + assert spy_config_cls.call_count == 1 + assert spy_client_init.call_count == 1 + # NOTE: Normally passing through self or cls isn't necessary in a mock call, but + # it seems that when mocking the __init__ it is. This is actually good though, as it + # allows us to match the specific object reference which otherwise is very dicey when + # mocking constructors/initializers + assert spy_client_init.call_args == mocker.call(client, spy_config_cls.spy_return) + assert isinstance(client, IoTHubModuleClient) + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the `DeviceId` from the connection string as the `device_id` on the IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params) + async def test_device_id(self, mocker, connection_string, ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + # Create a ConnectionString object from the connection string to simply value access + cs_obj = cs.ConnectionString(connection_string) + + client = await IoTHubModuleClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.device_id == cs_obj[cs.DEVICE_ID] + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the `ModuleId` from the connection string as the `module_id` on the IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params) + async def test_module_id(self, mocker, connection_string, ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + # Create a ConnectionString object from the connection string to simply value access + cs_obj = cs.ConnectionString(connection_string) + + client = await IoTHubModuleClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.module_id == cs_obj[cs.MODULE_ID] + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the `HostName` from the connection string as the `hostname` on the IoTHubClientConfig, if no `GatewayHostName` is present in the connection string" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_no_gateway) + async def test_hostname_cs_has_no_gateway(self, mocker, connection_string, ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + # Create a ConnectionString object from the connection string to simply value access + cs_obj = cs.ConnectionString(connection_string) + assert cs.GATEWAY_HOST_NAME not in cs_obj + + client = await IoTHubModuleClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.hostname == cs_obj[cs.HOST_NAME] + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the `HostName` from the connection string as the `hostname` on the IoTHubClientConfig used to create the client, if no `GatewayHostName` is present in the connection string" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_gateway) + async def test_hostname_cs_has_gateway(self, mocker, connection_string, ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + # Create a ConnectionString object from the connection string to simply value access + cs_obj = cs.ConnectionString(connection_string) + assert cs.GATEWAY_HOST_NAME in cs_obj + assert cs_obj[cs.GATEWAY_HOST_NAME] != cs_obj[cs.HOST_NAME] + + client = await IoTHubModuleClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.hostname == cs_obj[cs.GATEWAY_HOST_NAME] + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the provided `ssl_context` on the IoTHubClientConfig used to create the client, if provided" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_custom_ssl) + async def test_custom_ssl_context(self, mocker, connection_string, ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + assert ssl_context is not None + + client = await IoTHubModuleClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.ssl_context is ssl_context + + # Graceful exit + await client.shutdown() + + # NOTE: The details of this default SSLContext are covered in the TestDefaultSSLContext suite + @pytest.mark.it( + "Sets a default SSLContext as the `ssl_context` on the IoTHubClientConfig used to create the client, if `ssl_context` is not provided" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_default_ssl) + async def test_default_ssl_context(self, mocker, connection_string, ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + spy_default_ssl = mocker.spy(iothub_client, "_default_ssl_context") + assert ssl_context is None + + client = await IoTHubModuleClient.create_from_connection_string(connection_string) + + assert spy_default_ssl.call_count == 1 + assert spy_default_ssl.call_args == mocker.call() + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.ssl_context is spy_default_ssl.spy_return + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Creates a SasTokenProvider that uses symmetric key-based token generation and sets it on the IoTHubClientConfig used to create the client, if `SharedAccessKey` is present in the connection string" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_sak) + async def test_sk_auth(self, mocker, connection_string, ssl_context): + # Create a ConnectionString object from the connection string to simply value access + cs_obj = cs.ConnectionString(connection_string) + assert cs.SHARED_ACCESS_KEY in cs_obj + # Mock + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + spy_sk_sm_cls = mocker.spy(sm, "SymmetricKeySigningMechanism") + spy_st_generator_cls = mocker.spy(st, "InternalSasTokenGenerator") + spy_st_provider_create = mocker.spy(st.SasTokenProvider, "create_from_generator") + expected_token_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=cs_obj.get(cs.GATEWAY_HOST_NAME, default=cs_obj[cs.HOST_NAME]), + device_id=cs_obj[cs.DEVICE_ID], + module_id=cs_obj[cs.MODULE_ID], + ) + + client = await IoTHubModuleClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + # SymmetricKeySigningMechanism was created from the SharedAccessKey + assert spy_sk_sm_cls.call_count == 1 + assert spy_sk_sm_cls.call_args == mocker.call(cs_obj[cs.SHARED_ACCESS_KEY]) + # InternalSasTokenGenerator was created from the SymmetricKeySigningMechanism and expected URI + assert spy_st_generator_cls.call_count == 1 + assert spy_st_generator_cls.call_args == mocker.call( + signing_mechanism=spy_sk_sm_cls.spy_return, uri=expected_token_uri + ) + # SasTokenProvider was created from the InternalSasTokenGenerator + assert spy_st_provider_create.call_count == 1 + assert spy_st_provider_create.call_args == mocker.call(spy_st_generator_cls.spy_return) + # The SasTokenProvider was set on the IoTHubClientConfig that was used to instantiate the client + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.sastoken_provider is spy_st_provider_create.spy_return + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Does not set any SasTokenProvider on the IoTHubClientConfig used to create the client if no `SharedAccessKey` is present in the connection string" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_no_sak) + async def test_non_sas_auth(self, mocker, connection_string, ssl_context): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + # Create a ConnectionString object from the connection string to simply value access + cs_obj = cs.ConnectionString(connection_string) + assert cs.SHARED_ACCESS_KEY not in cs_obj + + client = await IoTHubModuleClient.create_from_connection_string( + connection_string, ssl_context=ssl_context + ) + + # No SasTokenProvider was set on the IoTHubClientConfig that was used to instantiate the client + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.sastoken_provider is None + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets any provided optional keyword arguments on IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params) + @pytest.mark.parametrize("kwarg_name, kwarg_value", factory_kwargs) + async def test_kwargs(self, mocker, connection_string, ssl_context, kwarg_name, kwarg_value): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + + kwargs = {kwarg_name: kwarg_value} + + client = await IoTHubModuleClient.create_from_connection_string( + connection_string, ssl_context=ssl_context, **kwargs + ) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert getattr(config, kwarg_name) == kwarg_value + + # Graceful exit + await client.shutdown() + + @pytest.mark.it("Raises ValueError if a `ModuleId` is not present in the connection string") + async def test_module_id_in_string(self, optional_ssl_context): + # NOTE: There could be many strings containing not containing a ModuleId, but I'm not going + # to try them all to avoid confounds with other errors, I'll just use a standard device + # string that uses a SharedAccessKey + connection_string = ( + "HostName={hostname};DeviceId={device_id};SharedAccessKey={shared_access_key}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + shared_access_key=FAKE_SYMMETRIC_KEY, + ) + ) + with pytest.raises(ValueError): + await IoTHubModuleClient.create_from_connection_string( + connection_string, ssl_context=optional_ssl_context + ) + + @pytest.mark.it( + "Raises ValueError if `x509=true` is present in the connection string, but no `ssl_context` is provided" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_no_sak) + async def test_x509_with_no_ssl(self, connection_string, ssl_context): + # Ignore the ssl_context provided by the parametrization + with pytest.raises(ValueError): + await IoTHubModuleClient.create_from_connection_string(connection_string) + + @pytest.mark.it( + "Does not raise a ValueError if `x509=false` is present in the connection string and no `ssl_context` is provided" + ) + async def test_x509_equals_false(self): + # NOTE: This is a weird test in that if you aren't using X509 certs, there shouldn't be + # an `x509` field in your connection string in the first place. But, semantically, it feels + # as though this test ought to exist to validate that we are checking the value of the + # field, not just the key name. + # NOTE: Because we're in the land of undefined behavior here, on account of this scenario + # not being supposed to happen, I'm arbitrarily deciding we're testing this with a string + # containing a SharedAccessKey and no GatewayHostName for simplicity. + connection_string = "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key};x509=false".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + shared_access_key=FAKE_SYMMETRIC_KEY, + ) + client = await IoTHubModuleClient.create_from_connection_string(connection_string) + # If the above invocation didn't raise, the test passed, no assertions required + + # Graceful exit + await client.shutdown() + + @pytest.mark.it("Allows any exceptions raised when parsing the connection string to propagate") + @pytest.mark.parametrize( + "exception", + [ + pytest.param(ValueError(), id="ValueError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_cs_parsing_raises(self, mocker, optional_ssl_context, exception): + # NOTE: This test covers all invalid connection string scenarios. For more detail, see the + # dedicated connection string parsing tests for the `connection_string.py` module - there's + # no reason to replicate them all here. + # NOTE: For the purposes of this test, it does not matter what this connection string is. + # The one provided here is valid, but the mock will cause the parsing to raise anyway. + connection_string = "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + shared_access_key=FAKE_SYMMETRIC_KEY, + ) + # Mock cs parsing + mocker.patch.object(cs, "ConnectionString", side_effect=exception) + + with pytest.raises(type(exception)) as e_info: + await IoTHubModuleClient.create_from_connection_string( + connection_string, ssl_context=optional_ssl_context + ) + assert e_info.value is exception + + @pytest.mark.it( + "Allows any exceptions raised when creating a SymmetricKeySigningMechanism to propagate" + ) + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_sak) + @pytest.mark.parametrize("exception", sk_sm_create_exceptions) + async def test_sksm_raises(self, mocker, connection_string, ssl_context, exception): + mocker.patch.object(sm, "SymmetricKeySigningMechanism", side_effect=exception) + + with pytest.raises(type(exception)) as e_info: + await IoTHubModuleClient.create_from_connection_string( + connection_string, + ssl_context=ssl_context, + ) + assert e_info.value is exception + + @pytest.mark.it("Allows any exceptions raised when creating a SasTokenProvider to propagate") + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_sak) + @pytest.mark.parametrize("exception", sastoken_provider_create_exceptions) + async def test_sastoken_provider_raises( + self, mocker, connection_string, ssl_context, exception + ): + mocker.patch.object(st.SasTokenProvider, "create_from_generator", side_effect=exception) + + with pytest.raises(type(exception)) as e_info: + await IoTHubModuleClient.create_from_connection_string( + connection_string, + ssl_context=ssl_context, + ) + assert e_info.value is exception + + @pytest.mark.it("Can be cancelled while waiting for SasTokenProvider creation") + @pytest.mark.parametrize("connection_string, ssl_context", factory_params_sak) + async def test_cancel_during_sastoken_provider_creation( + self, mocker, connection_string, ssl_context + ): + mocker.patch.object( + st.SasTokenProvider, "create_from_generator", custom_mock.HangingAsyncMock() + ) + + coro = IoTHubModuleClient.create_from_connection_string( + connection_string, + ssl_context=ssl_context, + ) + t = asyncio.create_task(coro) + + # Hanging, waiting for SasTokenProvider creation to finish + await st.SasTokenProvider.create_from_generator.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe( + "IoTHubModuleClient - .create_from_edge_environment() -- Real Edge Environment" +) +class TestIoTHubModuleClientCreateFromEdgeEnvironmentRealEdgeEnvironment( + IoTHubModuleClientTestConfig +): + @pytest.fixture + def edge_environment_variables(self): + return { + "IOTEDGE_DEVICEID": FAKE_DEVICE_ID, + "IOTEDGE_MODULEID": FAKE_MODULE_ID, + "IOTEDGE_IOTHUBHOSTNAME": FAKE_HOSTNAME, + "IOTEDGE_GATEWAYHOSTNAME": FAKE_GATEWAY_HOSTNAME, + "IOTEDGE_APIVERSION": "04-07-3023", + "IOTEDGE_MODULEGENERATIONID": "fake_generation_id", + "IOTEDGE_WORKLOADURI": "http://fake.workload/uri/", + # NOTE: I've included the IOTHUBHOSTNAME environment variable here, + # even though it is not actually used in practice by the client. + # By including it here, we can demonstrate that it is not used. + } + + @pytest.fixture(autouse=True) + def mock_environment_variables(self, mocker, edge_environment_variables): + """Auto-used fixture that will mock out os.environ to return the variables defined + in the fixture above. You shouldn't need to directly interact with this mock, so + no value is returned by this fixture, and as a result, you shouldn't ever need to + add it as a test parameter. It will just work""" + mocker.patch.dict(os.environ, edge_environment_variables, clear=True) + + @pytest.fixture(autouse=True) + def mock_ssl_load_verify_locations(self, mocker): + """Autouse fixture that will mock SSL cert chain loading so that fake values don't + get in the way. You shouldn't need to directly interact with this mock, so no value + is returned by this fixture, and as a result, you should'nt ever need to add it as a + test parameter. It will just work + """ + mocker.patch.object(ssl.SSLContext, "load_verify_locations") + + @pytest.fixture(autouse=True) + def mock_edge_hsm_cls(self, mocker): + mock_edge_hsm_cls = mocker.patch.object(edge_hsm, "IoTEdgeHsm", spec=edge_hsm.IoTEdgeHsm) + mock_edge_hsm_cls.return_value.sign.return_value = FAKE_SIGNATURE + mock_edge_hsm_cls.return_value.get_certificate.return_value = "fake_svc_string" + return mock_edge_hsm_cls + + @pytest.mark.it( + "Returns a new IoTHubModuleClient instance, created with the use of a new IoTHubClientConfig object" + ) + async def test_instantiation(self, mocker): + spy_config_cls = mocker.spy(config, "IoTHubClientConfig") + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + assert spy_config_cls.call_count == 0 + assert spy_client_init.call_count == 0 + + client = await IoTHubModuleClient.create_from_edge_environment() + + assert spy_config_cls.call_count == 1 + assert spy_client_init.call_count == 1 + # NOTE: Normally passing through self or cls isn't necessary in a mock call, but + # it seems that when mocking the __init__ it is. This is actually good though, as it + # allows us to match the specific object reference which otherwise is very dicey when + # mocking constructors/initializers + assert spy_client_init.call_args == mocker.call(client, spy_config_cls.spy_return) + assert isinstance(client, IoTHubModuleClient) + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the IOTEDGE_DEVICEID value from the Edge environment as the `device_id` on the IoTHubClientConfig used to create the client" + ) + async def test_device_id(self, mocker, edge_environment_variables): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + + client = await IoTHubModuleClient.create_from_edge_environment() + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.device_id == edge_environment_variables["IOTEDGE_DEVICEID"] + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the IOTEDGE_MODULEID value from the Edge environment as the `module_id` on the IoTHubClientConfig used to create the client" + ) + async def test_module_id(self, mocker, edge_environment_variables): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + + client = await IoTHubModuleClient.create_from_edge_environment() + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.module_id == edge_environment_variables["IOTEDGE_MODULEID"] + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets the IOTEDGE_GATEWAYHOSTNAME (and NOT the IOTEDGE_IOTHUBHOSTNAME) value from the Edge environment as the `hostname` on the IoTHubClientConfig used to create the client" + ) + async def test_hostname(self, mocker, edge_environment_variables): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + + client = await IoTHubModuleClient.create_from_edge_environment() + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.hostname == edge_environment_variables["IOTEDGE_GATEWAYHOSTNAME"] + assert config.hostname != edge_environment_variables["IOTEDGE_IOTHUBHOSTNAME"] + + # Graceful exit + await client.shutdown() + + @pytest.mark.it("Creates an IoTEdgeHsm using values from the Edge environment") + async def test_edge_hsm(self, mocker, edge_environment_variables, mock_edge_hsm_cls): + assert mock_edge_hsm_cls.call_count == 0 + + client = await IoTHubModuleClient.create_from_edge_environment() + + assert mock_edge_hsm_cls.call_count == 1 + assert mock_edge_hsm_cls.call_args == mocker.call( + module_id=edge_environment_variables["IOTEDGE_MODULEID"], + generation_id=edge_environment_variables["IOTEDGE_MODULEGENERATIONID"], + workload_uri=edge_environment_variables["IOTEDGE_WORKLOADURI"], + api_version=edge_environment_variables["IOTEDGE_APIVERSION"], + ) + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Creates a SasTokenProvider that uses the IoTEdgeHsm to generate SAS tokens, and sets it as the `sastoken_provider` on the IoTHubClientConfig used to create the client" + ) + async def test_sastoken_provider(self, mocker, edge_environment_variables, mock_edge_hsm_cls): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + spy_st_generator_cls = mocker.spy(st, "InternalSasTokenGenerator") + spy_st_provider_create = mocker.spy(st.SasTokenProvider, "create_from_generator") + expected_token_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=edge_environment_variables["IOTEDGE_GATEWAYHOSTNAME"], + device_id=edge_environment_variables["IOTEDGE_DEVICEID"], + module_id=edge_environment_variables["IOTEDGE_MODULEID"], + ) + + client = await IoTHubModuleClient.create_from_edge_environment() + + # IoTEdgeHsm was created + assert mock_edge_hsm_cls.call_count == 1 + # InternalSasTokenGenerator was created from the IoTEdgeHsm and expected URI + assert spy_st_generator_cls.call_count == 1 + assert spy_st_generator_cls.call_args == mocker.call( + signing_mechanism=mock_edge_hsm_cls.return_value, uri=expected_token_uri + ) + # SasTokenProvider was created from the InternalSasTokenGenerator + assert spy_st_provider_create.call_count == 1 + assert spy_st_provider_create.call_args == mocker.call(spy_st_generator_cls.spy_return) + # The SasTokenProvider was set on the IoTHubClientConfig that was used to instantiate the client + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert config.sastoken_provider is spy_st_provider_create.spy_return + + # Graceful exit + await client.shutdown() + + # NOTE: The details of this default SSLContext are covered in the TestDefaultSSLContext suite + @pytest.mark.it( + "Modifies a default SSLContext by loading a server verification certificate retrieved from the IoTEdgeHsm and sets it as the `ssl_context` on the IoTHubClientConfig used to create the client" + ) + async def test_ssl_context(self, mocker, mock_edge_hsm_cls): + mock_edge_hsm = mock_edge_hsm_cls.return_value + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + mock_default_ssl = mocker.patch.object(iothub_client, "_default_ssl_context") + mock_ssl_context = mock_default_ssl.return_value + + client = await IoTHubModuleClient.create_from_edge_environment() + + assert mock_default_ssl.call_count == 1 + assert mock_default_ssl.call_args == mocker.call() + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + # SSLContext was set on the Config + config = spy_client_init.call_args[0][1] + assert config.ssl_context is mock_ssl_context + # SSLContext was modified to load the cert returned by the HSM + expected_sv_cert = mock_edge_hsm.get_certificate.return_value + assert mock_ssl_context.load_verify_locations.call_count == 1 + assert mock_ssl_context.load_verify_locations.call_args == mocker.call( + cadata=expected_sv_cert + ) + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Sets any provided optional keyword arguments on IoTHubClientConfig used to create the client" + ) + @pytest.mark.parametrize("kwarg_name, kwarg_value", factory_kwargs) + async def test_kwargs(self, mocker, kwarg_name, kwarg_value): + spy_client_init = mocker.spy(IoTHubModuleClient, "__init__") + + kwargs = {kwarg_name: kwarg_value} + + client = await IoTHubModuleClient.create_from_edge_environment(**kwargs) + + assert spy_client_init.call_count == 1 + assert spy_client_init.call_args == mocker.call(client, mocker.ANY) + config = spy_client_init.call_args[0][1] + assert getattr(config, kwarg_name) == kwarg_value + + # Graceful exit + await client.shutdown() + + # NOTE: For what happens when the simulator variables ARE present, see the + # TestIoTHubModuleClientCreateFromEdgeEnvironmentSimulatedEdgeEnvironment test suite. + @pytest.mark.it( + "Raises IoTEdgeEnvironmentError if any expected environment variables cannot be found in the Edge environment, and no Edge Simulator variables are present either" + ) + @pytest.mark.parametrize( + "missing_variable", + [ + "IOTEDGE_MODULEID", + "IOTEDGE_DEVICEID", + "IOTEDGE_GATEWAYHOSTNAME", + "IOTEDGE_APIVERSION", + "IOTEDGE_MODULEGENERATIONID", + "IOTEDGE_WORKLOADURI", + # NOTE: "IOTEDGE_IOTHUBHOSTNAME" is not listed here, because it is not required + ], + ) + async def test_env_missing_vars(self, mocker, edge_environment_variables, missing_variable): + # Remove variable from env and re-patch + del edge_environment_variables[missing_variable] + mocker.patch.dict(os.environ, edge_environment_variables, clear=True) + # No simulator variables are in the environment either + assert "EdgeHubConnectionString" not in edge_environment_variables + assert "EdgeModuleCACertificateFile" not in edge_environment_variables + + with pytest.raises(iot_exceptions.IoTEdgeEnvironmentError): + await IoTHubModuleClient.create_from_edge_environment() + + @pytest.mark.it("Allows any exceptions raised while creating the IoTEdgeHsm to propagate") + @pytest.mark.parametrize( + "exception", + [ + pytest.param(ValueError(), id="ValueError"), + pytest.param(TypeError(), id="TypeError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_edge_hsm_instantiation_raises(self, mock_edge_hsm_cls, exception): + # NOTE: Why might this raise? Lots of reasons, probably due to corrupted env variables + mock_edge_hsm_cls.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await IoTHubModuleClient.create_from_edge_environment() + assert e_info.value is exception + + @pytest.mark.it( + "Allows any exceptions raised while fetching the server verification cert using the IoTEdgeHsm to propagate" + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(iot_exceptions.IoTEdgeError(), id="IoTEdgeError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_edge_hsm_get_cert_raises(self, mock_edge_hsm_cls, exception): + mock_edge_hsm_cls.return_value.get_certificate.side_effect = exception + + with pytest.raises(type(exception)) as e_info: + await IoTHubModuleClient.create_from_edge_environment() + assert e_info.value is exception + + @pytest.mark.it( + "Allows any exceptions raised while loading the server verification cert to propagate" + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(ValueError(), id="ValueError"), + pytest.param(TypeError(), id="TypeError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_ssl_load_verify_locations_raises(self, mocker, exception): + mocker.patch.object(ssl.SSLContext, "load_verify_locations", side_effect=exception) + + with pytest.raises(type(exception)) as e_info: + await IoTHubModuleClient.create_from_edge_environment() + assert e_info.value is exception + + @pytest.mark.it("Allows any exceptions raised when creating a SasTokenProvider to propagate") + @pytest.mark.parametrize("exception", sastoken_provider_create_exceptions) + async def test_sastoken_provider_raises(self, mocker, exception): + mocker.patch.object(st.SasTokenProvider, "create_from_generator", side_effect=exception) + + with pytest.raises(type(exception)) as e_info: + await IoTHubModuleClient.create_from_edge_environment() + assert e_info.value is exception + + @pytest.mark.it( + "Can be cancelled while waiting for the server verification cert to be retrieved" + ) + async def test_cancel_during_get_certificate(self, mock_edge_hsm_cls): + mock_edge_hsm = mock_edge_hsm_cls.return_value + mock_edge_hsm.get_certificate = custom_mock.HangingAsyncMock() + + t = asyncio.create_task(IoTHubModuleClient.create_from_edge_environment()) + + # Hanging, waiting for certificate retrieval to finish + await mock_edge_hsm.get_certificate.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + @pytest.mark.it("Can be cancelled while waiting for SasTokenProvider creation") + async def test_cancel_during_sastoken_provider_creation(self, mocker): + mocker.patch.object( + st.SasTokenProvider, "create_from_generator", custom_mock.HangingAsyncMock() + ) + + t = asyncio.create_task(IoTHubModuleClient.create_from_edge_environment()) + + # Hanging, waiting for SasTokenProvider creation to finish + await st.SasTokenProvider.create_from_generator.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe( + "IoTHubModuleClient - .create_from_edge_environment() -- Simulated Edge Environment" +) +class TestIoTHubModuleClientCreateFromEdgeEnvironmentSimulatedEdgeEnvironment( + IoTHubModuleClientTestConfig +): + @pytest.fixture + def edge_environment_variables(self): + edge_cs = "HostName={hostname};DeviceId={device_id};ModuleId={module_id};SharedAccessKey={shared_access_key};GatewayHostName={gateway_hostname}".format( + hostname=FAKE_HOSTNAME, + device_id=FAKE_DEVICE_ID, + module_id=FAKE_MODULE_ID, + shared_access_key=FAKE_SYMMETRIC_KEY, + gateway_hostname=FAKE_GATEWAY_HOSTNAME, + ) + return { + "EdgeHubConnectionString": edge_cs, + "EdgeModuleCACertificateFile": "fake/file/path", + } + + @pytest.fixture(autouse=True) + def mock_environment_variables(self, mocker, edge_environment_variables): + """Auto-used fixture that will mock out os.environ to return the variables defined + in the fixture above. You shouldn't need to directly interact with this mock, so + no value is returned by this fixture, and as a result, you shouldn't ever need to + add it as a test parameter. It will just work""" + mocker.patch.dict(os.environ, edge_environment_variables, clear=True) + + @pytest.fixture(autouse=True) + def mock_ssl_load_verify_locations(self, mocker): + """Autouse fixture that will mock SSL cert chain loading so that fake values don't + get in the way. You shouldn't need to directly interact with this mock, so no value + is returned by this fixture, and as a result, you should'nt ever need to add it as a + test parameter. It will just work + """ + mocker.patch.object(ssl.SSLContext, "load_verify_locations") + + @pytest.mark.it( + "Invokes and returns the result of the .create_from_connection_string() factory method, passing a connection string contained in the `EdgeHubConnectionString` environment variable and a default SSLContext" + ) + async def test_invokes_connection_string_factory(self, mocker, edge_environment_variables): + spy_create_from_cs = mocker.spy(IoTHubModuleClient, "create_from_connection_string") + mock_default_ssl = mocker.patch.object(iothub_client, "_default_ssl_context") + + client = await IoTHubModuleClient.create_from_edge_environment() + + assert spy_create_from_cs.await_count == 1 + assert spy_create_from_cs.await_args == mocker.call( + edge_environment_variables["EdgeHubConnectionString"], + ssl_context=mock_default_ssl.return_value, + ) + assert client is spy_create_from_cs.spy_return + assert isinstance(client, IoTHubModuleClient) + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Modifies the default SSLContext by loading a server verification certificate from the filepath contained in the `EdgeModuleCACertificate` environment variable" + ) + async def test_ssl_context(self, mocker, edge_environment_variables): + spy_create_from_cs = mocker.spy(IoTHubModuleClient, "create_from_connection_string") + mock_default_ssl = mocker.patch.object(iothub_client, "_default_ssl_context") + mock_ssl_context = mock_default_ssl.return_value + + client = await IoTHubModuleClient.create_from_edge_environment() + + assert mock_default_ssl.call_count == 1 + assert mock_default_ssl.call_args == mocker.call() + # SSLContext was modified to the load the certfile in the environment variable + assert mock_ssl_context.load_verify_locations.call_count == 1 + assert mock_ssl_context.load_verify_locations.call_args == mocker.call( + cafile=edge_environment_variables["EdgeModuleCACertificateFile"] + ) + # SSLContext was the one passed to the .create_from_connection_string() factory method + assert spy_create_from_cs.await_count == 1 + assert spy_create_from_cs.await_args == mocker.call( + mocker.ANY, ssl_context=mock_ssl_context + ) + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Passes any provided optional keyword arguments to the .create_from_connection_string() factory method" + ) + @pytest.mark.parametrize("kwarg_name, kwarg_value", factory_kwargs) + async def test_kwargs(self, mocker, kwarg_name, kwarg_value): + spy_create_from_cs = mocker.spy(IoTHubModuleClient, "create_from_connection_string") + kwargs = {kwarg_name: kwarg_value} + + client = await IoTHubModuleClient.create_from_edge_environment(**kwargs) + + assert spy_create_from_cs.await_count == 1 + assert spy_create_from_cs.await_args == mocker.call( + mocker.ANY, ssl_context=mocker.ANY, **kwargs + ) + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Does not invoke .create_from_connection_string() at all if real Edge environment variables are found" + ) + async def test_real_env_variables_found(self, mocker): + spy_create_from_cs = mocker.spy(IoTHubModuleClient, "create_from_connection_string") + # Mock Edge HSM from the Real Edge path, since it's where we're going to go + mock_edge_hsm_cls = mocker.patch.object(edge_hsm, "IoTEdgeHsm", spec=edge_hsm.IoTEdgeHsm) + mock_edge_hsm_cls.return_value.sign.return_value = FAKE_SIGNATURE + mock_edge_hsm_cls.return_value.get_certificate.return_value = "fake_svc_string" + # Add the Real Edge env vars to our environment + real_env_vars = { + "IOTEDGE_DEVICEID": FAKE_DEVICE_ID, + "IOTEDGE_MODULEID": FAKE_MODULE_ID, + "IOTEDGE_IOTHUBHOSTNAME": FAKE_HOSTNAME, + "IOTEDGE_GATEWAYHOSTNAME": FAKE_GATEWAY_HOSTNAME, + "IOTEDGE_APIVERSION": "04-07-3023", + "IOTEDGE_MODULEGENERATIONID": "fake_generation_id", + "IOTEDGE_WORKLOADURI": "http://fake.workload/uri/", + } + mocker.patch.dict(os.environ, real_env_vars, clear=False) + # The Simulator variables are also here + assert "EdgeHubConnectionString" in os.environ + assert "EdgeModuleCACertificateFile" in os.environ + + client = await IoTHubModuleClient.create_from_edge_environment() + + # But we did not follow the Simulator path due to real variables existing + assert spy_create_from_cs.await_count == 0 + # Instead we followed the Real Edge path + assert mock_edge_hsm_cls.call_count == 1 + # NOTE: I could show all the mocks and values that get invoked here, but there's a whole + # test suite dedicated to those so no point in replicating it here. + + # Graceful exit + await client.shutdown() + + @pytest.mark.it( + "Raises IoTEdgeEnvironmentError if any expected environment variables cannot be found in the Edge environment" + ) + @pytest.mark.parametrize( + "missing_variable", ["EdgeHubConnectionString", "EdgeModuleCACertificateFile"] + ) + async def test_env_missing_vars(self, mocker, edge_environment_variables, missing_variable): + # Remove variable from env and re-patch + del edge_environment_variables[missing_variable] + mocker.patch.dict(os.environ, edge_environment_variables, clear=True) + + with pytest.raises(iot_exceptions.IoTEdgeEnvironmentError): + await IoTHubModuleClient.create_from_edge_environment() + + @pytest.mark.it( + "Allows any exceptions raised by the .create_from_connection_string() factory method to propagate" + ) + @pytest.mark.parametrize( + "exception", + [ + pytest.param(ValueError(), id="ValueError"), + pytest.param(st.SasTokenError(), id="SasTokenError"), + pytest.param(lazy_fixture("arbitrary_exception"), id="Unexpected Exception"), + ], + ) + async def test_create_from_connection_string_raises(self, mocker, exception): + mocker.patch.object( + IoTHubModuleClient, "create_from_connection_string", side_effect=exception + ) + + with pytest.raises(type(exception)) as e_info: + await IoTHubModuleClient.create_from_edge_environment() + assert e_info.value is exception + + @pytest.mark.it( + "Can be cancelled while waiting for the client to be created from the connection string" + ) + async def test_cancelled_during_create_from_connection_string(self, mocker): + mocker.patch.object( + IoTHubModuleClient, "create_from_connection_string", custom_mock.HangingAsyncMock() + ) + + t = asyncio.create_task(IoTHubModuleClient.create_from_edge_environment()) + + # Hanging, waiting for client instantiation via connection string + await IoTHubModuleClient.create_from_connection_string.wait_for_hang() + assert not t.done() + + # Cancel + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + +@pytest.mark.describe("IoTHubModuleClient - .shutdown()") +class TestIoTHubModuleClientShutdown(SharedClientShutdownTests, IoTHubModuleClientTestConfig): + pass + + +# NOTE: This is a convention-private helper, which would normally just be implicitly tested, but +# since it is used so frequently, it's easier to just test separately +@pytest.mark.describe("Default SSLContext") +class TestDefaultSSLContext: + @pytest.mark.it("Returns an SSLContext") + def test_is_ssl_context(self): + ctx = iothub_client._default_ssl_context() + assert isinstance(ctx, ssl.SSLContext) + + @pytest.mark.it("Sets the protocol of the SSLContext to PROTOCOL_TLS_CLIENT") + def test_protocol(self): + ctx = iothub_client._default_ssl_context() + assert ctx.protocol == ssl.PROTOCOL_TLS_CLIENT + + @pytest.mark.it("Sets the verify mode of the SSLContext to CERT_REQUIRED") + def test_verify_mode(self): + ctx = iothub_client._default_ssl_context() + assert ctx.verify_mode == ssl.CERT_REQUIRED + + @pytest.mark.it("Sets the `check_hostname` flag on the SSLContext to True") + def test_check_hostname(self): + ctx = iothub_client._default_ssl_context() + assert ctx.check_hostname is True + + @pytest.mark.it("Loads the default certificate chain on the SSLContext") + def test_default_certs(self, mocker): + mocker.patch.object(ssl, "SSLContext") + mock_ctx = iothub_client._default_ssl_context() + assert mock_ctx.load_default_certs.call_count == 1 + assert mock_ctx.load_default_certs.call_args == mocker.call() diff --git a/v3_async_wip/tests/test_iothub_http_client.py b/v3_async_wip/tests/test_iothub_http_client.py index b9e241e7d..5b818c464 100644 --- a/v3_async_wip/tests/test_iothub_http_client.py +++ b/v3_async_wip/tests/test_iothub_http_client.py @@ -57,7 +57,7 @@ def mock_sastoken_provider(mocker, sastoken): return provider -@pytest.fixture +@pytest.fixture(autouse=True) def mock_session(mocker): mock_session = mocker.MagicMock(spec=aiohttp.ClientSession) # Mock out POST and it's response @@ -114,9 +114,8 @@ class TestIoTHubHTTPClientInstantiation: # This means that you must do graceful exit by shutting down the client at the end of all tests # and you may need to do a manual mock of the underlying HTTP client where appropriate. configurations = [ - pytest.param(FAKE_DEVICE_ID, None, False, id="Device Configuration"), - pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, False, id="Module Configuration"), - pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, True, id="Edge Module Configuration"), + pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), + pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Module Configuration"), ] @pytest.fixture(autouse=True) @@ -127,11 +126,10 @@ class TestIoTHubHTTPClientInstantiation: @pytest.mark.it( "Stores the `device_id` and `module_id` values from the IoTHubClientConfig as attributes" ) - @pytest.mark.parametrize("device_id, module_id, is_edge_module", configurations) - async def test_simple_ids(self, client_config, device_id, module_id, is_edge_module): + @pytest.mark.parametrize("device_id, module_id", configurations) + async def test_simple_ids(self, client_config, device_id, module_id): client_config.device_id = device_id client_config.module_id = module_id - client_config.is_edge_module = is_edge_module client = IoTHubHTTPClient(client_config) assert client._device_id == device_id @@ -140,12 +138,11 @@ class TestIoTHubHTTPClientInstantiation: await client.shutdown() @pytest.mark.it( - "Derives the `edge_module_id` from the `device_id` and `module_id` if the IoTHubClientConfig indicates use of an Edge Module" + "Derives the `edge_module_id` from the `device_id` and `module_id` if the IoTHubClientConfig contains a `module_id`" ) async def test_edge_module_id(self, client_config): client_config.device_id = FAKE_DEVICE_ID client_config.module_id = FAKE_MODULE_ID - client_config.is_edge_module = True expected_edge_module_id = "{device_id}/{module_id}".format( device_id=FAKE_DEVICE_ID, module_id=FAKE_MODULE_ID ) @@ -155,18 +152,12 @@ class TestIoTHubHTTPClientInstantiation: await client.shutdown() - @pytest.mark.it("Sets the `edge_module_id` to None if not using an Edge Module") - @pytest.mark.parametrize( - "device_id, module_id", - [ - pytest.param(FAKE_DEVICE_ID, None, id="Device Configuration"), - pytest.param(FAKE_DEVICE_ID, FAKE_MODULE_ID, id="Non-Edge Module Configuration"), - ], - ) - async def test_no_edge_module_id(self, client_config, device_id, module_id): - client_config.device_id = device_id - client_config.module_id = module_id - client_config.is_edge_module = False + # NOTE: It would be nice if we could only do this for Edge modules, but there's no way to + # indicate a Module is Edge vs non-Edge + @pytest.mark.it("Sets the `edge_module_id` to None if not using a Module") + async def test_no_edge_module_id(self, client_config): + client_config.device_id = FAKE_DEVICE_ID + client_config.module_id = None client = IoTHubHTTPClient(client_config) assert client._edge_module_id is None @@ -176,7 +167,7 @@ class TestIoTHubHTTPClientInstantiation: @pytest.mark.it( "Constructs the `user_agent_string` by concatenating the base IoTHub user agent with the `product_info` from the IoTHubClientConfig" ) - @pytest.mark.parametrize("device_id, module_id, is_edge_module", configurations) + @pytest.mark.parametrize("device_id, module_id", configurations) @pytest.mark.parametrize( "product_info", [ @@ -188,12 +179,9 @@ class TestIoTHubHTTPClientInstantiation: ), ], ) - async def test_user_agent( - self, client_config, device_id, module_id, is_edge_module, product_info - ): + async def test_user_agent(self, client_config, device_id, module_id, product_info): client_config.device_id = device_id client_config.module_id = module_id - client_config.is_edge_module = is_edge_module client_config.product_info = product_info expected_user_agent = user_agent.get_iothub_user_agent() + product_info @@ -203,7 +191,7 @@ class TestIoTHubHTTPClientInstantiation: await client.shutdown() @pytest.mark.it("Does not URL encode the user agent string") - @pytest.mark.parametrize("device_id, module_id, is_edge_module", configurations) + @pytest.mark.parametrize("device_id, module_id", configurations) @pytest.mark.parametrize( "product_info", [ @@ -215,12 +203,11 @@ class TestIoTHubHTTPClientInstantiation: ], ) async def test_user_agent_no_url_encoding( - self, client_config, device_id, module_id, is_edge_module, product_info + self, client_config, device_id, module_id, product_info ): # NOTE: The user agent DOES eventually get url encoded, just not here, and not yet client_config.device_id = device_id client_config.module_id = module_id - client_config.is_edge_module = is_edge_module client_config.product_info = product_info expected_user_agent = user_agent.get_iothub_user_agent() + product_info url_encoded_expected_user_agent = urllib.parse.quote_plus(expected_user_agent) @@ -231,23 +218,13 @@ class TestIoTHubHTTPClientInstantiation: await client.shutdown() - # - # - # TODO: hostname / gateway hostname test once we know whats going on there - # - # - @pytest.mark.it( - "Creates a aiohttp ClientSession configured for accessing a URL based on the hostname with a timeout of 10 seconds" + "Creates a aiohttp ClientSession configured for accessing a URL based on the IoTHubClientConfig's `hostname`, with a timeout of 10 seconds" ) - @pytest.mark.parametrize("device_id, module_id, is_edge_module", configurations) - async def test_client_session( - self, mocker, client_config, device_id, module_id, is_edge_module - ): - # TODO: this test needs to be altered when hostname/gateway hostname logic is worked out + @pytest.mark.parametrize("device_id, module_id", configurations) + async def test_client_session(self, mocker, client_config, device_id, module_id): client_config.device_id = device_id client_config.module_id = module_id - client_config.is_edge_module = is_edge_module spy_session_init = mocker.spy(aiohttp, "ClientSession") expected_base_url = "https://" + client_config.hostname @@ -266,11 +243,10 @@ class TestIoTHubHTTPClientInstantiation: await client.shutdown() @pytest.mark.it("Stores the `ssl_context` from the IoTHubClientConfig as an attribute") - @pytest.mark.parametrize("device_id, module_id, is_edge_module", configurations) - async def test_ssl_context(self, client_config, device_id, module_id, is_edge_module): + @pytest.mark.parametrize("device_id, module_id", configurations) + async def test_ssl_context(self, client_config, device_id, module_id): client_config.device_id = device_id client_config.module_id = module_id - client_config.is_edge_module = is_edge_module assert client_config.ssl_context is not None client = IoTHubHTTPClient(client_config) @@ -279,7 +255,7 @@ class TestIoTHubHTTPClientInstantiation: await client.shutdown() @pytest.mark.it("Stores the `sastoken_provider` from the IoTHubClientConfig as an attribute") - @pytest.mark.parametrize("device_id, module_id, is_edge_module", configurations) + @pytest.mark.parametrize("device_id, module_id", configurations) @pytest.mark.parametrize( "sastoken_provider", [ @@ -287,12 +263,9 @@ class TestIoTHubHTTPClientInstantiation: pytest.param(None, id="No SasTokenProvider present"), ], ) - async def test_sastoken_provider( - self, client_config, device_id, module_id, is_edge_module, sastoken_provider - ): + async def test_sastoken_provider(self, client_config, device_id, module_id, sastoken_provider): client_config.device_id = device_id client_config.module_id = module_id - client_config.is_edge_module = is_edge_module client_config.sastoken_provider = sastoken_provider client = IoTHubHTTPClient(client_config) @@ -374,16 +347,14 @@ class TestIoTHubHTTPClientInvokeDirectMethod: @pytest.fixture(autouse=True) def modify_client_config(self, client_config): """Modify the client config to always be an Edge Module""" - # TODO: likely need to modify once hostname/gateway hostname is ironed out client_config.device_id = FAKE_DEVICE_ID client_config.module_id = FAKE_MODULE_ID - client_config.is_edge_module = True @pytest.fixture(autouse=True) def modify_post_response(self, client): fake_method_response = { "status": 200, - "payload": "fake payload", + "payload": {"fake": "payload"}, } mock_response = client._session.post.return_value.__aenter__.return_value mock_response.json.return_value = fake_method_response @@ -392,7 +363,7 @@ class TestIoTHubHTTPClientInvokeDirectMethod: def method_params(self): return { "methodName": "fake method", - "payload": "fake payload", + "payload": {"fake": "payload"}, "connectTimeoutInSeconds": 47, "responseTimeoutInSeconds": 42, } @@ -578,19 +549,11 @@ class TestIoTHubHTTPClientInvokeDirectMethod: device_id=target_device_id, module_id=target_module_id, method_params=method_params ) - @pytest.mark.it("Raises IoTHubClientError if not configured as an Edge Module") - @pytest.mark.parametrize( - "module_id", - [ - pytest.param(None, id="Device Configuration"), - pytest.param(FAKE_MODULE_ID, id="Non-Edge Module Configuration"), - ], - ) + # NOTE: It'd be really great if we could reject non-Edge modules, but we can't. + @pytest.mark.it("Raises IoTHubClientError if not configured as a Module") @pytest.mark.parametrize("target_device_id, target_module_id", targets) - async def test_not_edge( - self, client, module_id, target_device_id, target_module_id, method_params - ): - client._module_id = module_id + async def test_not_edge(self, client, target_device_id, target_module_id, method_params): + client._module_id = None client._edge_module_id = None with pytest.raises(IoTHubClientError): @@ -678,10 +641,9 @@ class TestIoTHubHTTPClientInvokeDirectMethod: class TestIoTHubHTTPClientGetStorageInfoForBlob: @pytest.fixture(autouse=True) def modify_client_config(self, client_config): - """Modify the client config to always be an Device""" + """Modify the client config to always be a Device""" client_config.device_id = FAKE_DEVICE_ID client_config.module_id = None - client_config.is_edge_module = False @pytest.fixture(autouse=True) def modify_post_response(self, client): @@ -885,10 +847,9 @@ class TestIoTHubHTTPClientGetStorageInfoForBlob: class TestIoTHubHTTPClientNotifyBlobUploadStatus: @pytest.fixture(autouse=True) def modify_client_config(self, client_config): - """Modify the client config to always be an Device""" + """Modify the client config to always be a Device""" client_config.device_id = FAKE_DEVICE_ID client_config.module_id = None - client_config.is_edge_module = False @pytest.fixture(params=["Notify Upload Success", "Notify Upload Failure"]) def kwargs(self, request): diff --git a/v3_async_wip/tests/test_iothub_mqtt_client.py b/v3_async_wip/tests/test_iothub_mqtt_client.py index bf634ebd0..17527a0f1 100644 --- a/v3_async_wip/tests/test_iothub_mqtt_client.py +++ b/v3_async_wip/tests/test_iothub_mqtt_client.py @@ -31,7 +31,6 @@ FAKE_MODULE_ID = "fake_module_id" FAKE_DEVICE_CLIENT_ID = "fake_device_id" FAKE_MODULE_CLIENT_ID = "fake_device_id/fake_module_id" FAKE_HOSTNAME = "fake.hostname" -FAKE_GATEWAY_HOSTNAME = "fake.gateway.hostname" FAKE_SIGNATURE = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=" FAKE_EXPIRY = str(int(time.time()) + 3600) FAKE_URI = "fake/resource/location" @@ -176,13 +175,6 @@ class TestIoTHubMQTTClientInstantiation: ), ], ) - @pytest.mark.parametrize( - "hostname, gateway_hostname", - [ - pytest.param(FAKE_HOSTNAME, None, id="No Gateway Hostname"), - pytest.param(FAKE_HOSTNAME, FAKE_GATEWAY_HOSTNAME, id="Gateway Hostname"), - ], - ) @pytest.mark.parametrize( "product_info", [ @@ -205,14 +197,10 @@ class TestIoTHubMQTTClientInstantiation: device_id, module_id, client_id, - hostname, - gateway_hostname, product_info, ): client_config.device_id = device_id client_config.module_id = module_id - client_config.hostname = hostname - client_config.gateway_hostname = gateway_hostname client_config.product_info = product_info ua = user_agent.get_iothub_user_agent() @@ -226,7 +214,7 @@ class TestIoTHubMQTTClientInstantiation: # Determine expected username based on config if product_info.startswith(constant.DIGITAL_TWIN_PREFIX): expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}&{digital_twin_prefix}={custom_product_info}".format( - hostname=hostname, + hostname=client_config.hostname, client_id=client_id, api_version=constant.IOTHUB_API_VERSION, user_agent=url_encoded_user_agent, @@ -235,13 +223,12 @@ class TestIoTHubMQTTClientInstantiation: ) else: expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}{custom_product_info}".format( - hostname=hostname, + hostname=client_config.hostname, client_id=client_id, api_version=constant.IOTHUB_API_VERSION, user_agent=url_encoded_user_agent, custom_product_info=url_encoded_product_info, ) - # NOTE: Regarding the above, no matter if we have a gateway hostname set or not, it is the hostname that is always used. client = IoTHubMQTTClient(client_config) # The expected username was derived @@ -284,15 +271,6 @@ class TestIoTHubMQTTClientInstantiation: ), ], ) - @pytest.mark.parametrize( - "hostname, gateway_hostname, expected_hostname", - [ - pytest.param(FAKE_HOSTNAME, None, FAKE_HOSTNAME, id="No Gateway Hostname"), - pytest.param( - FAKE_HOSTNAME, FAKE_GATEWAY_HOSTNAME, FAKE_GATEWAY_HOSTNAME, id="Gateway Hostname" - ), - ], - ) @pytest.mark.parametrize( "websockets, expected_transport, expected_port, expected_ws_path", [ @@ -307,9 +285,6 @@ class TestIoTHubMQTTClientInstantiation: device_id, module_id, expected_client_id, - hostname, - gateway_hostname, - expected_hostname, websockets, expected_transport, expected_port, @@ -318,8 +293,6 @@ class TestIoTHubMQTTClientInstantiation: # Configure the client_config based on params client_config.device_id = device_id client_config.module_id = module_id - client_config.hostname = hostname - client_config.gateway_hostname = gateway_hostname client_config.websockets = websockets # Patch the MQTTClient constructor @@ -333,7 +306,7 @@ class TestIoTHubMQTTClientInstantiation: assert mock_constructor.call_count == 1 assert mock_constructor.call_args == mocker.call( client_id=expected_client_id, - hostname=expected_hostname, + hostname=client_config.hostname, port=expected_port, transport=expected_transport, keep_alive=client_config.keep_alive, @@ -780,7 +753,6 @@ class TestIoTHubMQTTClientShutdown: # correctness, lest we have to repeat all .disconnect() tests here. original_disconnect = client.disconnect client.disconnect = mocker.AsyncMock(side_effect=exception) - client.disconnect.side_effect = exception assert not client._keep_credentials_fresh_bg_task.done() assert not client._process_twin_responses_bg_task.done() diff --git a/v3_async_wip/v3_async_wip/config.py b/v3_async_wip/v3_async_wip/config.py index c142d5f34..2a62c7578 100644 --- a/v3_async_wip/v3_async_wip/config.py +++ b/v3_async_wip/v3_async_wip/config.py @@ -71,7 +71,6 @@ class ClientConfig: *, ssl_context: ssl.SSLContext, hostname: str, - gateway_hostname: Optional[str] = None, sastoken_provider: Optional[SasTokenProvider] = None, proxy_options: Optional[ProxyOptions] = None, keep_alive: int = 60, @@ -81,7 +80,6 @@ class ClientConfig: """Initializer for ClientConfig :param str hostname: The hostname being connected to - :param str gateway_hostname: The gateway hostname optionally being used :param sastoken_provider: Object that can provide SasTokens :type sastoken_provider: :class:`SasTokenProvider` :param proxy_options: Details of proxy configuration @@ -97,7 +95,6 @@ class ClientConfig: """ # Network self.hostname = hostname - self.gateway_hostname = gateway_hostname self.proxy_options = proxy_options # Auth @@ -116,7 +113,6 @@ class IoTHubClientConfig(ClientConfig): *, device_id: str, module_id: Optional[str] = None, - is_edge_module: bool = False, product_info: str = "", **kwargs: Any, ) -> None: @@ -125,14 +121,12 @@ class IoTHubClientConfig(ClientConfig): :param str device_id: The device identity being used with the IoTHub :param str module_id: The module identity being used with the IoTHub - :param bool is_edge_module: Boolean indicating whether or not using an Edge Module :param str product_info: A custom identification string. Additional parameters found in the docstring of the parent class """ self.device_id = device_id self.module_id = module_id - self.is_edge_module = is_edge_module self.product_info = product_info super().__init__(**kwargs) diff --git a/v3_async_wip/v3_async_wip/connection_string.py b/v3_async_wip/v3_async_wip/connection_string.py new file mode 100644 index 000000000..cc43ee2c6 --- /dev/null +++ b/v3_async_wip/v3_async_wip/connection_string.py @@ -0,0 +1,110 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module contains tools for working with Connection Strings""" + +__all__ = ["ConnectionString"] + +CS_DELIMITER = ";" +CS_VAL_SEPARATOR = "=" + +HOST_NAME = "HostName" +SHARED_ACCESS_KEY_NAME = "SharedAccessKeyName" +SHARED_ACCESS_KEY = "SharedAccessKey" +SHARED_ACCESS_SIGNATURE = "SharedAccessSignature" +DEVICE_ID = "DeviceId" +MODULE_ID = "ModuleId" +GATEWAY_HOST_NAME = "GatewayHostName" +X509 = "x509" + +_valid_keys = [ + HOST_NAME, + SHARED_ACCESS_KEY_NAME, + SHARED_ACCESS_KEY, + SHARED_ACCESS_SIGNATURE, + DEVICE_ID, + MODULE_ID, + GATEWAY_HOST_NAME, + X509, +] + +# TODO: does this module need revision for V3? + + +class ConnectionString(object): + """Key/value mappings for connection details. + Uses the same syntax as dictionary + """ + + def __init__(self, connection_string): + """Initializer for ConnectionString + + :param str connection_string: String with connection details provided by Azure + :raises: ValueError if provided connection_string is invalid + """ + self._dict = _parse_connection_string(connection_string) + self._strrep = connection_string + + def __contains__(self, item): + return item in self._dict + + def __getitem__(self, key): + return self._dict[key] + + def __repr__(self): + return self._strrep + + def get(self, key, default=None): + """Return the value for key if key is in the dictionary, else default + + :param str key: The key to retrieve a value for + :param str default: The default value returned if a key is not found + :returns: The value for the given key + """ + try: + return self._dict[key] + except KeyError: + return default + + +def _parse_connection_string(connection_string): + """Return a dictionary of values contained in a given connection string""" + try: + cs_args = connection_string.split(CS_DELIMITER) + except (AttributeError, TypeError): + raise TypeError("Connection String must be of type str") + try: + d = dict(arg.split(CS_VAL_SEPARATOR, 1) for arg in cs_args) + except ValueError: + # This occurs in an extreme edge case where a dictionary cannot be formed because there + # is only 1 token after the split (dict requires two in order to make a key/value pair) + raise ValueError("Invalid Connection String - Unable to parse") + if len(cs_args) != len(d): + # various errors related to incorrect parsing - duplicate args, bad syntax, etc. + raise ValueError("Invalid Connection String - Unable to parse") + if not all(key in _valid_keys for key in d.keys()): + raise ValueError("Invalid Connection String - Invalid Key") + _validate_keys(d) + return d + + +def _validate_keys(d): + """Raise ValueError if incorrect combination of keys in dict d""" + host_name = d.get(HOST_NAME) + shared_access_key_name = d.get(SHARED_ACCESS_KEY_NAME) + shared_access_key = d.get(SHARED_ACCESS_KEY) + device_id = d.get(DEVICE_ID) + x509 = d.get(X509) + + if shared_access_key and x509 and x509.lower() == "true": + raise ValueError("Invalid Connection String - Mixed authentication scheme") + + # This logic could be expanded to return the category of ConnectionString + if host_name and device_id and (shared_access_key or x509): + pass + elif host_name and shared_access_key and shared_access_key_name: + pass + else: + raise ValueError("Invalid Connection String - Incomplete") diff --git a/v3_async_wip/v3_async_wip/custom_typing.py b/v3_async_wip/v3_async_wip/custom_typing.py index 8e3f7d758..9f00e10bb 100644 --- a/v3_async_wip/v3_async_wip/custom_typing.py +++ b/v3_async_wip/v3_async_wip/custom_typing.py @@ -3,9 +3,14 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from typing import Union, Dict, List, Tuple, Callable, Awaitable, TypeVar +from typing_extensions import TypedDict, ParamSpec + + +_P = ParamSpec("_P") +_R = TypeVar("_R") +FunctionOrCoroutine = Union[Callable[_P, _R], Callable[_P, Awaitable[_R]]] -from typing import Union, Dict, List, Tuple -from typing_extensions import TypedDict # typing does not support recursion, so we must use forward references here (PEP484) JSONSerializable = Union[ @@ -25,14 +30,18 @@ Twin = Dict[str, Dict[str, JSONSerializable]] TwinPatch = Dict[str, JSONSerializable] -# TODO: should this be "direct method?" -class MethodParameters(TypedDict): +class DirectMethodParameters(TypedDict): methodName: str - payload: str + payload: JSONSerializable connectTimeoutInSeconds: int responseTimeoutInSeconds: int +class DirectMethodResult(TypedDict): + status: int + payload: JSONSerializable + + class StorageInfo(TypedDict): correlationId: str hostName: str diff --git a/v3_async_wip/v3_async_wip/iot_exceptions.py b/v3_async_wip/v3_async_wip/iot_exceptions.py index 7d137a6fd..5a765188d 100644 --- a/v3_async_wip/v3_async_wip/iot_exceptions.py +++ b/v3_async_wip/v3_async_wip/iot_exceptions.py @@ -7,17 +7,21 @@ class IoTHubError(Exception): - """Represents a failure reported by IoTHub""" + """Represents a failure reported by IoT Hub""" pass class IoTEdgeError(Exception): - """Represents a failure reported by IoTEdge""" + """Represents a failure reported by IoT Edge""" pass +class IoTEdgeEnvironmentError(Exception): + """Represents a failure retrieving data from the IoT Edge environment""" + + class IoTHubClientError(Exception): """Represents a failure from the IoTHub Client""" diff --git a/v3_async_wip/v3_async_wip/iothub_client.py b/v3_async_wip/v3_async_wip/iothub_client.py new file mode 100644 index 000000000..c5f659c28 --- /dev/null +++ b/v3_async_wip/v3_async_wip/iothub_client.py @@ -0,0 +1,619 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import abc +import logging +import os +import ssl +from typing import Optional, Union, cast +from .custom_typing import FunctionOrCoroutine +from .iot_exceptions import IoTEdgeEnvironmentError +from . import config, edge_hsm +from . import connection_string as cs +from . import iothub_http_client as http +from . import iothub_mqtt_client as mqtt +from . import sastoken as st +from . import signing_mechanism as sm + + +logger = logging.getLogger(__name__) + +# TODO: finalize documentation + + +class IoTHubClient(abc.ABC): + """Abstract parent class for IoTHubDeviceClient and IoTHubModuleClient containing + partial implementation. + """ + + def __init__(self, client_config: config.IoTHubClientConfig) -> None: + """Initializer for a generic IoTHubClient. + Do not directly use as the end user, use a factory method instead. + + :param client_config: The IoTHubClientConfig object + :type client_config: :class:`IoTHubClientConfig` + """ + # Internal clients + self._mqtt_client = mqtt.IoTHubMQTTClient(client_config) + self._http_client = http.IoTHubHTTPClient(client_config) + + # Keep a reference to the SAS Token Provider so it can be shut down later + self._sastoken_provider = client_config.sastoken_provider + + async def shutdown(self) -> None: + """Shut down the client + + Call only when completely done with the client for graceful exit. + + Cannot be cancelled - if you try, the client will still fully shut down as much as + possible (although the CancelledError will still be raised) + """ + cached_exception: Optional[Union[Exception, BaseException]] = None + logger.debug("Beginning IoTHubClient shutdown procedure") + try: + logger.debug("Shutting down IoTHubMQTTClient...") + await self._mqtt_client.shutdown() + logger.debug("IoTHubMQTTClient shutdown complete") + except (Exception, BaseException) as e: + logger.warning( + "Unexpected error during shutdown of IoTHubMQTTClient suppressed - still completing the rest of shutdown procedure" + ) + cached_exception = e + + try: + logger.debug("Shutting down IoTHubHTTPClient...") + await self._http_client.shutdown() + logger.debug("IoTHubHTTPClient shutdown complete") + except (Exception, BaseException) as e: + logger.warning( + "Unexpected error during shutdown of IoTHubHTTPClient suppressed - still completing the rest of shutdown procedure" + ) + cached_exception = e + if self._sastoken_provider: + try: + logger.debug("Shutting down SasTokenProvider...") + await self._sastoken_provider.shutdown() + logger.debug("SasTokenProvider shutdown complete") + except (Exception, BaseException) as e: + logger.warning( + "Unexpected error during shutdown of SasTokenProvider suppressed - still completing the rest of shutdown procedure" + ) + cached_exception = e + + logger.debug("IoTHubClient shutdown procedure complete") + if cached_exception: + # NOTE: In the case of multiple failures, only the last one gets raised. + # Not much way around it, and besides, this is all an extreme edge case anyway. + logger.warning( + "Raising previously suppressed error now that shutdown procedure is complete" + ) + raise cached_exception + + # ~~~~~ Abstract declarations ~~~~~ + # NOTE: rigid typechecking doesn't like when the signature changes in the child class + # implementation of an abstract method. This creates problems, given that Device/Module + # clients have some methods with different signatures. It may be worth considering + # dropping abstract definitions altogether if their use is too inconsistent, or at least + # paring them back to only the crucial ones (connect, shutdown) + + # @abc.abstractmethod + # async def connect(self) -> None: + # raise NotImplementedError + + # @abc.abstractmethod + # async def disconnect(self) -> None: + # raise NotImplementedError + + # @abc.abstractmethod + # async def send_message(self) -> None: + # raise NotImplementedError + + # @abc.abstractmethod + # async def send_direct_method_response(self) -> None: + # raise NotImplementedError + + # @abc.abstractmethod + # async def send_twin_reported_properties_patch(self) -> None: + # raise NotImplementedError + + # @abc.abstractmethod + # async def get_twin(self) -> Twin: + # raise NotImplementedError + + # ~~~~~~ Shared implementations ~~~~~ + + @classmethod + async def _shared_client_create( + cls, + *, + device_id: str, + module_id: Optional[str] = None, + hostname: str, + ssl_context: Optional[ssl.SSLContext] = None, + symmetric_key: Optional[str] = None, + sastoken_fn: Optional[FunctionOrCoroutine] = None, # TODO: need more rigid definition + # sastoken_fn: Optional[FunctionOrCoroutine[[], str]] = None, + **kwargs, + ) -> "IoTHubClient": + """Agnostic implementation of .create() shared between Devices and Modules + + :raises: ValueError if one of 'ssl_context', 'symmetric_key' or 'sastoken_fn' is not + provided + :raises: ValueError if both 'symmetric_key' and 'sastoken_fn' are provided + :raises: ValueError if an invalid 'symmetric_key' is provided + :raises: SasTokenError if there is a failure generating a SAS Token + """ + # Validate Parameters + _validate_kwargs(**kwargs) + if symmetric_key and sastoken_fn: + raise ValueError( + "Incompatible authentication - cannot provide both 'symmetric_key' and 'sastoken_fn'" + ) + if not symmetric_key and not sastoken_fn and not ssl_context: + raise ValueError( + "Missing authentication - must provide one of 'symmetric_key', 'sastoken_fn' or 'ssl_context'" + ) + + if symmetric_key: + signing_mechanism = sm.SymmetricKeySigningMechanism(symmetric_key) + else: + signing_mechanism = None + + return await cls._internal_factory( + device_id=device_id, + module_id=module_id, + hostname=hostname, + ssl_context=ssl_context, + sas_signing_mechanism=signing_mechanism, + sastoken_fn=sastoken_fn, + **kwargs, + ) + + @classmethod + async def _shared_client_create_from_connection_string( + cls, cs_obj: cs.ConnectionString, ssl_context: Optional[ssl.SSLContext] = None, **kwargs + ) -> "IoTHubClient": + """Agnostic implementation of .create_from_connection_string() shared between Devices + and Modules. Uses a ConnectionString object rather than a string, since the outer + client-specific implementation already converted it to validate + + :raises: ValueError if the provided connection string is invalid + :raises: SasTokenError if there is a failure generating a SAS Token""" + # ssl_context is required if x509 is indicated by the connection string + if cs_obj.get(cs.X509, "").lower() == "true" and not ssl_context: + raise ValueError( + "Connection string indicates X509 certificate authentication, but no ssl_context provided" + ) + + # If the Gateway Hostname exists, use it instead of the Hostname + hostname = cs_obj.get(cs.GATEWAY_HOST_NAME, cs_obj[cs.HOST_NAME]) + + if cs.SHARED_ACCESS_KEY in cs_obj: + signing_mechanism = sm.SymmetricKeySigningMechanism(cs_obj[cs.SHARED_ACCESS_KEY]) + else: + signing_mechanism = None + + return await cls._internal_factory( + device_id=cs_obj[cs.DEVICE_ID], + module_id=cs_obj.get(cs.MODULE_ID), + hostname=hostname, + sas_signing_mechanism=signing_mechanism, + ssl_context=ssl_context, + **kwargs, + ) + + @classmethod + async def _internal_factory( + cls, + *, + device_id: str, + module_id: Optional[str] = None, + hostname: str, + ssl_context: Optional[ssl.SSLContext] = None, + sas_signing_mechanism: Optional[sm.SigningMechanism] = None, + sastoken_fn: Optional[FunctionOrCoroutine] = None, # TODO: need more rigid definition + # sastoken_fn: Optional[FunctionOrCoroutine[[], str]] = None, + **kwargs, + ) -> "IoTHubClient": + """Internal factory method that creates a client for a all configurations + + :raises: SasTokenError if there is a failure generating a SAS Token + """ + # NOTE: Validation is assumed to have been done by the time this method is called. + + # Internal SAS Generation + sastoken_generator: st.SasTokenGenerator + if sas_signing_mechanism: + uri = _format_sas_uri(hostname=hostname, device_id=device_id, module_id=module_id) + sastoken_generator = st.InternalSasTokenGenerator( + signing_mechanism=sas_signing_mechanism, + uri=uri, + ) + sastoken_provider = await st.SasTokenProvider.create_from_generator(sastoken_generator) + + # External SAS Generation + elif sastoken_fn: + sastoken_generator = st.ExternalSasTokenGenerator(sastoken_fn) + sastoken_provider = await st.SasTokenProvider.create_from_generator(sastoken_generator) + + # No SAS Auth + else: + sastoken_provider = None + + # SSL + if not ssl_context: + ssl_context = _default_ssl_context() + + # Config setup + client_config = config.IoTHubClientConfig( + hostname=hostname, + device_id=device_id, + module_id=module_id, + sastoken_provider=sastoken_provider, + ssl_context=ssl_context, + **kwargs, + ) + + return cls(client_config) + + +class IoTHubDeviceClient(IoTHubClient): + """A client for connecting a device to an instance of IoT Hub""" + + @classmethod + async def create( + cls, + device_id: str, + hostname: str, + ssl_context: Optional[ssl.SSLContext] = None, + symmetric_key: Optional[str] = None, + sastoken_fn: Optional[FunctionOrCoroutine] = None, # TODO: more rigid definition + # sastoken_fn: Optional[FunctionOrCoroutine[[], str]] = None, + **kwargs, + ) -> "IoTHubDeviceClient": + """ + Instantiate an IoTHubDeviceClient + + - To use symmetric key authentication, provide the symmetric key as the 'symmetric_key' + parameter + - To use your own SAS tokens for authentication, provide a function or coroutine function + that returns SAS Tokens as the 'sastoken_fn' parameter + - To use X509 certificate authentication, configure an SSLContext for the certificate, and + provide it as the 'ssl_context' parameter + + One of the these three types of authentication is required to instantiate the client. + + :param str device_id: The device identity for the IoT Hub device + :param str hostname: Hostname of the IoT Hub or IoT Edge the device should connect to + :param ssl_context: Custom SSL context to be used by the client + If not provided, a default one will be used + :type ssl_context: :class:`ssl.SSLContext` + :param str symmetric_key: A symmetric key that can be used to generate SAS Tokens + :param sastoken_fn: A function or coroutine function that takes no arguments and returns + a SAS token string when invoked + + :keyword bool connection_retry: Indicates whether to use built-in connection retry policy. + Default is 'True' + :keyword int keep_alive: Maximum period in seconds between MQTT communications. If no + communications are exchanged for this period, a ping exchange will occur. + Default is 60 seconds + :keyword str product_info: Arbitrary product information which will be included in the + User-Agent string + :keyword proxy_options: Configuration structure for sending traffic through a proxy server + :type: proxy_options: :class:`ProxyOptions` + :keyword bool websockets: Set to 'True' to use WebSockets over MQTT. Default is 'False' + + :raises: ValueError if one of 'ssl_context', 'symmetric_key' or 'sastoken_fn' is not + provided + :raises: ValueError if both 'symmetric_key' and 'sastoken_fn' are provided + :raises: ValueError if an invalid 'symmetric_key' is provided + :raises: SasTokenError if there is a failure generating a SAS Token + + :return: An IoTHubDeviceClient instance + """ + + client = await cls._shared_client_create( + device_id=device_id, + hostname=hostname, + ssl_context=ssl_context, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + **kwargs, + ) + return cast(IoTHubDeviceClient, client) + + @classmethod + async def create_from_connection_string( + cls, connection_string: str, ssl_context: Optional[ssl.SSLContext] = None, **kwargs + ) -> "IoTHubDeviceClient": + """Instantiate an IoTHubDeviceClient using a IoT Hub device connection string + + :param str connection_string: The IoT Hub device connection string + :param ssl_context: Custom SSL context to be used by the client + If not provided, a default one will be used + :type ssl_context: :class:`ssl.SSLContext` + + :keyword bool connection_retry: Indicates whether to use built-in connection retry policy. + Default is 'True' + :keyword int keep_alive: Maximum period in seconds between MQTT communications. If no + communications are exchanged for this period, a ping exchange will occur. + Default is 60 seconds + :keyword str product_info: Arbitrary product information which will be included in the + User-Agent string + :keyword proxy_options: Configuration structure for sending traffic through a proxy server + :type: proxy_options: :class:`ProxyOptions` + :keyword bool websockets: Set to 'True' to use WebSockets over MQTT. Default is 'False' + + :raises: ValueError if the provided connection string is invalid + :raises: SasTokenError if there is a failure generating a SAS Token + + :return: An IoTHubDeviceClient instance + """ + # Validate connection string is for Device + cs_obj = cs.ConnectionString(connection_string) + if cs.MODULE_ID in cs_obj: + raise ValueError("IoT Hub module connection string provided for IoTHubDeviceClient") + + client = await cls._shared_client_create_from_connection_string( + cs_obj, ssl_context, **kwargs + ) + return cast(IoTHubDeviceClient, client) + + +class IoTHubModuleClient(IoTHubClient): + """A client for connecting a module to an instance of IoT Hub""" + + @classmethod + async def create( + cls, + device_id: str, + module_id: str, + hostname: str, + ssl_context: Optional[ssl.SSLContext] = None, + symmetric_key: Optional[str] = None, + # sastoken_fn: Optional[FunctionOrCoroutine[[], str]] = None, + sastoken_fn: Optional[FunctionOrCoroutine] = None, # TODO: more rigid definition + **kwargs, + ) -> "IoTHubModuleClient": + """ + Instantiate an IoTHubModuleClient + + - To use symmetric key authentication, provide the symmetric key as the 'symmetric_key' + parameter + - To use your own SAS tokens for authentication, provide a function or coroutine function + that returns SAS Tokens as the 'sastoken_fn' parameter + - To use X509 certificate authentication, configure an SSLContext for the certificate, and + provide it as the 'ssl_context' parameter + + One of the these three types of authentication is required to instantiate the client. + + :param str device_id: The device identity for the IoT Hub device containing the + IoT Hub module + :param str module_id: The module identity for the IoT Hub module + :param str hostname: Hostname of the IoT Hub or IoT Edge the device should connect to + :param ssl_context: Custom SSL context to be used by the client + If not provided, a default one will be used + :type ssl_context: :class:`ssl.SSLContext` + :param str symmetric_key: A symmetric key that can be used to generate SAS Tokens + :param sastoken_fn: A function or coroutine function that takes no arguments and returns + a SAS token string when invoked + + :keyword bool connection_retry: Indicates whether to use built-in connection retry policy. + Default is 'True' + :keyword int keep_alive: Maximum period in seconds between MQTT communications. If no + communications are exchanged for this period, a ping exchange will occur. + Default is 60 seconds + :keyword str product_info: Arbitrary product information which will be included in the + User-Agent string + :keyword proxy_options: Configuration structure for sending traffic through a proxy server + :type: proxy_options: :class:`ProxyOptions` + :keyword bool websockets: Set to 'True' to use WebSockets over MQTT. Default is 'False' + + :raises: ValueError if one of 'ssl_context', 'symmetric_key' or 'sastoken_fn' is not + provided + :raises: ValueError if both 'symmetric_key' and 'sastoken_fn' are provided + :raises: ValueError if an invalid 'symmetric_key' is provided + :raises: SasTokenError if there is a failure generating a SAS Token + + :return: An IoTHubModuleClient instance + """ + client = await cls._shared_client_create( + device_id=device_id, + module_id=module_id, + hostname=hostname, + ssl_context=ssl_context, + symmetric_key=symmetric_key, + sastoken_fn=sastoken_fn, + **kwargs, + ) + return cast(IoTHubModuleClient, client) + + @classmethod + async def create_from_connection_string( + cls, connection_string: str, ssl_context: Optional[ssl.SSLContext] = None, **kwargs + ) -> "IoTHubModuleClient": + """Instantiate an IoTHubModuleClient using a IoT Hub module connection string + + :param str connection_string: The IoT Hub module connection string + :param ssl_context: Custom SSL context to be used by the client + If not provided, a default one will be used + :type ssl_context: :class:`ssl.SSLContext` + + :keyword bool connection_retry: Indicates whether to use built-in connection retry policy. + Default is 'True' + :keyword int keep_alive: Maximum period in seconds between MQTT communications. If no + communications are exchanged for this period, a ping exchange will occur. + Default is 60 seconds + :keyword str product_info: Arbitrary product information which will be included in the + User-Agent string + :keyword proxy_options: Configuration structure for sending traffic through a proxy server + :type: proxy_options: :class:`ProxyOptions` + :keyword bool websockets: Set to 'True' to use WebSockets over MQTT. Default is 'False' + + :raises: ValueError if the provided connection string is invalid + :raises: SasTokenError if there is a failure generating a SAS Token + + :return: An IoTHubModuleClient instance + """ + # Validate connection string is for Module + cs_obj = cs.ConnectionString(connection_string) + if cs.MODULE_ID not in cs_obj: + raise ValueError("IoT Hub device connection string provided for IoTHubModuleClient") + + client = await cls._shared_client_create_from_connection_string( + cs_obj, ssl_context, **kwargs + ) + return cast(IoTHubModuleClient, client) + + @classmethod + async def create_from_edge_environment(cls, **kwargs) -> "IoTHubModuleClient": + """Instantiate an IoTHubModuleClient using information from an IoT Edge environment + + This method can only be run from inside an IoT Edge environment, or in a debugging + environment configured for Edge development (e.g. Visual Studio Code) + + :keyword bool connection_retry: Indicates whether to use built-in connection retry policy. + Default is 'True' + :keyword int keep_alive: Maximum period in seconds between MQTT communications. If no + communications are exchanged for this period, a ping exchange will occur. + Default is 60 seconds + :keyword str product_info: Arbitrary product information which will be included in the + User-Agent string + :keyword proxy_options: Configuration structure for sending traffic through a proxy server + :type: proxy_options: :class:`ProxyOptions` + :keyword bool websockets: Set to 'True' to use WebSockets over MQTT. Default is 'False' + + :raises: IoTEdgeEnvironmentError if the required environment variables are not present or + cannot be accessed + :raises: IoTEdgeError if there is a failure with the IoT Edge + :raises: SasTokenError if there is a failure generating a SAS Token + :raises: ValueError if IoT Edge environment variable values are invalid + :raises: TypeError if IoT Edge environment variable values are of the wrong format + + + :return: An IoTHubModuleClient instance + """ + _validate_kwargs(**kwargs) + + try: + # First, try to find the regular IoT Edge environment variables + return await cls._create_from_real_edge_environment(**kwargs) + except IoTEdgeEnvironmentError as original_exception: + try: + # If they can't be found, try looking for the IoT Edge simulator variables + return await cls._create_from_simulated_edge_environment(**kwargs) + except IoTEdgeEnvironmentError: + # Raise the original error if the IoT Edge simulator variables also cannot be found + raise original_exception + + @classmethod + async def _create_from_real_edge_environment(cls, **kwargs) -> "IoTHubModuleClient": + """Instantiate an IoTHubModuleClient from values stored in environment variables + in a IoT Edge deployment environment. + + :raises: IoTEdgeEnvironmentError if IoT Edge environment variables are not present or + cannot be accessed + :raises: IoTEdgeError if there is a failure communicating with IoT Edge + :raises: SasTokenError if there is a failure generating a SAS Token + :raises: ValueError if IoT Edge environment variables values are invalid + :raises: TypeError if IoT Edge environment variable values are of the wrong format + """ + # Read values from the IoT Edge environment variables + try: + device_id = os.environ["IOTEDGE_DEVICEID"] + module_id = os.environ["IOTEDGE_MODULEID"] + hostname = os.environ["IOTEDGE_GATEWAYHOSTNAME"] + module_generation_id = os.environ["IOTEDGE_MODULEGENERATIONID"] + workload_uri = os.environ["IOTEDGE_WORKLOADURI"] + api_version = os.environ["IOTEDGE_APIVERSION"] + except KeyError as e: + raise IoTEdgeEnvironmentError("Could not retrieve Edge environment variables") from e + + # The IoT Edge HSM will be used to get the verification certs, as well as to sign data + # for making SAS Tokens + hsm = edge_hsm.IoTEdgeHsm( + module_id=module_id, + generation_id=module_generation_id, + workload_uri=workload_uri, + api_version=api_version, + ) + + # Set up Edge SSL context by loading the cert data + server_verification_cert = await hsm.get_certificate() + ssl_context = _default_ssl_context() + ssl_context.load_verify_locations(cadata=server_verification_cert) + + # Send to the internal factory + client = await cls._internal_factory( + device_id=device_id, + module_id=module_id, + hostname=hostname, + ssl_context=ssl_context, + sas_signing_mechanism=hsm, + **kwargs, + ) + return cast(IoTHubModuleClient, client) + + @classmethod + async def _create_from_simulated_edge_environment(cls, **kwargs) -> "IoTHubModuleClient": + """Instantiate an IoTHubModuleClient from values stored in environment variables + in a simulated IoT Edge environment + + :raises: IoTEdgeEnvironmentError if IoT Edge environment variables are not present or + cannot be accessed + :raises: ValueError if the connection string in the environment is invalid + :raises: SasTokenError if there is a failure generating a SAS Token + """ + # Read values from the IoT Edge Simulator environment variables + try: + connection_string = os.environ["EdgeHubConnectionString"] + ca_cert_filepath = os.environ["EdgeModuleCACertificateFile"] + except KeyError as e: + raise IoTEdgeEnvironmentError("Could not retrieve Edge environment variables") from e + + # Set up Edge SSL context by loading the cert file + ssl_context = _default_ssl_context() + ssl_context.load_verify_locations(cafile=ca_cert_filepath) + + # Since we have a connection string, just use the connection string factory + return await cls.create_from_connection_string( + connection_string, ssl_context=ssl_context, **kwargs + ) + + +def _validate_kwargs(exclude=[], **kwargs): + """Helper function to validate user provided kwargs. + Raises TypeError if an invalid option has been provided""" + valid_kwargs = [ + "auto_reconnect", + "keep_alive", + "product_info", + "proxy_options", + "websockets", + ] + + for kwarg in kwargs: + if (kwarg not in valid_kwargs) or (kwarg in exclude): + # NOTE: TypeError is the conventional error that is returned when an invalid kwarg is + # supplied. It feels like it should be a ValueError, but it's not. + raise TypeError("Unsupported keyword argument: '{}'".format(kwarg)) + + +def _default_ssl_context() -> ssl.SSLContext: + """Return a default SSLContext""" + ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = True + ssl_context.load_default_certs() + return ssl_context + + +def _format_sas_uri(hostname: str, device_id: str, module_id: Optional[str] = None) -> str: + if module_id: + return "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=hostname, device_id=device_id, module_id=module_id + ) + else: + return "{hostname}/devices/{device_id}".format(hostname=hostname, device_id=device_id) diff --git a/v3_async_wip/v3_async_wip/iothub_http_client.py b/v3_async_wip/v3_async_wip/iothub_http_client.py index aedb9ab1d..7cdc99316 100644 --- a/v3_async_wip/v3_async_wip/iothub_http_client.py +++ b/v3_async_wip/v3_async_wip/iothub_http_client.py @@ -8,7 +8,7 @@ import asyncio import logging import urllib.parse from typing import Optional, cast -from .custom_typing import MethodParameters, StorageInfo +from .custom_typing import DirectMethodParameters, DirectMethodResult, StorageInfo from .iot_exceptions import IoTHubClientError, IoTHubError, IoTEdgeError from . import config, constant, user_agent from . import http_path_iothub as http_path @@ -30,7 +30,6 @@ HTTP_TIMEOUT = 10 # TODO: document aiohttp exceptions that can be raised # TODO: URL Encoding logic # TODO: Proxy support -# TODO: Hostname/Gateway Hostname split (E2E test to see what works) # TODO: Should direct method responses be a DirectMethodResponse object? If so, what is the rid? # See specific inline commentary for more details on what is required @@ -64,14 +63,8 @@ class IoTHubHTTPClient: """ self._device_id = client_config.device_id self._module_id = client_config.module_id - self._edge_module_id = _format_edge_module_id( - self._device_id, self._module_id, client_config.is_edge_module - ) + self._edge_module_id = _format_edge_module_id(self._device_id, self._module_id) self._user_agent_string = user_agent.get_iothub_user_agent() + client_config.product_info - if client_config.gateway_hostname: - self._hostname = client_config.gateway_hostname - else: - self._hostname = client_config.hostname # TODO: add proxy support # Doing so will require building a custom "Connector" that can be injected into the @@ -84,7 +77,7 @@ class IoTHubHTTPClient: logger.warning("Proxy use with .get_storage_info_for_blob() not supported") logger.warning("Proxy use with .notify_blob_upload_status() not supported") - self._session = _create_client_session(self._hostname) + self._session = _create_client_session(client_config.hostname) self._ssl_context = client_config.ssl_context self._sastoken_provider = client_config.sastoken_provider @@ -98,28 +91,34 @@ class IoTHubHTTPClient: # See: https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown await asyncio.sleep(0.25) - # TODO: Should this return type be a MethodResponse? Or should we get rid of those objects entirely? - # TODO: Either way, need a better rtype than "dict" async def invoke_direct_method( - self, *, device_id: str, module_id: Optional[str] = None, method_params: MethodParameters - ) -> dict: + self, + *, + device_id: str, + module_id: Optional[str] = None, + method_params: DirectMethodParameters + ) -> DirectMethodResult: """Send a request to invoke a direct method on a target device or module :param str device_id: The target device ID :param str module_id: The target module ID :param dict method_params: The parameters for the direct method invocation + :returns: A dictionary containing a status and payload reported by the target device + :rtype: dict + :raises: :class:`IoTHubClientError` if not using an IoT Edge Module :raises: :class:`IoTHubClientError` if the direct method response cannot be parsed :raises: :class:`IoTEdgeError` if IoT Edge responds with failure """ if not self._edge_module_id: + # NOTE: The Edge Module ID will be exist for any Module, it doesn't actually indicate + # if it is an Edge Module or not. There's no way to tell, unfortunately. raise IoTHubClientError(".invoke_direct_method() only available for Edge Modules") path = http_path.get_direct_method_invoke_path(device_id, module_id) query_params = {PARAM_API_VERISON: constant.IOTHUB_API_VERSION} # NOTE: Other headers are auto-generated by aiohttp - # TODO: we may need to explicitly add the Host header depending on how host/gateway host works out headers = { HEADER_USER_AGENT: urllib.parse.quote_plus(self._user_agent_string), HEADER_EDGE_MODULE_ID: self._edge_module_id, # TODO: I assume this isn't supposed to be URI encoded just like in MQTT? @@ -152,9 +151,9 @@ class IoTHubHTTPClient: logger.debug( "Successfully received response from IoT Edge for direct method invocation" ) - dm_response_json = await response.json() + dm_result = cast(DirectMethodResult, await response.json()) - return dm_response_json + return dm_result async def get_storage_info_for_blob(self, *, blob_name: str) -> StorageInfo: """Request information for uploading blob file via the Azure Storage SDK @@ -257,16 +256,10 @@ class IoTHubHTTPClient: return None -def _format_edge_module_id( - device_id: str, module_id: Optional[str], is_edge_module -) -> Optional[str]: +def _format_edge_module_id(device_id: str, module_id: Optional[str]) -> Optional[str]: """Returns the edge module identifier""" - if is_edge_module: - if module_id: - return "{device_id}/{module_id}".format(device_id=device_id, module_id=module_id) - else: - # This shouldn't ever happen - raise ValueError("Invalid configuration - Edge Module with no Module ID") + if module_id: + return "{device_id}/{module_id}".format(device_id=device_id, module_id=module_id) else: return None diff --git a/v3_async_wip/v3_async_wip/iothub_mqtt_client.py b/v3_async_wip/v3_async_wip/iothub_mqtt_client.py index 2987387a9..21a08e655 100644 --- a/v3_async_wip/v3_async_wip/iothub_mqtt_client.py +++ b/v3_async_wip/v3_async_wip/iothub_mqtt_client.py @@ -44,7 +44,6 @@ class IoTHubMQTTClient: self._module_id = client_config.module_id self._client_id = _format_client_id(self._device_id, self._module_id) self._username = _format_username( - # NOTE: Always use the original hostname, even if gateway hostname is set hostname=client_config.hostname, client_id=self._client_id, product_info=client_config.product_info, @@ -501,18 +500,13 @@ def _create_mqtt_client( ) -> mqtt.MQTTClient: logger.debug("Creating MQTTClient") + logger.debug("Using {} as hostname".format(client_config.hostname)) + if client_config.module_id: logger.debug("Using IoTHub Module. Client ID is {}".format(client_id)) else: logger.debug("Using IoTHub Device. Client ID is {}".format(client_id)) - if client_config.gateway_hostname: - logger.debug("Gateway Hostname is present. Using Gateway Hostname as Hostname") - hostname = client_config.gateway_hostname - else: - logger.debug("Gateway Hostname not present. Using Hostname as Hostname") - hostname = client_config.hostname - if client_config.websockets: logger.debug("Using MQTT over websockets") transport = "websockets" @@ -526,7 +520,7 @@ def _create_mqtt_client( client = mqtt.MQTTClient( client_id=client_id, - hostname=hostname, + hostname=client_config.hostname, port=port, transport=transport, keep_alive=client_config.keep_alive, diff --git a/v3_async_wip/v3_async_wip/sastoken.py b/v3_async_wip/v3_async_wip/sastoken.py index 9c081d020..8d1499922 100644 --- a/v3_async_wip/v3_async_wip/sastoken.py +++ b/v3_async_wip/v3_async_wip/sastoken.py @@ -10,9 +10,11 @@ import asyncio import logging import time import urllib.parse -from typing import Dict, List, Union, Awaitable, Callable, cast +from typing import Dict, List, Awaitable, Callable, cast +from .custom_typing import FunctionOrCoroutine from .signing_mechanism import SigningMechanism + logger = logging.getLogger(__name__) DEFAULT_TOKEN_UPDATE_MARGIN: int = 120 @@ -99,7 +101,8 @@ class InternalSasTokenGenerator(SasTokenGenerator): class ExternalSasTokenGenerator(SasTokenGenerator): - def __init__(self, generator_fn: Union[Callable[[], str], Callable[[], Awaitable[str]]]): + # TODO: need more specificity in generator_fn + def __init__(self, generator_fn: FunctionOrCoroutine): """An object that can generate SasTokens by invoking a provided callable. This callable can be a function or a coroutine function.